diff --git a/scripts/prepare_combined_format_inputs.py b/scripts/prepare_combined_format_inputs.py index fe80c426820d477f385d6e0c3d8c06b87649f632..f5bca8ed066cc4cfacb93fde998267b4b7d0f0d3 100755 --- a/scripts/prepare_combined_format_inputs.py +++ b/scripts/prepare_combined_format_inputs.py @@ -38,102 +38,106 @@ import numpy as np from pyaudio3dtools import audiofile, audioarray -FS = [48, 32, 16] - -# scripts/testv/ path -input_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testv') -output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testv') - -print(f'Writing new files into {output_dir}') - -# # prepare combined input for OMASA tests -force_overwrite = False # overwrite existing files - -# define MASA test item files -masa_alts = ({'masa_meta_file': 'stv2MASA2TC{}c.met', - 'masa_audio_file': 'stv2MASA2TC{}c.wav', - 'masa_tag': '2MASA2TC'}, - {'masa_meta_file': 'stv2MASA1TC{}c.met', - 'masa_audio_file': 'stv2MASA1TC{}c.wav', - 'masa_tag': '2MASA1TC'}, - {'masa_meta_file': 'stv1MASA2TC{}c.met', - 'masa_audio_file': 'stv1MASA2TC{}c.wav', - 'masa_tag': '1MASA2TC'}, - {'masa_meta_file': 'stv1MASA1TC{}c.met', - 'masa_audio_file': 'stv1MASA1TC{}c.wav', - 'masa_tag': '1MASA1TC'}) - -sba_alts = ({'sba_audio_file': 'stvFOA{}c.wav', - 'sba_tag': 'FOA'}, - {'sba_audio_file': 'stv2OA{}c.wav', - 'sba_tag': '2OA'}, - {'sba_audio_file': 'stv3OA{}c.wav', - 'sba_tag': '3OA'}) - -# files containing 1-4 ISMs as channels -ism_files = ('stv1ISM48s.wav', 'stv2ISM48s.wav', 'stv3ISM48s.wav', 'stv4ISM48s.wav') -# per-object metadata -ism_meta_files = ('stvISM1.csv', 'stvISM2.csv', 'stvISM3.csv', 'stvISM4.csv') - -wrote_files = [] -for fs in FS: - for enum_idx, (ism_audio_file, ism_meta_file) in enumerate(zip(ism_files, ism_meta_files)): - n_isms = enum_idx + 1 - ism_audio, ism_fs = audiofile.readfile(filename=os.path.join(input_dir, ism_audio_file)) - - # no stv MASA files in other sampling rates available currently - if fs == 48: - for masa_item in masa_alts: - masa_tag = masa_item['masa_tag'] - masa_audio_file = masa_item['masa_audio_file'].format(fs) - meta_file = masa_item['masa_meta_file'].format(fs) - - omasa_file_body = f'stvOMASA_{n_isms}ISM_{masa_tag}{fs}c' - omasa_file = os.path.join(output_dir, f'{omasa_file_body}.wav') - - if not os.path.exists(omasa_file) or force_overwrite: +def main(): + FS = [48, 32, 16] + + # scripts/testv/ path + input_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testv') + output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testv') + + print(f'Writing new files into {output_dir}') + + # # prepare combined input for OMASA tests + force_overwrite = False # overwrite existing files + + # define MASA test item files + masa_alts = ({'masa_meta_file': 'stv2MASA2TC{}c.met', + 'masa_audio_file': 'stv2MASA2TC{}c.wav', + 'masa_tag': '2MASA2TC'}, + {'masa_meta_file': 'stv2MASA1TC{}c.met', + 'masa_audio_file': 'stv2MASA1TC{}c.wav', + 'masa_tag': '2MASA1TC'}, + {'masa_meta_file': 'stv1MASA2TC{}c.met', + 'masa_audio_file': 'stv1MASA2TC{}c.wav', + 'masa_tag': '1MASA2TC'}, + {'masa_meta_file': 'stv1MASA1TC{}c.met', + 'masa_audio_file': 'stv1MASA1TC{}c.wav', + 'masa_tag': '1MASA1TC'}) + + sba_alts = ({'sba_audio_file': 'stvFOA{}c.wav', + 'sba_tag': 'FOA'}, + {'sba_audio_file': 'stv2OA{}c.wav', + 'sba_tag': '2OA'}, + {'sba_audio_file': 'stv3OA{}c.wav', + 'sba_tag': '3OA'}) + + # files containing 1-4 ISMs as channels + ism_files = ('stv1ISM48s.wav', 'stv2ISM48s.wav', 'stv3ISM48s.wav', 'stv4ISM48s.wav') + # per-object metadata + ism_meta_files = ('stvISM1.csv', 'stvISM2.csv', 'stvISM3.csv', 'stvISM4.csv') + + wrote_files = [] + for fs in FS: + for enum_idx, (ism_audio_file, ism_meta_file) in enumerate(zip(ism_files, ism_meta_files)): + n_isms = enum_idx + 1 + ism_audio, ism_fs = audiofile.readfile(filename=os.path.join(input_dir, ism_audio_file)) + + # no stv MASA files in other sampling rates available currently + if fs == 48: + for masa_item in masa_alts: + masa_tag = masa_item['masa_tag'] + masa_audio_file = masa_item['masa_audio_file'].format(fs) + meta_file = masa_item['masa_meta_file'].format(fs) + + omasa_file_body = f'stvOMASA_{n_isms}ISM_{masa_tag}{fs}c' + omasa_file = os.path.join(output_dir, f'{omasa_file_body}.wav') + + if not os.path.exists(omasa_file) or force_overwrite: + audiofile.combinefiles(in_filenames=[os.path.join(input_dir, ism_audio_file), + os.path.join(input_dir, masa_audio_file)], + out_file=omasa_file, + in_fs=fs * 1000) + wrote_files.append(omasa_file) + + # copy ISM metadata files under names matching the combined file + for ism_idx in range(n_isms): + ism_file_name = os.path.join(output_dir, f'{omasa_file_body}_ISM{ism_idx+1}{os.path.splitext(ism_meta_file)[1]}') + if not os.path.exists(ism_file_name) or force_overwrite: + shutil.copyfile(os.path.join(input_dir, ism_meta_file), ism_file_name) + wrote_files.append(ism_file_name) + + # copy MASA metadata file under a matching name + masa_meta_name = os.path.join(output_dir, f'{omasa_file_body}{os.path.splitext(meta_file)[1]}') + + if not os.path.exists(masa_meta_name) or force_overwrite: + shutil.copyfile(os.path.join(input_dir, meta_file), masa_meta_name) + wrote_files.append(masa_meta_name) + + for sba_item in sba_alts: + sba_tag = sba_item['sba_tag'] + sba_audio_file = sba_item['sba_audio_file'].format(fs) + + osba_file_body = f'stvOSBA_{n_isms}ISM_{sba_tag}{fs}c' + osba_file = os.path.join(output_dir, f'{osba_file_body}.wav') + + if not os.path.exists(osba_file) or force_overwrite: audiofile.combinefiles(in_filenames=[os.path.join(input_dir, ism_audio_file), - os.path.join(input_dir, masa_audio_file)], - out_file=omasa_file, + os.path.join(input_dir, sba_audio_file)], + out_file=osba_file, in_fs=fs * 1000) - wrote_files.append(omasa_file) + wrote_files.append(osba_file) # copy ISM metadata files under names matching the combined file for ism_idx in range(n_isms): - ism_file_name = os.path.join(output_dir, f'{omasa_file_body}_ISM{ism_idx+1}{os.path.splitext(ism_meta_file)[1]}') + ism_file_name = os.path.join(output_dir, f'{osba_file_body}_ISM{ism_idx+1}{os.path.splitext(ism_meta_file)[1]}') if not os.path.exists(ism_file_name) or force_overwrite: shutil.copyfile(os.path.join(input_dir, ism_meta_file), ism_file_name) wrote_files.append(ism_file_name) - # copy MASA metadata file under a matching name - masa_meta_name = os.path.join(output_dir, f'{omasa_file_body}{os.path.splitext(meta_file)[1]}') - if not os.path.exists(masa_meta_name) or force_overwrite: - shutil.copyfile(os.path.join(input_dir, meta_file), masa_meta_name) - wrote_files.append(masa_meta_name) + # info print. helps setting up .gitignore + if len(wrote_files) > 0: + print('New files written: {}'.format('\n'.join(wrote_files))) - for sba_item in sba_alts: - sba_tag = sba_item['sba_tag'] - sba_audio_file = sba_item['sba_audio_file'].format(fs) - - osba_file_body = f'stvOSBA_{n_isms}ISM_{sba_tag}{fs}c' - osba_file = os.path.join(output_dir, f'{osba_file_body}.wav') - - if not os.path.exists(osba_file) or force_overwrite: - audiofile.combinefiles(in_filenames=[os.path.join(input_dir, ism_audio_file), - os.path.join(input_dir, sba_audio_file)], - out_file=osba_file, - in_fs=fs * 1000) - wrote_files.append(osba_file) - - # copy ISM metadata files under names matching the combined file - for ism_idx in range(n_isms): - ism_file_name = os.path.join(output_dir, f'{osba_file_body}_ISM{ism_idx+1}{os.path.splitext(ism_meta_file)[1]}') - if not os.path.exists(ism_file_name) or force_overwrite: - shutil.copyfile(os.path.join(input_dir, ism_meta_file), ism_file_name) - wrote_files.append(ism_file_name) - - -# info print. helps setting up .gitignore -if len(wrote_files) > 0: - print('New files written: {}'.format('\n'.join(wrote_files))) +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/scripts/self_test.py b/scripts/self_test.py index 418f19b27dec31f7f29172e4f05152bb0b52651f..c6853c1eb04da908a0f69c06de6a7d08af345421 100755 --- a/scripts/self_test.py +++ b/scripts/self_test.py @@ -50,6 +50,8 @@ import multiprocessing import tempfile import urllib.parse import shutil +import prepare_combined_format_inputs +import errno BW_TO_SR = {"nb": 8, "wb": 16, "swb": 32, "fb": 48} @@ -993,6 +995,18 @@ class SelfTest(IvasScriptsCommon.IvasScript): {oc: list(map(self.test_for_file, dec_cmd))} ) + # check if we need to produce missing combined formats + if not os.path.exists(in_file): + if is_osba_or_omasa: + self.logger.info("Creating missing combined formats test vectors...") + prepare_combined_format_inputs.main() + if not os.path.exists(in_file): + self.logger.error(f"Test vector {in_file} does not exist!") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), in_file) + else: + self.logger.error(f"Test vector {in_file} does not exist!") + raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), in_file) + # handle intermediate processing steps, e.g. networksimulator, eid-xor, ... in_file = bs_enc_file proc_cmds = [] @@ -1008,7 +1022,7 @@ class SelfTest(IvasScriptsCommon.IvasScript): proc_cmd[0] = os.path.join(TOOLS_DIR, proc_cmd[0]) proc_cmd = [ - "{in_file}" if x == in_file else self.test_for_file(x) for x in proc_cmd + "{in_file}" if x == in_file else x for x in proc_cmd ] if mode[1]: for cmdline_arg in proc_cmd[1:]: diff --git a/tests/conftest.py b/tests/conftest.py index 3b2d5944db1e4cc1f465f86039e2c937d03eb185..14f78f2c3bb8c760e67304002857f2e19acca293 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -46,6 +46,12 @@ import tempfile logger = logging.getLogger(__name__) USE_LOGGER_FOR_DBG = False # current tests do not make use of the logger feature +HERE = Path(__file__).parent +SCRIPTS_DIR = str(HERE.parent.joinpath("scripts").absolute()) +import sys +sys.path.append(SCRIPTS_DIR) +import prepare_combined_format_inputs + def log_dbg_msg(message): """ @@ -175,6 +181,13 @@ def update_ref(request): """ return int(request.config.getoption("--update_ref")) +@pytest.fixture(scope="session", autouse=True) +def create_combined_formats_testvectors(request): + """ + Create input stv files for the combined formats if missing + """ + + prepare_combined_format_inputs.main() @pytest.fixture(scope="session", autouse=True) def get_mld(request):