Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify template and quality metrics #3537

Merged
121 changes: 77 additions & 44 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,23 +63,11 @@ class ComputeTemplateMetrics(AnalyzerExtension):
include_multi_channel_metrics : bool, default: False
Whether to compute multi-channel metrics
delete_existing_metrics : bool, default: False
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged.
metrics_kwargs : dict
Additional arguments to pass to the metric functions. Including:
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7
* peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2
* peak_width_ms: the width in samples to detect peaks, default: 0.2
* depth_direction: the direction to compute velocity above and below, default: "y" (see notes)
* min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5
* min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7
* exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp"
* min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5
* spread_threshold: the threshold to compute the spread, default: 0.2
* spread_smooth_um: the smoothing in um to compute the spread, default: 20
* column_range: the range in um in the horizontal direction to consider channels for velocity, default: None
- If None, all channels all channels are considered
- If 0 or 1, only the "column" that includes the max channel is considered
- If > 1, only channels within range (+/-) um from the max channel horizontal position are used
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged.
metric_params : dict of dicts
metric_params : dict of dicts or None
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
Dictionary with parameters for template metrics calculation.
Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()`

Returns
-------
Expand Down Expand Up @@ -109,6 +97,7 @@ def _set_params(
peak_sign="neg",
upsampling_factor=10,
sparsity=None,
metric_params=None,
metrics_kwargs=None,
include_multi_channel_metrics=False,
delete_existing_metrics=False,
Expand All @@ -134,33 +123,24 @@ def _set_params(
if include_multi_channel_metrics:
metric_names += get_multi_channel_template_metric_names()

if metrics_kwargs is None:
metrics_kwargs_ = _default_function_kwargs.copy()
if len(other_kwargs) > 0:
for m in other_kwargs:
if m in metrics_kwargs_:
metrics_kwargs_[m] = other_kwargs[m]
else:
metrics_kwargs_ = _default_function_kwargs.copy()
metrics_kwargs_.update(metrics_kwargs)
if metrics_kwargs is not None and metric_params is None:
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use metric_params instead"
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(deprecation_msg, category=DeprecationWarning)

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(metrics_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a deep copy is expensive, but there's no way around it I guess right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is the backwards compatibility case, where the individual metric_kwargs dict (which used to be forced to be the same for every template metric) gets copied repeatedly into metric_params. If there isn't a deepcopy, metric_params is just a dict of references to the metric_kwargs dict. So if one gets altered, they all do. Luckily, the dict is just a dict of kwargs (so is small) and this code only runs on instantiation if someone tries to use the old notation. So shouldn't be a performance issue.


metric_params_ = get_default_tm_params(metric_names)
for k in metric_params_:
if metric_params is not None and k in metric_params:
metric_params_[k].update(metric_params[k])

metrics_to_compute = metric_names
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if delete_existing_metrics is False and tm_extension is not None:

existing_params = tm_extension.params["metrics_kwargs"]
# checks that existing metrics were calculated using the same params
if existing_params != metrics_kwargs_:
warnings.warn(
f"The parameters used to calculate the previous template metrics are different"
f"than those used now.\nPrevious parameters: {existing_params}\nCurrent "
f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..."
)
tm_extension.params["metric_names"] = []
existing_metric_names = []
else:
existing_metric_names = tm_extension.params["metric_names"]

existing_metric_names = tm_extension.params["metric_names"]
existing_metric_names_propogated = [
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute
]
Expand All @@ -171,7 +151,7 @@ def _set_params(
sparsity=sparsity,
peak_sign=peak_sign,
upsampling_factor=int(upsampling_factor),
metrics_kwargs=metrics_kwargs_,
metric_params=metric_params_,
delete_existing_metrics=delete_existing_metrics,
metrics_to_compute=metrics_to_compute,
)
Expand Down Expand Up @@ -273,7 +253,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
sampling_frequency=sampling_frequency_up,
trough_idx=trough_idx,
peak_idx=peak_idx,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand Down Expand Up @@ -312,7 +292,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
template_upsampled,
channel_locations=channel_locations_sparse,
sampling_frequency=sampling_frequency_up,
**self.params["metrics_kwargs"],
**self.params["metric_params"][metric_name],
)
except Exception as e:
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}")
Expand All @@ -326,8 +306,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri

def _run(self, verbose=False):

delete_existing_metrics = self.params["delete_existing_metrics"]
metrics_to_compute = self.params["metrics_to_compute"]
delete_existing_metrics = self.params["delete_existing_metrics"]

# compute the metrics which have been specified by the user
computed_metrics = self._compute_metrics(
Expand All @@ -343,15 +323,39 @@ def _run(self, verbose=False):
):
existing_metrics = tm_extension.params["metric_names"]

existing_metrics = []
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset)
tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None)
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name]
# some metrics names produce data columns with other names. This deals with that.
for column_name in tm_compute_name_to_column_names[metric_name]:
computed_metrics[column_name] = tm_extension.data["metrics"][column_name]

self.data["metrics"] = computed_metrics

def _get_data(self):
return self.data["metrics"]

def load_params(self):
chrishalcrow marked this conversation as resolved.
Show resolved Hide resolved
AnalyzerExtension.load_params(self)
# For backwards compatibility - this reformats metrics_kwargs as metric_params
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None:

metric_params = {}
for metric_name in self.params["metric_names"]:
metric_params[metric_name] = deepcopy(metrics_kwargs)
self.params["metric_params"] = metric_params

del self.params["metrics_kwargs"]


register_result_extension(ComputeTemplateMetrics)
compute_template_metrics = ComputeTemplateMetrics.function_factory()
Expand All @@ -372,6 +376,35 @@ def _get_data(self):
)


def get_default_tm_params(metric_names):
if metric_names is None:
metric_names = get_template_metric_names()

base_tm_params = _default_function_kwargs

metric_params = {}
for metric_name in metric_names:
metric_params[metric_name] = deepcopy(base_tm_params)

return metric_params


# a dict converting the name of the metric for computation to the output of that computation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So for templates we don't have the same structure as for quality metrics where we have a separate py file with all these names and a list of metrics?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, it's a bit different. Here we're copying the qm_compute_name_to_column_names dict, which is the important dict for getting the actual dataframe column names. The other stuff in quality_metric_list.py is mostly used for the pca/non-pca split. We could follow the quality_metrics structure and make a templatemetrics folder inside src/spikeinterface/postprocessing?? @alejoe91 thoughts?

tm_compute_name_to_column_names = {
"peak_to_valley": ["peak_to_valley"],
"peak_trough_ratio": ["peak_trough_ratio"],
"half_width": ["half_width"],
"repolarization_slope": ["repolarization_slope"],
"recovery_slope": ["recovery_slope"],
"num_positive_peaks": ["num_positive_peaks"],
"num_negative_peaks": ["num_negative_peaks"],
"velocity_above": ["velocity_above"],
"velocity_below": ["velocity_below"],
"exp_decay": ["exp_decay"],
"spread": ["spread"],
}


def get_trough_and_peak_idx(template):
"""
Return the indices into the input template of the detected trough
Expand Down
49 changes: 47 additions & 2 deletions src/spikeinterface/postprocessing/tests/test_template_metrics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from spikeinterface.postprocessing.tests.common_extension_tests import AnalyzerExtensionCommonTestSuite
from spikeinterface.postprocessing import ComputeTemplateMetrics
from spikeinterface.postprocessing import ComputeTemplateMetrics, compute_template_metrics
import pytest
import csv

Expand All @@ -8,6 +8,49 @@
template_metrics = list(_single_channel_metric_name_to_func.keys())


def test_different_params_template_metrics(small_sorting_analyzer):
"""
Computes template metrics using different params, and check that they are
actually calculated using the different params.
"""
compute_template_metrics(
sorting_analyzer=small_sorting_analyzer,
metric_names=["exp_decay", "spread", "half_width"],
metric_params={"exp_decay": {"recovery_window_ms": 0.8}, "spread": {"spread_smooth_um": 15}},
)

tm_extension = small_sorting_analyzer.get_extension("template_metrics")
tm_params = tm_extension.params["metric_params"]

assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8
assert tm_params["spread"]["recovery_window_ms"] == 0.7
assert tm_params["half_width"]["recovery_window_ms"] == 0.7

assert tm_params["spread"]["spread_smooth_um"] == 15
assert tm_params["exp_decay"]["spread_smooth_um"] == 20
assert tm_params["half_width"]["spread_smooth_um"] == 20


def test_backwards_compat_params_template_metrics(small_sorting_analyzer):
"""
Computes template metrics using the metrics_kwargs keyword
"""
compute_template_metrics(
sorting_analyzer=small_sorting_analyzer,
metric_names=["exp_decay", "spread"],
metrics_kwargs={"recovery_window_ms": 0.8},
)

tm_extension = small_sorting_analyzer.get_extension("template_metrics")
tm_params = tm_extension.params["metric_params"]

assert tm_params["exp_decay"]["recovery_window_ms"] == 0.8
assert tm_params["spread"]["recovery_window_ms"] == 0.8

assert tm_params["spread"]["spread_smooth_um"] == 20
assert tm_params["exp_decay"]["spread_smooth_um"] == 20


def test_compute_new_template_metrics(small_sorting_analyzer):
"""
Computes template metrics then computes a subset of template metrics, and checks
Expand All @@ -17,6 +60,8 @@ def test_compute_new_template_metrics(small_sorting_analyzer):
are deleted.
"""

small_sorting_analyzer.delete_extension("template_metrics")

# calculate just exp_decay
small_sorting_analyzer.compute({"template_metrics": {"metric_names": ["exp_decay"]}})
template_metric_extension = small_sorting_analyzer.get_extension("template_metrics")
Expand Down Expand Up @@ -47,7 +92,7 @@ def test_compute_new_template_metrics(small_sorting_analyzer):

# check that, when parameters are changed, the old metrics are deleted
small_sorting_analyzer.compute(
{"template_metrics": {"metric_names": ["exp_decay"], "metrics_kwargs": {"recovery_window_ms": 0.6}}}
{"template_metrics": {"metric_names": ["exp_decay"], "metric_params": {"recovery_window_ms": 0.6}}}
)


Expand Down
36 changes: 23 additions & 13 deletions src/spikeinterface/qualitymetrics/pca_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from copy import deepcopy
import platform
from tqdm.auto import tqdm
from warnings import warn

import numpy as np

Expand Down Expand Up @@ -52,6 +53,7 @@ def get_quality_pca_metric_list():
def compute_pc_metrics(
sorting_analyzer,
metric_names=None,
metric_params=None,
qm_params=None,
unit_ids=None,
seed=None,
Expand All @@ -70,7 +72,7 @@ def compute_pc_metrics(
metric_names : list of str, default: None
The list of PC metrics to compute.
If not provided, defaults to all PC metrics.
qm_params : dict or None
metric_params : dict or None
Dictionary with parameters for each PC metric function.
unit_ids : list of int or None
List of unit ids to compute metrics for.
Expand All @@ -86,15 +88,23 @@ def compute_pc_metrics(
pc_metrics : dict
The computed PC metrics.
"""

if qm_params is not None and metric_params is None:
deprecation_msg = (
"`qm_params` is deprecated and will be removed in version 0.104.0. Please use metric_params instead"
)
warn(deprecation_msg, category=DeprecationWarning, stacklevel=2)
metric_params = qm_params

pca_ext = sorting_analyzer.get_extension("principal_components")
assert pca_ext is not None, "calculate_pc_metrics() need extension 'principal_components'"

sorting = sorting_analyzer.sorting

if metric_names is None:
metric_names = _possible_pc_metric_names.copy()
if qm_params is None:
qm_params = _default_params
if metric_params is None:
metric_params = _default_params

extremum_channels = get_template_extremum_channel(sorting_analyzer)

Expand Down Expand Up @@ -147,7 +157,7 @@ def compute_pc_metrics(
pcs = dense_projections[np.isin(all_labels, neighbor_unit_ids)][:, :, neighbor_channel_indices]
pcs_flat = pcs.reshape(pcs.shape[0], -1)

func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, qm_params, max_threads_per_process)
func_args = (pcs_flat, labels, non_nn_metrics, unit_id, unit_ids, metric_params, max_threads_per_process)
items.append(func_args)

if not run_in_parallel and non_nn_metrics:
Expand Down Expand Up @@ -184,7 +194,7 @@ def compute_pc_metrics(
units_loop = tqdm(units_loop, desc=f"calculate {metric_name} metric", total=len(unit_ids))

func = _nn_metric_name_to_func[metric_name]
metric_params = qm_params[metric_name] if metric_name in qm_params else {}
metric_params = metric_params[metric_name] if metric_name in metric_params else {}

for _, unit_id in units_loop:
try:
Expand Down Expand Up @@ -213,7 +223,7 @@ def compute_pc_metrics(


def calculate_pc_metrics(
sorting_analyzer, metric_names=None, qm_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
sorting_analyzer, metric_names=None, metric_params=None, unit_ids=None, seed=None, n_jobs=1, progress_bar=False
):
warnings.warn(
"The `calculate_pc_metrics` function is deprecated and will be removed in 0.103.0. Please use compute_pc_metrics instead",
Expand All @@ -224,7 +234,7 @@ def calculate_pc_metrics(
pc_metrics = compute_pc_metrics(
sorting_analyzer,
metric_names=metric_names,
qm_params=qm_params,
metric_params=metric_params,
unit_ids=unit_ids,
seed=seed,
n_jobs=n_jobs,
Expand Down Expand Up @@ -977,16 +987,16 @@ def _compute_isolation(pcs_target_unit, pcs_other_unit, n_neighbors: int):


def pca_metrics_one_unit(args):
(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params, max_threads_per_process) = args
(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params, max_threads_per_process) = args

if max_threads_per_process is None:
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params)
else:
with threadpool_limits(limits=int(max_threads_per_process)):
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params)
return _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params)


def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_params):
def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, metric_params):
pc_metrics = {}
# metrics
if "isolation_distance" in metric_names or "l_ratio" in metric_names:
Expand Down Expand Up @@ -1015,7 +1025,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_
if "nearest_neighbor" in metric_names:
try:
nn_hit_rate, nn_miss_rate = nearest_neighbors_metrics(
pcs_flat, labels, unit_id, **qm_params["nearest_neighbor"]
pcs_flat, labels, unit_id, **metric_params["nearest_neighbor"]
)
except:
nn_hit_rate = np.nan
Expand All @@ -1024,7 +1034,7 @@ def _pca_metrics_one_unit(pcs_flat, labels, metric_names, unit_id, unit_ids, qm_
pc_metrics["nn_miss_rate"] = nn_miss_rate

if "silhouette" in metric_names:
silhouette_method = qm_params["silhouette"]["method"]
silhouette_method = metric_params["silhouette"]["method"]
if "simplified" in silhouette_method:
try:
unit_silhouette_score = simplified_silhouette_score(pcs_flat, labels, unit_id)
Expand Down
Loading