Commit c6f081c2 authored by Anika Treffehn's avatar Anika Treffehn
Browse files

Merge branch 'warn-and-pad-if-audio-length-not-a-multiple-of-frame-size' into 'main'

Warn and pad if audio length not a multiple of frame size

See merge request !37
parents 9f5f7bd5 50ef8834
Loading
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -43,6 +43,7 @@ from ivas_processing_scripts.constants import (
)
from ivas_processing_scripts.processing import chains, config
from ivas_processing_scripts.processing.processing import (
    multiple_of_frame_size,
    preprocess,
    preprocess_2,
    preprocess_background_noise,
@@ -126,6 +127,9 @@ def main(args):

        cfg.metadata_path = metadata

        # checking if audio is a multiple of frame size
        multiple_of_frame_size(cfg)

        # run preprocessing only once
        if hasattr(cfg, "preprocessing"):
            # save process info for background noise
+2 −2
Original line number Diff line number Diff line
@@ -30,10 +30,10 @@
#  the United Nations Convention on Contracts on the International Sales of Goods.
#

import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional, Union
from warnings import warn

import numpy as np

@@ -274,7 +274,7 @@ class ObjectBasedAudio(Audio):
                    obj.metadata_files.append(file_name_meta)
                else:
                    raise ValueError(f"Metadata file {file_name_meta} not found.")
            warnings.warn(
            warn(
                f"No metadata files specified: The following files were found and used: \n {*obj.metadata_files,}"
            )

+2 −2
Original line number Diff line number Diff line
@@ -31,8 +31,8 @@
#

import logging
import warnings
from typing import Iterator, Optional, Tuple, Union
from warnings import warn

import numpy as np
import scipy.signal as sig
@@ -342,7 +342,7 @@ def limiter(
        fr_sig[idx_min] = -32768

    if limited:
        warnings.warn("Limiting had to be applied")
        warn("Limiting had to be applied")
    return x


+4 −4
Original line number Diff line number Diff line
@@ -30,9 +30,9 @@
#  the United Nations Convention on Contracts on the International Sales of Goods.
#

import warnings
from pathlib import Path
from typing import Optional, Tuple, Union
from warnings import warn

import numpy as np
from scipy.io import loadmat
@@ -149,7 +149,7 @@ def load_ir(
                    )
                ).is_file():
                    dataset_suffix = "SBA3"
                    warnings.warn("No SBA1 dataset found -> use truncated SBA3 dataset")
                    warn("No SBA1 dataset found -> use truncated SBA3 dataset")
            elif in_fmt.endswith("2"):
                dataset_suffix = "SBA2"
                # Use truncated SBA3 dataset if no SBA1 or 2 dataset exists
@@ -159,7 +159,7 @@ def load_ir(
                    )
                ).is_file():
                    dataset_suffix = "SBA3"
                    warnings.warn("No SBA2 dataset found -> use truncated SBA3 dataset")
                    warn("No SBA2 dataset found -> use truncated SBA3 dataset")
            else:
                dataset_suffix = "SBA3"

@@ -172,7 +172,7 @@ def load_ir(
        latency_smp = latency_s
    else:
        latency_smp = int(np.min(np.argmax(np.sum(np.abs(IR), axis=1), axis=0)))
        warnings.warn(
        warn(
            f"No latency of HRTF dataset specified in {path_dataset} file -> computed latency: {latency_smp} sample(s)"
        )

+65 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ from abc import ABC, abstractmethod
from itertools import repeat
from pathlib import Path
from shutil import copyfile
from typing import Iterable, Union
from typing import Iterable, Optional, Union
from warnings import warn

import numpy as np
@@ -74,6 +74,21 @@ class Processing(ABC):


def reorder_items_list(items_list: list, concatenation_order: list) -> list:
    """
    Reorder input items list based on conactenation order

    Parameters
    ----------
    items_list: list
        List of input items
    concatenation_order: list
        Concatenation order

    Returns
    -------
    ordered_full_files: list
        Re-ordered list of input items
    """
    name_to_full = {Path(full_file).name: full_file for full_file in items_list}
    ordered_full_files = [
        name_to_full[name] for name in concatenation_order if name in name_to_full
@@ -496,3 +511,52 @@ def preprocess_background_noise(cfg):
    ] = output_audio

    return


def multiple_of_frame_size(
    cfg: TestConfig,
    frame_size_in_ms: Optional[int] = 20,
) -> np.ndarray:
    """
    Warn/Exit if audio if it isn't a multiple of frame size

    Parameters
    ----------
    cfg: TestConfig
        Input configuration
    frame_size_in_ms: Optional[int]
        Frame size in milliseconds; default = 20
    """
    # get the number of channels from the input format
    input_format = cfg.input["fmt"]
    num_channels = audio.fromtype(input_format).num_channels
    for item in cfg.items_list:
        # read the audio file
        if "fs" in cfg.input:
            sampling_rate = cfg.input["fs"]
            x, fs = read(item, nchannels=num_channels, fs=sampling_rate)
        elif item.suffix == ".pcm" or item.suffix == ".raw":
            raise ValueError("Sampling rate must be specified for headerless files!")
        elif item.suffix == ".wav":
            x, fs = read(item)
            sampling_rate = fs
        n_samples_x, n_chan_x = x.shape
        if fs != sampling_rate:
            raise ValueError(
                f"Sampling rate of the file ({fs}) does NOT match with that ({sampling_rate}) specified in the config yaml."
            )
        if n_chan_x != num_channels:
            raise ValueError(
                f"The number of channels in the file ({n_chan_x}) do NOT match with those of format ({num_channels}, {input_format}) specified in the config yaml."
            )
        # warn if audio length not a multiple of frame length
        frame_length_samples = (frame_size_in_ms / 1000) * fs
        if n_samples_x % frame_length_samples != 0:
            if input_format.startswith("ISM") or input_format.startswith("MASA"):
                raise ValueError(
                    f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms)."
                )
            else:
                warn(
                    f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms)."
                )