-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unify template and quality metrics #3537
Changes from all commits
530864b
908318a
2706a0b
2246071
3579934
66190c3
2de25b4
8f66024
fdc01f5
9db0b83
039b408
771de98
de7210a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,23 +63,10 @@ class ComputeTemplateMetrics(AnalyzerExtension): | |
include_multi_channel_metrics : bool, default: False | ||
Whether to compute multi-channel metrics | ||
delete_existing_metrics : bool, default: False | ||
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metrics_kwargs` are unchanged. | ||
metrics_kwargs : dict | ||
Additional arguments to pass to the metric functions. Including: | ||
* recovery_window_ms: the window in ms after the peak to compute the recovery_slope, default: 0.7 | ||
* peak_relative_threshold: the relative threshold to detect positive and negative peaks, default: 0.2 | ||
* peak_width_ms: the width in samples to detect peaks, default: 0.2 | ||
* depth_direction: the direction to compute velocity above and below, default: "y" (see notes) | ||
* min_channels_for_velocity: the minimum number of channels above or below to compute velocity, default: 5 | ||
* min_r2_velocity: the minimum r2 to accept the velocity fit, default: 0.7 | ||
* exp_peak_function: the function to use to compute the peak amplitude for the exp decay, default: "ptp" | ||
* min_r2_exp_decay: the minimum r2 to accept the exp decay fit, default: 0.5 | ||
* spread_threshold: the threshold to compute the spread, default: 0.2 | ||
* spread_smooth_um: the smoothing in um to compute the spread, default: 20 | ||
* column_range: the range in um in the horizontal direction to consider channels for velocity, default: None | ||
- If None, all channels all channels are considered | ||
- If 0 or 1, only the "column" that includes the max channel is considered | ||
- If > 1, only channels within range (+/-) um from the max channel horizontal position are used | ||
If True, any template metrics attached to the `sorting_analyzer` are deleted. If False, any metrics which were previously calculated but are not included in `metric_names` are kept, provided the `metric_params` are unchanged. | ||
metric_params : dict of dicts or None, default: None | ||
Dictionary with parameters for template metrics calculation. | ||
Default parameters can be obtained with: `si.postprocessing.template_metrics.get_default_tm_params()` | ||
Returns | ||
------- | ||
|
@@ -100,15 +87,29 @@ class ComputeTemplateMetrics(AnalyzerExtension): | |
need_recording = False | ||
use_nodepipeline = False | ||
need_job_kwargs = False | ||
need_backward_compatibility_on_load = True | ||
|
||
min_channels_for_multi_channel_warning = 10 | ||
|
||
def _handle_backward_compatibility_on_load(self): | ||
|
||
# For backwards compatibility - this reformats metrics_kwargs as metric_params | ||
if (metrics_kwargs := self.params.get("metrics_kwargs")) is not None: | ||
|
||
metric_params = {} | ||
for metric_name in self.params["metric_names"]: | ||
metric_params[metric_name] = deepcopy(metrics_kwargs) | ||
self.params["metric_params"] = metric_params | ||
|
||
del self.params["metrics_kwargs"] | ||
|
||
def _set_params( | ||
self, | ||
metric_names=None, | ||
peak_sign="neg", | ||
upsampling_factor=10, | ||
sparsity=None, | ||
metric_params=None, | ||
metrics_kwargs=None, | ||
include_multi_channel_metrics=False, | ||
delete_existing_metrics=False, | ||
|
@@ -134,33 +135,24 @@ def _set_params( | |
if include_multi_channel_metrics: | ||
metric_names += get_multi_channel_template_metric_names() | ||
|
||
if metrics_kwargs is None: | ||
metrics_kwargs_ = _default_function_kwargs.copy() | ||
if len(other_kwargs) > 0: | ||
for m in other_kwargs: | ||
if m in metrics_kwargs_: | ||
metrics_kwargs_[m] = other_kwargs[m] | ||
else: | ||
metrics_kwargs_ = _default_function_kwargs.copy() | ||
metrics_kwargs_.update(metrics_kwargs) | ||
if metrics_kwargs is not None and metric_params is None: | ||
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" | ||
deprecation_msg = "`metrics_kwargs` is deprecated and will be removed in version 0.104.0. Please use `metric_params` instead" | ||
|
||
metric_params = {} | ||
for metric_name in metric_names: | ||
metric_params[metric_name] = deepcopy(metrics_kwargs) | ||
|
||
metric_params_ = get_default_tm_params(metric_names) | ||
for k in metric_params_: | ||
if metric_params is not None and k in metric_params: | ||
metric_params_[k].update(metric_params[k]) | ||
|
||
metrics_to_compute = metric_names | ||
tm_extension = self.sorting_analyzer.get_extension("template_metrics") | ||
if delete_existing_metrics is False and tm_extension is not None: | ||
|
||
existing_params = tm_extension.params["metrics_kwargs"] | ||
# checks that existing metrics were calculated using the same params | ||
if existing_params != metrics_kwargs_: | ||
warnings.warn( | ||
f"The parameters used to calculate the previous template metrics are different" | ||
f"than those used now.\nPrevious parameters: {existing_params}\nCurrent " | ||
f"parameters: {metrics_kwargs_}\nDeleting previous template metrics..." | ||
) | ||
tm_extension.params["metric_names"] = [] | ||
existing_metric_names = [] | ||
else: | ||
existing_metric_names = tm_extension.params["metric_names"] | ||
|
||
existing_metric_names = tm_extension.params["metric_names"] | ||
existing_metric_names_propogated = [ | ||
metric_name for metric_name in existing_metric_names if metric_name not in metrics_to_compute | ||
] | ||
|
@@ -171,7 +163,7 @@ def _set_params( | |
sparsity=sparsity, | ||
peak_sign=peak_sign, | ||
upsampling_factor=int(upsampling_factor), | ||
metrics_kwargs=metrics_kwargs_, | ||
metric_params=metric_params_, | ||
delete_existing_metrics=delete_existing_metrics, | ||
metrics_to_compute=metrics_to_compute, | ||
) | ||
|
@@ -273,7 +265,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri | |
sampling_frequency=sampling_frequency_up, | ||
trough_idx=trough_idx, | ||
peak_idx=peak_idx, | ||
**self.params["metrics_kwargs"], | ||
**self.params["metric_params"][metric_name], | ||
) | ||
except Exception as e: | ||
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") | ||
|
@@ -312,7 +304,7 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri | |
template_upsampled, | ||
channel_locations=channel_locations_sparse, | ||
sampling_frequency=sampling_frequency_up, | ||
**self.params["metrics_kwargs"], | ||
**self.params["metric_params"][metric_name], | ||
) | ||
except Exception as e: | ||
warnings.warn(f"Error computing metric {metric_name} for unit {unit_id}: {e}") | ||
|
@@ -326,8 +318,8 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri | |
|
||
def _run(self, verbose=False): | ||
|
||
delete_existing_metrics = self.params["delete_existing_metrics"] | ||
metrics_to_compute = self.params["metrics_to_compute"] | ||
delete_existing_metrics = self.params["delete_existing_metrics"] | ||
|
||
# compute the metrics which have been specified by the user | ||
computed_metrics = self._compute_metrics( | ||
|
@@ -343,9 +335,21 @@ def _run(self, verbose=False): | |
): | ||
existing_metrics = tm_extension.params["metric_names"] | ||
|
||
existing_metrics = [] | ||
# here we get in the loaded via the dict only (to avoid full loading from disk after params reset) | ||
tm_extension = self.sorting_analyzer.extensions.get("template_metrics", None) | ||
if ( | ||
delete_existing_metrics is False | ||
and tm_extension is not None | ||
and tm_extension.data.get("metrics") is not None | ||
): | ||
existing_metrics = tm_extension.params["metric_names"] | ||
|
||
# append the metrics which were previously computed | ||
for metric_name in set(existing_metrics).difference(metrics_to_compute): | ||
computed_metrics[metric_name] = tm_extension.data["metrics"][metric_name] | ||
# some metrics names produce data columns with other names. This deals with that. | ||
for column_name in tm_compute_name_to_column_names[metric_name]: | ||
computed_metrics[column_name] = tm_extension.data["metrics"][column_name] | ||
|
||
self.data["metrics"] = computed_metrics | ||
|
||
|
@@ -372,6 +376,35 @@ def _get_data(self): | |
) | ||
|
||
|
||
def get_default_tm_params(metric_names): | ||
if metric_names is None: | ||
metric_names = get_template_metric_names() | ||
|
||
base_tm_params = _default_function_kwargs | ||
|
||
metric_params = {} | ||
for metric_name in metric_names: | ||
metric_params[metric_name] = deepcopy(base_tm_params) | ||
|
||
return metric_params | ||
|
||
|
||
# a dict converting the name of the metric for computation to the output of that computation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So for templates we don't have the same structure as for quality metrics where we have a separate py file with all these names and a list of metrics? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, it's a bit different. Here we're copying the |
||
tm_compute_name_to_column_names = { | ||
"peak_to_valley": ["peak_to_valley"], | ||
"peak_trough_ratio": ["peak_trough_ratio"], | ||
"half_width": ["half_width"], | ||
"repolarization_slope": ["repolarization_slope"], | ||
"recovery_slope": ["recovery_slope"], | ||
"num_positive_peaks": ["num_positive_peaks"], | ||
"num_negative_peaks": ["num_negative_peaks"], | ||
"velocity_above": ["velocity_above"], | ||
"velocity_below": ["velocity_below"], | ||
"exp_decay": ["exp_decay"], | ||
"spread": ["spread"], | ||
} | ||
|
||
|
||
def get_trough_and_peak_idx(template): | ||
""" | ||
Return the indices into the input template of the detected trough | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a deep copy is expensive, but there's no way around it I guess right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, this is the backwards compatibility case, where the individual
metric_kwargs
dict (which used to be forced to be the same for every template metric) gets copied repeatedly intometric_params
. If there isn't a deepcopy,metric_params
is just a dict of references to themetric_kwargs
dict. So if one gets altered, they all do. Luckily, the dict is just a dict of kwargs (so is small) and this code only runs on instantiation if someone tries to use the old notation. So shouldn't be a performance issue.