Commit a7b7c462 authored by Jan Kiene's avatar Jan Kiene
Browse files

move ssnr script and functionality to scripts folder

parent 1ec14d37
Loading
Loading
Loading
Loading
+40 −0
Original line number Diff line number Diff line
@@ -513,3 +513,43 @@ def process_async(files: Iterable, func: Callable, **kwargs):
    for r in results:
        r.get()
    return results


def ssnr(
    ref_sig: np.ndarray,
    test_sig: np.ndarray,
    len_seg: int,
    thresh_low: float = -200,
    thresh_high: float = 0,
) -> np.ndarray:
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    """
    ss = list()

    denom_add = 10**-13 * len_seg
    segment_counter = np.zeros(ref_sig.shape[1])
    for ref_seg, test_seg in zip(
        get_framewise(ref_sig, len_seg, zero_pad=True),
        get_framewise(test_sig, len_seg, zero_pad=True),
    ):
        nrg_ref = np.sum(ref_seg**2, axis=0)

        diff_seg = ref_seg - test_seg
        nrg_diff = np.sum(diff_seg**2, axis=0)

        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))

        # only sum up segments that fall inside the thresholds
        # add small eps to nrg_ref to prevent RuntimeWarnings from numpy
        ref_power = 10 * np.log10((nrg_ref + 10**-7) / len_seg)
        zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high)

        ss_seg[zero_mask] = 0
        # increase segment counter only for channels that were not zeroed
        segment_counter += np.logical_not(zero_mask)

        ss.append(ss_seg)

    ssnr = 10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1)
    return ssnr
+44 −0
Original line number Diff line number Diff line
@@ -2,11 +2,6 @@ import argparse
import sys
import pathlib
import numpy as np

HERE = pathlib.Path(__file__).parent.absolute()
SCRIPT_DIR = HERE.parent.joinpath("scripts")

sys.path.append(str(SCRIPT_DIR))
from pyaudio3dtools import audiofile, audioarray


@@ -14,46 +9,6 @@ THRESH_LOW = -50
THRESH_HIGH = -15


def get_ssnr(
    ref_sig: np.ndarray,
    test_sig: np.ndarray,
    len_seg: int,
    thresh_low: float = -200,
    thresh_high: float = 0,
) -> np.ndarray:
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    """
    ss = list()

    denom_add = 10**-13 * len_seg
    segment_counter = np.zeros(ref_sig.shape[1])
    for ref_seg, test_seg in zip(
        audioarray.get_framewise(ref_sig, len_seg, zero_pad=True),
        audioarray.get_framewise(test_sig, len_seg, zero_pad=True),
    ):
        nrg_ref = np.sum(ref_seg**2, axis=0)

        diff_seg = ref_seg - test_seg
        nrg_diff = np.sum(diff_seg**2, axis=0)

        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))

        # only sum up segments that fall inside the thresholds
        # add small eps to nrg_ref to prevent RuntimeWarnings from numpy
        ref_power = 10 * np.log10((nrg_ref + 10**-7) / len_seg)
        zero_mask = np.logical_or(ref_power < thresh_low, ref_power > thresh_high)

        ss_seg[zero_mask] = 0
        # increase segment counter only for channels that were not zeroed
        segment_counter += np.logical_not(zero_mask)

        ss.append(ss_seg)

    ssnr = 10 * np.log10(10 ** (np.sum(ss, axis=0) / segment_counter) - 1)
    return ssnr


def main(args):
    ref_sig, fs_ref = audiofile.readfile(args.ref_file)
    test_sig, fs_test = audiofile.readfile(args.test_file)
@@ -67,7 +22,7 @@ def main(args):
    test_sig /= -np.iinfo(np.int16).min

    len_seg = int(20 * fs_ref / 1000)
    ssnr = get_ssnr(
    ssnr = audioarray.ssnr(
        ref_sig, test_sig, len_seg, thresh_low=THRESH_LOW, thresh_high=THRESH_HIGH
    )