diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index d71450853f..11ce3d0160 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -17,6 +17,7 @@ _misc_metric_name_to_func, _possible_pc_metric_names, qm_compute_name_to_column_names, + column_name_to_column_dtype, ) from .misc_metrics import _default_params as misc_metrics_params from .pca_metrics import _default_params as pca_metrics_params @@ -140,13 +141,20 @@ def _merge_extension_data( all_unit_ids = new_sorting_analyzer.unit_ids not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids)] + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) - metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] metrics.loc[new_unit_ids, :] = self._compute_metrics( new_sorting_analyzer, new_unit_ids, verbose, metric_names, **job_kwargs ) + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + new_data = dict(metrics=metrics) return new_data @@ -229,10 +237,20 @@ def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metri # add NaN for empty units if len(empty_unit_ids) > 0: metrics.loc[empty_unit_ids] = np.nan + # num_spikes is an int and should be 0 + if "num_spikes" in metrics.columns: + metrics.loc[empty_unit_ids, ["num_spikes"]] = 0 # we use the convert_dtypes to convert the columns to the most appropriate dtype and avoid object columns # (in case of NaN values) metrics = metrics.convert_dtypes() + + # we do this because the convert_dtypes infers the wrong types sometimes. + # the actual types for columns can be found in column_name_to_column_dtype dictionary. + for column in metrics.columns: + if column in column_name_to_column_dtype: + metrics[column] = metrics[column].astype(column_name_to_column_dtype[column]) + return metrics def _run(self, verbose=False, **job_kwargs): diff --git a/src/spikeinterface/qualitymetrics/quality_metric_list.py b/src/spikeinterface/qualitymetrics/quality_metric_list.py index fc7e92b50d..23b781eb9d 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_list.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_list.py @@ -66,7 +66,11 @@ "amplitude_cutoff": ["amplitude_cutoff"], "amplitude_median": ["amplitude_median"], "amplitude_cv": ["amplitude_cv_median", "amplitude_cv_range"], - "synchrony": ["sync_spike_2", "sync_spike_4", "sync_spike_8"], + "synchrony": [ + "sync_spike_2", + "sync_spike_4", + "sync_spike_8", + ], "firing_range": ["firing_range"], "drift": ["drift_ptp", "drift_std", "drift_mad"], "sd_ratio": ["sd_ratio"], @@ -79,3 +83,38 @@ "silhouette": ["silhouette"], "silhouette_full": ["silhouette_full"], } + +# this dict allows us to ensure the appropriate dtype of metrics rather than allow Pandas to infer them +column_name_to_column_dtype = { + "num_spikes": int, + "firing_rate": float, + "presence_ratio": float, + "snr": float, + "isi_violations_ratio": float, + "isi_violations_count": float, + "rp_violations": float, + "rp_contamination": float, + "sliding_rp_violation": float, + "amplitude_cutoff": float, + "amplitude_median": float, + "amplitude_cv_median": float, + "amplitude_cv_range": float, + "sync_spike_2": float, + "sync_spike_4": float, + "sync_spike_8": float, + "firing_range": float, + "drift_ptp": float, + "drift_std": float, + "drift_mad": float, + "sd_ratio": float, + "isolation_distance": float, + "l_ratio": float, + "d_prime": float, + "nn_hit_rate": float, + "nn_miss_rate": float, + "nn_isolation": float, + "nn_unit_id": float, + "nn_noise_overlap": float, + "silhouette": float, + "silhouette_full": float, +} diff --git a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py index 60f0490f51..ea8939ebb4 100644 --- a/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/tests/test_quality_metric_calculator.py @@ -48,6 +48,33 @@ def test_compute_quality_metrics(sorting_analyzer_simple): assert "isolation_distance" in metrics.columns +def test_merging_quality_metrics(sorting_analyzer_simple): + + sorting_analyzer = sorting_analyzer_simple + + metrics = compute_quality_metrics( + sorting_analyzer, + metric_names=None, + qm_params=dict(isi_violation=dict(isi_threshold_ms=2)), + skip_pc_metrics=False, + seed=2205, + ) + + # sorting_analyzer_simple has ten units + new_sorting_analyzer = sorting_analyzer.merge_units([[0, 1]]) + + new_metrics = new_sorting_analyzer.get_extension("quality_metrics").get_data() + + # we should copy over the metrics after merge + for column in metrics.columns: + assert column in new_metrics.columns + # should copy dtype too + assert metrics[column].dtype == new_metrics[column].dtype + + # 10 units vs 9 units + assert len(metrics.index) > len(new_metrics.index) + + def test_compute_quality_metrics_recordingless(sorting_analyzer_simple): sorting_analyzer = sorting_analyzer_simple @@ -106,10 +133,15 @@ def test_empty_units(sorting_analyzer_simple): seed=2205, ) - for empty_unit_id in sorting_empty.get_empty_unit_ids(): + # num_spikes are ints not nans so we confirm empty units are nans for everything except + # num_spikes which should be 0 + nan_containing_columns = [column for column in metrics_empty.columns if column != "num_spikes"] + for empty_unit_ids in sorting_empty.get_empty_unit_ids(): from pandas import isnull - assert np.all(isnull(metrics_empty.loc[empty_unit_id].values)) + assert np.all(isnull(metrics_empty.loc[empty_unit_ids, nan_containing_columns].values)) + if "num_spikes" in metrics_empty.columns: + assert sum(metrics_empty.loc[empty_unit_ids, ["num_spikes"]]) == 0 # TODO @alessio all theses old test should be moved in test_metric_functions.py or test_pca_metrics()