Commit b9053e97 authored by Jan Kiene's avatar Jan Kiene
Browse files

use length-adjusted files for ssnr too

parent 491494f0
Loading
Loading
Loading
Loading
Loading
+14 −10
Original line number Diff line number Diff line
@@ -307,6 +307,10 @@ def compare(
    framesize = fs // 50

    lengths_differ = ref.shape[0] != test.shape[0]

    test_orig = test.copy()
    ref_orig = ref.copy()

    if lengths_differ:
        if handle_differing_lengths == "fail":
            raise RuntimeError(
@@ -314,23 +318,23 @@ def compare(
            )
        elif handle_differing_lengths == "cut":
            min_len = min(ref.shape[0], test.shape[0])
            diff = abs(test[:min_len, :] - ref[:min_len, :])
            ref = ref[:min_len, :]
            test = test[:min_len, :]
        elif handle_differing_lengths == "pad":
            max_len = max(ref.shape[0], test.shape[0])
            ref_pad = np.pad(
            ref = np.pad(
                ref,
                ((0, max_len - ref.shape[0]), (0, 0)),
                mode="constant",
                constant_values=0,
            )
            test_pad = np.pad(
            test = np.pad(
                test,
                ((0, max_len - test.shape[0]), (0, 0)),
                mode="constant",
                constant_values=0,
            )
            diff = abs(test_pad - ref_pad)
    else:

    diff = abs(test - ref)

    max_diff = int(diff.max())
@@ -439,14 +443,14 @@ def compare(
                ### need to resample to 48kHz for MLD computation to be correct
                if fs != 48000:
                    ref_tmp = np.clip(
                        resample(ref.astype(float), fs, 48000), -32768, 32767
                        resample(ref_orig.astype(float), fs, 48000), -32768, 32767
                    )
                    test_tmp = np.clip(
                        resample(test.astype(float), fs, 48000), -32768, 32767
                        resample(test_orig.astype(float), fs, 48000), -32768, 32767
                    )
                else:
                    ref_tmp = ref.copy()
                    test_tmp = test.copy()
                    ref_tmp = ref_orig.copy()
                    test_tmp = test_orig.copy()

                wavfile.write(str(tmpfile_ref), 48000, ref_tmp.astype(np.int16))
                wavfile.write(str(tmpfile_test), 48000, test_tmp.astype(np.int16))