Commit ad177e0e authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

[opt] execute ITU STL filter in parallel for all channels instead of sequentially

parent e9484713
Loading
Loading
Loading
Loading
Loading
+50 −27
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@

import re
from copy import deepcopy
from itertools import repeat
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Optional
@@ -43,11 +44,39 @@ from ivas_processing_scripts.audiotools.audio import Audio, ChannelBasedAudio
from ivas_processing_scripts.audiotools.audioarray import delay_compensation, pad_delay
from ivas_processing_scripts.audiotools.audiofile import read, write
from ivas_processing_scripts.constants import DEFAULT_CONFIG_BINARIES
from ivas_processing_scripts.utils import find_binary, run
from ivas_processing_scripts.utils import apply_func_parallel, find_binary, run

FILTER_TYPES_REGEX = r"[\n][\s]{3}[A-Z0-9]\w+\s+"


def run_filter(
    cmd_base: list[str],
    audio: np.ndarray,
    fs: int,
    chan: int,
    skip_channel: list[int],
    block_size: Optional[int] = None,
):
    if chan in skip_channel:
        return audio

    with TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)
        tmp_in = tmp_dir.joinpath(f"tmpFilterOut_{chan}.pcm")
        tmp_out = tmp_dir.joinpath(f"tmpFilterIn_{chan}.pcm")

        cmd = cmd_base.copy()
        cmd.append(str(tmp_in))
        cmd.append(str(tmp_out))
        if block_size:
            cmd.append(str(block_size))

        write(tmp_in, audio, fs)
        run(cmd)
        out, _ = read(tmp_out, nchannels=1, fs=fs)
    return out


def filter_itu(
    input: Audio,
    flt_type: str,
@@ -155,32 +184,26 @@ def filter_itu(
        # normal filtering -> size remains
        output = np.zeros_like(input.audio)

    with TemporaryDirectory() as tmp_dir:
        tmp_dir = Path(tmp_dir)

        # process channels separately
        for channel in range(input.num_channels):
            if skip_channel and channel in skip_channel:
                output[:, channel] = input.audio[:, channel]
                continue

            cmd_in_out = cmd.copy()

            tmp_in = tmp_dir.joinpath(f"tmp_filterIn{channel}.pcm")
            tmp_out = tmp_dir.joinpath(f"tmp_filterOut{channel}.pcm")

            cmd_in_out.append(str(tmp_in))
            cmd_in_out.append(str(tmp_out))

            if block_size:
                cmd_in_out.append(str(block_size))

            write(tmp_in, input.audio[:, channel], input.fs)
    # make sure this is an empty list
    skip_channel = skip_channel or []

    filtered = apply_func_parallel(
        run_filter,
        zip(
            repeat(cmd),
            [input.audio[:, ch] for ch in range(input.num_channels)],
            repeat(input.fs),
            range(input.num_channels),
            repeat(skip_channel),
            repeat(block_size),
        ),
        show_progress=False,
    )

            run(cmd_in_out)
    assert len(filtered) == input.audio.shape[1] - len(skip_channel)

            a, _ = read(tmp_out, nchannels=1, fs=input.fs)
            output[:, channel][:, None] = a
    for ch, filt in enumerate(filtered):
        output[:, ch][:, None] = filt

    return output