Commit a0c18bdd authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

Merge branch 'fix-gain-parsing-and-deletion' into 'main'

bugfixes for gain parsing and output folder deletion

See merge request !166
parents 6dbf8e26 3ff4e17a
Loading
Loading
Loading
Loading
+3 −25
Original line number Diff line number Diff line
@@ -34,8 +34,6 @@ import logging
import sys
from itertools import product
from multiprocessing import Pool
from pathlib import Path
from shutil import rmtree
from time import sleep

from ivas_processing_scripts.audiotools.metadata import (
@@ -100,33 +98,13 @@ def main(args):
    # set up processing chains
    chains.init_processing_chains(cfg)

    # set up logging
    logger = logging_init(args, cfg)

    if cfg.delete_output:
        deletion_list = [d for d in [*cfg.out_dirs, *cfg.tmp_dirs] if Path(d).exists()]
        if deletion_list:
            logger.warning(
                "\nWARNING! The configuration key to delete output directories was specified!"
            )
            logger.warning(
                f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n"
            )
            confirm = input(
                "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: "
            )
            if confirm == "YES":
                for dir in deletion_list:
                    rmtree(dir)
            else:
                logger.warning(
                    "Deletion was canceled. Please remove the output directories manually."
                )

    # context manager to create output directories and clean up temporary directories
    with DirManager(
        cfg.out_dirs + cfg.tmp_dirs, cfg.tmp_dirs if cfg.delete_tmp else []
    ):
        # set up logging
        logger = logging_init(args, cfg)

        # Re-ordering items based on concatenation order
        if hasattr(cfg, "preprocessing_2"):
            if (
+1 −18
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ from ivas_processing_scripts.audiotools.constants import (
    BINAURAL_LFE_GAIN,
)
from ivas_processing_scripts.audiotools.convert import convert_file
from ivas_processing_scripts.utils import apply_func_parallel
from ivas_processing_scripts.utils import apply_func_parallel, parse_gain


def add_processing_args(group, input=True):
@@ -51,23 +51,6 @@ def add_processing_args(group, input=True):
        p = "out"
        ps = "o"

    # validation function(s)
    def parse_gain(g: str) -> float:
        g = g.strip()
        try:
            if g.lower().endswith("db"):
                g = float(g[:-2].strip())
                g = 10 ** (g / 20)
            else:
                g = float(g)

        except ValueError:
            raise argparse.ArgumentTypeError(
                f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB"
            )

        return g

    group.add_argument(
        f"-{ps}",
        f"--{p}",
+44 −10
Original line number Diff line number Diff line
@@ -30,7 +30,8 @@
#  the United Nations Convention on Contracts on the International Sales of Goods.
#

from shutil import copyfile
from pathlib import Path
from shutil import copyfile, rmtree
from typing import Optional
from warnings import warn

@@ -46,7 +47,7 @@ from ivas_processing_scripts.processing.preprocessing_2 import Preprocessing2
from ivas_processing_scripts.processing.processing_splitting_scaling import (
    Processing_splitting_scaling,
)
from ivas_processing_scripts.utils import get_abs_path, list_audio
from ivas_processing_scripts.utils import get_abs_path, list_audio, parse_gain


def init_processing_chains(cfg: TestConfig) -> None:
@@ -104,14 +105,17 @@ def init_processing_chains(cfg: TestConfig) -> None:
            f"Directory {cfg.input_path} does not exist, contains no audio files or all files were filtered out."
        )

    # validate input files for correct format and sampling rate
    validate_input_files(cfg)

    # assemble a list of output and temporary directories to create
    for chain in cfg.proc_chains:
        cfg.out_dirs.append(cfg.output_path.joinpath(chain["name"]))
        cfg.tmp_dirs.append(cfg.output_path.joinpath(f"tmp_{chain['name']}"))

    # delete output files if requested
    clean_outputs(cfg)

    # validate input files for correct format and sampling rate
    validate_input_files(cfg)


def get_preprocessing(cfg: TestConfig) -> dict:
    """Mapping from test configuration to preprocessing keyword arguments"""
@@ -140,8 +144,8 @@ def get_preprocessing(cfg: TestConfig) -> dict:
                "in_loudness": pre_cfg.get("loudness"),
                "in_loudness_fmt": pre_cfg.get("loudness_fmt", post_fmt),
                "in_mask": pre_cfg.get("mask", None),
                "in_gain_pre": pre_cfg.get("gain_pre"),
                "out_gain_post": pre_cfg.get("gain_post"),
                "in_gain_pre": parse_gain(pre_cfg.get("gain_pre")),
                "out_gain_post": parse_gain(pre_cfg.get("gain_post")),
                "multiprocessing": cfg.multiprocessing,
            }
        )
@@ -567,8 +571,8 @@ def get_processing_chain(
            {
                "in_fs": tmp_in_fs,
                "in_fmt": tmp_in_fmt,
                "in_gain_pre": post_cfg.get("gain_pre"),
                "out_gain_post": post_cfg.get("gain_post"),
                "in_gain_pre": parse_gain(post_cfg.get("gain_pre")),
                "out_gain_post": parse_gain(post_cfg.get("gain_post")),
                "out_fs": post_cfg.get("fs"),
                "out_fmt": post_fmt,
                "out_cutoff": tmp_lp_cutoff,
@@ -622,7 +626,7 @@ def validate_input_files(cfg: TestConfig):
    if input_format.startswith("ISM") or input_format.startswith("MASA"):
        frame_alignment = "error"

    if cfg.input["frame_alignment"] == "padding":
    if frame_alignment == "padding":
        # Create new input directory for padded files
        output_dir = cfg.output_path / "20ms_aligned_files"
        try:
@@ -703,3 +707,33 @@ def validate_input_files(cfg: TestConfig):
    if frame_alignment == "padding":
        # Make the output path as the new input path
        cfg.input_path = output_dir


def clean_outputs(cfg: TestConfig) -> None:
    if cfg.delete_output:
        deletion_list = [
            d
            for d in [
                *cfg.out_dirs,
                *cfg.tmp_dirs,
                cfg.output_path.joinpath("20ms_aligned_files"),
            ]
            if Path(d).exists()
        ]
        if deletion_list:
            warn(
                "\nWARNING! The configuration key to delete output directories was specified!"
            )
            warn(
                f"The following directories will be REMOVED from {cfg.output_path}:\n {', '.join([d.name for d in deletion_list])}\n"
            )
            confirm = input(
                "Are you sure you want to delete these? Type 'YES' in capitals to confirm deletion: "
            )
            if confirm == "YES":
                for dir in deletion_list:
                    rmtree(dir)
            else:
                print(
                    "Deletion was canceled. Please remove the output directories manually."
                )
+20 −0
Original line number Diff line number Diff line
@@ -320,3 +320,23 @@ def get_abs_path(rel_path):
    else:
        abs_path = None
    return abs_path


def parse_gain(g: str) -> float:
    if g is None:
        return None

    g = g.strip()
    try:
        if g.lower().endswith("db"):
            g = float(g[:-2].strip())
            g = 10 ** (g / 20)
        else:
            g = float(g)

    except ValueError:
        raise ValueError(
            f"Invalid gain value '{g}' specified. Must be a number or a number suffixed with dB"
        )

    return g