From 530864bceda4b436f5cf16fd5258efd19ccaa76d Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:33:03 +0000 Subject: [PATCH 01/13] Replace qm_params with metric_params --- .../qualitymetrics/pca_metrics.py | 36 ++++++++++++------- .../quality_metric_calculator.py | 30 ++++++++++------ .../tests/test_metrics_functions.py | 10 +++--- 3 files changed, 47 insertions(+), 29 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index 4c68dfea59..b4952bfe6d 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -6,6 +6,7 @@ from copy import deepcopy import platform from tqdm.auto import tqdm +from warnings import warn import numpy as np @@ -52,6 +53,7 @@ def get_quality_pca_metric_list(): def compute_pc_metrics( sorting_analyzer, metric_names=None, + metric_params=None, qm_params=None, unit_ids=None, seed=None, @@ -70,7 +72,7 @@ def compute_pc_metrics( metric_names : list of str, default: None The list of PC metrics to compute. If not provided, defaults to all PC metrics. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for each PC metric function. unit_ids : list of int or None List of unit ids to compute metrics for. @@ -86,6 +88,14 @@ def compute_pc_metrics( pc_metrics : dict The computed PC metrics. """ + + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" @@ -93,8 +103,8 @@ def compute_pc_metrics( if metric_names is None: metric_names = _possible_pc_metric_names.copy() - if qm_params is None: - qm_params = _default_params + if metric_params is None: + metric_params = _default_params extremum_channels = get_template_extremum_channel(sorting_analyzer) @@ -147,7 +157,7 @@ def compute_pc_metrics( pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices] pcs_flat = pcs.reshape(pcs.shape[0], -1) - func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process) + func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process) items.append(func_args) if not run_in_parallel and non_nn_metrics: @@ -184,7 +194,7 @@ def compute_pc_metrics( units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids)) func = _nn_metric_name_to_func[metric_name] - metric_params = qm_params[metric_name] if metric_name in qm_params else {} + metric_params = metric_params[metric_name] if metric_name in metric_params else {} for _, unit_id in units_loop: try: @@ -213,7 +223,7 @@ def compute_pc_metrics( def calculate_pc_metrics( - sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False + sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False ): warnings.warn( "The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead", @@ -224,7 +234,7 @@ def calculate_pc_metrics( pc_metrics = compute_pc_metrics( sorting_analyzer, metric_names=metric_names, - qm_params=qm_params, + metric_params=metric_params, unit_ids=unit_ids, seed=seed, n_jobs=n_jobs, @@ -977,16 +987,16 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int): def pca_metrics_one_unit(args): - (pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args + (pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args if max_threads_per_process is None: - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) else: with threadpool_limits(limits=int(max_threads_per_process)): - return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params) + return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params) -def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params): +def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params): pc_metrics = {} # metrics if "isolation_distance" in metric_names or "l_ratio" in metric_names: @@ -1015,7 +1025,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ if "nearest_neighbor" in metric_names: try: nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics( - pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"] + pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"] ) except: nn_hit_rate = np.nan @@ -1024,7 +1034,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_ pc_metrics["nn_miss_rate"] = nn_miss_rate if "silhouette" in metric_names: - silhouette_method = qm_params["silhouette"]["method"] + silhouette_method = metric_params["silhouette"]["method"] if "simplified" in silhouette_method: try: unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index b6a50d60f5..eb380304b6 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -6,6 +6,7 @@ from copy import deepcopy import numpy as np +from warnings import warn from spikeinterface.core.job_tools import fix_job_kwargs from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -31,7 +32,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - qm_params : dict or None + metric_params : dict or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -58,6 +59,7 @@ class ComputeQualityMetrics(AnalyzerExtension): def _set_params( self, metric_names=None, + metric_params=None, qm_params=None, peak_sign=None, seed=None, @@ -65,6 +67,12 @@ def _set_params( delete_existing_metrics=False, metrics_to_compute=None, ): + if qm_params is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = qm_params + warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) if metric_names is None: metric_names = list(_misc_metric_name_to_func.keys()) @@ -80,12 +88,12 @@ def _set_params( if "drift" in metric_names: metric_names.remove("drift") - qm_params_ = get_default_qm_params() - for k in qm_params_: - if qm_params is not None and k in qm_params: - qm_params_[k].update(qm_params[k]) - if "peak_sign" in qm_params_[k] and peak_sign is not None: - qm_params_[k]["peak_sign"] = peak_sign + metric_params_ = get_default_qm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names qm_extension = self.sorting_analyzer.get_extension("quality_metrics") @@ -101,7 +109,7 @@ def _set_params( metric_names=metric_names, peak_sign=peak_sign, seed=seed, - qm_params=qm_params_, + metric_params=metric_params_, skip_pc_metrics=skip_pc_metrics, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, @@ -141,7 +149,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri """ import pandas as pd - qm_params = self.params["qm_params"] + metric_params = self.params["metric_params"] # sparsity = self.params["sparsity"] seed = self.params["seed"] @@ -177,7 +185,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri func = _misc_metric_name_to_func[metric_name] - params = qm_params[metric_name] if metric_name in qm_params else {} + params = metric_params[metric_name] if metric_name in metric_params else {} res = func(sorting_analyzer, unit_ids=non_empty_unit_ids, **params) # QM with uninstall dependencies might return None if res is not None: @@ -205,7 +213,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # sparsity=sparsity, progress_bar=progress_bar, n_jobs=n_jobs, - qm_params=qm_params, + metric_params=metric_params, seed=seed, ) for col, values in pc_metrics.items(): diff --git a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py index 4c0890b62b..20869aa44a 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py +++ b/src/spikeinterface/qualitymetrics/tests/test_metrics_functions.py @@ -69,7 +69,7 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): assert calculated_metrics == ["snr"] small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": list(qm_params.keys()), "qm_params": qm_params}} + {"quality_metrics": {"metric_names": list(qm_params.keys()), "metric_params": qm_params}} ) small_sorting_analyzer.compute({"quality_metrics": {"metric_names": ["snr"]}}) @@ -96,13 +96,13 @@ def test_compute_new_quality_metrics(small_sorting_analyzer): # check that, when parameters are changed, the data and metadata are updated old_snr_data = deepcopy(quality_metric_extension.get_data()["snr"].values) small_sorting_analyzer.compute( - {"quality_metrics": {"metric_names": ["snr"], "qm_params": {"snr": {"peak_mode": "peak_to_peak"}}}} + {"quality_metrics": {"metric_names": ["snr"], "metric_params": {"snr": {"peak_mode": "peak_to_peak"}}}} ) new_quality_metric_extension = small_sorting_analyzer.get_extension("quality_metrics") new_snr_data = new_quality_metric_extension.get_data()["snr"].values assert np.all(old_snr_data != new_snr_data) - assert new_quality_metric_extension.params["qm_params"]["snr"]["peak_mode"] == "peak_to_peak" + assert new_quality_metric_extension.params["metric_params"]["snr"]["peak_mode"] == "peak_to_peak" # check that all quality metrics are deleted when parents are recomputed, even after # recomputation @@ -280,10 +280,10 @@ def test_unit_id_order_independence(small_sorting_analyzer): } quality_metrics_1 = compute_quality_metrics( - small_sorting_analyzer, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer, metric_names=get_quality_metric_list(), metric_params=qm_params ) quality_metrics_2 = compute_quality_metrics( - small_sorting_analyzer_2, metric_names=get_quality_metric_list(), qm_params=qm_params + small_sorting_analyzer_2, metric_names=get_quality_metric_list(), metric_params=qm_params ) for metric, metric_2_data in quality_metrics_2.items(): From 908318a04731f6e8723a9cf3b34914d8e782e900 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:38:28 +0000 Subject: [PATCH 02/13] fix tests --- .../tests/test_quality_metric_calculator.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index a6415c58e8..60f0490f51 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -24,14 +24,14 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=["snr"], - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) # print(metrics) qm = sorting_analyzer.get_extension("quality_metrics") - assert qm.params["qm_params"]["isi_violation"]["isi_threshold_ms"] == 2 + assert qm.params["metric_params"]["isi_violation"]["isi_threshold_ms"] == 2 assert "snr" in metrics.columns assert "isolation_distance" not in metrics.columns @@ -40,7 +40,7 @@ def test_compute_quality_metrics(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -54,7 +54,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics = compute_quality_metrics( sorting_analyzer, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -68,7 +68,7 @@ def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): metrics_norec = compute_quality_metrics( sorting_analyzer_norec, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=False, seed=2205, ) @@ -101,7 +101,7 @@ def test_empty_units(sorting_analyzer_simple): metrics_empty = compute_quality_metrics( sorting_analyzer_empty, metric_names=None, - qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + metric_params=dict(isi_violation=dict(isi_threshold_ms=2)), skip_pc_metrics=True, seed=2205, ) From 2706a0bee9fef9dd3b3a4af99715366aae1c1625 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:45:28 +0000 Subject: [PATCH 03/13] Change metrics_kwargs to metric_params and add depreciation message --- .../postprocessing/template_metrics.py | 36 +++++++++++-------- .../tests/test_template_metrics.py | 2 +- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 6e7bcf21b8..ef6abfe51f 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -63,8 +63,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): include_multi_channel_metrics : bool, default: False Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False - If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. - metrics_kwargs : dict + If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. + metric_params : dict Additional arguments to pass to the metric functions. Including: * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 @@ -109,12 +109,20 @@ def _set_params( peak_sign="neg", upsampling_factor=10, sparsity=None, + metric_params=None, metrics_kwargs=None, include_multi_channel_metrics=False, delete_existing_metrics=False, **other_kwargs, ): + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = ( + "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = metrics_kwargs + warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -134,27 +142,27 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() - if metrics_kwargs is None: - metrics_kwargs_ = _default_function_kwargs.copy() + if metric_params is None: + metric_params_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: for m in other_kwargs: - if m in metrics_kwargs_: - metrics_kwargs_[m] = other_kwargs[m] + if m in metric_params_: + metric_params_[m] = other_kwargs[m] else: - metrics_kwargs_ = _default_function_kwargs.copy() - metrics_kwargs_.update(metrics_kwargs) + metric_params_ = _default_function_kwargs.copy() + metric_params_.update(metric_params) metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metrics_kwargs"] + existing_params = tm_extension.params["metric_params"] # checks that existing metrics were calculated using the same params - if existing_params != metrics_kwargs_: + if existing_params != metric_params_: warnings.warn( f"The parameters used to calculate the previous template metrics are different" f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." + f"parameters: {metric_params_}\nDeleting previous template metrics..." ) tm_extension.params["metric_names"] = [] existing_metric_names = [] @@ -171,7 +179,7 @@ def _set_params( sparsity=sparsity, peak_sign=peak_sign, upsampling_factor=int(upsampling_factor), - metrics_kwargs=metrics_kwargs_, + metric_params=metric_params_, delete_existing_metrics=delete_existing_metrics, metrics_to_compute=metrics_to_compute, ) @@ -273,7 +281,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metrics_kwargs"], + **self.params["metric_params"], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -312,7 +320,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metrics_kwargs"], + **self.params["metric_params"], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 5056d4ff2a..1df723bfe3 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -47,7 +47,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer): # check that, when parameters are changed, the old metrics are deleted small_sorting_analyzer.compute( - {"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}} + {"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}} ) From 22460710bca1114e5f11f5ba4cbdad1b82941d70 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 09:46:16 +0000 Subject: [PATCH 04/13] Update warning message (oups) --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index ef6abfe51f..9b85f99c0d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -118,7 +118,7 @@ def _set_params( if metrics_kwargs is not None and metric_params is None: deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" ) metric_params = metrics_kwargs warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) From 3579934baf43e25759ca1e8ee7f1e3288180be71 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:04:02 +0000 Subject: [PATCH 05/13] Make compute work and add `get_default_tm_params` --- .../postprocessing/template_metrics.py | 45 +++++++++---------- 1 file changed, 20 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 9b85f99c0d..25e0d0d490 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,22 +64,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. - metric_params : dict - Additional arguments to pass to the metric functions. Including: - * recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 - * peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 - * peak_width_ms: the width in samples to detect peaks, default: 0.2 - * depth_direction: the direction to compute velocity above and below, default: "y" (see notes) - * min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 - * min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 - * exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" - * min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 - * spread_threshold: the threshold to compute the spread, default: 0.2 - * spread_smooth_um: the smoothing in um to compute the spread, default: 20 - * column_range: the range in um in the horizontal direction to consider channels for velocity, default: None - - If None, all channels all channels are considered - - If 0 or 1, only the "column" that includes the max channel is considered - - If > 1, only channels within range (+/-) um from the max channel horizontal position are used + metric_params : dict of dicts + metric_params : dict of dicts or None + Dictionary with parameters for quality metrics calculation. + Default parameters can be obtained with: `si.qualitymetrics.get_default_tm_params()` Returns ------- @@ -116,13 +104,6 @@ def _set_params( **other_kwargs, ): - if metrics_kwargs is not None and metric_params is None: - deprecation_msg = ( - "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" - ) - metric_params = metrics_kwargs - warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - import pandas as pd # TODO alessio can you check this : this used to be in the function but now we have ComputeTemplateMetrics.function_factory() @@ -142,6 +123,13 @@ def _set_params( if include_multi_channel_metrics: metric_names += get_multi_channel_template_metric_names() + if metrics_kwargs is not None and metric_params is None: + deprecation_msg = ( + "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + ) + metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) + warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + if metric_params is None: metric_params_ = _default_function_kwargs.copy() if len(other_kwargs) > 0: @@ -281,7 +269,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri sampling_frequency=sampling_frequency_up, trough_idx=trough_idx, peak_idx=peak_idx, - **self.params["metric_params"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -320,7 +308,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri template_upsampled, channel_locations=channel_locations_sparse, sampling_frequency=sampling_frequency_up, - **self.params["metric_params"], + **self.params["metric_params"][metric_name], ) except Exception as e: warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") @@ -380,6 +368,13 @@ def _get_data(self): ) +def get_default_tm_params(): + metric_names = get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() + base_tm_params = _default_function_kwargs + metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) + return metric_params + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough From 66190c3857bcabf30ca64af994693a3a029c41e1 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:18:46 +0000 Subject: [PATCH 06/13] Update compute_name_to_column_names to qm_compute_name_to_column_names --- .../qualitymetrics/quality_metric_calculator.py | 6 +++--- src/spikeinterface/qualitymetrics/quality_metric_list.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index eb380304b6..365d7bcc09 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -16,7 +16,7 @@ compute_pc_metrics, _misc_metric_name_to_func, _possible_pc_metric_names, - compute_name_to_column_names, + qm_compute_name_to_column_names, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -32,7 +32,7 @@ class ComputeQualityMetrics(AnalyzerExtension): A SortingAnalyzer object. metric_names : list or None List of quality metrics to compute. - metric_params : dict or None + metric_params : dict of dicts or None Dictionary with parameters for quality metrics calculation. Default parameters can be obtained with: `si.qualitymetrics.get_default_qm_params()` skip_pc_metrics : bool, default: False @@ -254,7 +254,7 @@ def _run(self, verbose=False, **job_kwargs): # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): # some metrics names produce data columns with other names. This deals with that. - for column_name in compute_name_to_column_names[metric_name]: + for column_name in qm_compute_name_to_column_names[metric_name]: computed_metrics[column_name] = qm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index 375dd320ae..fc7e92b50d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -55,7 +55,7 @@ } # a dict converting the name of the metric for computation to the output of that computation -compute_name_to_column_names = { +qm_compute_name_to_column_names = { "num_spikes": ["num_spikes"], "firing_rate": ["firing_rate"], "presence_ratio": ["presence_ratio"], From 2de25b47013c8954ece03af1f47250e5db1f7ffb Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 10:19:27 +0000 Subject: [PATCH 07/13] Unify template param checks with quality param checks --- .../postprocessing/template_metrics.py | 63 +++++++++++-------- 1 file changed, 38 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 25e0d0d490..cfdbd122b3 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -130,33 +130,18 @@ def _set_params( metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) - if metric_params is None: - metric_params_ = _default_function_kwargs.copy() - if len(other_kwargs) > 0: - for m in other_kwargs: - if m in metric_params_: - metric_params_[m] = other_kwargs[m] - else: - metric_params_ = _default_function_kwargs.copy() - metric_params_.update(metric_params) + metric_params_ = get_default_tm_params() + for k in metric_params_: + if metric_params is not None and k in metric_params: + metric_params_[k].update(metric_params[k]) + if "peak_sign" in metric_params_[k] and peak_sign is not None: + metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") if delete_existing_metrics is False and tm_extension is not None: - existing_params = tm_extension.params["metric_params"] - # checks that existing metrics were calculated using the same params - if existing_params != metric_params_: - warnings.warn( - f"The parameters used to calculate the previous template metrics are different" - f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " - f"parameters: {metric_params_}\nDeleting previous template metrics..." - ) - tm_extension.params["metric_names"] = [] - existing_metric_names = [] - else: - existing_metric_names = tm_extension.params["metric_names"] - + existing_metric_names = tm_extension.params["metric_names"] existing_metric_names_propogated = [ metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute ] @@ -322,8 +307,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri def _run(self, verbose=False): - delete_existing_metrics = self.params["delete_existing_metrics"] metrics_to_compute = self.params["metrics_to_compute"] + delete_existing_metrics = self.params["delete_existing_metrics"] # compute the metrics which have been specified by the user computed_metrics = self._compute_metrics( @@ -339,9 +324,21 @@ def _run(self, verbose=False): ): existing_metrics = tm_extension.params["metric_names"] + existing_metrics = [] + # here we get in the loaded via the dict only (to avoid full loading from disk after params reset) + tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) + if ( + delete_existing_metrics is False + and tm_extension is not None + and tm_extension.data.get("metrics") is not None + ): + existing_metrics = tm_extension.params["metric_names"] + # append the metrics which were previously computed for metric_name in set(existing_metrics).difference(metrics_to_compute): - computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] + # some metrics names produce data columns with other names. This deals with that. + for column_name in tm_compute_name_to_column_names[metric_name]: + computed_metrics[column_name] = tm_extension.data["metrics"][column_name] self.data["metrics"] = computed_metrics @@ -369,12 +366,28 @@ def _get_data(self): def get_default_tm_params(): - metric_names = get_single_channel_template_metric_names() + get_multi_channel_template_metric_names() + metric_names = get_template_metric_names() base_tm_params = _default_function_kwargs metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) return metric_params +# a dict converting the name of the metric for computation to the output of that computation +tm_compute_name_to_column_names = { + "peak_to_valley": ["peak_to_valley"], + "peak_trough_ratio": ["peak_trough_ratio"], + "half_width": ["half_width"], + "repolarization_slope": ["repolarization_slope"], + "recovery_slope": ["recovery_slope"], + "num_positive_peaks": ["num_positive_peaks"], + "num_negative_peaks": ["num_negative_peaks"], + "velocity_above": ["velocity_above"], + "velocity_below": ["velocity_below"], + "exp_decay": ["exp_decay"], + "spread": ["spread"], +} + + def get_trough_and_peak_idx(template): """ Return the indices into the input template of the detected trough From 8f6602423d53dc625689da2183a24fe43cbb8629 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 11:07:27 +0000 Subject: [PATCH 08/13] add some tests --- .../tests/test_template_metrics.py | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/tests/test_template_metrics.py b/src/spikeinterface/postprocessing/tests/test_template_metrics.py index 1df723bfe3..1bf49f64c1 100644 --- a/src/spikeinterface/postprocessing/tests/test_template_metrics.py +++ b/src/spikeinterface/postprocessing/tests/test_template_metrics.py @@ -1,5 +1,5 @@ from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite -from spikeinterface.postprocessing import ComputeTemplateMetrics +from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics import pytest import csv @@ -8,6 +8,49 @@ template_metrics = list(_single_channel_metric_name_to_func.keys()) +def test_different_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using different params, and check that they are + actually calculated using the different params. + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread", "half_width"], + metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.7 + assert tm_params["half_width"]["recovery_window_ms"] == 0.7 + + assert tm_params["spread"]["spread_smooth_um"] == 15 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + assert tm_params["half_width"]["spread_smooth_um"] == 20 + + +def test_backwards_compat_params_template_metrics(small_sorting_analyzer): + """ + Computes template metrics using the metrics_kwargs keyword + """ + compute_template_metrics( + sorting_analyzer=small_sorting_analyzer, + metric_names=["exp_decay", "spread"], + metrics_kwargs={"recovery_window_ms": 0.8}, + ) + + tm_extension = small_sorting_analyzer.get_extension("template_metrics") + tm_params = tm_extension.params["metric_params"] + + assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8 + assert tm_params["spread"]["recovery_window_ms"] == 0.8 + + assert tm_params["spread"]["spread_smooth_um"] == 20 + assert tm_params["exp_decay"]["spread_smooth_um"] == 20 + + def test_compute_new_template_metrics(small_sorting_analyzer): """ Computes template metrics then computes a subset of template metrics, and checks @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer): are deleted. """ + small_sorting_analyzer.delete_extension("template_metrics") + # calculate just exp_decay small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}}) template_metric_extension = small_sorting_analyzer.get_extension("template_metrics") From fdc01f5adb81f55c5787166fa469100d7bc06239 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:00:15 +0000 Subject: [PATCH 09/13] little fixes --- .../postprocessing/template_metrics.py | 31 +++++++++++-------- .../qualitymetrics/pca_metrics.py | 4 +-- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index cfdbd122b3..cbcf38d19d 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -66,8 +66,8 @@ class ComputeTemplateMetrics(AnalyzerExtension): If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. metric_params : dict of dicts metric_params : dict of dicts or None - Dictionary with parameters for quality metrics calculation. - Default parameters can be obtained with: `si.qualitymetrics.get_default_tm_params()` + Dictionary with parameters for template metrics calculation. + Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` Returns ------- @@ -124,18 +124,17 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() if metrics_kwargs is not None and metric_params is None: - deprecation_msg = ( - "`metrics_kwargs` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" - ) - metric_params = dict(zip(metric_names, [metrics_kwargs] * len(metric_names))) - warnings.warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + warnings.warn(deprecation_msg, category=DeprecationWarning) + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(metrics_kwargs) - metric_params_ = get_default_tm_params() + metric_params_ = get_default_tm_params(metric_names) for k in metric_params_: if metric_params is not None and k in metric_params: metric_params_[k].update(metric_params[k]) - if "peak_sign" in metric_params_[k] and peak_sign is not None: - metric_params_[k]["peak_sign"] = peak_sign metrics_to_compute = metric_names tm_extension = self.sorting_analyzer.get_extension("template_metrics") @@ -365,10 +364,16 @@ def _get_data(self): ) -def get_default_tm_params(): - metric_names = get_template_metric_names() +def get_default_tm_params(metric_names): + if metric_names is None: + metric_names = get_template_metric_names() + base_tm_params = _default_function_kwargs - metric_params = dict(zip(metric_names, [base_tm_params] * len(metric_names))) + + metric_params = {} + for metric_name in metric_names: + metric_params[metric_name] = deepcopy(base_tm_params) + return metric_params diff --git a/src/spikeinterface/qualitymetrics/pca_metrics.py b/src/spikeinterface/qualitymetrics/pca_metrics.py index b4952bfe6d..ca21f1e45f 100644 --- a/src/spikeinterface/qualitymetrics/pca_metrics.py +++ b/src/spikeinterface/qualitymetrics/pca_metrics.py @@ -91,10 +91,10 @@ def compute_pc_metrics( if qm_params is not None and metric_params is None: deprecation_msg = ( - "`qm_params` is deprecated and will be removed in version 0.104.0 Please use metric_params instead" + "`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" ) - metric_params = qm_params warn(deprecation_msg, category=DeprecationWarning, stacklevel=2) + metric_params = qm_params pca_ext = sorting_analyzer.get_extension("principal_components") assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'" From 9db0b83f7c7ad2c773c8673fa3dd09e5c3cdecb6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:11:01 +0000 Subject: [PATCH 10/13] backwards compatible loading --- .../postprocessing/template_metrics.py | 12 ++++++++++++ .../qualitymetrics/quality_metric_calculator.py | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index cbcf38d19d..477ad04440 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -344,6 +344,18 @@ def _run(self, verbose=False): def _get_data(self): return self.data["metrics"] + def load_params(self): + AnalyzerExtension.load_params(self) + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 365d7bcc09..e7e7c244ea 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -262,6 +262,13 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["metrics"] + def load_params(self): + AnalyzerExtension.load_params(self) + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] + register_result_extension(ComputeQualityMetrics) compute_quality_metrics = ComputeQualityMetrics.function_factory() From 039b408a59ce965b91908de82d0bc55114f8655e Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Mon, 2 Dec 2024 13:43:39 +0000 Subject: [PATCH 11/13] move backwards compat to `_handle_backward_compatibility_on_load` --- .../postprocessing/template_metrics.py | 25 ++++++++++--------- .../quality_metric_calculator.py | 14 +++++------ 2 files changed, 20 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 477ad04440..7de6e8766a 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -88,9 +88,22 @@ class ComputeTemplateMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = False + need_backward_compatibility_on_load = True min_channels_for_multi_channel_warning = 10 + def _handle_backward_compatibility_on_load(self): + + # For backwards compatibility - this reformats metrics_kwargs as metric_params + if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: + + metric_params = {} + for metric_name in self.params["metric_names"]: + metric_params[metric_name] = deepcopy(metrics_kwargs) + self.params["metric_params"] = metric_params + + del self.params["metrics_kwargs"] + def _set_params( self, metric_names=None, @@ -344,18 +357,6 @@ def _run(self, verbose=False): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this reformats metrics_kwargs as metric_params - if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: - - metric_params = {} - for metric_name in self.params["metric_names"]: - metric_params[metric_name] = deepcopy(metrics_kwargs) - self.params["metric_params"] = metric_params - - del self.params["metrics_kwargs"] - register_result_extension(ComputeTemplateMetrics) compute_template_metrics = ComputeTemplateMetrics.function_factory() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index e7e7c244ea..d71450853f 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -55,6 +55,13 @@ class ComputeQualityMetrics(AnalyzerExtension): need_recording = False use_nodepipeline = False need_job_kwargs = True + need_backward_compatibility_on_load = True + + def _handle_backward_compatibility_on_load(self): + # For backwards compatibility - this renames qm_params as metric_params + if (qm_params := self.params.get("qm_params")) is not None: + self.params["metric_params"] = qm_params + del self.params["qm_params"] def _set_params( self, @@ -262,13 +269,6 @@ def _run(self, verbose=False, **job_kwargs): def _get_data(self): return self.data["metrics"] - def load_params(self): - AnalyzerExtension.load_params(self) - # For backwards compatibility - this renames qm_params as metric_params - if (qm_params := self.params.get("qm_params")) is not None: - self.params["metric_params"] = qm_params - del self.params["qm_params"] - register_result_extension(ComputeQualityMetrics) compute_quality_metrics = ComputeQualityMetrics.function_factory() From 771de98c6ddd07e064cb28d6e2450e599d54d2a6 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:51:28 +0000 Subject: [PATCH 12/13] Respond to z-man --- src/spikeinterface/postprocessing/template_metrics.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index 7de6e8766a..da917e673c 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -64,8 +64,7 @@ class ComputeTemplateMetrics(AnalyzerExtension): Whether to compute multi-channel metrics delete_existing_metrics : bool, default: False If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. - metric_params : dict of dicts - metric_params : dict of dicts or None + metric_params : dict of dicts or None, default: None Dictionary with parameters for template metrics calculation. Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` @@ -138,7 +137,7 @@ def _set_params( if metrics_kwargs is not None and metric_params is None: deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" - warnings.warn(deprecation_msg, category=DeprecationWarning) + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" metric_params = {} for metric_name in metric_names: From de7210a43135c1164ee2f214e117543441935375 Mon Sep 17 00:00:00 2001 From: chrishalcrow <57948917+chrishalcrow@users.noreply.github.com> Date: Tue, 3 Dec 2024 09:52:16 +0000 Subject: [PATCH 13/13] oups --- src/spikeinterface/postprocessing/template_metrics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index da917e673c..1969480503 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -136,7 +136,7 @@ def _set_params( metric_names += get_multi_channel_template_metric_names() if metrics_kwargs is not None and metric_params is None: - deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead" + deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" metric_params = {}