Source code for neuroconv.tools.testing.mock_interfaces

from datetime import datetime
from typing import List, Optional

import numpy as np
from pynwb import NWBFile
from pynwb.base import DynamicTable

from .mock_ttl_signals import generate_mock_ttl_signal
from ...basetemporalalignmentinterface import BaseTemporalAlignmentInterface
from ...datainterfaces import SpikeGLXNIDQInterface
from ...datainterfaces.ecephys.baserecordingextractorinterface import (
    BaseRecordingExtractorInterface,
)
from ...datainterfaces.ophys.baseimagingextractorinterface import (
    BaseImagingExtractorInterface,
)
from ...utils import ArrayType, get_schema_from_method_signature


[docs]class MockBehaviorEventInterface(BaseTemporalAlignmentInterface):
[docs] @classmethod def get_source_schema(cls) -> dict: source_schema = get_schema_from_method_signature(method=cls.__init__, exclude=["event_times"]) source_schema["additionalProperties"] = True return source_schema
def __init__(self, event_times: Optional[ArrayType] = None): """ Define event times for some behavior. Parameters ---------- event_times : list of floats, optional The event times to set as timestamps for this interface. The default is the array [1.2, 2.3, 3.4] for similarity to the timescale of the MockSpikeGLXNIDQInterface. """ event_times = event_times or [1.2, 2.3, 3.4] self.event_times = np.array(event_times) self.original_event_times = np.array(event_times) # Make a copy of the initial loaded timestamps
[docs] def get_original_timestamps(self) -> np.ndarray: return self.original_event_times
[docs] def get_timestamps(self) -> np.ndarray: return self.event_times
[docs] def set_aligned_timestamps(self, aligned_timestamps: np.ndarray): self.event_times = aligned_timestamps
def add_to_nwbfile(self, nwbfile: NWBFile, metadata: dict): table = DynamicTable(name="BehaviorEvents", description="Times of various classified behaviors.") table.add_column(name="event_time", description="Time of each event.") for timestamp in self.get_timestamps(): # adding data by column gives error table.add_row(event_time=timestamp) nwbfile.add_acquisition(table)
[docs]class MockSpikeGLXNIDQInterface(SpikeGLXNIDQInterface): ExtractorName = "NumpyRecording"
[docs] @classmethod def get_source_schema(cls) -> dict: source_schema = get_schema_from_method_signature(method=cls.__init__, exclude=["ttl_times"]) source_schema["additionalProperties"] = True return source_schema
def __init__( self, signal_duration: float = 7.0, ttl_times: Optional[List[List[float]]] = None, ttl_duration: float = 1.0 ): """ Define a mock SpikeGLXNIDQInterface by overriding the recording extractor to be a mock TTL signal. Parameters ---------- signal_duration : float, default: 7.0 The number of seconds to simulate. ttl_times : list of lists of floats, optional The times within the `signal_duration` to trigger the TTL pulse for each channel. The outer list is over channels, while each inner list is the set of TTL times for each specific channel. The default generates 8 channels with periodic on/off cycle (which start in the 'off' state) each of which is of length `ttl_duration` with a 0.1 second offset per channel. ttl_duration : float, default: 1.0 How long the TTL pulses stays in the 'on' state when triggered, in seconds. """ from spikeinterface.extractors import NumpyRecording if ttl_times is None: # Begin in 'off' state number_of_periods = int(np.ceil((signal_duration - ttl_duration) / (ttl_duration * 2))) default_periodic_ttl_times = [ttl_duration * (1 + 2 * period) for period in range(number_of_periods)] ttl_times = [[ttl_time + 0.1 * channel for ttl_time in default_periodic_ttl_times] for channel in range(8)] number_of_channels = len(ttl_times) channel_ids = [f"nidq#XA{channel_index}" for channel_index in range(number_of_channels)] # NIDQ channel IDs channel_groups = ["NIDQChannelGroup"] * number_of_channels sampling_frequency = 25_000.0 # NIDQ sampling rate number_of_frames = int(signal_duration * sampling_frequency) traces = np.empty(shape=(number_of_frames, number_of_channels), dtype="int16") for channel_index in range(number_of_channels): traces[:, channel_index] = generate_mock_ttl_signal( signal_duration=signal_duration, ttl_times=ttl_times[channel_index], ttl_duration=ttl_duration, sampling_frequency_hz=sampling_frequency, ) self.recording_extractor = NumpyRecording( traces_list=traces, sampling_frequency=sampling_frequency, channel_ids=channel_ids ) # NIDQ channel gains self.recording_extractor.set_channel_gains(gains=[61.03515625] * self.recording_extractor.get_num_channels()) self.recording_extractor.set_property(key="group_name", values=channel_groups) # Minimal meta so `get_metadata` works similarly to real NIDQ header self.meta = {"acqMnMaXaDw": "0,0,8,1", "fileCreateTime": "2020-11-03T10:35:10", "niDev1ProductName": "PCI-6259"} self.subset_channels = None self.verbose = None self.es_key = "ElectricalSeriesNIDQ"
[docs]class MockRecordingInterface(BaseRecordingExtractorInterface): """An interface with a spikeinterface recording object for testing purposes.""" ExtractorModuleName = "spikeinterface.core.generate" ExtractorName = "generate_recording" def __init__( self, num_channels: int = 4, sampling_frequency: float = 30_000.0, durations: List[float] = [1.0], seed: int = 0, verbose: bool = True, es_key: str = "ElectricalSeries", ): super().__init__( num_channels=num_channels, sampling_frequency=sampling_frequency, durations=durations, seed=seed, verbose=verbose, es_key=es_key, )
[docs] def get_metadata(self) -> dict: metadata = super().get_metadata() session_start_time = datetime.now().astimezone() metadata["NWBFile"]["session_start_time"] = session_start_time return metadata
[docs]class MockImagingInterface(BaseImagingExtractorInterface): def __init__( self, num_frames: int = 30, num_rows: int = 10, num_columns: int = 10, sampling_frequency: float = 30, dtype: str = "uint16", verbose: bool = True, ): from roiextractors.testing import generate_dummy_imaging_extractor self.imaging_extractor = generate_dummy_imaging_extractor( num_frames=num_frames, num_rows=num_rows, num_columns=num_columns, sampling_frequency=sampling_frequency, dtype=dtype, ) self.verbose = verbose
[docs] def get_metadata(self) -> dict: session_start_time = datetime.now().astimezone() metadata = super().get_metadata() metadata["NWBFile"]["session_start_time"] = session_start_time return metadata