Commit 989053fa authored by Jan Kiene's avatar Jan Kiene
Browse files

remove input file padding and move validation after file collection

parent 74016ebc
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -41,7 +41,6 @@ from ivas_processing_scripts.constants import (
)
from ivas_processing_scripts.processing import chains, config
from ivas_processing_scripts.processing.processing import (
    multiple_of_frame_size,
    preprocess,
    preprocess_2,
    preprocess_background_noise,
@@ -127,9 +126,6 @@ def main(args):

        cfg.metadata_path = metadata

        # checking if audio is a multiple of frame size
        multiple_of_frame_size(cfg)

        # run preprocessing only once
        if hasattr(cfg, "preprocessing"):
            # save process info for background noise
+51 −1
Original line number Diff line number Diff line
@@ -43,6 +43,8 @@ from ivas_processing_scripts.processing.processing_splitting_scaling import (
    Processing_splitting_scaling,
)
from ivas_processing_scripts.utils import get_abs_path, list_audio
from ivas_processing_scripts.audiotools import audio
from ivas_processing_scripts.audiotools.audiofile import read


def init_processing_chains(cfg: TestConfig) -> None:
@@ -95,11 +97,14 @@ def init_processing_chains(cfg: TestConfig) -> None:
    cfg.items_list = list_audio(
        cfg.input_path, select_list=getattr(cfg, "input_select", None)
    )
    if not cfg.items_list:
    if len(cfg.items_list) == 0:
        raise SystemExit(
            f"Directory {cfg.input_path} does not exist, contains no audio files or all files were filtered out."
        )

    # validate input files for correct format and sampling rate
    validate_input_files(cfg)

    # assemble a list of output and temporary directories to create
    for chain in cfg.proc_chains:
        cfg.out_dirs.append(cfg.output_path.joinpath(chain["name"]))
@@ -508,3 +513,48 @@ def get_processing_chain(
    )

    return chain


def validate_input_files(cfg: TestConfig):
    """
    Go through list of input files and check whether they match the sampling rate and format
    (by checking number of channels) specified in the config and are aligned to the given
    input block size.
    """
    input_format = cfg.input["fmt"]
    num_chan_expected = audio.fromtype(input_format).num_channels
    check_block_aligned_to = 20
    for item in cfg.items_list:
        if "fs" in cfg.input:
            sampling_rate = cfg.input["fs"]
            x, fs = read(item, nchannels=num_chan_expected, fs=sampling_rate)
        elif item.suffix == ".pcm" or item.suffix == ".raw":
            raise ValueError("Sampling rate must be specified for headerless files!")
        elif item.suffix == ".wav":
            x, fs = read(item)
            sampling_rate = fs
        else:
            raise ValueError(f"Unsupported input file type {item.suffix}")
        n_samples_x, n_chan_x = x.shape

        # check for number of channels and sampling rate
        if fs != sampling_rate:
            raise ValueError(
                f"Sampling rate of the file ({fs}) does NOT match with that ({sampling_rate}) specified in the config yaml."
            )
        if n_chan_x != num_chan_expected:
            raise ValueError(
                f"The number of channels in the file ({n_chan_x}) do NOT match with those of format ({num_chan_expected}, {input_format}) specified in the config yaml."
            )

        if check_block_aligned_to > 0:
            frame_length_samples = (check_block_aligned_to / 1000) * fs
            if n_samples_x % frame_length_samples != 0:
                if input_format.startswith("ISM") or input_format.startswith("MASA"):
                    raise ValueError(
                        f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms) - not allowed for input formats with metadata."
                    )
                else:
                    warn(
                        f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms)."
                    )
 No newline at end of file
+1 −102
Original line number Diff line number Diff line
@@ -460,104 +460,3 @@ def preprocess_background_noise(cfg):
    ] = output_audio

    return
 No newline at end of file


def multiple_of_frame_size(
    cfg: TestConfig,
    frame_size_in_ms: Optional[int] = 20,
) -> np.ndarray:
    """
    This function checks if the list of multi channel audio files is a multiple of frame size.
    If the file isn't a multiple then the function pads it to the next integer of frame size and writes the file to an output directory.
    It also copies the already aligned files to the output directory.

    Parameters
    ----------
    cfg: TestConfig
        Input configuration
    frame_size_in_ms: Optional[int]
        Frame size in milliseconds; default = 20
    """
    # get the number of channels from the input format
    input_format = cfg.input["fmt"]
    num_channels = audio.fromtype(input_format).num_channels

    # Create output directory
    output_dir = cfg.output_path / "20ms_aligned_files"
    try:
        output_dir.mkdir(exist_ok=False)
    except FileExistsError:
        raise ValueError(
            "Folder for 20ms aligned files already exists. Please move or delete folder"
        )

    # iterate over input files
    for i, item in enumerate(cfg.items_list):
        # read the audio file
        if "fs" in cfg.input:
            sampling_rate = cfg.input["fs"]
            x, fs = read(item, nchannels=num_channels, fs=sampling_rate)
        elif item.suffix == ".pcm" or item.suffix == ".raw":
            raise ValueError("Sampling rate must be specified for headerless files!")
        elif item.suffix == ".wav":
            x, fs = read(item)
            sampling_rate = fs
        else:
            raise ValueError(f"Unsupported input file type {item.suffix}")
        n_samples_x, n_chan_x = x.shape

        # check for number of channels and sampling rate
        if fs != sampling_rate:
            raise ValueError(
                f"Sampling rate of the file ({fs}) does NOT match with that ({sampling_rate}) specified in the config yaml."
            )
        if n_chan_x != num_channels:
            raise ValueError(
                f"The number of channels in the file ({n_chan_x}) do NOT match with those of format ({num_channels}, {input_format}) specified in the config yaml."
            )

        # warn if audio length not a multiple of frame length
        frame_length_samples = (frame_size_in_ms / 1000) * fs
        remainder = n_samples_x % frame_length_samples
        if remainder != 0:
            # Calculate number of samples needed for padding
            padding_samples = int(frame_length_samples - remainder)

            if input_format.startswith("ISM") or input_format.startswith("MASA"):
                raise ValueError(
                    f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms)."
                )
            else:
                warn(
                    f"The length ({n_samples_x} samples) of audio ({item.name}) is not a multiple of frame length (20 ms). Padding to the nearest integer multiple."
                )

                # Create and append zeros
                padded_data = trim(x, sampling_rate, (0, -padding_samples), pad_noise=True, samples=True)
                # Write padded data to output directory
                write(output_dir / item.name, padded_data, fs)
        else:
            copyfile(item, output_dir / item.name)

        # Update audio file path in list
        cfg.items_list[i] = output_dir / item.name

        # Copy metadata and update path
        if input_format.startswith("ISM"):
            for j in range(int(cfg.input["fmt"][3])):
                copyfile(
                    cfg.metadata_path[i][j], output_dir / cfg.metadata_path[i][j].name
                )
                cfg.metadata_path[i][j] = output_dir / cfg.metadata_path[i][j].name
        elif input_format.startswith("MASA"):
            raise ValueError("MASA as input format not implemented yet")

    # Check if all files are present in output directory
    all_files_present = all(
        [(output_dir / audio_file.name).exists() for audio_file in cfg.items_list]
    )
    if not all_files_present:
        raise Exception("Not all files are present in the output directory")

    # Make the output path as the new input path
    cfg.input_path = output_dir