Commit eb5c8b5c authored by Ripinder Singh's avatar Ripinder Singh
Browse files

Replace multiprocessing.Process with multiprocessing.Pool API



* Limits the active process to number of CPUs
* On linux this prevents too many open file error
* Have to revert the progress stats reporting as shared variable
  usage in pool is very different vs process due to difference
  of fork() vs fixed processes working with messages.

Signed-off-by: default avatarRipinder Singh <ripinder.singh@dolby.com>
parent dfa517be
Loading
Loading
Loading
Loading
Loading
+8 −32
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ import subprocess
import tempfile
import sys
from typing import Tuple
from multiprocessing import Process, Value
from multiprocessing import Pool
import shutil
import scipy.io.wavfile as wav
import warnings
@@ -255,8 +255,6 @@ class MLDConformance:
        self.testvecDir = args.testvecDir
        self.toolsdir = os.path.join(self.scriptsDir, "tools")
        self.testvDir = os.path.join(self.testvecDir, "testv")
        self.executedTests = Value("i", 0)
        self.failedTests = Value("i", 0)
        self.setup()

    def accumulateCommands(self):
@@ -396,22 +394,15 @@ class MLDConformance:
            )

        self.process(command=" ".join(refDecCmd))
        self.executedTests.value += 1
        self.stats()

    def runReferenceGeneration(self, encTag="ENC"):
        processes = list()  # Multiprocess list
        commands = conformance.Commands[encTag]
        self.totalTests = len(commands)
        if not self.args.no_multi_processing:
            for commandIdx, command in enumerate(commands):
                p = Process(
                    target=self.genEncoderReferences, args=(command, commandIdx, encTag)
                )
                processes.append(p)
                p.start()
            for p in processes:
                p.join()
            with Pool() as pool:
                args = [(command, commandIdx, encTag) for commandIdx, command in enumerate(commands)]
                pool.starmap(self.genEncoderReferences, args)
        else:
            for commandIdx, command in enumerate(commands):
                conformance.genEncoderReferences(command, commandIdx, encTag)
@@ -631,18 +622,13 @@ class MLDConformance:
            self.runOneIsarDecoderTest(command)
        else:
            assert False, f"Un-implemented Tag {tag}"
        self.executedTests.value += 1
        self.stats()

    def runTag(self, tag: str):
        self.executedTests.value = 0
        self.failedTests.value = 0
        # reset MLD, Sample Stats
        open(self.mldcsv[tag], "w").close()
        with open(self.sampleStats[tag], "w") as f:
            f.write(f"PYTESTTAG, MAXDIFF, RMSdB, BEFRAMES_PERCENT, MAX_MLD\n")

        processes = list()  # Multiprocess list
        commands = list()
        if self.filter:
            for command in self.Commands[tag]:
@@ -656,15 +642,9 @@ class MLDConformance:
            f"Executing tests for {tag}  {'Filter='+self.filter if self.filter else ''} ({self.totalTests} tests)"
        )
        if not self.args.no_multi_processing:
            for command in commands:
                p = Process(
                    target=self.runOneCommand,
                    args=(tag, command),
                )
                processes.append(p)
                p.start()
            for p in processes:
                p.join()
            with Pool() as pool:
                args = [(tag, command) for command in commands]
                pool.starmap(self.runOneCommand, args)
        else:
            for command in commands:
                self.runOneCommand(tag, command)
@@ -683,15 +663,11 @@ class MLDConformance:
                if c.returncode:
                    with open(self.failedCmdsFile, "a") as f:
                        f.write(command + "\n")
                    self.failedTests.value += 1
                    # c.check_returncode()
        return 0

    def stats(self):
        print(
            f"Executed: {self.executedTests.value} / {self.totalTests} Failed: {self.failedTests.value}",
            end="\r",
        )
        pass

    def getSampleStats(self, refSamples: np.ndarray, dutSamples: np.ndarray):
        nSamples = min(refSamples.shape[0], dutSamples.shape[0])