Commit eec016ff authored by Jan Kiene's avatar Jan Kiene
Browse files

fix length check

parent 065b41da
Loading
Loading
Loading
Loading
Loading
+26 −1
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ from numpy.random import random, seed

from ivas_processing_scripts import main as generate_test
from ivas_processing_scripts.audiotools import audio
from ivas_processing_scripts.audiotools.audiofile import concat, write
from ivas_processing_scripts.audiotools.audiofile import concat, write, read
from ivas_processing_scripts.processing.config import TestConfig
from tests.constants import (
    FORMAT_TO_METADATA_FILES,
@@ -98,6 +98,25 @@ def setup_input_files_for_config(config):
        write(bg_noise_path, noise)


def all_lengths_equal(cfg):
    input_files = cfg.items_list
    output_folder = cfg.output_path

    all_lengths_equal = True
    for condition in cfg.conditions_to_generate.keys():
        output_condition_folder = output_folder.joinpath(condition)
        for input_file in input_files:
            output_file = output_condition_folder.joinpath(input_file.name)
            in_signal = read(input_file)
            out_signal = read(output_file)
            shapes_equal = in_signal.shape == out_signal.shape
            if not shapes_equal:
                print("Unequal file length for {input_file.name} in condition {condition}")
                all_lengths_equal = False

    return all_lengths_equal


@pytest.mark.parametrize(
    "exp_lab_pair", zip(INPUT_EXPERIMENT_NAMES, LAB_IDS_FOR_EXPERIMENTS)
)
@@ -106,7 +125,13 @@ def test_generate_test_items(exp_lab_pair):
    cfgs = create_experiment_setup(exp_name, lab_id)
    cfg = cfgs[0]

    # patch key to make checking for same input and output file length easier
    cfg.condition_in_output_filename = False

    args = Arguments(str(cfg))
    config = TestConfig(cfg)
    setup_input_files_for_config(config)
    generate_test(args)

    if not all_lengths_equal(cfg):
        raise RuntimeError("Unequal lengths between input and output files detected")