diff --git a/doc/get_started/install_sorters.rst b/doc/get_started/install_sorters.rst index 12233784fd..61abc0f129 100644 --- a/doc/get_started/install_sorters.rst +++ b/doc/get_started/install_sorters.rst @@ -43,9 +43,10 @@ Herdingspikes2 * Python + C++ * Url: https://github.com/mhhennig/hs2 -* Authors: Matthias Hennig, Jano Horvath,Cole Hurwitz, Oliver Muthmann, Albert Puente Encinas, Martino Sorbaro, Cesar Juarez Ramirez, Raimon Wintzer: GUI and visualisation +* Authors: Matthias Hennig, Jano Horvath, Cole Hurwitz, Rickey K. Liang, Oliver Muthmann, Albert Puente Encinas, Martino Sorbaro, Cesar Juarez Ramirez, Raimon Wintzer * Installation:: + pip install cython numpy pip install herdingspikes diff --git a/src/spikeinterface/comparison/multicomparisons.py b/src/spikeinterface/comparison/multicomparisons.py index 35c298d4ac..499004e32e 100644 --- a/src/spikeinterface/comparison/multicomparisons.py +++ b/src/spikeinterface/comparison/multicomparisons.py @@ -186,7 +186,7 @@ def save_to_folder(self, save_folder): warnings.warn( "save_to_folder() is deprecated. " "You should save and load the multi sorting comparison object using pickle." - "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb'))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", DeprecationWarning, stacklevel=2, ) @@ -218,7 +218,7 @@ def load_from_folder(folder_path): warnings.warn( "load_from_folder() is deprecated. " "You should save and load the multi sorting comparison object using pickle." - "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb')))))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", + "\n>>> pickle.dump(mcmp, open('mcmp.pkl', 'wb'))\n>>> mcmp_loaded = pickle.load(open('mcmp.pkl', 'rb'))", DeprecationWarning, stacklevel=2, ) diff --git a/src/spikeinterface/sorters/external/herdingspikes.py b/src/spikeinterface/sorters/external/herdingspikes.py index 9b9b751534..a84c05c240 100644 --- a/src/spikeinterface/sorters/external/herdingspikes.py +++ b/src/spikeinterface/sorters/external/herdingspikes.py @@ -1,13 +1,10 @@ from __future__ import annotations from pathlib import Path -import copy from packaging import version from ..basesorter import BaseSorter -from spikeinterface.core.old_api_utils import NewToOldRecording -from spikeinterface.core import load_extractor from spikeinterface.extractors import HerdingspikesSortingExtractor @@ -19,90 +16,72 @@ class HerdingspikesSorter(BaseSorter): requires_locations = True compatible_with_parallel = {"loky": True, "multiprocessing": True, "threading": False} _default_params = { - # core params - "clustering_bandwidth": 5.5, # 5.0, - "clustering_alpha": 5.5, # 5.0, + "chunk_size": None, + "rescale": True, + "rescale_value": -1280.0, + "common_reference": "median", + "spike_duration": 1.0, + "amp_avg_duration": 0.4, + "threshold": 8.0, + "min_avg_amp": 1.0, + "AHP_thr": 0.0, + "neighbor_radius": 90.0, + "inner_radius": 70.0, + "peak_jitter": 0.25, + "rise_duration": 0.26, + "decay_filtering": False, + "decay_ratio": 1.0, + "localize": True, + "save_shape": True, + "out_file": "HS2_detected", + "left_cutout_time": 0.3, + "right_cutout_time": 1.8, + "verbose": True, + "clustering_bandwidth": 4.0, + "clustering_alpha": 4.5, "clustering_n_jobs": -1, "clustering_bin_seeding": True, - "clustering_min_bin_freq": 16, # 10, + "clustering_min_bin_freq": 4, "clustering_subset": None, - "left_cutout_time": 0.3, # 0.2, - "right_cutout_time": 1.8, # 0.8, - "detect_threshold": 20, # 24, #15, - # extra probe params - "probe_masked_channels": [], - "probe_inner_radius": 70, - "probe_neighbor_radius": 90, - "probe_event_length": 0.26, - "probe_peak_jitter": 0.2, - # extra detection params - "t_inc": 100000, - "num_com_centers": 1, - "maa": 12, - "ahpthr": 11, - "out_file_name": "HS2_detected", - "decay_filtering": False, - "save_all": False, - "amp_evaluation_time": 0.4, # 0.14, - "spk_evaluation_time": 1.0, - # extra pca params "pca_ncomponents": 2, "pca_whiten": True, - # bandpass filter - "freq_min": 300.0, - "freq_max": 6000.0, - "filter": True, - # rescale traces - "pre_scale": True, - "pre_scale_value": 20.0, - # remove duplicates (based on spk_evaluation_time) - "filter_duplicates": True, } _params_description = { - # core params - "clustering_bandwidth": "Meanshift bandwidth, average spatial extent of spike clusters (um)", - "clustering_alpha": "Scalar for the waveform PC features when clustering.", - "clustering_n_jobs": "Number of cores to use for clustering.", - "clustering_bin_seeding": "Enable clustering bin seeding.", - "clustering_min_bin_freq": "Minimum spikes per bin for bin seeding.", - "clustering_subset": "Number of spikes used to build clusters. All by default.", - "left_cutout_time": "Cutout size before peak (ms).", - "right_cutout_time": "Cutout size after peak (ms).", - "detect_threshold": "Detection threshold", - # extra probe params - "probe_masked_channels": "Masked channels", - "probe_inner_radius": "Radius of area around probe channel for localization", - "probe_neighbor_radius": "Radius of area around probe channel for neighbor classification.", - "probe_event_length": "Duration of a spike event (ms)", - "probe_peak_jitter": "Maximum peak misalignment for synchronous spike (ms)", - # extra detection params - "t_inc": "Number of samples per chunk during detection.", - "num_com_centers": "Number of centroids to average when localizing.", - "maa": "Minimum summed spike amplitude for spike acceptance.", - "ahpthr": "Requires magnitude of spike rebound for acceptance", - "out_file_name": "File name for storage of unclustered detected spikes", - "decay_filtering": "Experimental: Set to True at your risk", - "save_all": "Save all working files after sorting (slow)", - "amp_evaluation_time": "Amplitude evaluation time (ms)", - "spk_evaluation_time": "Spike evaluation time (ms)", - # extra pca params - "pca_ncomponents": "Number of principal components to use when clustering", - "pca_whiten": "If true, whiten data for pca", - # bandpass filter - "freq_min": "High-pass filter cutoff frequency", - "freq_max": "Low-pass filter cutoff frequency", - "filter": "Enable or disable filter", - # rescale traces - "pre_scale": "Scales recording traces to optimize HerdingSpikes performance", - "pre_scale_value": "Scale to apply in case of pre-scaling of traces", - # remove duplicates (based on spk_evaluation_time) - "filter_duplicates": "Remove spike duplicates (based on spk_evaluation_time)", + "localize": "Perform spike localization. (`bool`, `True`)", + "save_shape": "Save spike shape. (`bool`, `True`)", + "out_file": "Path and filename to store detection and clustering results. (`str`, `HS2_detected`)", + "verbose": "Print progress information. (`bool`, `True`)", + "chunk_size": " Number of samples per chunk during detection. If `None`, a suitable value will be estimated. (`int`, `None`)", + "common_reference": "Method for common reference filtering, can be `average` or `median` (`str`, `median`)", + "rescale": "Automatically re-scale the data. (`bool`, `True`)", + "rescale_value": "Factor by which data is re-scaled. (`float`, `-1280.0`)", + "threshold": "Spike detection threshold. (`float`, `8.0`)", + "spike_duration": "Maximum duration over which a spike is evaluated (ms). (`float`, `1.0`)", + "amp_avg_duration": "Maximum duration over which the spike amplitude is evaluated (ms). (`float`, `0.4`)", + "min_avg_amp": "Minimum integrated spike amplitude for a true spike. (`float`, `1.0`)", + "AHP_thr": "Minimum value of the spike repolarisation for a true spike. (`float`, `0.0`)", + "neighbor_radius": "Radius of area around probe channel for neighbor classification (microns). (`float`, `90.0`)", + "inner_radius": "Radius of area around probe channel for spike localisation (microns). (`float`, `70.0`)", + "peak_jitter": "Maximum peak misalignment for synchronous spike (ms). (`float`, `0.25`)", + "rise_duration": "Maximum spike rise time, in milliseconds. (`float`, `0.26`)", + "decay_filtering": "Exclude duplicate spikes based on spatial decay pattern, experimental. (`bool`,`False`)", + "decay_ratio": "Spatial decay rate for `decay_filtering`. (`float`,`1.0`)", + "left_cutout_time": "Length of cutout before peak (ms). (`float`, `0.3`)", + "right_cutout_time": "Length of cutout after peak (ms). (`float`, `1.8`)", + "pca_ncomponents": "Number of principal components to use when clustering. (`int`, `2`)", + "pca_whiten": "If `True`, whiten data for PCA. (`bool`, `True`)", + "clustering_bandwidth": "Meanshift bandwidth, average spatial extent of spike clusters (microns). (`float`, `4.0`)", + "clustering_alpha": "Scalar for the waveform PC features when clustering. (`float`, `4.5`)", + "clustering_n_jobs": "Number of cores to use for clustering, use `-1` for all available cores. (`int`, `-1`)", + "clustering_bin_seeding": "Enable clustering bin seeding. (`bool`, `True`)", + "clustering_min_bin_freq": "Minimum spikes per bin for bin seeding. (`int`, `4`)", + "clustering_subset": "Number of spikes used to build clusters. All by default. (`int`, `None`)", } - sorter_description = """Herding Spikes is a density-based spike sorter designed for high-density retinal recordings. + sorter_description = """Herding Spikes is a density-based spike sorter designed for large-scale high-density recordings. It uses both PCA features and an estimate of the spike location to cluster different units. - For more information see https://doi.org/10.1016/j.jneumeth.2016.06.006""" + For more information see https://www.sciencedirect.com/science/article/pii/S221112471730236X""" installation_mesg = """\nTo use HerdingSpikes run:\n >>> pip install herdingspikes @@ -130,100 +109,50 @@ def get_sorter_version(cls): @classmethod def _check_apply_filter_in_params(cls, params): - return params["filter"] + return False @classmethod def _setup_recording(cls, recording, sorter_output_folder, params, verbose): - # nothing to copy inside the folder : Herdingspikes used natively spikeinterface + # nothing to copy inside the folder : Herdingspikes uses spikeinterface natively pass @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): import herdingspikes as hs - from spikeinterface.preprocessing import bandpass_filter, normalize_by_quantile hs_version = version.parse(hs.__version__) - if hs_version >= version.parse("0.3.99"): - new_api = True + if hs_version >= version.parse("0.4.001"): + lightning_api = True else: - new_api = False + lightning_api = False + + assert ( + lightning_api + ), "HerdingSpikes version <0.4.001 is no longer supported. run:\n>>> pip install --upgrade herdingspikes" recording = cls.load_recording_from_folder(sorter_output_folder.parent, with_warnings=False) + sorted_file = str(sorter_output_folder / "HS2_sorted.hdf5") + params["out_file"] = str(sorter_output_folder / "HS2_detected") p = params - # Bandpass filter - if p["filter"] and p["freq_min"] is not None and p["freq_max"] is not None: - recording = bandpass_filter(recording=recording, freq_min=p["freq_min"], freq_max=p["freq_max"]) - - if p["pre_scale"]: - recording = normalize_by_quantile( - recording=recording, scale=p["pre_scale_value"], median=0.0, q1=0.05, q2=0.95 - ) - - if new_api: - recording_to_hs = recording - else: - print( - "herdingspikes version<0.3.99 uses the OLD spikeextractors with NewToOldRecording.\n" - "Consider updating herdingspikes (pip install herdingspikes>=0.3.99)" - ) - recording_to_hs = NewToOldRecording(recording) - - # this should have its name changed - Probe = hs.probe.RecordingExtractor( - recording_to_hs, - masked_channels=p["probe_masked_channels"], - inner_radius=p["probe_inner_radius"], - neighbor_radius=p["probe_neighbor_radius"], - event_length=p["probe_event_length"], - peak_jitter=p["probe_peak_jitter"], - ) - - H = hs.HSDetection( - Probe, - file_directory_name=str(sorter_output_folder), - left_cutout_time=p["left_cutout_time"], - right_cutout_time=p["right_cutout_time"], - threshold=p["detect_threshold"], - to_localize=True, - num_com_centers=p["num_com_centers"], - maa=p["maa"], - ahpthr=p["ahpthr"], - out_file_name=p["out_file_name"], - decay_filtering=p["decay_filtering"], - save_all=p["save_all"], - amp_evaluation_time=p["amp_evaluation_time"], - spk_evaluation_time=p["spk_evaluation_time"], + det = hs.HSDetectionLightning(recording, p) + det.DetectFromRaw() + C = hs.HSClustering(det) + C.ShapePCA() + C.CombinedClustering( + alpha=p["clustering_alpha"], + cluster_subset=p["clustering_subset"], + bandwidth=p["clustering_bandwidth"], + bin_seeding=p["clustering_bin_seeding"], + min_bin_freq=p["clustering_min_bin_freq"], + n_jobs=p["clustering_n_jobs"], ) - H.DetectFromRaw(load=True, tInc=int(p["t_inc"])) - - sorted_file = str(sorter_output_folder / "HS2_sorted.hdf5") - if not H.spikes.empty: - C = hs.HSClustering(H) - C.ShapePCA(pca_ncomponents=p["pca_ncomponents"], pca_whiten=p["pca_whiten"]) - C.CombinedClustering( - alpha=p["clustering_alpha"], - cluster_subset=p["clustering_subset"], - bandwidth=p["clustering_bandwidth"], - bin_seeding=p["clustering_bin_seeding"], - n_jobs=p["clustering_n_jobs"], - min_bin_freq=p["clustering_min_bin_freq"], - ) - else: - C = hs.HSClustering(H) - - if p["filter_duplicates"]: - uids = C.spikes.cl.unique() - for u in uids: - s = C.spikes[C.spikes.cl == u].t.diff() < p["spk_evaluation_time"] / 1000 * Probe.fps - C.spikes = C.spikes.drop(s.index[s]) - if verbose: print("Saving to", sorted_file) - C.SaveHDF5(sorted_file, sampling=Probe.fps) + C.SaveHDF5(sorted_file, sampling=recording.get_sampling_frequency()) @classmethod def _get_result_from_folder(cls, sorter_output_folder):