Commit 38e796fe authored by Jan Kiene's avatar Jan Kiene
Browse files

handle differing lengths update to have orig files in wav-diff

parent 778ee967
Loading
Loading
Loading
Loading
Loading
+48 −7
Original line number Diff line number Diff line
@@ -39,7 +39,8 @@ import shutil
import subprocess
import tempfile
from pathlib import Path
from typing import Callable, Iterable, Optional, Tuple
from enum import Enum
from typing import Callable, Iterable, Optional, Tuple, Union

import numpy as np
import scipy.io.wavfile as wavfile
@@ -228,6 +229,12 @@ def cut(x: np.ndarray, limits: Tuple[int, int]) -> np.ndarray:
    return y


class HandleDifferingLengths(str, Enum):
    FAIL = "fail"
    PAD = "pad"
    CUT = "cut"


def compare(
    ref: np.ndarray,
    test: np.ndarray,
@@ -239,8 +246,9 @@ def compare(
    ssnr_thresh_high: float = np.inf,
    apply_thresholds_to_ref_only: bool = False,
    test_start_offset_ms: int = 0,
    ref_jbm_tf: Optional[Path] = None,
    test_jbm_tf: Optional[Path] = None,
    ref_jbm_tf: Optional[Union[Path, str]] = None,
    test_jbm_tf: Optional[Union[Path, str]] = None,
    handle_differing_lengths: HandleDifferingLengths = "fail",
) -> dict:
    """Compare two audio arrays

@@ -271,6 +279,17 @@ def compare(
    test_start_offset_ms: (non-negative) int
        offset in miliseconds for test signal. If > 0, the corresponding number of samples
        will be removed from the test array like so: test = test[sample_offset:, :].
    ref_jbm_tf: str|Path
        tracefile for ref signal for wav-diff based MLD comparison of JBM output
    test_jbm_tf: str|Path
        tracefile for test signal for wav-diff based MLD comparison of JBM output
    handle_differing_lengths: one of "fail", "pad", "cut"
        how to handle differing lengths in the input signals
        "fail" - raise error
        "pad" - pad shorter file with zeros
        "cut" - cut longer file to length of shorter one

        Note that external tools such as wav-diff (for mld) always use the unmodified files

    Returns
    -------
@@ -286,11 +305,34 @@ def compare(
    test = test[test_start_offset_samples:, :]

    framesize = fs // 50
    if ref.shape[0] != test.shape[0]:

    lengths_differ = ref.shape[0] != test.shape[0]
    if lengths_differ:
        if handle_differing_lengths == "fail":
            raise RuntimeError(
                f"Input signals have different lengths: ref - {ref.shape[0]}, test - {test.shape[0]}"
            )
        elif handle_differing_lengths == "cut":
            min_len = min(ref.shape[0], test.shape[0])
            diff = abs(test[:min_len, :] - ref[:min_len, :])
        elif handle_differing_lengths == "pad":
            max_len = max(ref.shape[0], test.shape[0])
            ref_pad = np.pad(
                ref,
                ((0, max_len - ref.shape[0]), (0, 0)),
                mode="constant",
                constant_values=0,
            )
            test_pad = np.pad(
                test,
                ((0, max_len - test.shape[0]), (0, 0)),
                mode="constant",
                constant_values=0,
            )
            diff = abs(test_pad - ref_pad)
    else:
        diff = abs(test - ref)

    max_diff = int(diff.max())
    result = {
        "bitexact": True,
@@ -397,7 +439,6 @@ def compare(
                tmpfile_ref = Path(tmpdir).joinpath("ref.wav")
                tmpfile_test = Path(tmpdir).joinpath("test.wav")


                ### need to resample to 48kHz for MLD computation to be correct
                if fs != 48000:
                    ref_tmp = np.clip(
+3 −8
Original line number Diff line number Diff line
@@ -78,15 +78,9 @@ def cmp_pcm(
        reason = "FAIL: Number of channels differ."
        return 1, reason

    handle_differing_lengths = "fail"
    if allow_differing_lengths:
        # to allow for MLD comparison, pad shorter file
        max_len = max(s1.shape[0], s2.shape[0])
        s1 = np.pad(
            s1, ((0, max_len - s1.shape[0]), (0, 0)), mode="constant", constant_values=0
        )
        s2 = np.pad(
            s2, ((0, max_len - s2.shape[0]), (0, 0)), mode="constant", constant_values=0
        )
        handle_differing_lengths = "pad"
    elif s1.shape != s2.shape:
        print(
            f"file size in samples: file 1 = {s1.shape[0]},",
@@ -108,6 +102,7 @@ def cmp_pcm(
        ssnr_thresh_low=-50,
        ref_jbm_tf=ref_jbm_tf,
        test_jbm_tf=cut_jbm_tf,
        handle_differing_lengths=handle_differing_lengths,
    )

    output_differs = 0