Commit 6124103f authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

[fix] try to remove inner forks and switch to multiprocessing.Pool.starmap()...

[fix] try to remove inner forks and switch to multiprocessing.Pool.starmap() instead of ProcessPoolExecutor.submit()
parent e0a531a2
Loading
Loading
Loading
Loading
+16 −16
Original line number Diff line number Diff line
@@ -32,6 +32,7 @@

import logging
from itertools import repeat
from multiprocessing import Pool

from ivas_processing_scripts.audiotools.metadata import check_ISM_metadata
from ivas_processing_scripts.constants import (
@@ -47,7 +48,7 @@ from ivas_processing_scripts.processing.processing import (
    process_condition,
    reorder_items_list,
)
from ivas_processing_scripts.utils import DirManager, apply_func_parallel
from ivas_processing_scripts.utils import DirManager


def logging_init(args, cfg):
@@ -117,7 +118,7 @@ def main(args):
                metadata_str = []
                for o in range(len(metadata[i])):
                    metadata_str.append(str(metadata[i][o]))
                logger.info(
                logger.debug(
                    f"  ISM metadata files item {cfg.items_list[i]}: {', '.join(metadata_str)}"
                )

@@ -148,8 +149,9 @@ def main(args):
            preprocess_2(cfg, logger)

        # attempt to parallelise either items or conditions based on test setup
        parallelise_items = len(cfg.items_list) > len(cfg.proc_chains)
        apply_func_parallel(
        # parallelise_items = len(cfg.items_list) > len(cfg.proc_chains)
        with Pool() as p:
            p.starmap(
                process_condition,
                zip(
                    repeat(cfg),
@@ -157,10 +159,8 @@ def main(args):
                    cfg.tmp_dirs,
                    cfg.out_dirs,
                    repeat(logger),
                repeat(parallelise_items),
                ),
            None,
            "mp" if (not parallelise_items and cfg.multiprocessing) else None,
                chunksize=8,
            )

    # copy configuration to output directory
+4 −4
Original line number Diff line number Diff line
@@ -542,8 +542,8 @@ def binaural_fftconv_framewise(
            repeat(indices_HRIR),
        ),
        None,
        "mp",
        False,
        None,
        False
    )

    y = np.stack(result, axis=1)
@@ -607,8 +607,8 @@ def render_ear(
            repeat(N_HRIR_taps),
        ),
        None,
        "mt",
        False,
        None,
        False
    )

    return np.hstack(result)
+2 −2
Original line number Diff line number Diff line
@@ -135,8 +135,8 @@ def render_oba_to_binaural(
                repeat(SourcePosition),
            ),
            None,
            "mt",
            False,
            None,
            False
        )

        # sum results over all objects
+32 −42
Original line number Diff line number Diff line
@@ -33,6 +33,7 @@
import logging
from abc import ABC, abstractmethod
from itertools import repeat
from multiprocessing import Pool
from pathlib import Path
from shutil import copyfile
from typing import Iterable, Union
@@ -49,7 +50,7 @@ from ivas_processing_scripts.audiotools.metadata import (
)
from ivas_processing_scripts.constants import LOGGER_DATEFMT, LOGGER_FORMAT
from ivas_processing_scripts.processing.config import TestConfig
from ivas_processing_scripts.utils import apply_func_parallel, list_audio, pairwise
from ivas_processing_scripts.utils import list_audio, pairwise


class Processing(ABC):
@@ -235,7 +236,8 @@ def preprocess(cfg, logger):
    logger.info(f"  Generating condition: {preprocessing['name']}")

    # run preprocessing
    apply_func_parallel(
    with Pool() as p:
        p.starmap(
            process_item,
            zip(
                cfg.items_list,
@@ -245,8 +247,7 @@ def preprocess(cfg, logger):
                repeat(logger),
                cfg.metadata_path,
            ),
        None,
        "mp" if cfg.multiprocessing else None,
            chunksize=8,
        )

    # update the configuration to use preprocessing outputs as new inputs
@@ -288,7 +289,8 @@ def preprocess_2(cfg, logger):
        concat_setup(cfg, chain, logger)

    # run preprocessing 2
    apply_func_parallel(
    with Pool() as p:
        p.starmap(
            process_item,
            zip(
                cfg.items_list,
@@ -298,8 +300,7 @@ def preprocess_2(cfg, logger):
                repeat(logger),
                cfg.metadata_path,
            ),
        None,
        "mp" if cfg.multiprocessing else None,
            chunksize=8,
        )

    # update the configuration to use preprocessing 2 outputs as new inputs
@@ -408,14 +409,14 @@ def process_condition(
    tmp_dir: Union[str, Path],
    out_dir: Union[str, Path],
    logger: logging.Logger,
    parallelise_items: bool = False,
    # parallelise_items: bool = False,
):
    chain = condition["processes"]

    logger.info(f"  Generating condition: {condition['name']}")

    if parallelise_items:
        apply_func_parallel(
    with Pool() as p:
        p.starmap(
            process_item,
            zip(
                cfg.items_list,
@@ -425,18 +426,7 @@ def process_condition(
                repeat(logger),
                cfg.metadata_path,
            ),
            None,
            "mp" if cfg.multiprocessing else None,
        )
    else:
        for item, metadata in zip(cfg.items_list, cfg.metadata_path):
            process_item(
                item,
                tmp_dir,
                out_dir,
                chain,
                logger,
                metadata,
            chunksize=8,
        )