Loading tests/ssnr.py +20 −12 Original line number Diff line number Diff line Loading @@ -20,32 +20,36 @@ def get_ssnr( len_seg: int, thresh_low: float = -200, thresh_high: float = 0, ): ) -> np.ndarray: """ Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4 """ ss = list() denom_add = 10**-13 * len_seg n = 0 segment_counter = np.zeros(ref_sig.shape[1]) for ref_seg, test_seg in zip( audioarray.get_framewise(ref_sig, len_seg, zero_pad=True), audioarray.get_framewise(test_sig, len_seg, zero_pad=True), ): nrg_ref = np.sum(ref_seg**2) ref_power = 10 * np.log10(nrg_ref / len_seg) if ref_power < thresh_low or ref_power > thresh_high: continue nrg_ref = np.sum(ref_seg**2, axis=0) diff_seg = ref_seg - test_seg nrg_diff = np.sum(diff_seg**2) nrg_diff = np.sum(diff_seg**2, axis=0) ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff)) # only sum up segments that fall inside the thresholds ref_power = 10 * np.log10(nrg_ref / len_seg) zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high) ss_seg[zero_mask] = 0 # increase segment counter only for channels that were not zeroed segment_counter += np.logical_not(zero_mask) ss.append(ss_seg) n += 1 ssnr = 10 * np.log10(10 ** (np.sum(ss) / n) - 1) ssnr = 10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1) return ssnr Loading @@ -65,7 +69,11 @@ def main(args): ssnr = get_ssnr( ref_sig, test_sig, len_seg, thresh_low=THRESH_LOW, thresh_high=THRESH_HIGH ) print(ssnr) for i, s in enumerate(ssnr, start=1): print(f"Channel {i}: {s}") return 0 if __name__ == "__main__": Loading @@ -77,4 +85,4 @@ if __name__ == "__main__": args = parser.parse_args() main(args) sys.exit(main(args)) Loading
tests/ssnr.py +20 −12 Original line number Diff line number Diff line Loading @@ -20,32 +20,36 @@ def get_ssnr( len_seg: int, thresh_low: float = -200, thresh_high: float = 0, ): ) -> np.ndarray: """ Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4 """ ss = list() denom_add = 10**-13 * len_seg n = 0 segment_counter = np.zeros(ref_sig.shape[1]) for ref_seg, test_seg in zip( audioarray.get_framewise(ref_sig, len_seg, zero_pad=True), audioarray.get_framewise(test_sig, len_seg, zero_pad=True), ): nrg_ref = np.sum(ref_seg**2) ref_power = 10 * np.log10(nrg_ref / len_seg) if ref_power < thresh_low or ref_power > thresh_high: continue nrg_ref = np.sum(ref_seg**2, axis=0) diff_seg = ref_seg - test_seg nrg_diff = np.sum(diff_seg**2) nrg_diff = np.sum(diff_seg**2, axis=0) ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff)) # only sum up segments that fall inside the thresholds ref_power = 10 * np.log10(nrg_ref / len_seg) zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high) ss_seg[zero_mask] = 0 # increase segment counter only for channels that were not zeroed segment_counter += np.logical_not(zero_mask) ss.append(ss_seg) n += 1 ssnr = 10 * np.log10(10 ** (np.sum(ss) / n) - 1) ssnr = 10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1) return ssnr Loading @@ -65,7 +69,11 @@ def main(args): ssnr = get_ssnr( ref_sig, test_sig, len_seg, thresh_low=THRESH_LOW, thresh_high=THRESH_HIGH ) print(ssnr) for i, s in enumerate(ssnr, start=1): print(f"Channel {i}: {s}") return 0 if __name__ == "__main__": Loading @@ -77,4 +85,4 @@ if __name__ == "__main__": args = parser.parse_args() main(args) sys.exit(main(args))