Commit cb832412 authored by Jan Kiene's avatar Jan Kiene
Browse files

add ctest measurement output for ssnr script

parent b5bb50a1
Loading
Loading
Loading
Loading
Loading
+34 −26
Original line number Diff line number Diff line
@@ -4,31 +4,6 @@ import pathlib
from pyaudio3dtools import audiofile, audioarray


def main(args):
    ref_sig, fs_ref = audiofile.readfile(args.ref_file)
    test_sig, fs_test = audiofile.readfile(args.test_file)

    if fs_ref != fs_test:
        print("Files need to have same sampling rate!")
        return -1

    len_seg = int(20 * fs_ref / 1000)
    print(len_seg, ref_sig.shape, test_sig.shape)
    ssnr = audioarray.ssnr(
        ref_sig,
        test_sig,
        len_seg,
        args.thresh_low,
        args.thresh_high,
        args.apply_thresholds_on_ref_only,
    )

    for i, s in enumerate(ssnr, start=1):
        print(f"Channel {i}: {s}")

    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("ref_file", type=pathlib.Path, help="Reference signal wav file")
@@ -56,7 +31,40 @@ if __name__ == "__main__":
        help="Use this to apply the thresholding on signal power to the reference signal only.\n"
        "This makes the implementation behaviour conform to the MPEG-D conformance spec.",
    )
    parser.add_argument(
        "--print-ctest-measurement",
        action="store_true",
        default=False,
        help="Print easy to parse single SSNR value",
    )

    args = parser.parse_args()

    sys.exit(main(args))
    ref_sig, fs_ref = audiofile.readfile(args.ref_file)
    test_sig, fs_test = audiofile.readfile(args.test_file)

    if fs_ref != fs_test:
        print("Files need to have same sampling rate!")
        sys.exit(1)

    len_seg = int(20 * fs_ref / 1000)
    print(len_seg, ref_sig.shape, test_sig.shape)
    ssnr = audioarray.ssnr(
        ref_sig,
        test_sig,
        len_seg,
        args.thresh_low,
        args.thresh_high,
        args.apply_thresholds_on_ref_only,
    )

    for i, s in enumerate(ssnr, start=1):
        print(f"Channel {i}: {s}")

    if args.print_ctest_measurement:
        min_ssnr = ssnr.min()
        print(
            f'<CTestMeasurement type="numeric/double" name="SSNR">{min_ssnr}</CTestMeasurement>'
        )

    sys.exit(0)