Commit eb50ffa6 authored by Jan Kiene's avatar Jan Kiene
Browse files

apply thresholding on both reference and test signal

parent 33f7971a
Loading
Loading
Loading
Loading
+30 −3
Original line number Diff line number Diff line
@@ -236,6 +236,7 @@ def compare(
    get_ssnr: bool = False,
    ssnr_thresh_low: float = -np.inf,
    ssnr_thresh_high: float = np.inf,
    apply_thresholds_to_ref_only: bool = False,
) -> dict:
    """Compare two audio arrays

@@ -251,6 +252,18 @@ def compare(
        Compute difference per frame (default True)
    get_mld: bool
        Run MLD tool if there is a difference between the signals (default False)
    get_ssnr: bool
        Compute Segmental SNR between signals
    ssnr_thresh_low: float
        Low threshold for including a segment in the SSNR computation. Per default, both
        reference and test signal power are compared to this threshold, see below
    ssnr_thresh_high: float
        High threshold for including a segment in the SSNR computation. Per default, both
        reference and test signal power are compared to this threshold, see below
    apply_thresholds_to_ref_only: bool
        Set to True to only apply the threshold comparison for the reference signal
        for whether to include a segment in the ssnr computation. Use this to align
        behaviour with the MPEG-D conformance specification.

    Returns
    -------
@@ -367,6 +380,7 @@ def compare(
                len_seg,
                thresh_low=ssnr_thresh_low,
                thresh_high=ssnr_thresh_high,
                apply_thresholds_to_ref_only=apply_thresholds_to_ref_only,
            )

    return result
@@ -547,6 +561,7 @@ def ssnr(
    len_seg: int,
    thresh_low: float = -200,
    thresh_high: float = 0,
    apply_thresholds_to_ref_only: bool = False,
) -> np.ndarray:
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
@@ -565,9 +580,12 @@ def ssnr(

    denom_add = 10**-13 * len_seg
    segment_counter = np.zeros(ref_sig.shape[1])
    for ref_seg, diff_seg in zip(

    # iterate over test signal too to allow power comparison to threshold
    for ref_seg, diff_seg, test_seg in zip(
        get_framewise(ref_sig_norm, len_seg, zero_pad=True),
        get_framewise(diff_sig_norm, len_seg, zero_pad=True),
        get_framewise(test_sig_norm, len_seg, zero_pad=True),
    ):
        nrg_ref = np.sum(ref_seg**2, axis=0)
        nrg_diff = np.sum(diff_seg**2, axis=0)
@@ -579,6 +597,15 @@ def ssnr(
        ref_power = 10 * np.log10((nrg_ref + 10**-7) / len_seg)
        zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high)

        # create same mask for test signal
        if not apply_thresholds_to_ref_only:
            nrg_test = np.sum(test_seg**2, axis=0)
            test_power = 10 * np.log10((nrg_test + 10**-7) / len_seg)
            zero_mask_test = np.logical_or(
                test_power < thresh_low, test_power > thresh_high
            )
            zero_mask = np.logical_or(zero_mask, zero_mask_test)

        ss_seg[zero_mask] = 0
        # increase segment counter only for channels that were not zeroed
        segment_counter += np.logical_not(zero_mask)
@@ -587,12 +614,12 @@ def ssnr(

    # if the reference signal was outside the thresholds for all segments in a channel, segment_counter is zero
    # for that channel and the division here would trigger a warning. We supress the warning and later
    # set the SSNR for those channels to -inf manually instead (overwriting later is simply easier than adding ifs here)
    # set the SSNR for those channels to nan manually instead (overwriting later is simply easier than adding ifs here)
    with warnings.catch_warnings():
        ssnr = np.round(
            10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1), 2
        )
    ssnr[segment_counter == 0] = -np.inf
    ssnr[segment_counter == 0] = np.nan

    # this prevents all-zero channels in both signals to be reported as -inf
    ssnr[channels_identical_idx] = np.inf
+17 −1
Original line number Diff line number Diff line
@@ -15,7 +15,12 @@ def main(args):
    len_seg = int(20 * fs_ref / 1000)
    print(len_seg, ref_sig.shape, test_sig.shape)
    ssnr = audioarray.ssnr(
        ref_sig, test_sig, len_seg, args.thresh_low, args.thresh_high
        ref_sig,
        test_sig,
        len_seg,
        args.thresh_low,
        args.thresh_high,
        args.apply_thresholds_on_ref_only,
    )

    for i, s in enumerate(ssnr, start=1):
@@ -34,11 +39,22 @@ if __name__ == "__main__":
        "--thresh_low",
        type=float,
        default="-inf",
        help="Low threshold for signal power in a segment to be used in the SSNR calculation (default: -inf).\n"
        "Applied to both signals per default (see apply_thresholds_on_ref_only argument).",
    )
    parser.add_argument(
        "--thresh_high",
        type=float,
        default="inf",
        help="High threshold for signal power in a segment to be used in the SSNR calculation (default: +inf).\n"
        "Applied to both signals per default (see apply_thresholds_on_ref_only argument).",
    )
    parser.add_argument(
        "--apply_thresholds_on_ref_only",
        action="store_true",
        default=False,
        help="Use this to apply the thresholding on signal power to the reference signal only.\n"
        "This makes the implementation behaviour conform to the MPEG-D conformance spec.",
    )

    args = parser.parse_args()