Commit 3c4c3e80 authored by Jan Kiene's avatar Jan Kiene
Browse files

add constant module to experiments package

parent 7b90aa4d
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
from pathlib import Path

EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)]
EXPERIMENTS_BS1534 = [f"BS1534-{i}{x}" for i in range(1, 8) for x in ["a", "b"]]
LAB_IDS = ["a", "b", "c", "d"]
 No newline at end of file
+4 −6
Original line number Diff line number Diff line
@@ -6,12 +6,9 @@ import sys
HERE = Path(__file__).parent.absolute().resolve()
sys.path.append(str(HERE.parent))
from ivas_processing_scripts import config
from .constants import EXPERIMENTS_P800, EXPERIMENTS_BS1534, LAB_IDS


EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)]
EXPERIMENTS_BS1534 = [f"BS1534-{i}{x}" for i in range(1, 8) for x in ["a", "b"]]
EXPERIMENTS = EXPERIMENTS_P800 + EXPERIMENTS_BS1534
LABS = ["a", "b", "c", "d"]
# TODO: this is a placeholder for later, currently everything is FOA
IN_FMT_FOR_MASA_EXPS = {
    "P800-8": dict(zip([f"cat{i}" for i in range(1, 7)], ["FOA"] * 6)),
@@ -22,7 +19,8 @@ IN_FMT_FOR_MASA_EXPS = {


def _get_seed(exp, lab):
    return 101 + EXPERIMENTS.index(exp) * 4 + LABS.index(lab)
    experiments = EXPERIMENTS_P800 + EXPERIMENTS_BS1534
    return 101 + experiments.index(exp) * 4 + LAB_IDS.index(lab)


def create_experiment_setup(experiment, lab) -> list[Path]:
@@ -82,7 +80,7 @@ def create_experiment_setup(experiment, lab) -> list[Path]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("experiment", type=str, choices=EXPERIMENTS_BS1534+EXPERIMENTS_P800)
    parser.add_argument("lab", type=str, choices=LABS)
    parser.add_argument("lab", type=str, choices=LAB_IDS)

    args = parser.parse_args()

+3 −4
Original line number Diff line number Diff line
@@ -8,8 +8,7 @@ sys.path.append(str(HERE.parent))
from ivas_processing_scripts import main as generate_test
from ivas_processing_scripts.utils import apply_func_parallel

P800_TESTS = [f"P800-{i}" for i in range(1, 8)]
LABS = ["a", "b", "c", "d"]
from .constants import EXPERIMENTS_P800, LAB_IDS


class Arguments:
@@ -31,8 +30,8 @@ def create_items(experiment, lab):
# if is necessary here so that multiprocessing does not crash
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("experiment", choices=P800_TESTS)
    parser.add_argument("lab", choices=LABS)
    parser.add_argument("experiment", choices=EXPERIMENTS_P800)
    parser.add_argument("lab", choices=LAB_IDS)
    args = parser.parse_args()

    create_items(args.experiment, args.lab)
+2 −2
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ from tests.constants import (
    LAB_IDS_FOR_EXPERIMENTS,
    TESTS_DIR,
)
from experiments import create_experiment_config, create_items_p800
from experiments import create_experiment_config, create_items

BG_NOISE_NAME = "background_noise_cat.wav"

@@ -125,7 +125,7 @@ def test_categories(exp_name, lab_id):
        config = TestConfig(cfg)
        setup_input_files_for_config(config)

    create_items_p800.create_items(exp_name, lab_id)
    create_items.create_items(exp_name, lab_id)


@pytest.mark.parametrize("exp_name", [f"P800-{i}" for i in range(1, 10)])