Commit 46783e26 authored by Jan Kiene's avatar Jan Kiene
Browse files

add ssnr script

parent e433c88c
Loading
Loading
Loading
Loading

tests/ssnr.py

0 → 100644
+80 −0
Original line number Diff line number Diff line
import argparse
import sys
import pathlib
import numpy as np

HERE = pathlib.Path(__file__).parent.absolute()
SCRIPT_DIR = HERE.parent.joinpath("scripts")

sys.path.append(str(SCRIPT_DIR))
from pyaudio3dtools import audiofile, audioarray


THRESH_LOW = -50
THRESH_HIGH = -15


def get_ssnr(
    ref_sig: np.ndarray,
    test_sig: np.ndarray,
    len_seg: int,
    thresh_low: float = -200,
    thresh_high: float = 0,
):
    """
    Calculate Segmental SNR for test_sig to ref_sig as defined in ISO/IEC 14496-4
    """
    ss = list()

    denom_add = 10**-13 * len_seg
    n = 0
    for ref_seg, test_seg in zip(
        audioarray.get_framewise(ref_sig, len_seg, zero_pad=True),
        audioarray.get_framewise(test_sig, len_seg, zero_pad=True),
    ):
        nrg_ref = np.sum(ref_seg**2)

        ref_power = 10 * np.log10(nrg_ref / len_seg)
        if ref_power < thresh_low or ref_power > thresh_high:
            continue

        diff_seg = ref_seg - test_seg
        nrg_diff = np.sum(diff_seg**2)

        ss_seg = np.log10(1 + nrg_ref / (denom_add + nrg_diff))
        ss.append(ss_seg)
        n += 1

    ssnr = 10 * np.log10(10 ** (np.sum(ss) / n) - 1)
    return ssnr


def main(args):
    ref_sig, fs_ref = audiofile.readfile(args.ref_file)
    test_sig, fs_test = audiofile.readfile(args.test_file)

    if fs_ref != fs_test:
        print("Files need to have same sampling rate!")
        return -1

    # normalize 16Bit wav signals to range of [-1, 1]
    ref_sig /= -np.iinfo(np.int16).min
    test_sig /= -np.iinfo(np.int16).min

    len_seg = int(20 * fs_ref / 1000)
    ssnr = get_ssnr(
        ref_sig, test_sig, len_seg, thresh_low=THRESH_LOW, thresh_high=THRESH_HIGH
    )
    print(ssnr)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("ref_file", type=pathlib.Path, help="Reference signal wav file")
    parser.add_argument(
        "test_file", type=pathlib.Path, help="Signal under test wav file"
    )

    args = parser.parse_args()

    main(args)