Commit 1d7dd442 authored by Jan Kiene's avatar Jan Kiene
Browse files

parse config directly instead of patching text

parent 1329f02a
Loading
Loading
Loading
Loading
+25 −35
Original line number Diff line number Diff line
#! /usr/bin/env python3
import argparse
from pathlib import Path
import re
from ivas_processing_scripts import config


EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)]
@@ -9,24 +9,20 @@ EXPERIMENTS_BS1534 = [f"BS1534-{i}{x}" for i in range(1, 8) for x in ["a", "b"]]
EXPERIMENTS = EXPERIMENTS_P800 + EXPERIMENTS_BS1534
LABS = ["a", "b", "c", "d"]
HERE = Path(__file__).parent.absolute().resolve()
# TODO: this is a placeholder for later, currently everything is FOA
IN_FMT_FOR_MASA_EXPS = {
    "P800-8": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)),
    "P800-9": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)),
}


def _get_seed(exp, lab):
    return 101 + EXPERIMENTS.index(exp) * 4 + LABS.index(lab)


def _patch_value(line: str, value) -> str:
    line_split = line.split(':')
    line_split[-1] = f" {value}\n"
    return ':'.join(line_split)


def create_experiment_setup(experiment, lab) -> list[Path]:
    default_cfg_path = HERE.joinpath(f"selection/{experiment}/config/{experiment}.yml")

    with open(default_cfg_path) as f:
        cfg_lines = f.readlines()

    categories = [f"cat{i}" for i in range(1, 7)] if experiment in EXPERIMENTS_P800 else [""]
    seed = _get_seed(experiment, lab)
    base_path = Path(HERE.name).joinpath(f"selection/{experiment}")
@@ -37,38 +33,32 @@ def create_experiment_setup(experiment, lab) -> list[Path]:
        output_path = base_path.joinpath("proc_output").joinpath(cat)
        bg_noise_path = base_path.joinpath("background_noise").joinpath(f"background_noise_{cat}.wav")
        cfg_path = default_cfg_path.parent.joinpath(f"{experiment}{cat}-lab_{lab}.yml")
        cfgs.append(cfg_path)

        cat_cfg = []
        for line in cfg_lines:
            new_line = line
        # set new lab- and category-dependent values
        cfg = config.TestConfig(default_cfg_path)
        cfg.name = f"{experiment}{cat}-lab_{lab}"
        cfg.prerun_seed = seed
        cfg.input_path = str(input_path)
        cfg.output_path = str(output_path)
        cfg.preprocessing_2["background_noise"]["background_noise_path"] = str(bg_noise_path)

        # bg noise SNR only differs from default config for some experiments
        cat_num = int(cat[-1])
            patch_snr = experiment in ["P800-5", "P800-9"] and cat_num >= 3

            line_stripped = line.strip()
            if line_stripped.startswith("name:"):
                new_line = _patch_value(line, f"{experiment}{cat}-lab_{lab}")
            elif line_stripped.startswith("prerun_seed"):
                new_line = _patch_value(line, seed)
            elif line_stripped.startswith("input_path"):
                new_line = _patch_value(line, f'"{str(input_path)}"')
            elif line_stripped.startswith("output_path"):
                new_line = _patch_value(line, f'"{str(output_path)}"')
            elif line_stripped.startswith("snr") and patch_snr:
                new_line = _patch_value(line, 15)
            elif line_stripped.startswith("background_noise_path"):
                new_line = _patch_value(line, f'"{str(bg_noise_path)}"')

            cat_cfg.append(new_line)

        with open(cfg_path, "w") as f:
            f.writelines(cat_cfg)
        if experiment in ["P800-5", "P800-9"] and cat_num >= 3:
            cfg.preprocessing_2["background_noise"]["snr"] = 15

        # for MASA, the input format can differ between categories
        if (fmt_for_category := IN_FMT_FOR_MASA_EXPS.get(experiment, None)) is not None:
            cfg.input["fmt"] = fmt_for_category[cat]

        # ensure that necessary directories are there
        input_path.mkdir(parents=True, exist_ok=True)
        output_path.mkdir(parents=True, exist_ok=True)
        bg_noise_path.parent.mkdir(parents=True, exist_ok=True)

        cfgs.append(cfg_path)
        # write out config
        cfg.to_file(cfg_path)

    # Return the list of configs that were generated. Not strictly necessary, but makes testing easier.
    return cfgs
+1 −2
Original line number Diff line number Diff line
@@ -191,5 +191,4 @@ def main(args):
            rename_generated_conditions(cfg.output_path)

    # copy configuration to output directory
    with open(cfg.output_path.joinpath(f"{cfg.name}.yml"), "w") as f:
        yaml.safe_dump(cfg._yaml_dump, f)
    cfg.to_file(cfg.output_path.joinpath(f"{cfg.name}.yml"))
+4 −0
Original line number Diff line number Diff line
@@ -135,6 +135,10 @@ class TestConfig:

        return cfg

    def to_file(self, outfile: str|Path):
        with open(outfile, "w") as f:
            yaml.safe_dump(self._yaml_dump, f)

    def _validate_file_cfg(self, cfg: dict, use_windows_codec_binaries: bool):
        """ensure configuration contains required keys"""
        MISSING_KEYS = []