Commit 1110f7fd authored by Jan Kiene's avatar Jan Kiene
Browse files

cleanup in create_short_testvectors.py

parent 8ee4df33
Loading
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -1177,7 +1177,7 @@ check-bitexactness-hrtf-rom-and-file:
    - *print-common-info
    - cmake .
    - make -j
    - python3 tests/create_short_testvectors.py --which all --cut_len 1.0
    - python3 tests/create_short_testvectors.py --cut_len 1.0
    - python3 -m pytest tests/hrtf_binary_loading --html=report.html --junit-xml=report-junit.xml --self-contained-html
  artifacts:
    paths:
+15 −20
Original line number Diff line number Diff line
@@ -42,10 +42,13 @@ from cut_pcm import cut_samples

HERE = Path(__file__).parent.resolve()
TEST_VECTOR_DIR = HERE.joinpath("../scripts/testv").resolve()
SCRIPTS_DIR = HERE.joinpath("../scripts").resolve()

NUM_CHANNELS = "4"  # currently only FOA
CUT_FROM = "0.0"
sys.path.append(SCRIPTS_DIR)
from pyaudio3dtools import audiofile

CUT_FROM = "0.0"
GAIN = "1.0"
FILE_IDS = [
    "stv51MC",
    "stv71MC",
@@ -55,38 +58,31 @@ FILE_IDS = [
    "ISM",
    "MASA",
]
GAINS = ["1.0", "16.0", ".004"]


def collect_files(file_ids):
def collect_files():
    files = [
        f.absolute()
        for f in TEST_VECTOR_DIR.iterdir()
        if f.suffix == ".wav"
        and any([id in f.name for id in file_ids])
        and any([id in f.name for id in FILE_IDS])
        and "_cut" not in f.name
    ]
    return files


def create_short_testvectors(which="sba", cut_len=5.0):
    file_ids = []
    if which == "all":
        file_ids = FILE_IDS
def create_short_testvectors(cut_len=5.0):
    files = collect_files()

    for f in collect_files(file_ids):
        for g in GAINS:
    for f in files:
        suffix = "_cut"
            if g != "1.0":
                suffix += f"_{g}"

        out_file = f.parent.joinpath(f.stem + suffix + f.suffix)
            cut_samples(f, out_file, NUM_CHANNELS, CUT_FROM, f"{cut_len}", g)
        num_channels = audiofile.get_wav_file_info(f)["channels"]
        cut_samples(f, out_file, num_channels, CUT_FROM, f"{cut_len}", GAIN)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--which", choices=["sba", "all"], default="sba")

    def positive_float(x: str) -> float:
        x = float(x)
@@ -96,6 +92,5 @@ if __name__ == "__main__":

    parser.add_argument("--cut_len", type=positive_float, default=5.0)
    args = parser.parse_args()
    which = args.which
    cut_len = args.cut_len
    sys.exit(create_short_testvectors(which=args.which, cut_len=cut_len))
    sys.exit(create_short_testvectors(cut_len=cut_len))