Commit 6ee8b46a authored by Jan Kiene's avatar Jan Kiene
Browse files

add argument for scope of short testvector creation

parent edea7dee
Loading
Loading
Loading
Loading
+23 −11
Original line number Diff line number Diff line
#!/usr/bin/env python3

__copyright__ = \
"""
__copyright__ = """
(C) 2022-2023 IVAS codec Public Collaboration with portions copyright Dolby International AB, Ericsson AB,
Fraunhofer-Gesellschaft zur Foerderung der angewandten Forschung e.V., Huawei Technologies Co. LTD.,
Koninklijke Philips N.V., Nippon Telegraph and Telephone Corporation, Nokia Technologies Oy, Orange,
@@ -31,12 +30,12 @@ accordance with the laws of the Federal Republic of Germany excluding its confli
the United Nations Convention on Contracts on the International Sales of Goods.
"""

__doc__ = \
"""
__doc__ = """
Create short (5sec) testvectors.
"""

import sys
import argparse
from pathlib import Path
from cut_pcm import cut_samples

@@ -47,17 +46,27 @@ NUM_CHANNELS = "4" # currently only FOA
CUT_FROM = "0.0"
CUT_LEN = "5.0"

FILE_IDS = ["stvFOA"]
FILE_IDS = ["stvFOA", "stv20A", "stv3OA", "stv51MC", "stv71MC", "stv512MC", "stv514MC", "stv714MC", "ISM", "MASA"]
GAINS = ["1.0", "16.0", ".004"]


def collect_files():
    files = [f.absolute() for f in TEST_VECTOR_DIR.iterdir() if f.suffix == ".wav" and any([f.name.startswith(id) for id in FILE_IDS])]
def collect_files(file_ids):
    files = [
        f.absolute()
        for f in TEST_VECTOR_DIR.iterdir()
        if f.suffix == ".wav" and any([id in f.name for id in file_ids]) and not "_cut" in f.name
    ]
    return files


def create_short_testvectors():
    for f in collect_files():
def create_short_testvectors(which="foa"):
    file_ids = []
    if which == "all":
        file_ids = FILE_IDS
    elif which == "foa":
        file_ids = FILE_IDS[:1]

    for f in collect_files(file_ids):
        for g in GAINS:
            suffix = "_cut"
            if g != "1.0":
@@ -68,4 +77,7 @@ def create_short_testvectors():


if __name__ == "__main__":
    sys.exit(create_short_testvectors())
    parser = argparse.ArgumentParser()
    parser.add_argument("which", choices=["foa", "all"])
    args = parser.parse_args()
    sys.exit(create_short_testvectors(args.which))
+2 −5
Original line number Diff line number Diff line
@@ -87,11 +87,8 @@ def cut_samples(in_file, out_file, num_channels, start, duration, gain="1.0", sa
    num_in_samples = s.shape[0]
    num_samples_to_skip = int(start_sec * fs)
    dur_samples = int(dur_sec * fs)
    if num_samples_to_skip + dur_samples > num_in_samples:
        sys.exit(
            f"requested too many samples ({num_samples_to_skip}+{dur_samples})"
            + f" - input is too short ({num_in_samples})"
        )
    if num_samples_to_skip > dur_samples:
        raise ValueError(f"Requested to skip {num_samples_to_skip}, but file only has {dur_samples} samples")

    s_out = s[num_samples_to_skip:num_samples_to_skip + dur_samples, :] * gain_f