Commit 2c0e6187 authored by Jan Kiene's avatar Jan Kiene
Browse files

add writeout of encoder dmx to pytest

parent ad6652bf
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -205,8 +205,8 @@ def test_param_file_tests(
    get_enc_stats,
    get_odg,
    compare_to_input,
    compare_enc_dmx,
):

    enc_opts, dec_opts, sim_opts, eid_opts = param_file_test_dict[test_tag]

    run_test(
@@ -236,6 +236,7 @@ def test_param_file_tests(
        get_enc_stats,
        get_odg,
        compare_to_input,
        compare_enc_dmx,
    )


@@ -266,8 +267,8 @@ def run_test(
    get_enc_stats,
    get_odg,
    compare_to_input,
    compare_enc_dmx,
):

    # If compare_to_input is set, only run pass-through test cases
    if compare_to_input:
        passthrough = [
@@ -328,6 +329,7 @@ def run_test(
            enc_split,
            update_ref,
            get_enc_stats,
            compare_enc_dmx,
        )

        # compare binary files extracted from the encoder
@@ -489,7 +491,6 @@ def run_test(
    )

    if update_ref in [0, 2]:

        # Output file names for comparison
        dut_output_file = f"{dut_base_path}/param_file/dec/{output_file}"
        ref_output_file = f"{reference_path}/param_file/dec/{output_file}"
@@ -499,7 +500,6 @@ def run_test(
        odg_test = None
        odg_ref = None
        if get_odg:

            # Find input format
            in_fmt = [(a, b) for (a, b) in INPUT_FMT if re.search(a, enc_opts)][0][1]

@@ -661,6 +661,7 @@ def encode(
    enc_opts_list,
    update_ref,
    get_enc_stats=False,
    compare_enc_dmx=False,
):
    """
    Call REF and/or DUT encoder.
@@ -691,6 +692,7 @@ def encode(
            ref_out_file,
            add_option_list=enc_opts_list,
            stats_file=ref_stats_file,
            compare_enc_dmx=compare_enc_dmx,
        )

    if update_ref in [0, 2]:
@@ -704,6 +706,7 @@ def encode(
            dut_out_file,
            add_option_list=enc_opts_list,
            stats_file=dut_stats_file,
            compare_enc_dmx=compare_enc_dmx,
        )


+17 −4
Original line number Diff line number Diff line
@@ -33,7 +33,6 @@ __doc__ = """
    The outputs are compared with C generated references.
    """

import errno
import os
import pytest
from cut_bs import cut_from_start
@@ -83,6 +82,7 @@ ivas_br_plc = ["13200", "16400", "32000", "64000", "96000", "256000"]
# SBA order to IVAS_rend format table
SBA_FORMAT = {1: "FOA", 2: "HOA2", 3: "HOA3"}


@pytest.mark.parametrize("tag", tag_list)
@pytest.mark.parametrize("sampling_rate", sample_rate_list)
def test_pca_enc(
@@ -109,6 +109,7 @@ def test_pca_enc(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):
    pca = True
    bitrate = "256000"
@@ -156,6 +157,7 @@ def test_pca_enc(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

    if not encoder_only:
@@ -226,8 +228,8 @@ def test_sba_enc_system(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):

    plc_pattern = None
    pca = False
    max_bw = "FB"
@@ -297,6 +299,7 @@ def test_sba_enc_system(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

        if update_ref == 0 and get_enc_stats:
@@ -394,6 +397,7 @@ def test_spar_hoa2_enc_system(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):
    sampling_rate = "48"
    pca = False
@@ -439,6 +443,7 @@ def test_spar_hoa2_enc_system(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

        if update_ref == 0 and get_enc_stats:
@@ -536,6 +541,7 @@ def test_spar_hoa3_enc_system(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):
    sampling_rate = "48"
    pca = False
@@ -581,6 +587,7 @@ def test_spar_hoa3_enc_system(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

        if get_enc_stats:
@@ -676,6 +683,7 @@ def test_sba_enc_BWforce_system(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):
    sid = 0
    plc_pattern = None
@@ -731,6 +739,7 @@ def test_sba_enc_BWforce_system(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

        if update_ref == 0 and get_enc_stats:
@@ -836,6 +845,7 @@ def test_sba_plc_system(
    get_odg,
    get_enc_stats,
    compare_to_input,
    compare_enc_dmx,
):
    sid = 0
    pca = False
@@ -903,6 +913,7 @@ def test_sba_plc_system(
            pca=pca,
            plc_pattern=plc_pattern,
            get_enc_stats=get_enc_stats,
            compare_enc_dmx=compare_enc_dmx,
        )

    if not encoder_only:
@@ -964,8 +975,8 @@ def sba_enc(
    pca=False,
    plc_pattern=None,
    get_enc_stats=False,
    compare_enc_dmx=False,
):

    input_path = f"{test_vector_path}/{tag}.wav"
    dtx_mode = dtx == "1"

@@ -1028,6 +1039,7 @@ def sba_enc(
            pca=pca,
            dtx_mode=dtx_mode,
            stats_file=ref_stats_file,
            compare_enc_dmx=compare_enc_dmx,
        )

    if update_ref == 0:
@@ -1042,6 +1054,7 @@ def sba_enc(
            pca=pca,
            dtx_mode=dtx_mode,
            stats_file=dut_stats_file,
            compare_enc_dmx=compare_enc_dmx,
        )

    if sid == 1:
+42 −14
Original line number Diff line number Diff line
@@ -275,6 +275,13 @@ def pytest_addoption(parser):
        default=False,
    )

    parser.addoption(
        "--compare_enc_dmx",
        action="store_true",
        help="Trigger comparison of dmx signals written out from the encoder. If --update_ref is 1, dmx is written without comparison.",
        default=False,
    )


@pytest.fixture(scope="session", autouse=True)
def update_ref(request):
@@ -360,6 +367,11 @@ def compare_bitstream(request) -> bool:
    return request.config.option.compare_bitstream


@pytest.fixture(scope="session")
def compare_enc_dmx(request) -> bool:
    return request.config.option.compare_enc_dmx


@pytest.fixture(scope="session")
def dut_encoder_path(request) -> str:
    """
@@ -387,11 +399,13 @@ def dut_encoder_path(request) -> str:

    return path


# fixture returns test information, enabling per-testcase SNR
@pytest.fixture
def test_info(request):
    return request


class EncoderFrontend:
    def __init__(self, path, enc_type, record_property, timeout=None) -> None:
        self._path = Path(path).absolute()
@@ -492,6 +506,7 @@ class EncoderFrontend:
        add_option_list: Optional[list] = None,
        run_dir: Optional[Path] = None,
        stats_file: Optional[Path] = None,
        compare_enc_dmx: Optional[bool] = False,
    ) -> None:
        command = [str(self._path)]

@@ -525,13 +540,13 @@ class EncoderFrontend:
        cmd_str = textwrap.indent(" ".join(command), prefix="\t")
        log_dbg_msg(f"{self._type} encoder command:\n{cmd_str}")

        try:
        with tempfile.TemporaryDirectory() as tmp_dir:
            if run_dir is None:
                cwd = Path(tmp_dir).absolute()
            else:
                cwd = Path(run_dir).absolute()

            try:
                result = run(
                    command,
                    capture_output=True,
@@ -539,14 +554,28 @@ class EncoderFrontend:
                    timeout=self.timeout,
                    cwd=cwd,
                )
            except TimeoutExpired:
                pytest.fail(
                    f"{self._type} encoder run timed out after {self.timeout}s."
                )

            if stats_file is not None:
                self.extract_enc_stats(
                    cwd.joinpath("res"), stats_file, input_sampling_rate
                )

        except TimeoutExpired:
            pytest.fail(f"{self._type} encoder run timed out after {self.timeout}s.")
            if compare_enc_dmx:
                for dmx_file in cwd.glob("res/ivas_input_dmx.id*.pcm"):
                    id_match = re.search(r"id(\d+).pcm", dmx_file.name)
                    if id_match is None:
                        pytest.fail(
                            "No dmx signal files found - did you build with DEBUG_MODE_INFO?"
                        )
                    assert id_match is not None
                    id = id_match.group(1)
                    dmx_file.rename(
                        Path(output_bitstream_path).with_suffix(f".dmx.ch{id}.pcm")
                    )

        self.returncode = result.returncode
        self.stderr = result.stderr.decode("ascii")
@@ -1052,7 +1081,6 @@ def compare_to_input(request) -> bool:
    return request.config.getoption("--compare_to_input")



def pytest_configure(config):
    config.addinivalue_line("markers", "serial: mark test to run only in serial")
    if config.option.param_file: