Commit 27befb67 authored by Jan Kiene's avatar Jan Kiene
Browse files

switch to thread-based parallelism

the previous process-based implementation can crash due to no pooling
parent d6c508d9
Loading
Loading
Loading
Loading
Loading
+18 −29
Original line number Diff line number Diff line
@@ -29,6 +29,7 @@ submitted to and settled by the final, binding jurisdiction of the courts of Mun
accordance with the laws of the Federal Republic of Germany excluding its conflict of law rules and
the United Nations Convention on Contracts on the International Sales of Goods.
"""

import argparse
import os
import platform
@@ -37,8 +38,9 @@ import numpy as np
import subprocess
import tempfile
import sys
from typing import Optional
from multiprocessing import Process, Value
from multiprocessing import Value
from concurrent.futures import ThreadPoolExecutor
from itertools import repeat
import shutil

sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), ".."))
@@ -178,9 +180,9 @@ class MLDConformance:
        print(
            f"Mapped decoder tests for {len(self.EncoderToDecoderCmdMap)} encoder tests out of {len(self.Commands['ENC'])} tests"
        )
        assert len(self.EncoderToDecoderCmdMap) == len(
            self.Commands["ENC"]
        ), "Failed to Map Encoder Commands to Decoder Commands"
        assert len(self.EncoderToDecoderCmdMap) == len(self.Commands["ENC"]), (
            "Failed to Map Encoder Commands to Decoder Commands"
        )

    def genEncoderReferences(self, command: str, encCommandIdx: int):
        # RUN ENCODER COMMAND LINE WITH REFERENCE ENCODER
@@ -209,18 +211,12 @@ class MLDConformance:
        self.stats()

    def runReferenceGeneration(self):
        processes = list()  # Multiprocess list
        commands = conformance.Commands["ENC"]
        self.totalTests = len(commands)
        if not self.args.no_multi_processing:
            for commandIdx, command in enumerate(commands):
                p = Process(
                    target=self.genEncoderReferences, args=(command, commandIdx)
                )
                processes.append(p)
                p.start()
            for p in processes:
                p.join()
            command_ids = range(len(commands))
            with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
                list(executor.map(self.genEncoderReferences, commands, command_ids))
        else:
            for commandIdx, command in enumerate(commands):
                conformance.genEncoderReferences(command, commandIdx)
@@ -415,15 +411,8 @@ class MLDConformance:
            f"Executing tests for {tag}  {'Filter=' + self.filter if self.filter else ''} ({self.totalTests} tests)"
        )
        if not self.args.no_multi_processing:
            for command in commands:
                p = Process(
                    target=self.runOneCommand,
                    args=(tag, command),
                )
                processes.append(p)
                p.start()
            for p in processes:
                p.join()
            with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor:
                list(executor.map(self.runOneCommand, repeat(tag), commands))
        else:
            for command in commands:
                self.runOneCommand(tag, command)
@@ -470,9 +459,9 @@ class MLDConformance:
        with tempfile.TemporaryDirectory() as tmpdir:
            refSamples, fsR = readfile(refFile, outdtype="float")
            dutSamples, fsD = readfile(dutFile, outdtype="float")
            assert (
                refSamples.shape[1] == dutSamples.shape[1]
            ), "No of channels mismatch if ref vs cut"
            assert refSamples.shape[1] == dutSamples.shape[1], (
                "No of channels mismatch if ref vs cut"
            )
            maxDiff, rmsdB, beSamplesPercent = self.getSampleStats(
                refSamples, dutSamples
            )