Skip to content

Commit

Permalink
Merge pull request #3497 from zm711/merge-qc
Browse files Browse the repository at this point in the history
Fix dtype of quality metrics before and after merging
  • Loading branch information
alejoe91 authored Jan 10, 2025
2 parents 0b1bf67 + 33feca3 commit 82d62ca
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 4 deletions.
20 changes: 19 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
_misc_metric_name_to_func,
_possible_pc_metric_names,
qm_compute_name_to_column_names,
column_name_to_column_dtype,
)
from .misc_metrics import _default_params as misc_metrics_params
from .pca_metrics import _default_params as pca_metrics_params
Expand Down Expand Up @@ -140,13 +141,20 @@ def _merge_extension_data(
all_unit_ids = new_sorting_analyzer.unit_ids
not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)]

# this creates a new metrics dictionary, but the dtype for everything will be
# object. So we will need to fix this later after computing metrics
metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns)

metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :]
metrics.loc[new_unit_ids, :] = self._compute_metrics(
new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs
)

# we need to fix the dtypes after we compute everything because we have nans
# we can iterate through the columns and convert them back to the dtype
# of the original quality dataframe.
for column in old_metrics.columns:
metrics[column] = metrics[column].astype(old_metrics[column].dtype)

new_data = dict(metrics=metrics)
return new_data

Expand Down Expand Up @@ -229,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri
# add NaN for empty units
if len(empty_unit_ids) > 0:
metrics.loc[empty_unit_ids] = np.nan
# num_spikes is an int and should be 0
if "num_spikes" in metrics.columns:
metrics.loc[empty_unit_ids, ["num_spikes"]] = 0

# we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns
# (in case of NaN values)
metrics = metrics.convert_dtypes()

# we do this because the convert_dtypes infers the wrong types sometimes.
# the actual types for columns can be found in column_name_to_column_dtype dictionary.
for column in metrics.columns:
if column in column_name_to_column_dtype:
metrics[column] = metrics[column].astype(column_name_to_column_dtype[column])

return metrics

def _run(self, verbose=False, **job_kwargs):
Expand Down
41 changes: 40 additions & 1 deletion src/spikeinterface/qualitymetrics/quality_metric_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
"amplitude_cutoff": ["amplitude_cutoff"],
"amplitude_median": ["amplitude_median"],
"amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"],
"synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"],
"synchrony": [
"sync_spike_2",
"sync_spike_4",
"sync_spike_8",
],
"firing_range": ["firing_range"],
"drift": ["drift_ptp", "drift_std", "drift_mad"],
"sd_ratio": ["sd_ratio"],
Expand All @@ -79,3 +83,38 @@
"silhouette": ["silhouette"],
"silhouette_full": ["silhouette_full"],
}

# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them
column_name_to_column_dtype = {
"num_spikes": int,
"firing_rate": float,
"presence_ratio": float,
"snr": float,
"isi_violations_ratio": float,
"isi_violations_count": float,
"rp_violations": float,
"rp_contamination": float,
"sliding_rp_violation": float,
"amplitude_cutoff": float,
"amplitude_median": float,
"amplitude_cv_median": float,
"amplitude_cv_range": float,
"sync_spike_2": float,
"sync_spike_4": float,
"sync_spike_8": float,
"firing_range": float,
"drift_ptp": float,
"drift_std": float,
"drift_mad": float,
"sd_ratio": float,
"isolation_distance": float,
"l_ratio": float,
"d_prime": float,
"nn_hit_rate": float,
"nn_miss_rate": float,
"nn_isolation": float,
"nn_unit_id": float,
"nn_noise_overlap": float,
"silhouette": float,
"silhouette_full": float,
}
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple):
assert "isolation_distance" in metrics.columns


def test_merging_quality_metrics(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple

metrics = compute_quality_metrics(
sorting_analyzer,
metric_names=None,
qm_params=dict(isi_violation=dict(isi_threshold_ms=2)),
skip_pc_metrics=False,
seed=2205,
)

# sorting_analyzer_simple has ten units
new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]])

new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data()

# we should copy over the metrics after merge
for column in metrics.columns:
assert column in new_metrics.columns
# should copy dtype too
assert metrics[column].dtype == new_metrics[column].dtype

# 10 units vs 9 units
assert len(metrics.index) > len(new_metrics.index)


def test_compute_quality_metrics_recordingless(sorting_analyzer_simple):

sorting_analyzer = sorting_analyzer_simple
Expand Down Expand Up @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple):
seed=2205,
)

for empty_unit_id in sorting_empty.get_empty_unit_ids():
# num_spikes are ints not nans so we confirm empty units are nans for everything except
# num_spikes which should be 0
nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"]
for empty_unit_ids in sorting_empty.get_empty_unit_ids():
from pandas import isnull

assert np.all(isnull(metrics_empty.loc[empty_unit_id].values))
assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values))
if "num_spikes" in metrics_empty.columns:
assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0


# TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()
Expand Down

0 comments on commit 82d62ca

Please sign in to comment.