Source code for neuroconv.datainterfaces.ecephys.baserecordingextractorinterface
import json
from typing import Any, Dict, List, Literal, Optional, Union
import numpy as np
from pynwb import NWBFile
from pynwb.device import Device
from pynwb.ecephys import ElectricalSeries, ElectrodeGroup
from ...baseextractorinterface import BaseExtractorInterface
from ...utils import (
DeepDict,
NWBMetaDataEncoder,
get_base_schema,
get_schema_from_hdmf_class,
)
[docs]class BaseRecordingExtractorInterface(BaseExtractorInterface):
"""Parent class for all RecordingExtractorInterfaces."""
keywords = BaseExtractorInterface.keywords + ["extracellular electrophysiology", "voltage", "recording"]
ExtractorModuleName = "spikeinterface.extractors"
def __init__(self, verbose: bool = True, es_key: str = "ElectricalSeries", **source_data):
"""
Parameters
----------
verbose : bool, default: True
If True, will print out additional information.
es_key : str, default: "ElectricalSeries"
The key of this ElectricalSeries in the metadata dictionary.
source_data : dict
The key-value pairs of extractor-specific arguments.
"""
super().__init__(**source_data)
self.recording_extractor = self.get_extractor()(**source_data)
self.subset_channels = None
self.verbose = verbose
self.es_key = es_key
self._number_of_segments = self.recording_extractor.get_num_segments()
[docs] def get_metadata_schema(self) -> dict:
"""Compile metadata schema for the RecordingExtractor."""
metadata_schema = super().get_metadata_schema()
metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys")
metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"]
metadata_schema["properties"]["Ecephys"]["properties"] = dict(
Device=dict(type="array", minItems=1, items={"$ref": "#/properties/Ecephys/properties/definitions/Device"}),
ElectrodeGroup=dict(
type="array", minItems=1, items={"$ref": "#/properties/Ecephys/properties/definitions/ElectrodeGroup"}
),
Electrodes=dict(
type="array",
minItems=0,
renderForm=False,
items={"$ref": "#/properties/Ecephys/properties/definitions/Electrodes"},
),
)
# Schema definition for arrays
metadata_schema["properties"]["Ecephys"]["properties"]["definitions"] = dict(
Device=get_schema_from_hdmf_class(Device),
ElectrodeGroup=get_schema_from_hdmf_class(ElectrodeGroup),
Electrodes=dict(
type="object",
additionalProperties=False,
required=["name"],
properties=dict(
name=dict(type="string", description="name of this electrodes column"),
description=dict(type="string", description="description of this electrodes column"),
),
),
)
if self.es_key is not None:
metadata_schema["properties"]["Ecephys"]["properties"].update(
{self.es_key: get_schema_from_hdmf_class(ElectricalSeries)}
)
return metadata_schema
[docs] def get_metadata(self) -> DeepDict:
metadata = super().get_metadata()
channel_groups_array = self.recording_extractor.get_channel_groups()
unique_channel_groups = set(channel_groups_array) if channel_groups_array is not None else ["ElectrodeGroup"]
electrode_metadata = [
dict(name=str(group_id), description="no description", location="unknown", device="DeviceEcephys")
for group_id in unique_channel_groups
]
metadata["Ecephys"] = dict(
Device=[dict(name="DeviceEcephys", description="no description")],
ElectrodeGroup=electrode_metadata,
)
if self.es_key is not None:
metadata["Ecephys"][self.es_key] = dict(
name=self.es_key, description=f"Acquisition traces for the {self.es_key}."
)
return metadata
[docs] def get_electrode_table_json(self) -> List[Dict[str, Any]]:
"""
A convenience function for collecting and organizing the property values of the underlying recording extractor.
Uses the structure of the Handsontable (list of dict entries) component of the NWB GUIDE.
"""
property_names = set(self.recording_extractor.get_property_keys()) - {
"contact_vector", # TODO: add consideration for contact vector (probeinterface) info
"location", # testing
}
electrode_ids = self.recording_extractor.get_channel_ids()
table = list()
for electrode_id in electrode_ids:
electrode_column = dict()
for property_name in property_names:
recording_property_value = self.recording_extractor.get_property(key=property_name, ids=[electrode_id])[
0 # First axis is always electodes in SI
] # Since only fetching one electrode at a time, use trivial zero-index
electrode_column.update({property_name: recording_property_value})
table.append(electrode_column)
table_as_json = json.loads(json.dumps(table, cls=NWBMetaDataEncoder))
return table_as_json
[docs] def get_original_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
"""
Retrieve the original unaltered timestamps for the data in this interface.
This function should retrieve the data on-demand by re-initializing the IO.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
new_recording = self.get_extractor()(
**{keyword: value for keyword, value in self.source_data.items() if keyword not in ["verbose", "es_key"]}
)
if self._number_of_segments == 1:
return new_recording.get_times()
else:
return [
new_recording.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs] def get_timestamps(self) -> Union[np.ndarray, List[np.ndarray]]:
"""
Retrieve the timestamps for the data in this interface.
Returns
-------
timestamps: numpy.ndarray or list of numpy.ndarray
The timestamps for the data stream; if the recording has multiple segments, then a list of timestamps is returned.
"""
if self._number_of_segments == 1:
return self.recording_extractor.get_times()
else:
return [
self.recording_extractor.get_times(segment_index=segment_index)
for segment_index in range(self._number_of_segments)
]
[docs] def set_aligned_timestamps(self, aligned_timestamps: np.ndarray):
assert (
self._number_of_segments == 1
), "This recording has multiple segments; please use 'align_segment_timestamps' instead."
self.recording_extractor.set_times(times=aligned_timestamps)
[docs] def set_aligned_segment_timestamps(self, aligned_segment_timestamps: List[np.ndarray]):
"""
Replace all timestamps for all segments in this interface with those aligned to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_timestamps : list of numpy.ndarray
The synchronized timestamps for segment of data in this interface.
"""
assert isinstance(
aligned_segment_timestamps, list
), "Recording has multiple segment! Please pass a list of timestamps to align each segment."
assert (
len(aligned_segment_timestamps) == self._number_of_segments
), f"The number of timestamp vectors ({len(aligned_segment_timestamps)}) does not match the number of segments ({self._number_of_segments})!"
for segment_index in range(self._number_of_segments):
self.recording_extractor.set_times(
times=aligned_segment_timestamps[segment_index], segment_index=segment_index
)
[docs] def set_aligned_starting_time(self, aligned_starting_time: float):
if self._number_of_segments == 1:
self.set_aligned_timestamps(aligned_timestamps=self.get_timestamps() + aligned_starting_time)
else:
self.set_aligned_segment_timestamps(
aligned_segment_timestamps=[
segment_timestamps + aligned_starting_time for segment_timestamps in self.get_timestamps()
]
)
[docs] def set_aligned_segment_starting_times(self, aligned_segment_starting_times: List[float]):
"""
Align the starting time for each segment in this interface relative to the common session start time.
Must be in units seconds relative to the common 'session_start_time'.
Parameters
----------
aligned_segment_starting_times : list of floats
The starting time for each segment of data in this interface.
"""
assert len(aligned_segment_starting_times) == self._number_of_segments, (
f"The length of the starting_times ({len(aligned_segment_starting_times)}) does not match the "
"number of segments ({self._number_of_segments})!"
)
if self._number_of_segments == 1:
self.set_aligned_starting_time(aligned_starting_time=aligned_segment_starting_times[0])
else:
aligned_segment_timestamps = [
segment_timestamps + aligned_segment_starting_time
for segment_timestamps, aligned_segment_starting_time in zip(
self.get_timestamps(), aligned_segment_starting_times
)
]
self.set_aligned_segment_timestamps(aligned_segment_timestamps=aligned_segment_timestamps)
[docs] def set_probe(self, probe, group_mode: Literal["by_shank", "by_probe"]):
"""
Set the probe information via a ProbeInterface object.
Parameters
----------
probe : probeinterface.Probe
The probe object.
group_mode : {'by_shank', 'by_probe'}
How to group the channels. If 'by_shank', channels are grouped by the shank_id column.
If 'by_probe', channels are grouped by the probe_id column.
This is a required parameter to avoid the pitfall of using the wrong mode.
"""
# Set the probe to the recording extractor
self.recording_extractor.set_probe(
probe,
in_place=True,
group_mode=group_mode,
)
# Spike interface sets the "group" property
# But neuroconv allows "group_name" property to override spike interface "group" value
self.recording_extractor.set_property("group_name", self.recording_extractor.get_property("group").astype(str))
[docs] def has_probe(self) -> bool:
"""
Check if the recording extractor has probe information.
Returns
-------
has_probe : bool
True if the recording extractor has probe information.
"""
return self.recording_extractor.has_probe()
[docs] def align_by_interpolation(
self,
unaligned_timestamps: np.ndarray,
aligned_timestamps: np.ndarray,
):
if self._number_of_segments == 1:
self.set_aligned_timestamps(
aligned_timestamps=np.interp(x=self.get_timestamps(), xp=unaligned_timestamps, fp=aligned_timestamps)
)
else:
raise NotImplementedError("Multi-segment support for aligning by interpolation has not been added yet.")
[docs] def subset_recording(self, stub_test: bool = False):
"""
Subset a recording extractor according to stub and channel subset options.
Parameters
----------
stub_test : bool, default: False
"""
from spikeinterface.core.segmentutils import ConcatenateSegmentRecording
max_frames = 100
recording_extractor = self.recording_extractor
number_of_segments = recording_extractor.get_num_segments()
recording_segments = [recording_extractor.select_segments([index]) for index in range(number_of_segments)]
end_frame_list = [min(max_frames, segment.get_num_frames()) for segment in recording_segments]
recording_segments_stubbed = [
segment.frame_slice(start_frame=0, end_frame=end_frame)
for segment, end_frame in zip(recording_segments, end_frame_list)
]
recording_extractor = ConcatenateSegmentRecording(recording_segments_stubbed)
return recording_extractor
[docs] def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: Optional[dict] = None,
stub_test: bool = False,
starting_time: Optional[float] = None,
write_as: Literal["raw", "lfp", "processed"] = "raw",
write_electrical_series: bool = True,
compression: Optional[str] = "gzip",
compression_opts: Optional[int] = None,
iterator_type: str = "v2",
iterator_opts: Optional[dict] = None,
):
"""
Primary function for converting raw (unprocessed) RecordingExtractor data to the NWB standard.
Parameters
----------
nwbfile : NWBFile
NWBFile to which the recording information is to be added
metadata : dict, optional
metadata info for constructing the NWB file.
Should be of the format::
metadata['Ecephys']['ElectricalSeries'] = dict(name=my_name, description=my_description)
The default is False (append mode).
starting_time : float, optional
Sets the starting time of the ElectricalSeries to a manually set value.
stub_test : bool, default: False
If True, will truncate the data to run the conversion faster and take up less memory.
write_as : {'raw', 'lfp', 'processed'}
write_electrical_series : bool, default: True
Electrical series are written in acquisition. If False, only device, electrode_groups,
and electrodes are written to NWB.
compression : {'gzip', 'lzf', None}
Type of compression to use.
Set to None to disable all compression.
compression_opts : int, default: 4
Only applies to compression="gzip". Controls the level of the GZIP.
iterator_type : {'v2', 'v1'}
The type of DataChunkIterator to use.
'v1' is the original DataChunkIterator of the hdmf data_utils.
'v2' is the locally developed RecordingExtractorDataChunkIterator, which offers full control over chunking.
iterator_opts : dict, optional
Dictionary of options for the RecordingExtractorDataChunkIterator (iterator_type='v2').
Valid options are
buffer_gb : float, default: 1.0
In units of GB. Recommended to be as much free RAM as available. Automatically calculates suitable
buffer shape.
buffer_shape : tuple, optional
Manual specification of buffer shape to return on each iteration.
Must be a multiple of chunk_shape along each axis.
Cannot be set if `buffer_gb` is specified.
chunk_mb : float. default: 1.0
Should be below 1 MB. Automatically calculates suitable chunk shape.
chunk_shape : tuple, optional
Manual specification of the internal chunk shape for the HDF5 dataset.
Cannot be set if `chunk_mb` is also specified.
display_progress : bool, default: False
Display a progress bar with iteration rate and estimated completion time.
progress_bar_options : dict, optional
Dictionary of keyword arguments to be passed directly to tqdm.
See https://github.com/tqdm/tqdm#parameters for options.
"""
from ...tools.spikeinterface import add_recording
if stub_test or self.subset_channels is not None:
recording = self.subset_recording(stub_test=stub_test)
else:
recording = self.recording_extractor
add_recording(
recording=recording,
nwbfile=nwbfile,
metadata=metadata,
starting_time=starting_time,
write_as=write_as,
write_electrical_series=write_electrical_series,
es_key=self.es_key,
compression=compression,
compression_opts=compression_opts,
iterator_type=iterator_type,
iterator_opts=iterator_opts,
)