Commit 9229ad86 authored by Jan Kiene's avatar Jan Kiene
Browse files

Merge branch 'fix_concatenation_order_in_experiments' into 'main'

Fix concatenation order in experiments

See merge request !112
parents dabab8c0 4b56fad3
Loading
Loading
Loading
Loading
+12 −4
Original line number Diff line number Diff line
@@ -42,8 +42,8 @@ EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)]
EXPERIMENTS_BS1534 = [f"BS1534-{i}{x}" for i in range(1, 8) for x in ["a", "b"]]
LAB_IDS = ["a", "b", "c", "d"]
IN_FMT_FOR_MASA_EXPS = {
    "P800-8": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)),
    "P800-9": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)),
    "P800-8": {"cat1": "FOA", "cat2": "FOA", "cat3": "FOA", "cat4": "FOA", "cat5": "FOA", "cat6": "FOA"},
    "P800-9": {"cat1": "FOA", "cat2": "FOA", "cat3": "FOA", "cat4": "FOA", "cat5": "FOA", "cat6": "FOA"},
    "BS1534-7a": {"cat1": "FOA", "cat2": "HOA2"},
    "BS1534-7b": {"cat1": "FOA", "cat2": "HOA2"},
}
@@ -102,7 +102,7 @@ def create_experiment_setup(experiment, lab) -> list[Path]:
        input_path = base_path.joinpath("proc_input").joinpath(cat)
        output_path = base_path.joinpath("proc_output").joinpath(suffix)
        bg_noise_path = base_path.joinpath("background_noise").joinpath(
            f"background_noise_{suffix}.wav"
            f"background_noise_{cat}.wav"
        )
        cfg_path = default_cfg_path.parent.joinpath(f"{experiment}{cat}-lab_{lab}.yml")
        cfgs.append(cfg_path)
@@ -113,15 +113,18 @@ def create_experiment_setup(experiment, lab) -> list[Path]:
        cfg.prerun_seed = seed
        cfg.input_path = str(input_path)
        cfg.output_path = str(output_path)

        cat_num = int(cat[-1])
        if (
            bg_noise_pre_proc_2 := cfg.preprocessing_2.get("background_noise", None)
        ) is not None:
            bg_noise_pre_proc_2["background_noise_path"] = str(bg_noise_path)

            # bg noise SNR only differs from default config for some experiments
            cat_num = int(cat[-1])
            if experiment in ["P800-5", "P800-9"] and cat_num >= 3:
                bg_noise_pre_proc_2["snr"] = 15
        if cfg.preprocessing_2.get("concatenate_input", None) is not None:
            cfg.preprocessing_2["concatenation_order"] = concatenation_order(lab, experiment, cat_num)

        # for MASA, the input format can differ between categories
        if (fmt_for_category := IN_FMT_FOR_MASA_EXPS.get(experiment, None)) is not None:
@@ -158,6 +161,11 @@ def exp_lab_pair(arg):
    return exp, lab


def concatenation_order(lab_id, experiment, category):
    exp_id = f"p0{experiment[-1]}"
    return [f"{lab_id}{exp_id}a{category}s0{i}.wav" for i in range(1, 8)]


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Generate config files and process files for selecton experiments. Experiment names and lab ids must be given as comma-separated pairs (e.g. 'P800-5,b BS1534-4a,d ...')"
+4 −0
Original line number Diff line number Diff line
@@ -83,6 +83,10 @@ def reorder_items_list(items_list: list, concatenation_order: list) -> list:
        Re-ordered list of input items
    """
    name_to_full = {Path(full_file).name: full_file for full_file in items_list}

    if set(name_to_full.keys()) != set(concatenation_order):
        raise ValueError(f"Items given in concatenation_order {concatenation_order} are not identical to what was found in the input folder {name_to_full.keys()}")

    ordered_full_files = [
        name_to_full[name] for name in concatenation_order if name in name_to_full
    ]
+3 −0
Original line number Diff line number Diff line
@@ -127,6 +127,9 @@ def list_audio(path: str, select_list: list = None) -> list:
            f for f in audio_list if any([pattern in f.stem for pattern in select_set])
        ]

    # sort file list alphanumerically by filenames
    audio_list = sorted(audio_list, key=lambda p: p.name)

    return audio_list


+11 −1
Original line number Diff line number Diff line
@@ -77,6 +77,7 @@ def setup_input_files_for_config(config):
    dummy_md_files = FORMAT_TO_METADATA_FILES.get(input_fmt, list())

    # copy input files
    files_copied = list()
    for f in dummy_input_files:
        f_out = input_path.joinpath(f.name).resolve().absolute()
        # need at least 2s of input files for gen-patt to be happy (can not keep the tolerance for 50 frames only)
@@ -86,6 +87,8 @@ def setup_input_files_for_config(config):
            md_f_out = ".".join([str(f_out), suffix])
            shutil.copy(md_f, md_f_out)

        files_copied.append(f_out.name)

    # create background noise files with white noise
    if "background_noise" in config.preprocessing_2:
        # always set the same seed to have reproducible test noises
@@ -98,6 +101,8 @@ def setup_input_files_for_config(config):
        ).absolute()
        write(bg_noise_path, noise)

    return files_copied


def all_lengths_equal(cfg):
    output_folder = cfg.output_path
@@ -132,7 +137,12 @@ def test_generate_test_items(exp_lab_pair):
    args = Arguments(str(cfg))
    config = TestConfig(cfg)

    setup_input_files_for_config(config)
    input_filenames = setup_input_files_for_config(config)
    # patch concatenation order
    if config.preprocessing_2.get("concatenate_input", None) is not None:
        config.preprocessing_2["concatenation_order"] = sorted(input_filenames)
        config.to_file(cfg)

    generate_test(args)

    if not all_lengths_equal(config):