import json
from pathlib import Path
from typing import List, Literal, Optional
import numpy as np
import scipy
from pynwb import NWBFile, TimeSeries
from ....basetemporalalignmentinterface import BaseTemporalAlignmentInterface
from ....tools.audio import add_acoustic_waveform_series
from ....tools.nwb_helpers import make_or_load_nwbfile
from ....utils import (
DeepDict,
FilePathType,
get_base_schema,
get_schema_from_hdmf_class,
)
def _check_audio_names_are_unique(metadata: dict):
neurodata_names = [neurodata["name"] for neurodata in metadata]
neurodata_names_are_unique = len(set(neurodata_names)) == len(neurodata_names)
assert neurodata_names_are_unique, f"Some of the names for Audio metadata are not unique."
[docs]class AudioInterface(BaseTemporalAlignmentInterface):
"""Data interface for writing .wav audio recordings to an NWB file."""
help = "Interface for writing audio recordings to an NWB file."
display_name = "Wav Audio"
def __init__(self, file_paths: list, verbose: bool = False):
"""
Data interface for writing acoustic recordings to an NWB file.
Writes acoustic recordings as an ``AcousticWaveformSeries`` from the ndx_sound extension.
Parameters
----------
file_paths : list of FilePathTypes
The file paths to the audio recordings in sorted, consecutive order.
We recommend using ``natsort`` to ensure the files are in consecutive order.
>>> from natsort import natsorted
>>> natsorted(file_paths)
verbose : bool, default: False
"""
suffixes = [suffix for file_path in file_paths for suffix in Path(file_path).suffixes]
format_is_not_supported = [
suffix for suffix in suffixes if suffix not in [".wav"]
] # TODO: add support for more formats
if format_is_not_supported:
raise ValueError(
"The currently supported file format for audio is WAV file. "
f"Some of the provided files does not match this format: {format_is_not_supported}."
)
self._number_of_audio_files = len(file_paths)
self.verbose = verbose
super().__init__(file_paths=file_paths)
self._segment_starting_times = None
[docs] def get_original_timestamps(self) -> np.ndarray:
raise NotImplementedError("The AudioInterface does not yet support timestamps.")
[docs] def get_timestamps(self) -> Optional[np.ndarray]:
raise NotImplementedError("The AudioInterface does not yet support timestamps.")
[docs] def set_aligned_timestamps(self, aligned_timestamps: List[np.ndarray]):
raise NotImplementedError("The AudioInterface does not yet support timestamps.")
[docs] def set_aligned_starting_time(self, aligned_starting_time: float):
"""
Align all starting times for all audio files in this interface relative to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_starting_time : float
The common starting time for all temporal data in this interface.
Applies to all segments if there are multiple file paths used by the interface.
"""
if self._segment_starting_times is None and self._number_of_audio_files == 1:
self._segment_starting_times = [aligned_starting_time]
elif self._segment_starting_times is not None and self._number_of_audio_files > 1:
self._segment_starting_times = [
segment_starting_time + aligned_starting_time for segment_starting_time in self._segment_starting_times
]
else:
raise ValueError(
"There are no segment starting times to shift by a common value! "
"Please set them using 'set_aligned_segment_starting_times'."
)
[docs] def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float]):
"""
Align the individual starting time for each audio file in this interface relative to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_starting_times : list of floats
The relative starting times of each audio file (segment).
"""
aligned_segment_starting_times_length = len(aligned_segment_starting_times)
assert isinstance(aligned_segment_starting_times, list) and all(
[isinstance(x, float) for x in aligned_segment_starting_times]
), "Argument 'aligned_segment_starting_times' must be a list of floats."
assert aligned_segment_starting_times_length == self._number_of_audio_files, (
f"The number of entries in 'aligned_segment_starting_times' ({aligned_segment_starting_times_length}) "
f"must be equal to the number of audio file paths ({self._number_of_audio_files})."
)
self._segment_starting_times = aligned_segment_starting_times
[docs] def align_by_interpolation(self, unaligned_timestamps: np.ndarray, aligned_timestamps: np.ndarray):
raise NotImplementedError("The AudioInterface does not yet support timestamps.")
[docs] def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: Optional[dict] = None,
stub_test: bool = False,
stub_frames: int = 1000,
write_as: Literal["stimulus", "acquisition"] = "stimulus",
iterator_options: Optional[dict] = None,
compression_options: Optional[dict] = None,
overwrite: bool = False,
verbose: bool = True,
):
"""
Parameters
----------
nwbfile : NWBFile
Append to this NWBFile object
metadata : dict, optional
stub_test : bool, default: False
stub_frames : int, default: 1000
write_as : {'stimulus', 'acquisition'}
The acoustic waveform series can be added to the NWB file either as
"stimulus" or as "acquisition".
iterator_options : dict, optional
Dictionary of options for the SliceableDataChunkIterator.
compression_options : dict, optional
Dictionary of options for compressing the data for H5DataIO.
overwrite : bool, default: False
verbose : bool, default: True
Returns
-------
NWBFile
"""
file_paths = self.source_data["file_paths"]
audio_metadata = metadata["Behavior"]["Audio"]
_check_audio_names_are_unique(metadata=audio_metadata)
assert len(audio_metadata) == self._number_of_audio_files, (
f"The Audio metadata is incomplete ({len(audio_metadata)} entry)! "
f"Expected {self._number_of_audio_files} (one for each entry of 'file_paths')."
)
audio_name_list = [audio["name"] for audio in audio_metadata]
any_duplicated_audio_names = len(set(audio_name_list)) < len(file_paths)
if any_duplicated_audio_names:
raise ValueError("There are duplicated file names in the metadata!")
if self._number_of_audio_files > 1 and self._segment_starting_times is None:
raise ValueError(
"If you have multiple audio files, then you must specify each starting time by calling "
"'.set_aligned_segment_starting_times(...)'!"
)
starting_times = self._segment_starting_times or [0.0]
for file_index, (acoustic_waveform_series_metadata, file_path) in enumerate(zip(audio_metadata, file_paths)):
sampling_rate, acoustic_series = scipy.io.wavfile.read(filename=file_path, mmap=True)
if stub_test:
acoustic_series = acoustic_series[:stub_frames]
add_acoustic_waveform_series(
acoustic_series=acoustic_series,
nwbfile=nwbfile,
rate=sampling_rate,
metadata=acoustic_waveform_series_metadata,
write_as=write_as,
starting_time=starting_times[file_index],
iterator_options=iterator_options,
compression_options=compression_options,
)
return nwbfile