diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 33d384b2be..729830f3a5 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -13,7 +13,7 @@ from spikeinterface.preprocessing import common_reference, whiten, bandpass_filter, correct_motion from spikeinterface.sortingcomponents.tools import ( cache_preprocessing, - get_prototype_and_waveforms, + get_prototype_and_waveforms_from_recording, get_shuffled_recording_slices, ) from spikeinterface.core.basesorting import minimum_spike_dtype @@ -206,7 +206,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "noise_levels.npy", noise_levels) if params["matched_filtering"]: - prototype, waveforms = get_prototype_and_waveforms( + prototype, waveforms = get_prototype_and_waveforms_from_recording( recording_w, ms_before=ms_before, ms_after=ms_after, @@ -222,11 +222,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(clustering_folder / "prototype.npy", prototype) if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs) peaks = detect_peaks(recording_w, "matched_filtering", **detection_params, **job_kwargs) else: waveforms = None if skip_peaks: detection_params["skip_after_n_peaks"] = n_peaks + detection_params["recording_slices"] = get_shuffled_recording_slices(recording_w, seed=params["seed"], **job_kwargs) peaks = detect_peaks(recording_w, "locally_exclusive", **detection_params, **job_kwargs) if verbose: diff --git a/src/spikeinterface/sortingcomponents/peak_localization.py b/src/spikeinterface/sortingcomponents/peak_localization.py index 85ffab1bf5..4024ce81d3 100644 --- a/src/spikeinterface/sortingcomponents/peak_localization.py +++ b/src/spikeinterface/sortingcomponents/peak_localization.py @@ -33,7 +33,7 @@ get_grid_convolution_templates_and_weights, ) -from .tools import get_prototype_and_waveforms +from .tools import get_prototype_and_waveforms_from_peaks def get_localization_pipeline_nodes( @@ -73,7 +73,7 @@ def get_localization_pipeline_nodes( assert isinstance(peak_source, (PeakRetriever, SpikeRetriever)) # extract prototypes silently job_kwargs["progress_bar"] = False - method_kwargs["prototype"] = get_prototype_and_waveforms( + method_kwargs["prototype"] = get_prototype_and_waveforms_from_peaks( recording, peaks=peak_source.peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) extract_dense_waveforms = ExtractDenseWaveforms( diff --git a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py index b6ef240f80..284bbb9ac0 100644 --- a/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py +++ b/src/spikeinterface/sortingcomponents/tests/test_peak_detection.py @@ -22,7 +22,7 @@ ) from spikeinterface.core.node_pipeline import run_node_pipeline -from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms +from spikeinterface.sortingcomponents.tools import get_prototype_and_waveforms_from_peaks from spikeinterface.sortingcomponents.tests.common import make_dataset @@ -314,7 +314,7 @@ def test_detect_peaks_locally_exclusive_matched_filtering(recording, job_kwargs) ms_before = 1.0 ms_after = 1.0 - prototype = get_prototype_and_waveforms( + prototype = get_prototype_and_waveforms_from_peaks( recording, peaks=peaks_by_channel_np, ms_before=ms_before, ms_after=ms_after, **job_kwargs ) diff --git a/src/spikeinterface/sortingcomponents/tools.py b/src/spikeinterface/sortingcomponents/tools.py index 3c968f5da4..fb3dae669d 100644 --- a/src/spikeinterface/sortingcomponents/tools.py +++ b/src/spikeinterface/sortingcomponents/tools.py @@ -69,29 +69,24 @@ def extract_waveform_at_max_channel(rec, peaks, ms_before=0.5, ms_after=1.5, **j return all_wfs -def get_prototype_and_waveforms( - recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, return_waveforms=False, **all_kwargs -): +def get_prototype_and_waveforms_from_peaks(recording, peaks, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs): """ - Function to extract a prototype waveform from a peak list or from a peak detection. Note that in case - of a peak detection, the detection stops as soon as n_peaks are detected. + Function to extract a prototype waveform from peaks. Parameters ---------- recording : Recording The recording object containing the data. - n_peaks : int, optional - Number of peaks to consider, by default 5000. peaks : numpy.array, optional Array of peaks, if None, peaks will be detected, by default None. + n_peaks : int, optional + Number of peaks to consider, by default 5000. ms_before : float, optional Time in milliseconds before the peak to extract the waveform, by default 0.5. ms_after : float, optional Time in milliseconds after the peak to extract the waveform, by default 0.5. seed : int or None, optional Seed for random number generator, by default None. - return_waveforms : bool, optional - Whether to return the waveforms along with the prototype, by default False. **all_kwargs : dict Additional keyword arguments for peak detection and job kwargs. @@ -99,60 +94,134 @@ def get_prototype_and_waveforms( ------- prototype : numpy.array The prototype waveform. - waveforms : numpy.array, optional - The extracted waveforms, returned if return_waveforms is True. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. """ + from spikeinterface.sortingcomponents.peak_selection import select_peaks + _, job_kwargs = split_job_kwargs(all_kwargs) - seed = seed if seed else None - rng = np.random.default_rng(seed=seed) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + + few_peaks = select_peaks( + peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed + ) + waveforms = extract_waveform_at_max_channel( + recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs + ) + with np.errstate(divide="ignore", invalid="ignore"): + prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) + + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms_from_recording(recording, n_peaks=5000, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs): + """ + Function to extract a prototype waveform from peaks detected on the fly. + + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array + The extracted waveforms for the selected peaks. + peaks : numpy.array + The selected peaks used to extract waveforms. + """ + from spikeinterface.sortingcomponents.peak_detection import detect_peaks + from spikeinterface.core.node_pipeline import ExtractSparseWaveforms detection_kwargs, job_kwargs = split_job_kwargs(all_kwargs) + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) - nafter = int(ms_after * recording.sampling_frequency / 1000.0) + node = ExtractSparseWaveforms( + recording, + parents=None, + return_output=True, + ms_before=ms_before, + ms_after=ms_after, + radius_um=0, + ) + + pipeline_nodes = [node] - if peaks is None: - from spikeinterface.sortingcomponents.peak_detection import detect_peaks - from spikeinterface.core.node_pipeline import ExtractSparseWaveforms - - node = ExtractSparseWaveforms( - recording, - parents=None, - return_output=True, - ms_before=ms_before, - ms_after=ms_after, - radius_um=0, - ) - pipeline_nodes = [node] - - recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) - - res = detect_peaks( - recording, - pipeline_nodes=pipeline_nodes, - skip_after_n_peaks=n_peaks, - recording_slices=recording_slices, - **detection_kwargs, - **job_kwargs, - ) - waveforms = res[1] - else: - from spikeinterface.sortingcomponents.peak_selection import select_peaks + recording_slices = get_shuffled_recording_slices(recording, seed=seed, **job_kwargs) - few_peaks = select_peaks( - peaks, recording=recording, method="uniform", n_peaks=n_peaks, margin=(nbefore, nafter), seed=seed - ) - waveforms = extract_waveform_at_max_channel( - recording, few_peaks, ms_before=ms_before, ms_after=ms_after, **job_kwargs - ) + res = detect_peaks( + recording, + pipeline_nodes=pipeline_nodes, + skip_after_n_peaks=n_peaks, + recording_slices=recording_slices, + **detection_kwargs, + **job_kwargs, + ) + + rng = np.random.RandomState(seed) + indices = rng.permutation(np.arange(len(res[0]))) + + few_peaks = res[0][indices[:n_peaks]] + waveforms = res[1][indices[:n_peaks]] with np.errstate(divide="ignore", invalid="ignore"): prototype = np.nanmedian(waveforms[:, :, 0] / (np.abs(waveforms[:, nbefore, 0][:, np.newaxis])), axis=0) - if not return_waveforms: - return prototype - else: - return prototype, waveforms[:, :, 0] + return prototype, waveforms[:, :, 0], few_peaks + + +def get_prototype_and_waveforms( + recording, n_peaks=5000, peaks=None, ms_before=0.5, ms_after=0.5, seed=None, **all_kwargs +): + """ + Function to extract a prototype waveform either from peaks or from a peak detection. Note that in case + of a peak detection, the detection stops as soon as n_peaks are detected. + Parameters + ---------- + recording : Recording + The recording object containing the data. + n_peaks : int, optional + Number of peaks to consider, by default 5000. + peaks : numpy.array, optional + Array of peaks, if None, peaks will be detected, by default None. + ms_before : float, optional + Time in milliseconds before the peak to extract the waveform, by default 0.5. + ms_after : float, optional + Time in milliseconds after the peak to extract the waveform, by default 0.5. + seed : int or None, optional + Seed for random number generator, by default None. + return_waveforms : bool, optional + Whether to return the waveforms along with the prototype, by default False. + **all_kwargs : dict + Additional keyword arguments for peak detection and job kwargs. + + Returns + ------- + prototype : numpy.array + The prototype waveform. + waveforms : numpy.array, optional + The extracted waveforms, returned if return_waveforms is True. + """ + if peaks is None: + return get_prototype_and_waveforms_from_peaks(recording, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs) + else: + return get_prototype_and_waveforms_from_recording(recording, peaks, n_peaks, ms_before=ms_before, ms_after=ms_after, seed=seed, **all_kwargs) def check_probe_for_drift_correction(recording, dist_x_max=60): num_channels = recording.get_num_channels()