Commit d7d223a0 authored by Jan Kiene's avatar Jan Kiene
Browse files

extend for more than one channel

parent 46783e26
Loading
Loading
Loading
Loading
+20 −12
Original line number Diff line number Diff line
@@ -20,32 +20,36 @@ def get_ssnr(
    len_seg: int,
    thresh_low: float = -200,
    thresh_high: float = 0,
):
) -> np.ndarray:
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    """
    ss = list()

    denom_add = 10**-13 * len_seg
    n = 0
    segment_counter = np.zeros(ref_sig.shape[1])
    for ref_seg, test_seg in zip(
        audioarray.get_framewise(ref_sig, len_seg, zero_pad=True),
        audioarray.get_framewise(test_sig, len_seg, zero_pad=True),
    ):
        nrg_ref = np.sum(ref_seg**2)

        ref_power = 10 * np.log10(nrg_ref / len_seg)
        if ref_power < thresh_low or ref_power > thresh_high:
            continue
        nrg_ref = np.sum(ref_seg**2, axis=0)

        diff_seg = ref_seg - test_seg
        nrg_diff = np.sum(diff_seg**2)
        nrg_diff = np.sum(diff_seg**2, axis=0)

        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))

        # only sum up segments that fall inside the thresholds
        ref_power = 10 * np.log10(nrg_ref / len_seg)
        zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high)

        ss_seg[zero_mask] = 0
        # increase segment counter only for channels that were not zeroed
        segment_counter += np.logical_not(zero_mask)

        ss.append(ss_seg)
        n += 1

    ssnr = 10 * np.log10(10 ** (np.sum(ss) / n) - 1)
    ssnr = 10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1)
    return ssnr


@@ -65,7 +69,11 @@ def main(args):
    ssnr = get_ssnr(
        ref_sig, test_sig, len_seg, thresh_low=THRESH_LOW, thresh_high=THRESH_HIGH
    )
    print(ssnr)

    for i, s in enumerate(ssnr, start=1):
        print(f"Channel {i}: {s}")

    return 0


if __name__ == "__main__":
@@ -77,4 +85,4 @@ if __name__ == "__main__":

    args = parser.parse_args()

    main(args)
    sys.exit(main(args))