Commit 95ee32d8 authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

bugfixes for gain parsing and output folder deletion

parent 6dbf8e26
Loading
Loading
Loading
Loading
Loading
+2 −24
Original line number Diff line number Diff line
@@ -34,8 +34,6 @@ import logging
import sys
from itertools import product
from multiprocessing import Pool
from pathlib import Path
from shutil import rmtree
from time import sleep

from ivas_processing_scripts.audiotools.metadata import (
@@ -97,31 +95,11 @@ def main(args):
    if hasattr(args, "multiprocessing"):
        cfg.multiprocessing = args.multiprocessing

    # set up processing chains
    chains.init_processing_chains(cfg)

    # set up logging
    logger = logging_init(args, cfg)

    if cfg.delete_output:
        deletion_list = [d for d in [*cfg.out_dirs, *cfg.tmp_dirs] if Path(d).exists()]
        if deletion_list:
            logger.warning(
                "\nWARNING! The configuration key to delete output directories was specified!"
            )
            logger.warning(
                f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n"
            )
            confirm = input(
                "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: "
            )
            if confirm == "YES":
                for dir in deletion_list:
                    rmtree(dir)
            else:
                logger.warning(
                    "Deletion was canceled. Please remove the output directories manually."
                )
    # set up processing chains
    chains.init_processing_chains(cfg, logger)

    # context manager to create output directories and clean up temporary directories
    with DirManager(
+1 −18
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ from ivas_processing_scripts.audiotools.constants import (
    BINAURAL_LFE_GAIN,
)
from ivas_processing_scripts.audiotools.convert import convert_file
from ivas_processing_scripts.utils import apply_func_parallel
from ivas_processing_scripts.utils import apply_func_parallel, parse_gain


def add_processing_args(group, input=True):
@@ -51,23 +51,6 @@ def add_processing_args(group, input=True):
        p = "out"
        ps = "o"

    # validation function(s)
    def parse_gain(g: str) -> float:
        g = g.strip()
        try:
            if g.lower().endswith("db"):
                g = float(g[:-2].strip())
                g = 10 ** (g / 20)
            else:
                g = float(g)

        except ValueError:
            raise argparse.ArgumentTypeError(
                f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB"
            )

        return g

    group.add_argument(
        f"-{ps}",
        f"--{p}",
+46 −11
Original line number Diff line number Diff line
@@ -30,7 +30,9 @@
#  the United Nations Convention on Contracts on the International Sales of Goods.
#

from shutil import copyfile
import logging
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Optional
from warnings import warn

@@ -46,10 +48,10 @@ from ivas_processing_scripts.processing.preprocessing_2 import Preprocessing2
from ivas_processing_scripts.processing.processing_splitting_scaling import (
    Processing_splitting_scaling,
)
from ivas_processing_scripts.utils import get_abs_path, list_audio
from ivas_processing_scripts.utils import get_abs_path, list_audio, parse_gain


def init_processing_chains(cfg: TestConfig) -> None:
def init_processing_chains(cfg: TestConfig, logger: logging.Logger) -> None:
    """initialise processing chains for each condition and list items to process"""
    # pre-processing - special case; should only be run once
    if hasattr(cfg, "preprocessing"):
@@ -104,14 +106,17 @@ def init_processing_chains(cfg: TestConfig) -> None:
            f"Directory {cfg.input_path} does not exist, contains no audio files or all files were filtered out."
        )

    # validate input files for correct format and sampling rate
    validate_input_files(cfg)

    # assemble a list of output and temporary directories to create
    for chain in cfg.proc_chains:
        cfg.out_dirs.append(cfg.output_path.joinpath(chain["name"]))
        cfg.tmp_dirs.append(cfg.output_path.joinpath(f"tmp_{chain['name']}"))

    # delete output files if requested
    clean_outputs(cfg, logger)

    # validate input files for correct format and sampling rate
    validate_input_files(cfg)


def get_preprocessing(cfg: TestConfig) -> dict:
    """Mapping from test configuration to preprocessing keyword arguments"""
@@ -140,8 +145,8 @@ def get_preprocessing(cfg: TestConfig) -> dict:
                "in_loudness": pre_cfg.get("loudness"),
                "in_loudness_fmt": pre_cfg.get("loudness_fmt", post_fmt),
                "in_mask": pre_cfg.get("mask", None),
                "in_gain_pre": pre_cfg.get("gain_pre"),
                "out_gain_post": pre_cfg.get("gain_post"),
                "in_gain_pre": parse_gain(pre_cfg.get("gain_pre")),
                "out_gain_post": parse_gain(pre_cfg.get("gain_post")),
                "multiprocessing": cfg.multiprocessing,
            }
        )
@@ -567,8 +572,8 @@ def get_processing_chain(
            {
                "in_fs": tmp_in_fs,
                "in_fmt": tmp_in_fmt,
                "in_gain_pre": post_cfg.get("gain_pre"),
                "out_gain_post": post_cfg.get("gain_post"),
                "in_gain_pre": parse_gain(post_cfg.get("gain_pre")),
                "out_gain_post": parse_gain(post_cfg.get("gain_post")),
                "out_fs": post_cfg.get("fs"),
                "out_fmt": post_fmt,
                "out_cutoff": tmp_lp_cutoff,
@@ -622,7 +627,7 @@ def validate_input_files(cfg: TestConfig):
    if input_format.startswith("ISM") or input_format.startswith("MASA"):
        frame_alignment = "error"

    if cfg.input["frame_alignment"] == "padding":
    if frame_alignment == "padding":
        # Create new input directory for padded files
        output_dir = cfg.output_path / "20ms_aligned_files"
        try:
@@ -703,3 +708,33 @@ def validate_input_files(cfg: TestConfig):
    if frame_alignment == "padding":
        # Make the output path as the new input path
        cfg.input_path = output_dir


def clean_outputs(cfg: TestConfig, logger: logging.Logger) -> None:
    if cfg.delete_output:
        deletion_list = [
            d
            for d in [
                *cfg.out_dirs,
                *cfg.tmp_dirs,
                cfg.output_path.joinpath("20ms_aligned_files"),
            ]
            if Path(d).exists()
        ]
        if deletion_list:
            logger.warning(
                "\nWARNING! The configuration key to delete output directories was specified!"
            )
            logger.warning(
                f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n"
            )
            confirm = input(
                "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: "
            )
            if confirm == "YES":
                for dir in deletion_list:
                    rmtree(dir)
            else:
                logger.warning(
                    "Deletion was canceled. Please remove the output directories manually."
                )
+20 −0
Original line number Diff line number Diff line
@@ -320,3 +320,23 @@ def get_abs_path(rel_path):
    else:
        abs_path = None
    return abs_path


def parse_gain(g: str) -> float:
    if g is None:
        return None

    g = g.strip()
    try:
        if g.lower().endswith("db"):
            g = float(g[:-2].strip())
            g = 10 ** (g / 20)
        else:
            g = float(g)

    except ValueError:
        raise ValueError(
            f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB"
        )

    return g