Commit 1835ea32 authored by Anika Treffehn's avatar Anika Treffehn
Browse files

added from file functions to OMASA class

parent 1b61b072
Loading
Loading
Loading
Loading
+87 −88
Original line number Diff line number Diff line
@@ -380,7 +380,6 @@ class SceneBasedAudio(Audio):

class OMASAAudio(Audio):
    """Sub-class for combined OMASA format"""
    # TODO treffehn: write class
    def __init__(self, name: str):
        super().__init__(name)
        try:
@@ -390,92 +389,92 @@ class OMASAAudio(Audio):
        self.object_pos = []
        self.metadata_files = []  # first ISM metadata followed by masa metadata

    # @classmethod
    # def _from_file(
    #     cls,
    #     name: str,
    #     filename: Union[str, Path],
    #     metadata_files_ism: list[Union[str, Path]],
    #     fs: Optional[int] = None,
    # ) -> "ObjectBasedAudio":
    #     obj = super()._from_file(name, filename, fs)
    #     if metadata_files is not None:
    #         obj.metadata_files = [Path(f) for f in metadata_files]
    #     else:
    #         # search for metadata with naming scheme: name.(wav, pcm).(0-3).csv
    #         for obj_idx in range(obj.num_channels):
    #             file_name_meta = filename.with_suffix(
    #                 f"{filename.suffix}.{obj_idx}.csv"
    #             )
    #             if file_name_meta.is_file():
    #                 obj.metadata_files.append(file_name_meta)
    #             else:
    #                 raise ValueError(f"Metadata file {file_name_meta} not found.")
    #         warn(
    #             f"No metadata files specified: The following files were found and used: \n {*obj.metadata_files,}"
    #         )
    #
    #     obj.init_metadata()
    #     return obj
    #
    # @classmethod
    # def _from_filelist(
    #     cls,
    #     name: str,
    #     filename: Path,
    #     metadata_files: list[Union[str, Path]],
    #     fs: Optional[int] = None,
    # ) -> "ObjectBasedAudio":
    #     obj = super()._from_filelist(name, filename, fs)
    #     obj.metadata_files = [Path(f) for f in metadata_files]
    #     obj.init_metadata()
    #     return obj
    #
    # def init_metadata(self):
    #     # check if number of metadata files matches format
    #     if self.audio.shape[1] != len(self.metadata_files):
    #         raise ValueError(
    #             f"Mismatch between number of channels in file [{self.audio.shape[1]}], and metadata [{len(self.metadata_files)}]"
    #         )
    #
    #     self.object_pos = []
    #     for i, f in enumerate(self.metadata_files):
    #         pos = np.genfromtxt(f, delimiter=",")
    #
    #         # check if metadata has right number of columns
    #         num_columns = pos.shape[1]
    #         if num_columns < 2:
    #             raise ValueError(
    #                 "Metadata incomplete. Columns are missing. Azimuth and elevation are mandatory."
    #             )
    #         elif num_columns > NUMBER_COLUMNS_ISM_METADATA:
    #             raise ValueError("Too many columns in metadata")
    #
    #         # pad metadata to max number of columns
    #         if num_columns < NUMBER_COLUMNS_ISM_METADATA:
    #             pos = np.hstack(
    #                 [pos, np.array(pos.shape[0] * [DEFAULT_ISM_METADATA[num_columns:]])]
    #             )
    #
    #         # check if metadata is longer than file -> cut off
    #         num_frames = int(
    #             np.ceil(self.audio.shape[0] / (self.fs * IVAS_FRAME_LEN_MS / 1000))
    #         )
    #         if num_frames < pos.shape[0]:
    #             pos = pos[:num_frames]
    #         # check if metadata is shorter than file -> loop
    #         elif num_frames > pos.shape[0]:
    #             pos_loop = np.zeros((num_frames, pos.shape[1]))
    #             pos_loop[: pos.shape[0]] = pos
    #             for idx in range(pos.shape[0], num_frames):
    #                 pos_loop[idx, :2] = pos[idx % pos.shape[0], :2]
    #             pos = pos_loop
    #
    #         # wrap metadata to target value range
    #         for j in range(num_frames):
    #             pos[j, 0], pos[j, 1] = wrap_angles(pos[j, 0], pos[j, 1], clip_ele=True)
    #
    #         self.object_pos.append(pos)
    @classmethod
    def _from_file(
        cls,
        name: str,
        filename: Union[str, Path],
        metadata_files: list[Union[str, Path]],
        fs: Optional[int] = None,
    ) -> "OMASAAudio":
        obj = super()._from_file(name, filename, fs)
        if metadata_files is not None:
            obj.metadata_files = [Path(f) for f in metadata_files]
        else:
            # search for metadata with naming scheme: name.(wav, pcm).(0-3).csv
            for obj_idx in range(obj.num_ism_channels):
                file_name_meta = filename.with_suffix(
                    f"{filename.suffix}.{obj_idx}.csv"
                )
                if file_name_meta.is_file():
                    obj.metadata_files.append(file_name_meta)
                else:
                    raise ValueError(f"Metadata file {file_name_meta} not found.")
            warn(
                f"No metadata files specified: The following files were found and used: \n {*obj.metadata_files,}"
            )

        obj.init_metadata()
        return obj

    @classmethod
    def _from_filelist(
        cls,
        name: str,
        filename: Path,
        metadata_files: list[Union[str, Path]],
        fs: Optional[int] = None,
    ) -> "OMASAAudio":
        obj = super()._from_filelist(name, filename, fs)
        obj.metadata_files = [Path(f) for f in metadata_files]
        obj.init_metadata()
        return obj

    def init_metadata(self):
        # check if number of metadata files matches format
        if self.num_ism_channels != len(self.metadata_files):
            raise ValueError(
                f"Mismatch between number of channels in file [{self.audio.shape[1]}], and metadata [{len(self.metadata_files)}]"
            )

        self.object_pos = []
        for i, f in enumerate(self.metadata_files):
            pos = np.genfromtxt(f, delimiter=",")

            # check if metadata has right number of columns
            num_columns = pos.shape[1]
            if num_columns < 2:
                raise ValueError(
                    "Metadata incomplete. Columns are missing. Azimuth and elevation are mandatory."
                )
            elif num_columns > NUMBER_COLUMNS_ISM_METADATA:
                raise ValueError("Too many columns in metadata")

            # pad metadata to max number of columns
            if num_columns < NUMBER_COLUMNS_ISM_METADATA:
                pos = np.hstack(
                    [pos, np.array(pos.shape[0] * [DEFAULT_ISM_METADATA[num_columns:]])]
                )

            # check if metadata is longer than file -> cut off
            num_frames = int(
                np.ceil(self.audio.shape[0] / (self.fs * IVAS_FRAME_LEN_MS / 1000))
            )
            if num_frames < pos.shape[0]:
                pos = pos[:num_frames]
            # check if metadata is shorter than file -> loop
            elif num_frames > pos.shape[0]:
                pos_loop = np.zeros((num_frames, pos.shape[1]))
                pos_loop[: pos.shape[0]] = pos
                for idx in range(pos.shape[0], num_frames):
                    pos_loop[idx, :2] = pos[idx % pos.shape[0], :2]
                pos = pos_loop

            # wrap metadata to target value range
            for j in range(num_frames):
                pos[j, 0], pos[j, 1] = wrap_angles(pos[j, 0], pos[j, 1], clip_ele=True)

            self.object_pos.append(pos)


class OSBAAudio(Audio):
@@ -629,7 +628,7 @@ def fromfile(
    """Create an Audio object of the specified format from the given file"""
    filename = Path(filename)
    fmt_cls = _get_audio_class(fmt)
    if fmt_cls is ObjectBasedAudio or fmt_cls is MetadataAssistedSpatialAudio:
    if fmt_cls is ObjectBasedAudio or fmt_cls is MetadataAssistedSpatialAudio or fmt_cls is OMASAAudio or fmt_cls is OSBAAudio:
        return fmt_cls._from_file(fmt, filename, in_meta, fs)
    else:
        return fmt_cls._from_file(fmt, filename, fs)