Commit 522f8de8 authored by Jan Kiene's avatar Jan Kiene
Browse files

Merge branch 'lower-memory-usage-of-ssnr-implementation' into 'main'

[BASOP-CI]Lower memory usage of ssnr implementation

See merge request !2119
parents 70ff4fea 34c9cab8
Loading
Loading
Loading
Loading
Loading
+26 −22
Original line number Original line Diff line number Diff line
@@ -352,7 +352,9 @@ def compare(
            )
            )


        search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
        search_path = toolsdir.joinpath(curr_platform.replace("Windows", "Win32"))
        wdiff = search_path.joinpath("wav-diff").with_suffix(".exe" if curr_platform == "Windows" else "")
        wdiff = search_path.joinpath("wav-diff").with_suffix(
            ".exe" if curr_platform == "Windows" else ""
        )


        if not wdiff.exists():
        if not wdiff.exists():
            wdiff = shutil.which("wav-diff")
            wdiff = shutil.which("wav-diff")
@@ -405,7 +407,6 @@ def compare(


        result["MLD"] = mld_max
        result["MLD"] = mld_max



    # Run remanining tests after checking if the lenght differs
    # Run remanining tests after checking if the lenght differs


    lengths_differ = ref.shape[0] != test.shape[0]
    lengths_differ = ref.shape[0] != test.shape[0]
@@ -622,7 +623,9 @@ def limiter(x: np.ndarray, fs: int):
        fr_sig[idx_min] = -32768
        fr_sig[idx_min] = -32768




def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
def get_framewise(
    x: np.ndarray, chunk_size: int, zero_pad=False, scale_fac=1.0
) -> np.ndarray:
    """Generator to yield a signal frame by frame
    """Generator to yield a signal frame by frame
        If array size is not a multiple of chunk_size, last frame contains the remainder
        If array size is not a multiple of chunk_size, last frame contains the remainder


@@ -634,6 +637,8 @@ def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
        Size of frames to yield
        Size of frames to yield
    zero_pad: bool
    zero_pad: bool
        Whether to zero pad the last chunk if there are not enough samples
        Whether to zero pad the last chunk if there are not enough samples
    scale_fac: float
        scale returned chunks with this factor


    Yields
    Yields
    -------
    -------
@@ -642,9 +647,9 @@ def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
    """
    """
    n_frames = x.shape[0] // chunk_size
    n_frames = x.shape[0] // chunk_size
    for i in range(n_frames):
    for i in range(n_frames):
        yield x[i * chunk_size : (i + 1) * chunk_size, :]
        yield x[i * chunk_size : (i + 1) * chunk_size, :] * scale_fac
    if x.shape[0] % chunk_size:
    if x.shape[0] % chunk_size:
        last_chunk = x[n_frames * chunk_size :, :]
        last_chunk = x[n_frames * chunk_size :, :] * scale_fac
        if zero_pad:
        if zero_pad:
            yield np.pad(
            yield np.pad(
                last_chunk, [[0, chunk_size - (x.shape[0] % chunk_size)], [0, 0]]
                last_chunk, [[0, chunk_size - (x.shape[0] % chunk_size)], [0, 0]]
@@ -678,29 +683,28 @@ def ssnr(
    """
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    """
    """
    ss = list()

    ref_sig_norm = ref_sig / -np.iinfo(np.int16).min
    test_sig_norm = test_sig / -np.iinfo(np.int16).min

    # check if diff of signal is zero already, then SNR is infinite, since no noise
    # check if diff of signal is zero already, then SNR is infinite, since no noise
    diff_sig_norm = ref_sig_norm - test_sig_norm
    signals_equal = (ref_sig == test_sig).all()
    if np.all(diff_sig_norm == 0):
    if signals_equal:
        return np.asarray([np.inf] * ref_sig_norm.shape[1])
        return np.asarray([np.inf] * ref_sig.shape[1])


    channels_identical_idx = np.sum(np.abs(diff_sig_norm), axis=0) == 0
    n_channels = ref_sig.shape[1]
    channels_identical_idx = np.asarray(
        [(ref_sig[:, c] == test_sig[:, c]).all() for c in range(n_channels)]
    )


    # iterate over test signal too to allow power comparison to threshold
    ss = list()
    denom_add = 10**-13 * len_seg
    denom_add = 10**-13 * len_seg
    segment_counter = np.zeros(ref_sig.shape[1])
    segment_counter = np.zeros(ref_sig.shape[1])

    # apply normalization factor on the chunks to avoid big reallocation of the whole signal
    # iterate over test signal too to allow power comparison to threshold
    norm_fac = 1 / -np.iinfo(np.int16).min
    for ref_seg, diff_seg, test_seg in zip(
    for ref_seg, test_seg in zip(
        get_framewise(ref_sig_norm, len_seg, zero_pad=True),
        get_framewise(ref_sig, len_seg, zero_pad=True, scale_fac=norm_fac),
        get_framewise(diff_sig_norm, len_seg, zero_pad=True),
        get_framewise(test_sig, len_seg, zero_pad=True, scale_fac=norm_fac),
        get_framewise(test_sig_norm, len_seg, zero_pad=True),
    ):
    ):
        nrg_ref = np.sum(ref_seg**2, axis=0)
        nrg_ref = np.sum(ref_seg**2, axis=0)
        nrg_diff = np.sum(diff_seg**2, axis=0)
        nrg_diff = np.sum((test_seg - ref_seg) ** 2, axis=0)


        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))
        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))