Commit 18a46e06 authored by Jan Kiene's avatar Jan Kiene
Browse files

write out tmp files for wav-diff in 24bit format

parent 61075657
Loading
Loading
Loading
Loading
Loading
+126 −80
Original line number Diff line number Diff line
@@ -38,6 +38,9 @@ import platform
import shutil
import subprocess
import tempfile
import wave
import struct
import itertools
from pathlib import Path
from enum import Enum
from typing import Callable, Iterable, Optional, Tuple, Union
@@ -46,6 +49,8 @@ import numpy as np
import scipy.io.wavfile as wavfile
import scipy.signal as sig

from . import audiofile

main_logger = logging.getLogger("__main__")
logger = main_logger.getChild(__name__)
logger.setLevel(logging.DEBUG)
@@ -329,83 +334,16 @@ def compare(
        result["nframes_diff"] = 0
        result["nframes_diff_percentage"] = 0.0

    # MLD (wav-diff) tool is run first, since it uses the input signals without length difference check for JBM test cases.
    if get_mld:

        def parse_wav_diff(proc: subprocess.CompletedProcess) -> float:
            if proc.returncode:
                raise ChildProcessError(f"{proc.stderr}\n{proc.stdout}")
            line = proc.stdout.splitlines()[-1].strip()
            start = line.find(">") + 1
            stop = line.rfind("<")
            mld = float(line[start:stop].strip())

            return mld

        mld_max = 0
        toolsdir = Path(__file__).parent.parent.joinpath("tools")

        curr_platform = platform.system()
        if curr_platform not in {"Windows", "Linux", "Darwin"}:
            raise NotImplementedError(
                f"wav-diff tool not available for {curr_platform}"
            )

        search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
        wdiff = search_path.joinpath("wav-diff").with_suffix(
            ".exe" if curr_platform == "Windows" else ""
        )

        if not wdiff.exists():
            wdiff = shutil.which("wav-diff")
            if wdiff is None:
                raise FileNotFoundError(
                    f"wav-diff tool not found in {search_path} or PATH!"
                )

        with tempfile.TemporaryDirectory() as tmpdir:
            tmpfile_ref = Path(tmpdir).joinpath("ref.wav")
            tmpfile_test = Path(tmpdir).joinpath("test.wav")

            ### need to resample to 48kHz for MLD computation to be correct
            ### write out and delete tmp variables to reduce memory usage
            if fs != 48000:
                ref_tmp = np.clip(
                    resample(ref.astype(float), fs, 48000), -32768, 32767
                ).astype(np.int16)
                wavfile.write(str(tmpfile_ref), 48000, ref_tmp)
                del ref_tmp
                test_tmp = np.clip(
                    resample(test.astype(float), fs, 48000), -32768, 32767
                ).astype(np.int16)
                wavfile.write(str(tmpfile_test), 48000, test_tmp)
                del test_tmp
    assert ref.shape == test.shape
    if ref.ndim == 1:
        nsamples_total = ref.shape
        nchannels = 1
    else:
                wavfile.write(str(tmpfile_ref), 48000, ref.astype(np.int16))
                wavfile.write(str(tmpfile_test), 48000, test.astype(np.int16))

            cmd = [
                str(wdiff),
                "--print-ctest-measurement",
                # wav-diff return code is 1 if differences are found which
                # would cause parse_wav_diff to raise an Exception on these cases
                "--no-fail",
                str(tmpfile_ref),
                str(tmpfile_test),
            ]
            if ref_jbm_tf and test_jbm_tf:
                cmd.extend(
                    [
                        "--ref-jbm-trace",
                        str(ref_jbm_tf),
                        "--cut-jbm-trace",
                        str(test_jbm_tf),
                    ]
                )
            proc = subprocess.run(cmd, capture_output=True, text=True)
            mld_max = parse_wav_diff(proc)
        nsamples_total, nchannels = ref.shape

        result["MLD"] = mld_max
    # MLD (wav-diff) tool is run first, since it uses the input signals without length difference check for JBM test cases.
    if get_mld:
        result["MLD"] = run_wavdiff(ref, test, fs, nchannels, ref_jbm_tf, test_jbm_tf)

    # Run remanining tests after checking if the lenght differs

@@ -440,11 +378,6 @@ def compare(
    max_diff = int(diff.max())

    if max_diff != 0:
        if diff.ndim == 1:
            nsamples_total = diff.shape
            nchannels = 1
        else:
            nsamples_total, nchannels = diff.shape
        max_diff_pos = np.nonzero(diff == max_diff)
        max_diff_pos = [
            max_diff_pos[0][0],
@@ -499,6 +432,119 @@ def compare(
    return result


def parse_wav_diff(proc: subprocess.CompletedProcess) -> float:
    if proc.returncode:
        raise ChildProcessError(f"{proc.stderr}\n{proc.stdout}")
    line = proc.stdout.splitlines()[-1].strip()
    start = line.find(">") + 1
    stop = line.rfind("<")
    mld = float(line[start:stop].strip())

    return mld


def run_wavdiff(
    ref: np.ndarray, test: np.ndarray, fs, nchannels, ref_jbm_tf, test_jbm_tf
) -> float:
    mld_max = 0

    toolsdir = Path(__file__).parent.parent.joinpath("tools")

    curr_platform = platform.system()
    if curr_platform not in {"Windows", "Linux", "Darwin"}:
        raise NotImplementedError(f"wav-diff tool not available for {curr_platform}")

    search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
    wdiff = search_path.joinpath("wav-diff").with_suffix(
        ".exe" if curr_platform == "Windows" else ""
    )

    if not wdiff.exists():
        wdiff = shutil.which("wav-diff")
        if wdiff is None:
            raise FileNotFoundError(
                f"wav-diff tool not found in {search_path} or PATH!"
            )

    with tempfile.TemporaryDirectory() as tmpdir:
        tmpfile_ref = Path(tmpdir).joinpath("ref.wav")
        tmpfile_test = Path(tmpdir).joinpath("test.wav")

        ### need to resample to 48kHz for MLD computation to be correct
        ### write out and delete tmp variables to reduce memory usage
        ### write 24bit wav files to handle cases where scaling or resampling
        ### goes outside of 16bit range
        ref_tmp = ref.astype(np.int32)
        test_tmp = test.astype(np.int32)

        if fs != 48000:
            ref_tmp = np.clip(
                resample(ref.astype(float), fs, 48000), -32768, 32767
            ).astype(np.int32)
            test_tmp = np.clip(
                resample(test.astype(float), fs, 48000), -32768, 32767
            ).astype(np.int32)

        bytes_per_sample = 3
        with wave.open(str(tmpfile_ref), mode="wb") as ref_tmp_wav:
            data_bytes = ref_tmp.astype("<i").tobytes()
            ref_tmp_wav.setnchannels(nchannels)
            ref_tmp_wav.setsampwidth(bytes_per_sample)
            ref_tmp_wav.setframerate(fs)

            # only take first three bytes per 32 bit chunk
            data_bytes = bytes(
                itertools.chain.from_iterable(
                    [data_bytes[i : i + 3] for i in range(0, len(data_bytes), 4)]
                )
            )
            ref_tmp_wav.writeframes(data_bytes)

            # for i in range(0, len(data_bytes), 4):
            #     chunk = data_bytes[i : i + bytes_per_sample]
            #     ref_tmp_wav.writeframes(chunk)

        with wave.open(str(tmpfile_test), mode="wb") as test_tmp_wav:
            data_bytes = test_tmp.astype("<i").tobytes()
            test_tmp_wav.setnchannels(nchannels)
            test_tmp_wav.setsampwidth(bytes_per_sample)
            test_tmp_wav.setframerate(fs)

            # only take first three bytes per 32 bit chunk
            data_bytes = bytes(
                itertools.chain.from_iterable(
                    [data_bytes[i : i + 3] for i in range(0, len(data_bytes), 4)]
                )
            )
            test_tmp_wav.writeframes(data_bytes)

        del ref_tmp
        del test_tmp

        cmd = [
            str(wdiff),
            "--print-ctest-measurement",
            # wav-diff return code is 1 if differences are found which
            # would cause parse_wav_diff to raise an Exception on these cases
            "--no-fail",
            str(tmpfile_ref),
            str(tmpfile_test),
        ]
        if ref_jbm_tf and test_jbm_tf:
            cmd.extend(
                [
                    "--ref-jbm-trace",
                    str(ref_jbm_tf),
                    "--cut-jbm-trace",
                    str(test_jbm_tf),
                ]
            )
        proc = subprocess.run(cmd, capture_output=True, text=True)
        mld_max = parse_wav_diff(proc)

    return mld_max


def getdelay(x: np.ndarray, y: np.ndarray) -> int:
    """Get the delay between two audio signals

+4 −3
Original line number Diff line number Diff line
@@ -96,8 +96,9 @@ def cmp_pcm(

    for s1, s2 in zip(np.split(s1, split_idx), np.split(s2, split_idx)):
        # Apply scalefac if specified. Useful in case scaling has been applied on the input, and the inverse is scaling is supplied in scalefac.
        if scalefac != 1:
            s1 = np.round(s1 * scalefac, 0)  # Need rounding for max abs diff search
        # Need rounding for max abs diff search
        # This has the side-effect of chaning the dtype of the arrays to float
        s1 = np.round(s1 * scalefac, 0)
        s2 = np.round(s2 * scalefac, 0)

        cmp_result = pyaudio3dtools.audioarray.compare(