Commit 66b1c883 authored by Archit Tamarapu's avatar Archit Tamarapu
Browse files

add a new generic function for multiprocessing and cleanup audio3dtools.py

parent 34c6b03f
Loading
Loading
Loading
Loading
+109 −109
Original line number Diff line number Diff line
@@ -46,7 +46,114 @@ logger = main_logger.getChild(__name__)
logger.setLevel(logging.DEBUG)


def main():
def main(args):
    # Set up logging handlers
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(logging.Formatter("%(message)s"))

    # Configure loggers
    LOGGER_FORMAT = "%(asctime)s | %(name)-12s | %(levelname)-8s | %(message)s"
    LOGGER_DATEFMT = "%m-%d %H:%M"
    logging.basicConfig(
        format=LOGGER_FORMAT,
        datefmt=LOGGER_DATEFMT,
        level=logging.INFO,
        handlers=[console_handler],
    )
    logger.info("Audio3DTools")

    if args.list is True or args.long is True:
        logger.info("===Supported spatial audio formats===")
        spatialaudioformat.Format.list_all(args.long)

    elif args.infiles is not None:
        logger.info("===Convert spatial audio file===")
        # Input folder can be a path, a file or a list of files
        if os.path.isdir(args.infiles):
            path = args.infiles
            audio_list = [
                os.path.join(path, f) for f in os.listdir(path) if f.endswith((".wav"))
            ]
        else:
            audio_list = [args.infiles]

        outdir = args.outdir
        _, output_ext = os.path.splitext(os.path.basename(outdir))
        if (len(audio_list) == 1) and (
            (output_ext.lower() == ".wav") or (output_ext.lower() == ".pcm")
        ):
            outfile = outdir
        else:
            outfile = None
            if not os.path.exists(outdir):
                os.makedirs(outdir)

        for infile in audio_list:
            logger.info(f"  process {infile}")

            _, input_ext = os.path.splitext(os.path.basename(infile))

            if outfile is None:
                outfile = os.path.basename(infile)
                if not args.dont_rename:
                    if args.outformat is not None:
                        outfile = outfile.replace(input_ext, f"_{args.outformat}.wav")
                    else:
                        outfile = outfile.replace(input_ext, ".out.wav")
                outfile = os.path.join(outdir, outfile)

            spatialaudioconvert.spatial_audio_convert(
                infile,
                outfile,
                in_format=args.informat,
                in_fs=args.infs,
                in_nchans=args.inchan,
                in_meta_files=args.metadata,
                in_ls_layout_file=args.layoutfile,
                out_format=args.outformat,
                out_fs=args.outfs,
                out_fc=args.outfc,
                output_loudness=args.normalize,
                loudness_tool=args.loudness_tool,
                trajectory=args.trajectory,
                binaural_dataset=args.binaural_dataset,
            )

            logger.info(f"  Output {outfile}")

            if args.binaural:
                if args.outformat.startswith("BINAURAL"):
                    raise SystemExit(
                        "BINAURAL output format can not be binauralized again!"
                    )

                _, output_ext = os.path.splitext(os.path.basename(outfile))
                outfile_bin = outfile.replace(output_ext, "_BINAURAL.wav")
                logger.info(f"  Output binaural {outfile_bin}")

                spatialaudioconvert.spatial_audio_convert(
                    in_file=outfile,
                    out_file=outfile_bin,
                    in_format=args.outformat,
                    in_fs=args.outfs,
                    in_meta_files=args.metadata,
                    in_ls_layout_file=args.layoutfile,
                    out_format="BINAURAL",
                    output_loudness=args.normalize,
                    loudness_tool=args.loudness_tool,
                    trajectory=args.trajectory,
                    binaural_dataset=args.binaural_dataset,
                )

            outfile = None
    else:
        raise Exception(
            "Input file must be provided for conversion and audio manipulation."
        )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Audio3DTools: Convert/Manipulate spatial audio files."
    )
@@ -186,111 +293,4 @@ def main():
    )
    args = parser.parse_args()

    # Set up logging handlers
    console_handler = logging.StreamHandler()
    console_handler.setLevel(logging.INFO)
    console_handler.setFormatter(logging.Formatter("%(message)s"))

    # Configure loggers
    LOGGER_FORMAT = "%(asctime)s | %(name)-12s | %(levelname)-8s | %(message)s"
    LOGGER_DATEFMT = "%m-%d %H:%M"
    logging.basicConfig(
        format=LOGGER_FORMAT,
        datefmt=LOGGER_DATEFMT,
        level=logging.INFO,
        handlers=[console_handler],
    )
    logger.info("Audio3DTools")

    if args.list is True or args.long is True:
        logger.info("===Supported spatial audio formats===")
        spatialaudioformat.Format.list_all(args.long)

    elif args.infiles is not None:
        logger.info("===Convert spatial audio file===")
        # Input folder can be a path, a file or a list of files
        if os.path.isdir(args.infiles):
            path = args.infiles
            audio_list = [
                os.path.join(path, f) for f in os.listdir(path) if f.endswith((".wav"))
            ]
        else:
            audio_list = [args.infiles]

        outdir = args.outdir
        _, output_ext = os.path.splitext(os.path.basename(outdir))
        if (len(audio_list) == 1) and (
            (output_ext.lower() == ".wav") or (output_ext.lower() == ".pcm")
        ):
            outfile = outdir
        else:
            outfile = None
            if not os.path.exists(outdir):
                os.makedirs(outdir)

        for infile in audio_list:
            logger.info(f"  process {infile}")

            _, input_ext = os.path.splitext(os.path.basename(infile))

            if outfile is None:
                outfile = os.path.basename(infile)
                if not args.dont_rename:
                    if args.outformat is not None:
                        outfile = outfile.replace(input_ext, f"_{args.outformat}.wav")
                    else:
                        outfile = outfile.replace(input_ext, ".out.wav")
                outfile = os.path.join(outdir, outfile)

            spatialaudioconvert.spatial_audio_convert(
                infile,
                outfile,
                in_format=args.informat,
                in_fs=args.infs,
                in_nchans=args.inchan,
                in_meta_files=args.metadata,
                in_ls_layout_file=args.layoutfile,
                out_format=args.outformat,
                out_fs=args.outfs,
                out_fc=args.outfc,
                output_loudness=args.normalize,
                loudness_tool=args.loudness_tool,
                trajectory=args.trajectory,
                binaural_dataset=args.binaural_dataset,
            )

            logger.info(f"  Output {outfile}")

            if args.binaural:
                if args.outformat.startswith("BINAURAL"):
                    raise SystemExit(
                        "BINAURAL output format can not be binauralized again!"
                    )

                _, output_ext = os.path.splitext(os.path.basename(outfile))
                outfile_bin = outfile.replace(output_ext, "_BINAURAL.wav")
                logger.info(f"  Output binaural {outfile_bin}")

                spatialaudioconvert.spatial_audio_convert(
                    in_file=outfile,
                    out_file=outfile_bin,
                    in_format=args.outformat,
                    in_fs=args.outfs,
                    in_meta_files=args.metadata,
                    in_ls_layout_file=args.layoutfile,
                    out_format="BINAURAL",
                    output_loudness=args.normalize,
                    loudness_tool=args.loudness_tool,
                    trajectory=args.trajectory,
                    binaural_dataset=args.binaural_dataset,
                )

            outfile = None
    else:
        raise Exception(
            "Input file must be provided for conversion and audio manipulation."
        )


if __name__ == "__main__":
    main()
    main(args)
+16 −1
Original line number Diff line number Diff line
@@ -32,9 +32,10 @@

import logging
import math
from typing import Optional, Tuple
from typing import Callable, Iterable, Optional, Tuple

import numpy as np
import multiprocessing as mp
import scipy.signal as sig

main_logger = logging.getLogger("__main__")
@@ -430,3 +431,17 @@ def get_framewise(x: np.ndarray, chunk_size: int) -> np.ndarray:
        yield x[i * chunk_size : (i + 1) * chunk_size, :]
    if x.shape[0] % chunk_size:
        yield x[n_frames * chunk_size :, :]


def process_async(files: Iterable, func: Callable, **kwargs):
    """Applies a function asynchronously to an array of audio files/filenames using a multiprocessing pool"""

    p = mp.pool(mp.cpu_count())
    results = []
    for f in files:
        results.append(p.apply_async(func, args=(f, kwargs)))
    p.close()
    p.join()
    for r in results:
        r.get()
    return results