diff --git a/scripts/ivas_conformance/runConformance.py b/scripts/ivas_conformance/runConformance.py index 15f9fbcd534704c4202dd7660a7deb3010488883..b9a246749ecd6264756b8ef1401a490cb588ab6c 100644 --- a/scripts/ivas_conformance/runConformance.py +++ b/scripts/ivas_conformance/runConformance.py @@ -29,6 +29,7 @@ submitted to and settled by the final, binding jurisdiction of the courts of Mun accordance with the laws of the Federal Republic of Germany excluding its conflict of law rules and the United Nations Convention on Contracts on the International Sales of Goods. """ + import argparse import os import platform @@ -36,9 +37,10 @@ import re import numpy as np import subprocess import tempfile +import threading import sys -from typing import Optional -from multiprocessing import Process, Value +from concurrent.futures import ThreadPoolExecutor +from itertools import repeat import shutil sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), "..")) @@ -113,10 +115,13 @@ class MLDConformance: self.testvecDir = args.testvecDir self.toolsdir = os.path.join(self.scriptsDir, "tools") self.testvDir = os.path.join(self.testvecDir, "testv") - self.executedTests = Value("i", 0) - self.failedTests = Value("i", 0) + self.executedTests = 0 + self.failedTests = 0 self.setup() + # synchronize all writes to state variables + self.lock = threading.Lock() + def accumulateCommands(self): for root, _, files in os.walk(self.testvecDir): for file_name in files: @@ -178,9 +183,9 @@ class MLDConformance: print( f"Mapped decoder tests for {len(self.EncoderToDecoderCmdMap)} encoder tests out of {len(self.Commands['ENC'])} tests" ) - assert len(self.EncoderToDecoderCmdMap) == len( - self.Commands["ENC"] - ), "Failed to Map Encoder Commands to Decoder Commands" + assert len(self.EncoderToDecoderCmdMap) == len(self.Commands["ENC"]), ( + "Failed to Map Encoder Commands to Decoder Commands" + ) def genEncoderReferences(self, command: str, encCommandIdx: int): # RUN ENCODER COMMAND LINE WITH REFERENCE ENCODER @@ -205,22 +210,17 @@ class MLDConformance: + [refEncOutput, refDecOutputFile] ) self.process(command=" ".join(refDecCmd)) - self.executedTests.value += 1 + with self.lock: + self.executedTests += 1 self.stats() def runReferenceGeneration(self): - processes = list() # Multiprocess list commands = conformance.Commands["ENC"] self.totalTests = len(commands) if not self.args.no_multi_processing: - for commandIdx, command in enumerate(commands): - p = Process( - target=self.genEncoderReferences, args=(command, commandIdx) - ) - processes.append(p) - p.start() - for p in processes: - p.join() + command_ids = range(len(commands)) + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + list(executor.map(self.genEncoderReferences, commands, command_ids)) else: for commandIdx, command in enumerate(commands): conformance.genEncoderReferences(command, commandIdx) @@ -356,9 +356,9 @@ class MLDConformance: "$CUT_PATH/ref/sba_bs/", f"{self.testvDir}/ref/sba_bs/" ) else: - #command = command.replace( + # command = command.replace( # "$CUT_PATH/dut/sba_bs/pkt/", f"{self.outputDir}/dut/enc/" - #) + # ) command = command.replace( "$CUT_PATH/ref/param_file/enc/", f"{self.outputDir}/dut/enc/" ) @@ -390,7 +390,8 @@ class MLDConformance: self.runOneRendererTest(tag, command) else: assert False, f"Un-implemented Tag {tag}" - self.executedTests.value += 1 + with self.lock: + self.executedTests += 1 self.stats() def runTag(self, tag: str): @@ -401,7 +402,6 @@ class MLDConformance: with open(self.sampleStats[tag], "w") as f: f.write(f"PYTESTTAG, MAXDIFF, RMSdB, BEFRAMES_PERCENT, MAX_MLD\n") - processes = list() # Multiprocess list commands = list() if self.filter: for command in self.Commands[tag]: @@ -412,18 +412,11 @@ class MLDConformance: self.totalTests = len(commands) print( - f"Executing tests for {tag} {'Filter='+self.filter if self.filter else ''} ({self.totalTests} tests)" + f"Executing tests for {tag} {'Filter=' + self.filter if self.filter else ''} ({self.totalTests} tests)" ) if not self.args.no_multi_processing: - for command in commands: - p = Process( - target=self.runOneCommand, - args=(tag, command), - ) - processes.append(p) - p.start() - for p in processes: - p.join() + with ThreadPoolExecutor(max_workers=os.cpu_count()) as executor: + list(executor.map(self.runOneCommand, repeat(tag), commands)) else: for command in commands: self.runOneCommand(tag, command) @@ -442,13 +435,14 @@ class MLDConformance: if c.returncode: with open(self.failedCmdsFile, "a") as f: f.write(command + "\n") - self.failedTests.value += 1 + with self.lock: + self.failedTests += 1 # c.check_returncode() return 0 def stats(self): print( - f"Executed: {self.executedTests.value} / {self.totalTests} Failed: {self.failedTests.value}", + f"Executed: {self.executedTests} / {self.totalTests} Failed: {self.failedTests}", end="\r", ) @@ -470,9 +464,9 @@ class MLDConformance: with tempfile.TemporaryDirectory() as tmpdir: refSamples, fsR = readfile(refFile, outdtype="float") dutSamples, fsD = readfile(dutFile, outdtype="float") - assert ( - refSamples.shape[1] == dutSamples.shape[1] - ), "No of channels mismatch if ref vs cut" + assert refSamples.shape[1] == dutSamples.shape[1], ( + "No of channels mismatch if ref vs cut" + ) maxDiff, rmsdB, beSamplesPercent = self.getSampleStats( refSamples, dutSamples ) @@ -616,15 +610,21 @@ if __name__ == "__main__": if args.regenerate_enc_refs: conformance.runReferenceGeneration() - sys.exit(0) - - testTags = ( - MLDConformance.IVAS_Bins.keys() if args.test_mode == "ALL" else [args.test_mode] - ) - for tag in testTags: - if tag == "ISAR": - # Not implemented yet - continue - if not args.analyse_only: - conformance.runTag(tag) - conformance.doAnalysis(selectTag=tag) + else: + testTags = ( + MLDConformance.IVAS_Bins.keys() if args.test_mode == "ALL" else [args.test_mode] + ) + for tag in testTags: + if tag == "ISAR": + # Not implemented yet + continue + if not args.analyse_only: + conformance.runTag(tag) + conformance.doAnalysis(selectTag=tag) + + # final \n makes sure that the output from stat() is readable after the command terminates + print() + + if conformance.failedTests != 0: + print(f"Error: {conformance.failedTests} tests failed") + sys.exit(1)