From 1cf68ea476d55c01bbb984fa8cab6972e02691fa Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Fri, 27 Sep 2024 17:50:49 +0100 Subject: [PATCH] Continue adding docstring / types. --- debugging/alignment_utils.py | 103 ++++++++++++++++++++++++--------- debugging/session_alignment.py | 23 ++++---- 2 files changed, 90 insertions(+), 36 deletions(-) diff --git a/debugging/alignment_utils.py b/debugging/alignment_utils.py index a32488eca6..f16c144f89 100644 --- a/debugging/alignment_utils.py +++ b/debugging/alignment_utils.py @@ -1,3 +1,8 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from spikeinterface import BaseRecording + import numpy as np import matplotlib.pyplot as plt import spikeinterface.full as si @@ -14,16 +19,50 @@ from scipy.ndimage import gaussian_filter from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel -# ----------------------------------------------------------------------------- +# ############################################################################# # Get Histograms -# ----------------------------------------------------------------------------- +# ############################################################################# def get_activity_histogram( - recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, depth_smooth_um + recording: BaseRecording, + peaks: np.ndarray, + peak_locations: np.ndarray, + spatial_bin_edges: np.ndarray, + log_scale: bool, + bin_s: float| None, + depth_smooth_um: float | None, ): """ - TODO: assumes 1-segment recording + Generate a 2D activity histogram for the session. Wraps the underlying + spikeinterface function with some adjustments for scaling to time and + log transform. + + Parameters + ---------- + + recording: BaseRecording, + A SpikeInterface recording object. + peaks: np.ndarray, + A SpikeInterface `peaks` array. + peak_locations: np.ndarray, + A SpikeInterface `peak_locations` array. + spatial_bin_edges: np.ndarray, + A (1 x n_bins + 1) array of spatial (probe y dimension) bin edges. + log_scale: bool, + If `True`, histogram is log scaled. + bin_s | None: float, + If `None`, a single histogram will be generated from all session + peaks. Otherwise, multiple histograms will be generated, one for + each time bin. + depth_smooth_um: float | None + If not `None`, smooth the histogram across the spatial + axis. see `make_2d_motion_histogram()` for details. + + TODO + ---- + - assumes 1-segment recording + - ask Sam whether it makes sense to integrate this function with `make_2d_motion_histogram`. """ activity_histogram, temporal_bin_edges, generated_spatial_bin_edges = make_2d_motion_histogram( recording, @@ -55,11 +94,6 @@ def get_activity_histogram( return activity_histogram, temporal_bin_centers, spatial_bin_centers -# ----------------------------------------------------------------------------- -# Utils -# ----------------------------------------------------------------------------- - - def get_bin_centers(bin_edges): return (bin_edges[1:] + bin_edges[:-1]) / 2 @@ -71,6 +105,10 @@ def estimate_chunk_size(scaled_activity_histogram): estimated within 10% 99% of the time, corrected based on assumption of Poisson firing (based on CLT). + + TODO + ---- + - make the details available. """ firing_rate = np.percentile(scaled_activity_histogram, 98) @@ -84,32 +122,39 @@ def estimate_chunk_size(scaled_activity_histogram): return t, lambda_ -# ----------------------------------------------------------------------------- +# ############################################################################# # Chunked Histogram estimation methods -# ----------------------------------------------------------------------------- - +# ############################################################################# +# Given a set off chunked_session_histograms (num time chunks x num spatial bins) +# take the summary statistic over the time axis. def get_chunked_hist_mean(chunked_session_histograms): - """ """ + mean_hist = np.mean(chunked_session_histograms, axis=0) return mean_hist def get_chunked_hist_median(chunked_session_histograms): - """ """ + median_hist = np.median(chunked_session_histograms, axis=0) return median_hist def get_chunked_hist_supremum(chunked_session_histograms): - """ """ + max_hist = np.max(chunked_session_histograms, axis=0) return max_hist def get_chunked_hist_poisson_estimate(chunked_session_histograms): - """ """ + """ + Make a MLE estimate of the most likely value for each bin + given the assumption of Poisson firing. Turns out this is + basically identical to the mean :'D. + Keeping for now as opportunity to add prior or do some outlier + removal per bin. But if not useful, deprecate in future. + """ def obj_fun(lambda_, m, sum_k): return -(sum_k * np.log(lambda_) - m * lambda_) @@ -123,17 +168,19 @@ def obj_fun(lambda_, m, sum_k): m = ks.shape sum_k = np.sum(ks) - # lol, this is painfully close to the mean, no meaningful - # prior comes to mind to extend the method with. - poisson_estimate[i] = minimize(obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)).x + poisson_estimate[i] = minimize( + obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),) + ).x + return poisson_estimate -# TODO: currently deprecated due to scaling issues between -# sessions. A much better (?) way will to make PCA from all -# sessions, then align based on projection -def get_chunked_hist_eigenvector(chunked_session_histograms): - """ """ +def DEPRECATE_get_chunked_hist_eigenvector(chunked_session_histograms): + """ + TODO: currently deprecated due to scaling issues between + sessions. A much better (?) way will to make PCA from all + sessions, then align based on projection + """ if chunked_session_histograms.shape[0] == 1: # TODO: handle elsewhere return chunked_session_histograms.squeeze() @@ -150,9 +197,9 @@ def get_chunked_hist_eigenvector(chunked_session_histograms): return first_eigenvector -# ----------------------------------------------------------------------------- +# ############################################################################# # TODO: MOVE creating recordings -# ----------------------------------------------------------------------------- +# ############################################################################# def compute_histogram_crosscorrelation( @@ -168,6 +215,10 @@ def compute_histogram_crosscorrelation( smoothing_sigma_window, ): """ + + + + # TODO: what happens when this bigger than thing. Also rename about shifts # TODO: this is kind of wasteful, no optimisations made against redundant # session computation, but these in generate very fast. diff --git a/debugging/session_alignment.py b/debugging/session_alignment.py index cba79a83d7..1fc76b78c1 100644 --- a/debugging/session_alignment.py +++ b/debugging/session_alignment.py @@ -190,6 +190,7 @@ def align_sessions( a list of corrected `peak_locations` and activity histogram generated after correction. """ + non_rigid_window_kwargs = copy.deepcopy(non_rigid_window_kwargs) estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs) compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs) interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs) @@ -457,7 +458,7 @@ def _get_single_session_activity_histogram( Returns ------- session_histogram : np.ndarray - Summary acitivity histogram for the session. + Summary activity histogram for the session. temporal_bin_centers : np.ndarray Temporal bin center (session mid-point as we only have one time point) for the session. @@ -469,7 +470,7 @@ def _get_single_session_activity_histogram( for the chunked histograms, with length num_chunks. "session_std" : The mean across bin-wise standard deviation of the chunked histograms. - "chunked_bin_size_s" : time of each chunk used to TODO + "chunked_bin_size_s" : time of each chunk used to calculate the chunked histogram. """ times = recording.get_times() @@ -543,8 +544,8 @@ def _create_motion_recordings( Returns ------- corrected_recordings_list : list[BaseRecording] - A list of InterpolateMotionRecording recordings of shift- - corrected recordings coressponding to `recordings_list`. + A list of InterpolateMotionRecording recordings of shift-corrected + recordings corresponding to `recordings_list`. motion_objects_list : list[Motion] A list of Motion objects. If the recording in `recordings_list` @@ -613,9 +614,9 @@ def _add_displacement_to_interpolate_recording( TODO ---- - Check + ask Sam if any other fields need to be chagned. This is a little - # hairy (4 possible combinations of new and - # old displacement shapes, rigid or nonrigid, so test thoroughly. + Check + ask Sam if any other fields need to be changed. This is a little + hairy (4 possible combinations of new and old displacement shapes, + rigid or nonrigid, so test thoroughly. """ # Everything is done in place, so keep a short variable # name reference to the new recordings `motion` object @@ -665,11 +666,13 @@ def _correct_session_displacement( ): """ Internal function to apply the correction from `align_sessions` - to build a corrected histogram for comparison. First, + to build a corrected histogram for comparison. First, create + new shifted peak locations. Then, create a new 'corrected' + activity histogram from the new peak locations. Parameters ---------- - see `align_sessions()` for parameters. TODO: can add motion shifts if we need. + see `align_sessions()` for parameters. Returns ------- @@ -796,7 +799,6 @@ def _compute_session_alignment( else: shifts = rigid_shifts + non_rigid_shifts - breakpoint() return shifts, non_rigid_windows, non_rigid_window_centers @@ -865,6 +867,7 @@ def _akima_interpolate_nonrigid_shifts( An array (length num_spatial_bins) of shifts interpolated from the non-rigid shifts. + TODO ---- requires scipy 14 """