Commit 24b215da authored by sagnowski's avatar sagnowski
Browse files

Extend isar_bstool.py with a command to patch file header (for testing purposes)

Type hints were also improved.
parent 44372bba
Loading
Loading
Loading
Loading
+172 −81
Original line number Diff line number Diff line
@@ -31,23 +31,28 @@
"""

from __future__ import annotations

import argparse
import io
import math
import sys
from collections.abc import Callable
from pathlib import Path
from typing import Protocol, cast, final


class IsarBstoolError(Exception):
    pass


@final
class IsarBitstream:
    def __init__(self, file_path: Path) -> None:
        self.file_path = file_path

        with open(file_path, "rb") as reader:
            self.header = IsarFileHeader(reader)
            self.frames = []
            self.frames: list[IsarFileFrame] = []

            while reader.peek(1):
                self.frames.append(IsarFileFrame(reader))
@@ -138,10 +143,10 @@ class IsarBitstream:
        self.file_path = file_path

        with open(file_path, "wb") as writer:
            writer.write(self.header.as_bytes)
            self.header.write(writer)

            for frame in self.frames:
                writer.write(frame.as_bytes)
                frame.write(writer)

    def trim(self, start_time_s: float, length_s: float | None = None) -> IsarBitstream:
        if length_s is None:
@@ -182,58 +187,64 @@ class IsarBitstream:
        return self.header == other.header and self.frames == other.frames


class _AsBytes:
    def __init__(self) -> None:
        self.as_bytes = bytearray()

    def _read(self, reader, num_bytes):
        bytes_ = reader.read(num_bytes)
        self.as_bytes.extend(bytes_)
        return bytes_

    def __eq__(self, value: object, /) -> bool:
        if not isinstance(value, _AsBytes):
            return False
        return self.as_bytes == value.as_bytes

@final
class IsarFileHeader:
    FILE_HEADER = b"MAIN_SPLITH"

class IsarFileHeader(_AsBytes):
    def __init__(self, reader) -> None:
    def __init__(self, reader: io.BufferedReader) -> None:
        super().__init__()

        FILE_HEADER = b"MAIN_SPLITH"
        file_header_top = self._read(reader, len(FILE_HEADER))
        if file_header_top != FILE_HEADER:
        file_header_top = reader.read(len(self.FILE_HEADER))
        if file_header_top != self.FILE_HEADER:
            raise IsarBstoolError(f"Not a valid ISAR file: {reader.name}")

        self.delay_ns = _int_from_bytes(self._read(reader, 4))
        self.codec = _codec_from_bytes(self._read(reader, 4))
        self.pose_correction = _pose_corr_from_bytes(self._read(reader, 4))
        self.codec_frame_size_ms = _int_from_bytes(self._read(reader, 2))
        self.isar_frame_size_ms = _int_from_bytes(self._read(reader, 2))
        self.sample_rate = _int_from_bytes(self._read(reader, 4))
        self.lc3plus_hires = bool(_int_from_bytes(self._read(reader, 2)))

        self.delay_ns = _int_from_bytes(reader.read(4))
        self.codec = _codec_from_bytes(reader.read(4))
        self.pose_correction = _pose_corr_from_bytes(reader.read(4))
        self.codec_frame_size_ms = _int_from_bytes(reader.read(2))
        self.isar_frame_size_ms = _int_from_bytes(reader.read(2))
        self.sample_rate = _int_from_bytes(reader.read(4))
        self.lc3plus_hires = bool(_int_from_bytes(reader.read(2)))

    def write(self, writer: io.BufferedWriter) -> None:
        _write_exact(writer, self.FILE_HEADER)
        _write_exact(writer, _int_to_bytes(self.delay_ns, 4))
        _write_exact(writer, _codec_to_bytes(self.codec))
        _write_exact(writer, _pose_corr_to_bytes(self.pose_correction))
        _write_exact(writer, _int_to_bytes(self.codec_frame_size_ms, 2))
        _write_exact(writer, _int_to_bytes(self.isar_frame_size_ms, 2))
        _write_exact(writer, _int_to_bytes(self.sample_rate, 4))
        _write_exact(writer, _int_to_bytes(int(self.lc3plus_hires), 2))


@final
class IsarFileFrame:
    FRAME_HEADER = b"SPLIT_FRAME"
    VERSION = 0

class IsarFileFrame(_AsBytes):
    def __init__(self, reader) -> None:
    def __init__(self, reader: io.BufferedReader) -> None:
        super().__init__()

        FRAME_HEADER = b"SPLIT_FRAME"
        frame_header = self._read(reader, len(FRAME_HEADER))
        if frame_header != FRAME_HEADER:
        frame_header = reader.read(len(self.FRAME_HEADER))
        if frame_header != self.FRAME_HEADER:
            raise IsarBstoolError(f"Not a valid ISAR file: {reader.name}")

        version = _int_from_bytes(self._read(reader, 1))
        if version != 0:
        version = _int_from_bytes(reader.read(1))
        if version != self.VERSION:
            raise IsarBstoolError(
                f"Unupported version of ISAR file format: {reader.name}"
            )

        self.num_bits = _int_from_bytes(self._read(reader, 4))
        self.num_bits = _int_from_bytes(reader.read(4))

        payload_size = math.ceil(self.num_bits / 8)
        self.payload = self._read(reader, payload_size)
        self.payload = reader.read(payload_size)

    def write(self, writer: io.BufferedWriter) -> None:
        _write_exact(writer, self.FRAME_HEADER)
        _write_exact(writer, _int_to_bytes(self.VERSION, 1))
        _write_exact(writer, _int_to_bytes(self.num_bits, 4))
        _write_exact(writer, self.payload)


######################################################################################
@@ -241,11 +252,24 @@ class IsarFileFrame(_AsBytes):
######################################################################################


def _int_from_bytes(bytes_):
def _write_exact(writer: io.BufferedWriter, data: bytes) -> None:
    num_written = writer.write(data)
    if num_written != len(data):
        file_name = getattr(writer, "name", "<stream>")
        raise IsarBstoolError(
            f"Failed to write to {file_name}: wrote {num_written} of {len(data)} bytes"
        )


def _int_from_bytes(bytes_: bytes) -> int:
    return int.from_bytes(bytes_, byteorder="little")


def _codec_from_bytes(bytes_):
def _int_to_bytes(x: int, num_bytes: int) -> bytes:
    return x.to_bytes(num_bytes, byteorder="little")


def _codec_from_bytes(bytes_: bytes) -> str:
    # Refer to ISAR_SPLIT_REND_CODEC enum in C code
    CODECS = ["LCLD", "LC3PLUS", "DEFAULT", "NONE"]
    x = _int_from_bytes(bytes_)
@@ -256,7 +280,13 @@ def _codec_from_bytes(bytes_):
    return "UNKNOWN"


def _pose_corr_from_bytes(bytes_):
def _codec_to_bytes(codec: str) -> bytes:
    # Refer to ISAR_SPLIT_REND_CODEC enum in C code
    CODECS = {"LCLD": 0, "LC3PLUS": 1, "DEFAULT": 2, "NONE": 3}
    return CODECS.get(codec, 255).to_bytes(4, byteorder="little")


def _pose_corr_from_bytes(bytes_: bytes) -> str:
    # Refer to ISAR_SPLIT_REND_POSE_CORRECTION_MODE enum in C code
    POSE_CORR_MODES = ["NONE", "CLDFB"]
    x = _int_from_bytes(bytes_)
@@ -267,12 +297,23 @@ def _pose_corr_from_bytes(bytes_):
    return "UNKNOWN"


def _pose_corr_to_bytes(pose_corr: str) -> bytes:
    # Refer to ISAR_SPLIT_REND_POSE_CORRECTION_MODE enum in C code
    POSE_CORR_MODES = {"NONE": 0, "CLDFB": 1}
    return POSE_CORR_MODES.get(pose_corr, 255).to_bytes(4, byteorder="little")


######################################################################################
#                             subcommand functions
######################################################################################


def _subcmd_info(args):
class _InfoArgs(Protocol):
    file_in: Path
    only: str | None


def _subcmd_info(args: _InfoArgs) -> None:
    bs = IsarBitstream(args.file_in)

    match args.only:
@@ -314,9 +355,37 @@ def _subcmd_info(args):
            raise IsarBstoolError(f"Not a valid parameter value: '{args.only}'")


def _subcmd_trim(args):
class _TrimArgs(Protocol):
    file_in: Path
    file_out: Path
    start_time: str
    length: str | None


def _subcmd_trim(args: _TrimArgs) -> None:
    bs = IsarBitstream(args.file_in)
    bs.trim(float(args.start_time), float(args.length) if args.length else None)
    _ = bs.trim(float(args.start_time), float(args.length) if args.length else None)
    bs.write(args.file_out)


class _PatchCodecFrameSizeArgs(Protocol):
    file_in: Path
    file_out: Path
    codec_frame_size_ms: str


def _subcmd_patch_codec_frame_size(args: _PatchCodecFrameSizeArgs) -> None:
    bs = IsarBitstream(args.file_in)
    codec_frame_size_ms = int(args.codec_frame_size_ms)

    match codec_frame_size_ms:
        case 5 | 10 | 20:
            bs.header.codec_frame_size_ms = codec_frame_size_ms
        case _:
            raise IsarBstoolError(
                f"Invalid codec frame size (in ms): {args.codec_frame_size_ms}. Valid values are 5 and 10."
            )

    bs.write(args.file_out)


@@ -324,17 +393,18 @@ def _subcmd_trim(args):
#                                     main
######################################################################################


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        prog="isar_bstool",
        description="Utility for inspecting and modifying ISAR bitstreams",
    )
    parser.set_defaults(func=lambda _: parser.print_help())
    parser.set_defaults(func=lambda: parser.print_help())
    subparsers = parser.add_subparsers(title="Commands")

    info = subparsers.add_parser("info", help="Print information about a bitstream")
    info.add_argument("file_in", help="Path to input file")
    info.add_argument(
    _ = info.add_argument("file_in", type=Path, help="Path to input file")
    _ = info.add_argument(
        "--only",
        help="Print only a specific parameter",
        default=None,
@@ -362,23 +432,44 @@ if __name__ == "__main__":
    trim = subparsers.add_parser(
        "trim", help="Remove initial frames from a bitstream file"
    )
    trim.add_argument("file_in", help="Path to input file")
    trim.add_argument("file_out", help="Path to output file")
    trim.add_argument(
    _ = trim.add_argument("file_in", type=Path, help="Path to input file")
    _ = trim.add_argument("file_out", type=Path, help="Path to output file")
    _ = trim.add_argument(
        "start_time",
        help="Start point (in s) from which content should be copied to the output.",
    )
    trim.add_argument(
    _ = trim.add_argument(
        "--length",
        help="Amount of time (in s) to copy to the output. If not given, content is copied until the end of the input is reached.",
        default=None,
    )
    trim.set_defaults(func=_subcmd_trim)

    patch_codec_frame_size = subparsers.add_parser(
        "patch_codec_frame_size",
        help="Overwrite the codec frame size field in the header of an ISAR bitstream file",
        description=(
            "Note that this doesn't modify the actual frames in the file. The purpose of this command "
            "is to inject incorrect frame size info into the header for testing LC3plus reconfiguration handling."
        ),
    )
    _ = patch_codec_frame_size.add_argument(
        "file_in", type=Path, help="Path to input file"
    )
    _ = patch_codec_frame_size.add_argument(
        "file_out", type=Path, help="Path to output file"
    )
    _ = patch_codec_frame_size.add_argument(
        "codec_frame_size_ms",
        help="Codec frame size (in ms) to write into the header of the output file",
    )
    patch_codec_frame_size.set_defaults(func=_subcmd_patch_codec_frame_size)

    args = parser.parse_args()
    func = cast(Callable[[argparse.Namespace], None], args.func)

    try:
        args.func(args)
        func(args)
    except (FileNotFoundError, PermissionError, IsarBstoolError) as e:
        print(e, file=sys.stderr)
        sys.exit(1)