Commit 9d43d536 authored by Jan Kiene's avatar Jan Kiene
Browse files

fix for no diff, but out of threshold cases

parent 5ee6c325
Loading
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -234,6 +234,8 @@ def compare(
    per_frame: bool = True,
    get_mld: bool = False,
    get_ssnr: bool = False,
    ssnr_thresh_low: float = -np.inf,
    ssnr_thresh_high: float = np.inf,
) -> dict:
    """Compare two audio arrays

@@ -353,7 +355,7 @@ def compare(
            # length of segment is always 20ms
            len_seg = int(0.02 * fs)
            print(len_seg, ref.shape, test.shape)
            result["SSNR"] = ssnr(ref, test, len_seg, thresh_low=-50, thresh_high=-15)
            result["SSNR"] = ssnr(ref, test, len_seg, thresh_low=ssnr_thresh_low, thresh_high=ssnr_thresh_high)

    return result

@@ -542,6 +544,7 @@ def ssnr(
    ref_sig_norm = ref_sig / -np.iinfo(np.int16).min
    test_sig_norm = test_sig / -np.iinfo(np.int16).min

    # check if diff of signal is zero already, then SNR is infinite, since no noise
    diff_norm = ref_sig_norm - test_sig_norm
    if np.all(diff_norm == 0):
        return np.asarray([np.inf] * ref_sig_norm.shape[1])
@@ -572,9 +575,13 @@ def ssnr(

    # if the reference signal was outside the thresholds for all segments in a channel, segment_counter is zero
    # for that channel and the division here would trigger a warning. We supress the warning and later
    # set the SSNr for those channels to -inf manually instead (overwriting later is simply easier than adding ifs here)
    # set the SSNR for those channels to -inf manually instead (overwriting later is simply easier than adding ifs here)
    with warnings.catch_warnings():
        ssnr = np.round(10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1), 2)
    ssnr[segment_counter == 0] = -np.inf

    # set to zero for channels with no diff (this handles e.g. the corner-case of an all-zero channel in both ref and dut)
    zero_diff_mask = np.asarray([np.all(diff_norm[:, c] == 0) for c in range(ref_sig.shape[1])])
    ssnr[zero_diff_mask] = 0

    return ssnr
+7 −1
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ def main(args):

    len_seg = int(20 * fs_ref / 1000)
    print(len_seg, ref_sig.shape, test_sig.shape)
    ssnr = audioarray.ssnr(ref_sig, test_sig, len_seg, thresh_low=-50, thresh_high=-15)
    ssnr = audioarray.ssnr(ref_sig, test_sig, len_seg, args.thresh_low, args.thresh_high)

    for i, s in enumerate(ssnr, start=1):
        print(f"Channel {i}: {s}")
@@ -28,6 +28,12 @@ if __name__ == "__main__":
    parser.add_argument(
        "test_file", type=pathlib.Path, help="Signal under test wav file"
    )
    parser.add_argument(
            "--thresh_low", type=float, default="-inf",
    )
    parser.add_argument(
            "--thresh_high", type=float, default="inf",
            )

    args = parser.parse_args()