Commit 2ac91783 authored by Jan Kiene's avatar Jan Kiene
Browse files

reduce memory usage of ssnr function

parent 4be705f6
Loading
Loading
Loading
Loading
Loading
+19 −21
Original line number Diff line number Diff line
@@ -623,7 +623,9 @@ def limiter(x: np.ndarray, fs: int):
        fr_sig[idx_min] = -32768


def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
def get_framewise(
    x: np.ndarray, chunk_size: int, zero_pad=False, scale_fac=1.0
) -> np.ndarray:
    """Generator to yield a signal frame by frame
        If array size is not a multiple of chunk_size, last frame contains the remainder

@@ -635,6 +637,8 @@ def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
        Size of frames to yield
    zero_pad: bool
        Whether to zero pad the last chunk if there are not enough samples
    scale_fac: float
        scale returned chunks with this factor

    Yields
    -------
@@ -643,9 +647,9 @@ def get_framewise(x: np.ndarray, chunk_size: int, zero_pad=False) -> np.ndarray:
    """
    n_frames = x.shape[0] // chunk_size
    for i in range(n_frames):
        yield x[i * chunk_size : (i + 1) * chunk_size, :]
        yield x[i * chunk_size : (i + 1) * chunk_size, :] * scale_fac
    if x.shape[0] % chunk_size:
        last_chunk = x[n_frames * chunk_size :, :]
        last_chunk = x[n_frames * chunk_size :, :] * scale_fac
        if zero_pad:
            yield np.pad(
                last_chunk, [[0, chunk_size - (x.shape[0] % chunk_size)], [0, 0]]
@@ -684,29 +688,23 @@ def ssnr(
    if signals_equal:
        return np.asarray([np.inf] * ref_sig.shape[1])

    ss = list()

    # allocation here
    ref_sig_norm = ref_sig / -np.iinfo(np.int16).min
    # allocation here
    test_sig_norm = test_sig / -np.iinfo(np.int16).min
    # allocation here
    diff_sig_norm = ref_sig_norm - test_sig_norm

    # allocation here
    channels_identical_idx = np.sum(np.abs(diff_sig_norm), axis=0) == 0
    n_channels = ref_sig.shape[1]
    channels_identical_idx = np.asarray(
        [(ref_sig[:, c] == test_sig[:, c]).all() for c in range(n_channels)]
    )

    # iterate over test signal too to allow power comparison to threshold
    ss = list()
    denom_add = 10**-13 * len_seg
    segment_counter = np.zeros(ref_sig.shape[1])

    # iterate over test signal too to allow power comparison to threshold
    for ref_seg, diff_seg, test_seg in zip(
        get_framewise(ref_sig_norm, len_seg, zero_pad=True),
        get_framewise(diff_sig_norm, len_seg, zero_pad=True),
        get_framewise(test_sig_norm, len_seg, zero_pad=True),
    # apply normalization factor on the chunks to avoid big reallocation of the whole signal
    norm_fac = 1 / -np.iinfo(np.int16).min
    for ref_seg, test_seg in zip(
        get_framewise(ref_sig, len_seg, zero_pad=True, scale_fac=norm_fac),
        get_framewise(test_sig, len_seg, zero_pad=True, scale_fac=norm_fac),
    ):
        nrg_ref = np.sum(ref_seg**2, axis=0)
        nrg_diff = np.sum(diff_seg**2, axis=0)
        nrg_diff = np.sum((test_seg - ref_seg) ** 2, axis=0)

        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))