From 2739ea707980515b22e7105b1535c8f1c5290c72 Mon Sep 17 00:00:00 2001 From: Heberto Mayorquin Date: Tue, 17 Dec 2024 11:03:16 -0600 Subject: [PATCH] add time series extractor from nwb --- .../extractors/extractorlist.py | 10 +- .../extractors/nwbextractors.py | 382 +++++++++++++++++- .../extractors/tests/test_nwbextractors.py | 192 ++++++++- 3 files changed, 578 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/extractors/extractorlist.py b/src/spikeinterface/extractors/extractorlist.py index bd35180a7e..9daec58d74 100644 --- a/src/spikeinterface/extractors/extractorlist.py +++ b/src/spikeinterface/extractors/extractorlist.py @@ -26,7 +26,15 @@ from .neoextractors import NeuroScopeSortingExtractor, MaxwellEventExtractor # NWB sorting/recording/event -from .nwbextractors import NwbRecordingExtractor, NwbSortingExtractor, read_nwb, read_nwb_recording, read_nwb_sorting +from .nwbextractors import ( + NwbRecordingExtractor, + NwbSortingExtractor, + NwbTimeSeriesExtractor, + read_nwb, + read_nwb_recording, + read_nwb_sorting, + read_nwb_timeseries, +) from .cbin_ibl import CompressedBinaryIblExtractor, read_cbin_ibl from .iblextractors import IblRecordingExtractor, IblSortingExtractor, read_ibl_recording, read_ibl_sorting diff --git a/src/spikeinterface/extractors/nwbextractors.py b/src/spikeinterface/extractors/nwbextractors.py index 171992f6b1..3ef60234cc 100644 --- a/src/spikeinterface/extractors/nwbextractors.py +++ b/src/spikeinterface/extractors/nwbextractors.py @@ -585,6 +585,7 @@ def __init__( segment_data, times_kwargs, ) = self._fetch_recording_segment_info_backend(file, cache, load_time_vector, samples_for_rate_estimation) + BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) recording_segment = NwbRecordingSegment( electrical_series_data=segment_data, @@ -595,12 +596,9 @@ def __init__( # fetch and add main recording properties if use_pynwb: gains, offsets, locations, groups = self._fetch_main_properties_pynwb() - self.extra_requirements.append("pynwb") else: gains, offsets, locations, groups = self._fetch_main_properties_backend() - self.extra_requirements.append("h5py") - if stream_mode is not None: - self.extra_requirements.append(stream_mode) + self.set_channel_gains(gains) self.set_channel_offsets(offsets) if locations is not None: @@ -659,6 +657,19 @@ def __init__( "file": file, } + # Set extra requirements for the extractor, so they can be installed from docker + if use_pynwb: + self.extra_requirements.append("pynwb") + else: + if self.backend == "hdf5": + self.extra_requirements.append("h5py") + if self.backend == "zarr": + self.extra_requirements.append("zarr") + if self.stream_mode == "fsspec": + self.extra_requirements.append("fsspec") + if self.stream_mode == "remfile": + self.extra_requirements.append("remfile") + def _fetch_recording_segment_info_pynwb(self, file, cache, load_time_vector, samples_for_rate_estimation): self._nwbfile = read_nwbfile( backend=self.backend, @@ -1364,8 +1375,371 @@ def get_unit_spike_train( return frames[start_index:end_index].astype("int64", copy=False) +def _find_timeseries_from_backend(group, path="", result=None, backend="hdf5"): + """ + Recursively searches for groups with TimeSeries neurodata_type in hdf5 or zarr object, + and returns a list with their paths. + """ + if backend == "hdf5": + import h5py + + group_class = h5py.Group + else: + import zarr + + group_class = zarr.Group + + if result is None: + result = [] + + for name, value in group.items(): + if isinstance(value, group_class): + current_path = f"{path}/{name}" if path else name + if value.attrs.get("neurodata_type") == "TimeSeries": + result.append(current_path) + _find_timeseries_from_backend(value, current_path, result, backend) + return result + + +class NwbTimeSeriesExtractor(BaseRecording, _BaseNWBExtractor): + """Load a TimeSeries from an NWBFile as a RecordingExtractor. + + Parameters + ---------- + file_path : str | Path | None + Path to NWB file or an s3 URL. Use this parameter to specify the file location + if not using the `file` parameter. + timeseries_path : str | None + The path to the TimeSeries object within the NWB file. This parameter is required + when the NWB file contains multiple TimeSeries objects. The path corresponds to + the location within the NWB file hierarchy, e.g. 'acquisition/MyTimeSeries'. + load_time_vector : bool, default: False + If True, the time vector is loaded into the recording object. Useful when + precise timing information is needed. + samples_for_rate_estimation : int, default: 1000 + The number of timestamp samples used for estimating the sampling rate when + timestamps are used instead of a fixed rate. + stream_mode : Literal["fsspec", "remfile", "zarr"] | None, default: None + Determines the streaming mode for reading the file. + file : BinaryIO | None, default: None + A file-like object representing the NWB file. Use this parameter if you have + an in-memory representation of the NWB file instead of a file path. + cache : bool, default: False + If True, the file is cached locally when using streaming. + stream_cache_path : str | Path | None, default: None + Local path for caching. Only used if cache is True. + storage_options : dict | None, default: None + Additional kwargs (e.g. AWS credentials) passed to zarr.open. Only used with + "zarr" stream_mode. + use_pynwb : bool, default: False + If True, uses pynwb library to read the NWB file. Default False uses h5py/zarr + directly for better performance. + + Returns + ------- + recording : NwbTimeSeriesExtractor + A recording extractor containing the TimeSeries data. + """ + + installation_mesg = "To use the Nwb extractors, install pynwb: \n\n pip install pynwb\n\n" + + def __init__( + self, + file_path: str | Path | None = None, + timeseries_path: str | None = None, + load_time_vector: bool = False, + samples_for_rate_estimation: int = 1_000, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + stream_cache_path: str | Path | None = None, + *, + file: BinaryIO | None = None, + cache: bool = False, + storage_options: dict | None = None, + use_pynwb: bool = False, + ): + if file_path is not None and file is not None: + raise ValueError("Provide either file_path or file, not both") + if file_path is None and file is None: + raise ValueError("Provide either file_path or file") + + self.file_path = file_path + self.stream_mode = stream_mode + self.stream_cache_path = stream_cache_path + self.storage_options = storage_options + self.timeseries_path = timeseries_path + + if self.stream_mode is None and file is None: + self.backend = _get_backend_from_local_file(file_path) + else: + self.backend = "zarr" if self.stream_mode == "zarr" else "hdf5" + + if use_pynwb: + try: + import pynwb + except ImportError: + raise ImportError(self.installation_mesg) + + channel_ids, sampling_frequency, dtype, segment_data, times_kwargs = self._fetch_recording_segment_info( + file, cache, load_time_vector, samples_for_rate_estimation + ) + else: + channel_ids, sampling_frequency, dtype, segment_data, times_kwargs = ( + self._fetch_recording_segment_info_backend(file, cache, load_time_vector, samples_for_rate_estimation) + ) + + BaseRecording.__init__(self, channel_ids=channel_ids, sampling_frequency=sampling_frequency, dtype=dtype) + recording_segment = NwbTimeSeriesSegment( + timeseries_data=segment_data, + times_kwargs=times_kwargs, + ) + self.add_recording_segment(recording_segment) + + if storage_options is not None and stream_mode == "zarr": + warnings.warn( + "The `storage_options` parameter will not be propagated to JSON or pickle files " + "for security reasons, so the extractor will not be JSON/pickle serializable." + ) + self._serializability["json"] = False + self._serializability["pickle"] = False + + self._kwargs = { + "file_path": file_path, + "timeseries_path": self.timeseries_path, + "load_time_vector": load_time_vector, + "samples_for_rate_estimation": samples_for_rate_estimation, + "stream_mode": stream_mode, + "storage_options": storage_options, + "cache": cache, + "stream_cache_path": stream_cache_path, + "file": file, + } + + if use_pynwb: + self.extra_requirements.append("pynwb") + else: + if self.backend == "hdf5": + self.extra_requirements.append("h5py") + if self.backend == "zarr": + self.extra_requirements.append("zarr") + if self.stream_mode == "fsspec": + self.extra_requirements.append("fsspec") + if self.stream_mode == "remfile": + self.extra_requirements.append("remfile") + + def _fetch_recording_segment_info(self, file, cache, load_time_vector, samples_for_rate_estimation): + self._nwbfile = read_nwbfile( + backend=self.backend, + file_path=self.file_path, + file=file, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + storage_options=self.storage_options, + ) + + from pynwb.base import TimeSeries + + time_series_dict: dict[str, TimeSeries] = {} + + for item in self._nwbfile.all_children(): + if isinstance(item, TimeSeries): + time_series_dict[item.data.name.replace("/data", "")[1:]] = item + + if self.timeseries_path is not None: + if self.timeseries_path not in time_series_dict: + raise ValueError(f"TimeSeries {self.timeseries_path} not found in file") + + else: + if len(time_series_dict) == 1: + self.timeseries_path = list(time_series_dict.keys())[0] + else: + raise ValueError( + f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {list(time_series_dict.keys())}" + ) + + timeseries = time_series_dict[self.timeseries_path] + + # Get sampling frequency and timing info + if hasattr(timeseries, "rate") and timeseries.rate is not None: + sampling_frequency = timeseries.rate + t_start = timeseries.starting_time if hasattr(timeseries, "starting_time") else 0 + timestamps = None + elif hasattr(timeseries, "timestamps"): + timestamps = timeseries.timestamps + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + t_start = timestamps[0] + + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=timestamps) + else: + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + + # Create channel IDs based on data shape + data = timeseries.data + if data.ndim == 1: + num_channels = 1 + else: + num_channels = data.shape[1] + channel_ids = np.arange(num_channels) + dtype = data.dtype + + return channel_ids, sampling_frequency, dtype, data, times_kwargs + + def _fetch_recording_segment_info_backend(self, file, cache, load_time_vector, samples_for_rate_estimation): + open_file = read_file_from_backend( + file_path=self.file_path, + file=file, + stream_mode=self.stream_mode, + cache=cache, + stream_cache_path=self.stream_cache_path, + storage_options=self.storage_options, + ) + + # If timeseries_path not provided, find all TimeSeries objects + if self.timeseries_path is None: + available_timeseries = _find_timeseries_from_backend(open_file, backend=self.backend) + if len(available_timeseries) == 1: + self.timeseries_path = available_timeseries[0] + else: + raise ValueError( + f"Multiple TimeSeries found! Specify 'timeseries_path'. Options: {available_timeseries}" + ) + + # Get TimeSeries object + try: + timeseries = open_file[self.timeseries_path] + except KeyError: + available_timeseries = _find_timeseries_from_backend(open_file, backend=self.backend) + raise ValueError(f"{self.timeseries_path} not found! Available options: {available_timeseries}") + + # Get timing information + if "starting_time" in timeseries: + t_start = timeseries["starting_time"][()] + sampling_frequency = timeseries["starting_time"].attrs["rate"] + timestamps = None + elif "timestamps" in timeseries: + timestamps = timeseries["timestamps"][:] + sampling_frequency = 1.0 / np.median(np.diff(timestamps[:samples_for_rate_estimation])) + t_start = timestamps[0] + else: + raise ValueError("TimeSeries must have either starting_time or timestamps") + + if load_time_vector and timestamps is not None: + times_kwargs = dict(time_vector=timestamps) + else: + times_kwargs = dict(sampling_frequency=sampling_frequency, t_start=t_start) + + # Create channel IDs based on data shape + data = timeseries["data"] + if data.ndim == 1: + num_channels = 1 + else: + num_channels = data.shape[1] + channel_ids = np.arange(num_channels) + dtype = data.dtype + + # Store for later use + self.timeseries = timeseries + self._file = open_file + + return channel_ids, sampling_frequency, dtype, data, times_kwargs + + @staticmethod + def fetch_available_timeseries_paths( + file_path: str | Path, + stream_mode: Optional[Literal["fsspec", "remfile", "zarr"]] = None, + storage_options: dict | None = None, + ) -> list[str]: + """ + Get paths to all TimeSeries objects in a neurodata file. + + Parameters + ---------- + file_path : str | Path + Path to the NWB file. + stream_mode : str | None + Streaming mode for reading remote files. + storage_options : dict | None + Additional options for zarr storage. + + Returns + ------- + list[str] + List of paths to TimeSeries objects. + """ + if stream_mode is None: + backend = _get_backend_from_local_file(file_path) + else: + backend = "zarr" if stream_mode == "zarr" else "hdf5" + + file_handle = read_file_from_backend( + file_path=file_path, + stream_mode=stream_mode, + storage_options=storage_options, + ) + + timeseries_paths = _find_timeseries_from_backend( + file_handle, + backend=backend, + ) + return timeseries_paths + + +class NwbTimeSeriesSegment(BaseRecordingSegment): + """Segment class for NwbTimeSeriesExtractor.""" + + def __init__(self, timeseries_data, times_kwargs): + BaseRecordingSegment.__init__(self, **times_kwargs) + self.timeseries_data = timeseries_data + self._num_samples = self.timeseries_data.shape[0] + + def get_num_samples(self): + """Returns the number of samples in this signal block.""" + return self._num_samples + + def get_traces(self, start_frame, end_frame, channel_indices): + """ + Extract traces from the TimeSeries between start_frame and end_frame for specified channels. + + Parameters + ---------- + start_frame : int + Start frame of the slice to extract. + end_frame : int + End frame of the slice to extract. + channel_indices : array-like + Channel indices to extract. + + Returns + ------- + traces : np.ndarray + Extracted traces of shape (num_frames, num_channels) + """ + if self.timeseries_data.ndim == 1: + traces = self.timeseries_data[start_frame:end_frame][:, np.newaxis] + elif isinstance(channel_indices, slice): + traces = self.timeseries_data[start_frame:end_frame, channel_indices] + else: + # channel_indices is np.ndarray + if np.array(channel_indices).size > 1 and np.any(np.diff(channel_indices) < 0): + # get around h5py constraint that it does not allow datasets + # to be indexed out of order + sorted_channel_indices = np.sort(channel_indices) + resorted_indices = np.array([list(sorted_channel_indices).index(ch) for ch in channel_indices]) + recordings = self.timeseries_data[start_frame:end_frame, sorted_channel_indices] + traces = recordings[:, resorted_indices] + else: + traces = self.timeseries_data[start_frame:end_frame, channel_indices] + + return traces + + +# Create the reading function + + read_nwb_recording = define_function_from_class(source_class=NwbRecordingExtractor, name="read_nwb_recording") read_nwb_sorting = define_function_from_class(source_class=NwbSortingExtractor, name="read_nwb_sorting") +read_nwb_timeseries = define_function_from_class(source_class=NwbTimeSeriesExtractor, name="read_nwb_timeseries") def read_nwb(file_path, load_recording=True, load_sorting=False, electrical_series_path=None): diff --git a/src/spikeinterface/extractors/tests/test_nwbextractors.py b/src/spikeinterface/extractors/tests/test_nwbextractors.py index b698f7dfe1..4bfc43dd69 100644 --- a/src/spikeinterface/extractors/tests/test_nwbextractors.py +++ b/src/spikeinterface/extractors/tests/test_nwbextractors.py @@ -5,7 +5,7 @@ import pytest import numpy as np -from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor +from spikeinterface.extractors import NwbRecordingExtractor, NwbSortingExtractor, NwbTimeSeriesExtractor from spikeinterface.extractors.tests.common_tests import RecordingCommonTestSuite, SortingCommonTestSuite from spikeinterface.core.testing import check_recordings_equal @@ -606,6 +606,196 @@ def test_multiple_unit_tables(tmp_path, use_pynwb): assert "a_second_property" in sorting_extractor_processing.get_property_keys() +def nwbfile_with_timeseries(): + from pynwb.testing.mock.file import mock_NWBFile + from pynwb.base import TimeSeries + + nwbfile = mock_NWBFile() + + # Add regular TimeSeries with rate + num_frames = 10 + num_channels = 5 + rng = np.random.default_rng(0) + data = rng.random(size=(num_frames, num_channels)) + rate = 30_000.0 + starting_time = 0.0 + + timeseries = TimeSeries(name="TimeSeries", data=data, rate=rate, starting_time=starting_time, unit="volts") + nwbfile.add_acquisition(timeseries) + + # Add TimeSeries with timestamps + timestamps = np.arange(num_frames) / rate + timestamps[2] = 0 + timeseries_with_timestamps = TimeSeries( + name="TimeSeriesWithTimestamps", data=data, timestamps=timestamps, unit="volts" + ) + nwbfile.add_acquisition(timeseries_with_timestamps) + + # Add single channel TimeSeries + single_channel_data = rng.random(size=(num_frames,)) + single_channel_series = TimeSeries(name="SingleChannelSeries", data=single_channel_data, rate=rate, unit="volts") + nwbfile.add_acquisition(single_channel_series) + + # Add TimeSeries in processing module + processing = nwbfile.create_processing_module(name="test_module", description="test module") + proc_timeseries = TimeSeries(name="ProcessingTimeSeries", data=data, rate=rate, unit="volts") + processing.add(proc_timeseries) + + return nwbfile + + +def _generate_nwbfile_with_time_series(backend, file_path): + from pynwb import NWBHDF5IO + from hdmf_zarr import NWBZarrIO + + nwbfile = nwbfile_with_timeseries() + if backend == "hdf5": + io_class = NWBHDF5IO + elif backend == "zarr": + io_class = NWBZarrIO + with io_class(str(file_path), mode="w") as io: + io.write(nwbfile) + return file_path, nwbfile + + +@pytest.fixture(scope="module", params=["hdf5", "zarr"]) +def generate_nwbfile_with_time_series(request, tmp_path_factory): + backend = request.param + nwbfile_path = tmp_path_factory.mktemp("nwb_tests_directory") / "test.nwb" + nwbfile_path, nwbfile = _generate_nwbfile_with_time_series(backend, nwbfile_path) + return nwbfile_path, nwbfile + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_timeseries_basic_functionality(generate_nwbfile_with_time_series, use_pynwb): + """Test basic functionality with a regular TimeSeries.""" + path_to_nwbfile, nwbfile = generate_nwbfile_with_time_series + + recording = NwbTimeSeriesExtractor(path_to_nwbfile, timeseries_path="acquisition/TimeSeries", use_pynwb=use_pynwb) + + timeseries = nwbfile.acquisition["TimeSeries"] + + # Check data matches + assert np.array_equal(recording.get_traces(), timeseries.data[:]) + + # Check sampling frequency matches + assert recording.get_sampling_frequency() == timeseries.rate + + # Check number of channels matches + assert recording.get_num_channels() == timeseries.data.shape[1] + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_timeseries_with_timestamps(generate_nwbfile_with_time_series, use_pynwb): + """Test functionality with a TimeSeries using timestamps.""" + path_to_nwbfile, nwbfile = generate_nwbfile_with_time_series + + recording = NwbTimeSeriesExtractor( + path_to_nwbfile, timeseries_path="acquisition/TimeSeriesWithTimestamps", use_pynwb=use_pynwb + ) + + timeseries = nwbfile.acquisition["TimeSeriesWithTimestamps"] + + # Check data matches + assert np.array_equal(recording.get_traces(), timeseries.data[:]) + + # Check sampling frequency is correctly estimated + expected_sampling_frequency = 1.0 / np.median(np.diff(timeseries.timestamps[:1000])) + assert np.isclose(recording.get_sampling_frequency(), expected_sampling_frequency) + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_time_series_load_time_vector(generate_nwbfile_with_time_series, use_pynwb): + """Test loading time vector from TimeSeries with timestamps.""" + path_to_nwbfile, nwbfile = generate_nwbfile_with_time_series + + recording = NwbTimeSeriesExtractor( + path_to_nwbfile, + timeseries_path="acquisition/TimeSeriesWithTimestamps", + load_time_vector=True, + use_pynwb=use_pynwb, + ) + + timeseries = nwbfile.acquisition["TimeSeriesWithTimestamps"] + + times = recording.get_times() + + np.testing.assert_almost_equal(times, timeseries.timestamps[:]) + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_single_channel_timeseries(generate_nwbfile_with_time_series, use_pynwb): + """Test functionality with a single channel TimeSeries.""" + path_to_nwbfile, nwbfile = generate_nwbfile_with_time_series + + recording = NwbTimeSeriesExtractor( + path_to_nwbfile, timeseries_path="acquisition/SingleChannelSeries", use_pynwb=use_pynwb + ) + + timeseries = nwbfile.acquisition["SingleChannelSeries"] + + # Check data matches + assert np.array_equal(recording.get_traces().squeeze(), timeseries.data[:]) + + # Check it's treated as a single channel + assert recording.get_num_channels() == 1 + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_processing_module_timeseries(generate_nwbfile_with_time_series, use_pynwb): + """Test accessing TimeSeries from a processing module.""" + path_to_nwbfile, nwbfile = generate_nwbfile_with_time_series + + recording = NwbTimeSeriesExtractor( + path_to_nwbfile, timeseries_path="processing/test_module/ProcessingTimeSeries", use_pynwb=use_pynwb + ) + + timeseries = nwbfile.processing["test_module"]["ProcessingTimeSeries"] + + # Check data matches + assert np.array_equal(recording.get_traces(), timeseries.data[:]) + + +def test_fetch_available_timeseries_paths(generate_nwbfile_with_time_series): + """Test the fetch_available_timeseries_paths static method.""" + path_to_nwbfile, _ = generate_nwbfile_with_time_series + + available_timeseries = NwbTimeSeriesExtractor.fetch_available_timeseries_paths(file_path=path_to_nwbfile) + + expected_paths = [ + "acquisition/TimeSeries", + "acquisition/TimeSeriesWithTimestamps", + "acquisition/SingleChannelSeries", + "processing/test_module/ProcessingTimeSeries", + ] + + assert sorted(available_timeseries) == sorted(expected_paths) + + +@pytest.mark.parametrize("use_pynwb", [True, False]) +def test_error_with_wrong_timeseries_path(generate_nwbfile_with_time_series, use_pynwb): + """Test that appropriate error is raised for non-existent TimeSeries.""" + path_to_nwbfile, _ = generate_nwbfile_with_time_series + + with pytest.raises(ValueError): + _ = NwbTimeSeriesExtractor( + path_to_nwbfile, timeseries_path="acquisition/NonExistentTimeSeries", use_pynwb=use_pynwb + ) + + +def test_time_series_recording_equality_with_pynwb_and_backend(generate_nwbfile_with_time_series): + """Test that pynwb and backend (h5py/zarr) modes produce identical results.""" + path_to_nwbfile, _ = generate_nwbfile_with_time_series + + recording_backend = NwbTimeSeriesExtractor( + path_to_nwbfile, timeseries_path="acquisition/TimeSeries", use_pynwb=False + ) + + recording_pynwb = NwbTimeSeriesExtractor(path_to_nwbfile, timeseries_path="acquisition/TimeSeries", use_pynwb=True) + + check_recordings_equal(recording_backend, recording_pynwb) + + if __name__ == "__main__": tmp_path = Path("tmp") if tmp_path.is_dir():