Commit d78a847d authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

add a script for a sanity check of conditions

parent cf2396d2
Loading
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -120,13 +120,17 @@ def create_experiment_setup(experiment, lab) -> list[Path]:
        if experiment in EXPERIMENTS_P800:
            input_path = input_path.joinpath(cat)
            output_path = output_path.joinpath(cat)
            cfg_path = default_cfg_path.parent.joinpath(f"{experiment}-{cat}-lab_{lab}.yml")
            cfg_path = default_cfg_path.parent.joinpath(
                f"{experiment}-{cat}-lab_{lab}.yml"
            )
        # this is for catching the Mushra MASA tests
        elif experiment in IN_FMT_FOR_MASA_EXPS:
            fmt = fmt_for_category[cat]
            input_path = input_path.joinpath(fmt)
            output_path = output_path.joinpath(fmt)
            cfg_path = default_cfg_path.parent.joinpath(f"{experiment}-{fmt}-lab_{lab}.yml")
            cfg_path = default_cfg_path.parent.joinpath(
                f"{experiment}-{fmt}-lab_{lab}.yml"
            )
        elif experiment in EXPERIMENTS_BS1534:
            cfg_path = default_cfg_path.parent.joinpath(f"{experiment}-lab_{lab}.yml")

+180 −0
Original line number Diff line number Diff line
#!/usr/bin/python3
import argparse
import multiprocessing as mp
import re
import sys
from pathlib import Path
from time import sleep
from typing import Tuple
from warnings import catch_warnings, warn

import numpy as np

sys.path.append(str(Path(__file__).parent.parent))
from ivas_processing_scripts.audiotools.audioarray import getdelay
from ivas_processing_scripts.audiotools.audiofile import read
from ivas_processing_scripts.utils import progressbar_update, spinner


def compare_audio_arrays(
    left: np.ndarray, left_fs: int, right: np.ndarray, right_fs: int
) -> Tuple[float, float, float]:
    if left_fs != right_fs:
        return ValueError(f"Differing samplerates: {left_fs} vs {right_fs}!")

    if left.shape[1] != right.shape[1]:
        cmp_ch = min(left.shape[1], right.shape[1])
        warn(
            f"Differing number of channels: {left.shape[1]} vs {right.shape[1]}! Comparing first {cmp_ch} channel(s)",
        )
        left = left[:, :cmp_ch]
        right = right[:, :cmp_ch]

    if left.shape[0] != right.shape[0]:
        cmp_smp = min(left.shape[0], right.shape[0])
        warn(
            f"Different durations: {left.shape[0] / left_fs:.2f}s vs {right.shape[0] / right_fs:.2f}s! Comparing first {cmp_smp} sample(s)",
        )
        left = left[:cmp_smp, :]
        right = right[:cmp_smp, :]

    if not np.array_equal(left, right):
        delay = getdelay(left, right)
        delay_abs = np.abs(delay)
        # getdelay can return large values if signals are quite different
        # limit any delay compensation to 20 ms
        if delay_abs > 1 and (delay_abs < left_fs / 50):
            warn(
                f"File B is delayed by {delay} samples ({delay*1000 / left_fs : .2f}ms)!",
            )

            # shift array
            left = np.roll(left, delay, axis=0)

            # zero shifted out samples
            if delay < 0:
                left[-np.abs(delay) :, :] = 0
            elif delay > 0:
                left[: np.abs(delay), :] = 0
        """
        http://www-mmsp.ece.mcgill.ca/Documents/Software/Packages/AFsp/AFsp/CompAudio.html
        """
        num = np.sum(left * right)
        den = np.sqrt(np.sum(left**2) * np.sum(right**2))
        if den > 0:
            r = num / den
        else:
            r = np.inf
        snr = 10 * np.log10(1 / (1 - (r**2)))
        gain_b = num / np.sum(right**2)
        max_diff = np.abs(np.max(left - right))
    else:
        snr = np.inf
        gain_b = 1
        max_diff = 0

    return snr, gain_b, max_diff


def compare_audio_arrays_wrap(ref_file: Path, cut_file: Path):
    ref, ref_fs = read(ref_file)
    cut, cut_fs = read(cut_file)

    with catch_warnings(record=True) as warnings_list:
        snr, gain_b, max_diff = compare_audio_arrays(ref, ref_fs, cut, cut_fs)

        if np.isnan(snr) or gain_b == 0:
            raise ValueError(f"Invalid signals! Check {ref_file} and {cut_file}!")

        for w in warnings_list:
            print(f"\r{cut_file.stem} : {w.message}", flush=True)


def get_common_files(ref_dir: Path, cut_dir: Path):
    # list REF files
    ref_files = sorted(list(ref_dir.glob("*.wav")))
    if not ref_files:
        raise FileNotFoundError(
            f"Reference directory {ref_dir} contains no .WAV files!"
        )

    # list CUT files
    cut_files = sorted(list(cut_dir.glob("*.wav")))
    if not cut_files:
        raise FileNotFoundError(
            f"Condition directory {cut_dir} contains no .WAV files!"
        )

    ref_suffix = f".{ref_dir.name}"
    cut_suffix = f".{cut_dir.name}"

    # strip .cXX suffix
    ref_filenames = set([str(f.stem).replace(ref_suffix, "") for f in ref_files])
    cut_filenames = set([str(f.stem).replace(cut_suffix, "") for f in cut_files])

    common_files = ref_filenames.intersection(cut_filenames)
    diff_files = ref_filenames.symmetric_difference(cut_filenames)

    if diff_files:
        warn(f"Directories differ! Unique files found: {diff_files}")

    if not common_files:
        raise FileNotFoundError("No common .WAV files found!")

    common_files_ref = [ref_dir.joinpath(f"{f}{ref_suffix}.wav") for f in common_files]
    common_files_cut = [cut_dir.joinpath(f"{f}{cut_suffix}.wav") for f in common_files]

    return common_files_ref, common_files_cut


def compare_dirs(ref_dir: Path, cut_dir: Path):
    print(80 * "-")
    print(f"Comparing REF {ref_dir.name} with CUT {cut_dir.name}")
    print(80 * "-")
    print()

    ref_files, cut_files = get_common_files(ref_dir, cut_dir)

    count = len(ref_files)
    width = 80

    with mp.Pool() as p:
        results = p.starmap_async(compare_audio_arrays_wrap, zip(ref_files, cut_files))

        progressbar_update(0, count, width)
        while not results.ready():
            progressbar_update(count - int(results._number_left), count, width)
            spinner()
            sleep(0.1)
        progressbar_update(count, count, width)
        print("\n", flush=True, file=sys.stdout)
        results.get()


def main(args):
    condition_dirs = list()
    for p in args.test_dir.iterdir():
        if p.is_dir() and re.search(r"/c\d\d$", str(p)):
            condition_dirs.append(p)
    condition_dirs = sorted(condition_dirs)

    if not condition_dirs:
        raise FileNotFoundError(
            f"No condition directories with the cXX prefix found in {args.test_dir}!"
        )

    ref_dir = condition_dirs[0]
    condition_dirs = condition_dirs[1:]
    for cut_dir in condition_dirs:
        compare_dirs(ref_dir, cut_dir)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Script perform sanity checks on a listening test output directory"
    )
    parser.add_argument("test_dir", help="Test directory with cXX directories", type=Path)

    args = parser.parse_args()

    main(args)
+1 −1

File changed.

Contains only whitespace changes.