Commit 949e07ed authored by Jan Kiene's avatar Jan Kiene
Browse files

use context manager for getting the bitstream path

parent 896b9372
Loading
Loading
Loading
Loading
+52 −37
Original line number Diff line number Diff line
@@ -31,6 +31,10 @@ the United Nations Convention on Contracts on the International Sales of Goods.
import pytest
from . import is_be_to_ref, get_bitstream_path, get_testv_path, REF_PATH, DUT_PATH
from .constants import *
from pathlib import Path
import subprocess
from contextlib import contextmanager
from tempfile import TemporaryDirectory


### --------------- Helper functions ---------------
@@ -43,6 +47,23 @@ def get_output_path(bitstream_path, output_format, output_sampling_rate):
    return DUT_PATH.joinpath(output_name)


@contextmanager
def get_bitstream(testv_name, encoder_format, bitrate, sampling_rate, dtx, suffix="", processing=None):
    """
    Utility to get either the stored reference bitstream or the processed version as a temporary file
    """
    with TemporaryDirectory() as tmp_dir:
        bitstream = get_bitstream_path(REF_PATH, testv_name, encoder_format, bitrate, sampling_rate, dtx, suffix)
        if processing == "FER_15":
            bitstream_out = Path(tmp_dir).joinpath(bitstream.stem + f".{processing}.192")
            ep_path = REF_PATH.joinpath("ltv_ep_015.192")
            cmd = ["eid-xor", "-fer", "-vbr", str(bitstream), str(ep_path), str(bitstream_out)]
            subprocess.run(cmd)
            bitstream = bitstream_out

        yield bitstream


def run_check(
    ref_bitstream,
    output_format,
@@ -77,9 +98,7 @@ def test_decoder_clean_channel_channelbased_and_masa(
    update_ref,
):
    testv_name = get_testv_path(input_format, input_sampling_rate).stem
    ref_bitstream = get_bitstream_path(
        REF_PATH, testv_name, input_format, bitrate, input_sampling_rate, dtx
    )
    with get_bitstream(testv_name, input_format, bitrate, input_sampling_rate, dtx) as ref_bitstream:
        dut_output = get_output_path(ref_bitstream, output_format, output_sampling_rate)

        run_check(
@@ -112,9 +131,7 @@ def test_decoder_clean_channel_objectbased(
        suffix = "_ext_MD"
    elif md_type == ISM_MD_NULL:
        suffix = "_null_MD"
    ref_bitstream = get_bitstream_path(
        REF_PATH, testv_name, input_format, bitrate, input_sampling_rate, dtx, suffix=suffix
    )
    with get_bitstream(testv_name, input_format, bitrate, input_sampling_rate, dtx, suffix=suffix) as ref_bitstream:
        dut_output = get_output_path(ref_bitstream, output_format, output_sampling_rate)

        run_check(
@@ -145,9 +162,7 @@ def test_decoder_clean_channel_scenebased(
    suffix = ""
    if pca == SBA_FOA_PCA_ON:
        suffix="-pca"
    ref_bitstream = get_bitstream_path(
        REF_PATH, testv_name, input_format, bitrate, input_sampling_rate, dtx, suffix=suffix
    )
    with get_bitstream(testv_name, input_format, bitrate, input_sampling_rate, dtx, suffix=suffix) as ref_bitstream:
        dut_output = get_output_path(ref_bitstream, output_format, output_sampling_rate)

        run_check(