Commit 49984457 authored by norvell's avatar norvell
Browse files

Move MLD check first since it uses files without length alignment

parent 7081fa6d
Loading
Loading
Loading
Loading
Loading
+79 −75
Original line number Diff line number Diff line
@@ -306,6 +306,85 @@ def compare(

    framesize = fs // 50

    # MLD (wav-diff) tool is run first, since it uses the input signals without length difference check for JBM test cases.
    if get_mld:

        def parse_wav_diff(proc: subprocess.CompletedProcess) -> float:
            if proc.returncode:
                raise ChildProcessError(f"{proc.stderr}\n{proc.stdout}")
            line = proc.stdout.splitlines()[-1].strip()
            start = line.find(">") + 1
            stop = line.rfind("<")
            mld = float(line[start:stop].strip())

            return mld

        mld_max = 0
        toolsdir = Path(__file__).parent.parent.joinpath("tools")

        curr_platform = platform.system()
        if curr_platform not in {"Windows", "Linux", "Darwin"}:
            raise NotImplementedError(
                f"wav-diff tool not available for {curr_platform}"
            )
        
        search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
        wdiff = search_path.joinpath("wav-diff").with_suffix(".exe" if curr_platform == "Windows" else "")

        if not wdiff.exists():
            wdiff = shutil.which("wav-diff")
            if wdiff is None:
                raise FileNotFoundError(
                    f"wav-diff tool not found in {search_path} or PATH!"
                )

        with tempfile.TemporaryDirectory() as tmpdir:
            tmpfile_ref = Path(tmpdir).joinpath("ref.wav")
            tmpfile_test = Path(tmpdir).joinpath("test.wav")

            ### need to resample to 48kHz for MLD computation to be correct
            ### write out and delete tmp variables to reduce memory usage
            if fs != 48000:
                ref_tmp = np.clip(
                    resample(ref.astype(float), fs, 48000), -32768, 32767
                ).astype(np.int16)
                wavfile.write(str(tmpfile_ref), 48000, ref_tmp)
                del ref_tmp
                test_tmp = np.clip(
                    resample(test.astype(float), fs, 48000), -32768, 32767
                ).astype(np.int16)
                wavfile.write(str(tmpfile_test), 48000, test_tmp)
                del test_tmp
            else:
                wavfile.write(str(tmpfile_ref), 48000, ref)
                wavfile.write(str(tmpfile_test), 48000, test)

            cmd = [
                str(wdiff),
                "--print-ctest-measurement",
                # wav-diff return code is 1 if differences are found which
                # would cause parse_wav_diff to raise an Exception on these cases
                "--no-fail",
                str(tmpfile_ref),
                str(tmpfile_test),
            ]
            if ref_jbm_tf and test_jbm_tf:
                cmd.extend(
                    [
                        "--ref-jbm-trace",
                        str(ref_jbm_tf),
                        "--cut-jbm-trace",
                        str(test_jbm_tf),
                    ]
                )
            proc = subprocess.run(cmd, capture_output=True, text=True)
            mld_max = parse_wav_diff(proc)

        result["MLD"] = mld_max


    # Run remanining tests after checking if the lenght differs

    lengths_differ = ref.shape[0] != test.shape[0]

    if lengths_differ:
@@ -403,81 +482,6 @@ def compare(
            result["nframes_diff"] = nframes_diff
            result["nframes_diff_percentage"] = nframes_diff_percentage

        if get_mld:

            def parse_wav_diff(proc: subprocess.CompletedProcess) -> float:
                if proc.returncode:
                    raise ChildProcessError(f"{proc.stderr}\n{proc.stdout}")
                line = proc.stdout.splitlines()[-1].strip()
                start = line.find(">") + 1
                stop = line.rfind("<")
                mld = float(line[start:stop].strip())

                return mld

            mld_max = 0
            toolsdir = Path(__file__).parent.parent.joinpath("tools")

            curr_platform = platform.system()
            if curr_platform not in {"Windows", "Linux", "Darwin"}:
                raise NotImplementedError(
                    f"wav-diff tool not available for {curr_platform}"
                )
            
            search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
            wdiff = search_path.joinpath("wav-diff").with_suffix(".exe" if curr_platform == "Windows" else "")

            if not wdiff.exists():
                wdiff = shutil.which("wav-diff")
                if wdiff is None:
                    raise FileNotFoundError(
                        f"wav-diff tool not found in {search_path} or PATH!"
                    )

            with tempfile.TemporaryDirectory() as tmpdir:
                tmpfile_ref = Path(tmpdir).joinpath("ref.wav")
                tmpfile_test = Path(tmpdir).joinpath("test.wav")

                ### need to resample to 48kHz for MLD computation to be correct
                ### write out and delete tmp variables to reduce memory usage
                if fs != 48000:
                    ref_tmp = np.clip(
                        resample(ref.astype(float), fs, 48000), -32768, 32767
                    ).astype(np.int16)
                    wavfile.write(str(tmpfile_ref), 48000, ref_tmp)
                    del ref_tmp
                    test_tmp = np.clip(
                        resample(test.astype(float), fs, 48000), -32768, 32767
                    ).astype(np.int16)
                    wavfile.write(str(tmpfile_test), 48000, test_tmp)
                    del test_tmp
                else:
                    wavfile.write(str(tmpfile_ref), 48000, ref)
                    wavfile.write(str(tmpfile_test), 48000, test)

                cmd = [
                    str(wdiff),
                    "--print-ctest-measurement",
                    # wav-diff return code is 1 if differences are found which
                    # would cause parse_wav_diff to raise an Exception on these cases
                    "--no-fail",
                    str(tmpfile_ref),
                    str(tmpfile_test),
                ]
                if ref_jbm_tf and test_jbm_tf:
                    cmd.extend(
                        [
                            "--ref-jbm-trace",
                            str(ref_jbm_tf),
                            "--cut-jbm-trace",
                            str(test_jbm_tf),
                        ]
                    )
                proc = subprocess.run(cmd, capture_output=True, text=True)
                mld_max = parse_wav_diff(proc)

            result["MLD"] = mld_max

        if get_ssnr:
            # length of segment is always 20ms
            len_seg = int(0.02 * fs)