Commit b01c431c authored by TYAGIRIS's avatar TYAGIRIS
Browse files

fix short test vectors to generate cut samples for all sba inputs

parent 792ad1da
Loading
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -454,18 +454,19 @@ def encode(
def pre_proc_input(testv_file, fs):
    cut_from = "0.0"
    cut_len = "5.0"
    cut_gain = "0.004"
    cut_gain = ".004"
    if "stvFOA" in testv_file:
        num_channel = "4"
    elif "stv2OA" in testv_file:
        num_channel = "9"
    elif "stv3OA" in testv_file:
        num_channel = "16"
    cut_file = testv_file.replace(".wav", num_channel + "chn_" + cut_gain + ".wav")
    cut_file = testv_file.replace(".wav", "_cut_" + cut_gain + ".wav")
    cut_file_pre_exist = 1;
    if not os.path.exists(cut_file):
        tmpf = tempfile.TemporaryFile()
        cut_file = tmpf.name + num_channel + "chn_" + cut_gain + ".wav"
        cut_file = tmpf.name
        cut_file += "_cut_" + cut_gain + ".wav"
        cut_file_pre_exist = 0;
    
    cut_samples(testv_file, cut_file, num_channel, cut_from, cut_len, cut_gain)
+4 −4
Original line number Diff line number Diff line
@@ -72,12 +72,12 @@ def collect_files(file_ids):
    return files


def create_short_testvectors(which="foa", cut_len=5.0):
def create_short_testvectors(which="sba", cut_len=5.0):
    file_ids = []
    if which == "all":
        file_ids = FILE_IDS
    elif which == "foa":
        file_ids = FILE_IDS[:1]
    elif which == "sba":
        file_ids = FILE_IDS[:3]

    for f in collect_files(file_ids):
        for g in GAINS:
@@ -91,7 +91,7 @@ def create_short_testvectors(which="foa", cut_len=5.0):

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--which", choices=["foa", "all"], default="foa")
    parser.add_argument("--which", choices=["sba", "all"], default="sba")

    def positive_float(x: str) -> float:
        x = float(x)