Commit fb531a80 authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

[fix] add a progressbar and dispatch all jobs to the same pool

parent 2548b402
Loading
Loading
Loading
Loading
+3 −8
Original line number Diff line number Diff line
@@ -31,11 +31,11 @@
#

import argparse
from multiprocessing import Pool
from pathlib import Path

from ivas_processing_scripts import config
from ivas_processing_scripts import main as generate_test
from ivas_processing_scripts.utils import apply_func_parallel

HERE = Path(__file__).parent.absolute()
EXPERIMENTS_P800 = [f"P800-{i}" for i in range(1, 10)]
@@ -64,12 +64,7 @@ def generate_tests(exp_lab_pairs, run_parallel=True, create_cfg_only=False):
        return

    args = [Arguments(str(cfg)) for cfg in cfgs]
    if run_parallel:
        with Pool() as p:
            p.starmap(generate_test, zip(args), chunksize=8)
    else:
        map(generate_test, args)
    # apply_func_parallel(generate_test, zip(args), type="mp" if run_parallel else None)
    apply_func_parallel(generate_test, zip(args), None)


class Arguments:
@@ -77,7 +72,7 @@ class Arguments:
        self.config = config
        self.debug = False
        # used to overwrite the multiprocessing key in the configs to rather parallelize on category level
        self.multiprocessing = False
        self.multiprocessing = True


def create_experiment_setup(experiment, lab) -> list[Path]:
+40 −16
Original line number Diff line number Diff line
@@ -31,8 +31,9 @@
#

import logging
from itertools import repeat
from itertools import product
from multiprocessing import Pool
from time import sleep

from ivas_processing_scripts.audiotools.metadata import check_ISM_metadata
from ivas_processing_scripts.constants import (
@@ -45,10 +46,14 @@ from ivas_processing_scripts.processing.processing import (
    preprocess,
    preprocess_2,
    preprocess_background_noise,
    process_condition,
    process_item,
    reorder_items_list,
)
from ivas_processing_scripts.utils import DirManager
from ivas_processing_scripts.utils import (
    DirManager,
    apply_func_parallel,
    progressbar_update,
)


def logging_init(args, cfg):
@@ -148,20 +153,39 @@ def main(args):
            # preprocess 2
            preprocess_2(cfg, logger)

        # attempt to parallelise either items or conditions based on test setup
        # parallelise_items = len(cfg.items_list) > len(cfg.proc_chains)
        with Pool() as p:
            p.starmap(
                process_condition,
                zip(
                    repeat(cfg),
                    cfg.proc_chains,
                    cfg.tmp_dirs,
                    cfg.out_dirs,
                    repeat(logger),
                ),
                chunksize=8,
        # assemble a list of all item and condition combinations
        item_args = list()
        for (chain, tmp_dir, out_dir), (item, metadata) in product(
            zip(cfg.proc_chains, cfg.tmp_dirs, cfg.out_dirs),
            zip(cfg.items_list, cfg.metadata_path),
        ):
            item_args.append(
                (item, tmp_dir, out_dir, chain["processes"], logger, metadata)
            )

        if cfg.multiprocessing:
            p = Pool()
            chunksize = 8
            results = p.starmap_async(
                process_item,
                item_args,
                chunksize,
            )
            width = 80
            count = len(item_args)

            progressbar_update(0, count, width)
            while not results.ready():
                progressbar_update(
                    count - (results._number_left * chunksize), count, width
                )
                sleep(0.1)

            p.close()
            p.join()

        else:
            apply_func_parallel(process_item, item_args, None, None, True)

    # copy configuration to output directory
    cfg.to_file(cfg.output_path.joinpath(f"{cfg.name}.yml"))
+2 −2
Original line number Diff line number Diff line
@@ -543,7 +543,7 @@ def binaural_fftconv_framewise(
        ),
        None,
        None,
        False
        False,
    )

    y = np.stack(result, axis=1)
@@ -608,7 +608,7 @@ def render_ear(
        ),
        None,
        None,
        False
        False,
    )

    return np.hstack(result)
+1 −1
Original line number Diff line number Diff line
@@ -136,7 +136,7 @@ def render_oba_to_binaural(
            ),
            None,
            None,
            False
            False,
        )

        # sum results over all objects
+12 −11
Original line number Diff line number Diff line
@@ -268,11 +268,7 @@ def pairwise(iter):
    return zip(a, b)


def progressbar(iter: Iterable, width=80):
    """simple unicode progressbar"""
    count = len(iter)

    def update(progress):
def progressbar_update(progress, count, width):
    fill = int(width * progress / count)
    print(
        f"{int(progress/count*100):3d}%{'|'}{'='*fill}{(' '*(width-fill))}{'|'}{progress}/{count}",
@@ -281,10 +277,15 @@ def progressbar(iter: Iterable, width=80):
        flush=True,
    )

    update(0)

def progressbar(iter: Iterable, width=80):
    """simple unicode progressbar"""
    count = len(iter)

    progressbar_update(0, count, width)
    for i, item in enumerate(iter):
        yield item
        update(i + 1)
        progressbar_update(i + 1, count, width)
    print("\n", flush=True, file=sys.stdout)