From db53c14de678e580959016dcc9db7ab57dd10ba0 Mon Sep 17 00:00:00 2001 From: JoeZiminski Date: Mon, 30 Sep 2024 17:37:22 +0100 Subject: [PATCH] Continue with docstrings. --- debugging/_test_session_alignment.py | 43 ++++--- debugging/alignment_utils.py | 179 +++++++++++++++++---------- debugging/plotting.py | 83 +++++++++---- debugging/session_alignment.py | 4 +- 4 files changed, 195 insertions(+), 114 deletions(-) diff --git a/debugging/_test_session_alignment.py b/debugging/_test_session_alignment.py index 54b63d6357..ef92d995fa 100644 --- a/debugging/_test_session_alignment.py +++ b/debugging/_test_session_alignment.py @@ -1,36 +1,16 @@ from __future__ import annotations -import spikeinterface.full as si from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings import matplotlib.pyplot as plt import numpy as np -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms -from scipy.optimize import minimize -from pathlib import Path import alignment_utils # TODO import pickle import session_alignment # TODO -from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks -from spikeinterface.widgets.motion import DriftRasterMapWidget -from spikeinterface.widgets.base import BaseWidget import plotting - - -import spikeinterface.full as si -from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings -import matplotlib.pyplot as plt -import numpy as np from spikeinterface.sortingcomponents.peak_detection import detect_peaks from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms -from scipy.optimize import minimize -from pathlib import Path -import alignment_utils # TODO -import pickle -import session_alignment # TODO -from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks +import spikeinterface.full as si + # TODO: all of the nonrigid methods (and even rigid) could be having some strange affects on AP # waveforms. definately needs looking into! @@ -132,6 +112,25 @@ # - xcorr is not the best for large shifts due to lower num overlapping samples # - +def _prep_recording(recording, plot=False): + """ + :param recording: + :return: + """ + peaks = detect_peaks(recording, method="locally_exclusive") + + peak_locations = localize_peaks(recording, peaks, method="grid_convolution") + + if plot: + si.plot_drift_raster_map( + peaks=peaks, + peak_locations=peak_locations, + recording=recording, + clim=(-300, 0), # fix clim for comparability across plots + ) + plt.show() + + return peaks, peak_locations MOTION = True # True SAVE = False diff --git a/debugging/alignment_utils.py b/debugging/alignment_utils.py index f16c144f89..156cd8a6dc 100644 --- a/debugging/alignment_utils.py +++ b/debugging/alignment_utils.py @@ -1,21 +1,7 @@ -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from spikeinterface import BaseRecording - +from spikeinterface import BaseRecording import numpy as np -import matplotlib.pyplot as plt -import spikeinterface.full as si -from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms +from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram from scipy.optimize import minimize -from pathlib import Path -from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording -from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, Motion -from spikeinterface.sortingcomponents.motion.iterative_template import iterative_template_registration -from spikeinterface.sortingcomponents.motion.motion_interpolation import correct_motion_on_peaks from scipy.ndimage import gaussian_filter from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel @@ -203,35 +189,95 @@ def DEPRECATE_get_chunked_hist_eigenvector(chunked_session_histograms): def compute_histogram_crosscorrelation( - session_histogram_list, - non_rigid_windows, - num_shifts_block, - interpolate, - interp_factor, - kriging_sigma, - kriging_p, - kriging_d, - smoothing_sigma_bin, - smoothing_sigma_window, + session_histogram_list: list[np.ndarray], + non_rigid_windows: np.ndarray, + num_shifts_block: int, + interpolate: bool, + interp_factor: int, + kriging_sigma: float, + kriging_p: float, + kriging_d: float, + smoothing_sigma_bin: float, + smoothing_sigma_window: float, ): """ + Given a list of session activity histograms, cross-correlate + all histograms returning the peak correlation shift (in indices) + in a symmetric (num_session x num_session) matrix. + Supports non-rigid estimation by windowing the activity histogram + and performing separate cross-correlations on each window separately. + Parameters + ---------- - - # TODO: what happens when this bigger than thing. Also rename about shifts - # TODO: this is kind of wasteful, no optimisations made against redundant - # session computation, but these in generate very fast. - - # The problem is this stratergy completely fails when thexcorr is very bad. - # The smoothing and interpolation make it much worse, because bad xcorr are - # merged together. The xcorr can be bad when the recording is shifted a lot - # and so there are empty regions that are correlated with non-empty regions - # in the nonrigid approach. A different approach will need to be taken in - # this case. - - # Note that due to interpolation, ij vs ji are not exact duplicates. - # dont improve for now. + session_histogram_list : list[np.ndarray] + non_rigid_windows : np.ndarray + A (num windows x num_bins) binary of weights by which to window + the activity histogram for non-rigid-registration. For example, if + 2 rectangular masks were used, there would be a two row binary mask + the first row with mask of the first half of the probe and the second + row a mask for the second half of the probe. + num_shifts_block : int + Number of indices by which to shift the histogram to find the maximum + of the cross correlation. If `None`, the entire activity histograms + are cross-correlated. + interpolate : bool + If `True`, the cross-correlation is interpolated before maximum is taken. + interp_factor: + Factor by which to interpolate the cross-correlation. + kriging_sigma : float + sigma parameter for kriging_kernel function. See `kriging_kernel`. + kriging_p : float + p parameter for kriging_kernel function. See `kriging_kernel`. + kriging_d : float + d parameter for kriging_kernel function. See `kriging_kernel`. + smoothing_sigma_bin : float + sigma parameter for the gaussian smoothing kernel over the + spatial bins. + smoothing_sigma_window : float + sigma parameter for the gaussian smoothing kernel over the + non-rigid windows. + + Returns + ------- + + shift_matrix : ndarray + A (num_session x num_session) symmetric matrix of shifts + (indices) between pairs of session activity histograms. + + Notes + ----- + + - This function is very similar to the IterativeTemplateRegistration + function used in motion correct, though slightly difference in scope. + It was not convenient to merge them at this time, but worth looking + into in future. + + - Some obvious performances boosts, not done so because already fast + 1) the cross correlations for each session comparison are performed + twice. They are slightly different due to interpolation, but + still probably better to calculate once and flip. + 2) `num_shifts_block` is implemented by simply making the full + cross correlation. Would probably be nicer to explicitly calculate + only where needed. However, in general these cross correlations are + only a few thousand datapoints and so are already extremely + fast to cross correlate. + + Notes + ----- + + - The original kilosort method does not work in the inter-session + context because it averages over time bins to form a template to + align too. In this case, averaging over a small number of possibly + quite different session histograms does not work well. + + - In the nonrigid case, this strategy can completely fail when the xcorr + is very bad for a certain window. The smoothing and interpolation + make it much worse, because bad xcorr are merged together. The x-corr + can be bad when the recording is shifted a lot and so there are empty + regions that are correlated with non-empty regions in the nonrigid + approach. A different approach will need to be taken in this case. Note that kilosort method does not work because creating a mean does not make sense over sessions. @@ -259,7 +305,10 @@ def compute_histogram_crosscorrelation( xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full") if num_shifts_block: - window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block) + window_indices = np.arange( + center_bin - num_shifts_block, + center_bin + num_shifts_block + ) mask = np.zeros_like(xcorr) mask[window_indices] = 1 xcorr *= mask @@ -298,31 +347,33 @@ def compute_histogram_crosscorrelation( return shift_matrix -def shift_array_fill_zeros(array, shift): - abs_shift = np.abs(shift) - pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) - padded_hist = np.pad(array, pad_tuple, mode="constant") - cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] - return cut_padded_hist +def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray: + """ + Shift an array by `shift` indices, padding with zero. + Samples going out of bounds are dropped i,e, the array is not + extended and samples are not wrapped around to the start of the array. + Parameters + ---------- -# TODO: deprecate -def prep_recording(recording, plot=False): - """ - :param recording: - :return: - """ - peaks = detect_peaks(recording, method="locally_exclusive") + array : np.ndarray + The array to pad. + shift : int + Number of indices why which to shift the array. If positive, the + zeros are added from the end of the array. If negative, the zeros + are added from the start of the array. - peak_locations = localize_peaks(recording, peaks, method="grid_convolution") + Returns + ------- - if plot: - si.plot_drift_raster_map( - peaks=peaks, - peak_locations=peak_locations, - recording=recording, - clim=(-300, 0), # fix clim for comparability across plots - ) - plt.show() + cut_padded_array : np.ndarray + The `array` padded with zeros and cut down (i.e. out of bounds + samples dropped). + + """ + abs_shift = np.abs(shift) + pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0) + padded_hist = np.pad(array, pad_tuple, mode="constant") + cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift] - return peaks, peak_locations + return cut_padded_array diff --git a/debugging/plotting.py b/debugging/plotting.py index d863177d10..fefdd26299 100644 --- a/debugging/plotting.py +++ b/debugging/plotting.py @@ -1,18 +1,7 @@ import itertools -import spikeinterface.full as si -from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings -import matplotlib.pyplot as plt +from spikeinterface.core import BaseRecording import numpy as np -from spikeinterface.sortingcomponents.peak_detection import detect_peaks -from spikeinterface.sortingcomponents.peak_localization import localize_peaks -from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms -from scipy.optimize import minimize -from pathlib import Path -import alignment_utils # TODO -import pickle -import session_alignment # TODO -from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks from spikeinterface.widgets.base import BaseWidget from spikeinterface.widgets.base import to_attr from spikeinterface.widgets.motion import DriftRasterMapWidget @@ -23,18 +12,59 @@ class SessionAlignmentWidget(BaseWidget): def __init__( self, - recordings_list, - peaks_list, - peak_locations_list, - session_histogram_list, - spatial_bin_centers=None, - corrected_peak_locations_list=None, - corrected_session_histogram_list=None, - drift_raster_map_kwargs=None, - session_alignment_histogram_kwargs=None, # TODO: rename, the widget too. + recordings_list: list[BaseRecording], + peaks_list: list[np.ndarray], + peak_locations_list: list[np.ndarray], + session_histogram_list: list[np.ndarray], + spatial_bin_centers: np.ndarray | None = None, + corrected_peak_locations_list: list[np.ndarray] | None = None, + corrected_session_histogram_list: list[np.ndarray] = None, + drift_raster_map_kwargs : dict | None = None, + session_alignment_histogram_kwargs: dict | None = None, **backend_kwargs, ): + """ + Widget to display the output of inter-session alignment. + In the top section, `DriftRasterMapWidget`s are used to display + the raster maps for each session, before and after alignment. + The order of all lists should correspond to the same recording. + + If histograms are provided, `SessionAlignmentHistogramWidget` + are used to show the activity histograms, before and after alignment. + See `align_sessions` for context. + + Corrected and uncorrected activity histograms are generated + as part of the `align_sessions` step. + + Parameters + ---------- + + recordings_list : list[BaseRecording] + List of recordings to plot. + peaks_list : list[np.ndarray] + List of detected peaks for each session. + peak_locations_list : list[np.ndarray] + List of detected peak locations for each session. + session_histogram_list : np.ndarray | None + A list of activity histograms as output from `align_sessions`. + If `None`, no histograms will be displayed. + spatial_bin_centers=None : np.ndarray | None + Spatial bin centers for the histogram (each session activity + histogram will have the same spatial bin centers). + corrected_peak_locations_list : list[np.ndarray] | None + A list of corrected peak locations. If provided, the corrected + raster plots will be displayed. + corrected_session_histogram_list : list[np.ndarray] + A list of corrected session activity histograms, as + output from `align_sessions`. + drift_raster_map_kwargs : dict | None + Kwargs to be passed to `DriftRasterMapWidget`. + session_alignment_histogram_kwargs : dict | None + Kwargs to be passed to `SessionAlignmentHistogramWidget`. + **backend_kwargs + """ + # TODO: check all lengths more carefully e.g. histogram vs. peaks. assert len(recordings_list) <= 8, ( @@ -49,7 +79,6 @@ def __init__( "version of `session_histogram_list`." ) if corrected_peak_locations_list is not None: - # TODO: this is almost identical to the above. if not len(corrected_peak_locations_list) == len(peak_locations_list): raise ValueError( "`corrected_peak_locations_list` must be the same length as `peak_locations_list`. " @@ -83,7 +112,9 @@ def __init__( BaseWidget.__init__(self, plot_data, backend="matplotlib", **backend_kwargs) def plot_matplotlib(self, data_plot, **backend_kwargs): - + """ + Create the `SessionAlignmentWidget` for matplotlib. + """ from spikeinterface.widgets.utils_matplotlib import make_mpl_figure dp = to_attr(data_plot) @@ -101,8 +132,7 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): max_y = np.max(np.hstack([locs["y"] for locs in dp.peak_locations_list])) if dp.corrected_peak_locations_list is None: - - # Own function + # TODO: Own function num_cols = np.min([4, len(dp.peak_locations_list)]) num_rows = 1 if num_cols <= 4 else 2 @@ -200,8 +230,9 @@ def plot_matplotlib(self, data_plot, **backend_kwargs): class SessionAlignmentHistogramWidget(BaseWidget): - """ """ + """ + """ def __init__( self, session_histogram_list: list[np.ndarray], diff --git a/debugging/session_alignment.py b/debugging/session_alignment.py index 1fc76b78c1..0120638aba 100644 --- a/debugging/session_alignment.py +++ b/debugging/session_alignment.py @@ -901,7 +901,7 @@ def _get_shifts_from_session_matrix( ---------- alignment_order : "to_middle" or "to_session_X" where "N" is the number of the session to align to. - session_offsets_matri : np.ndarray + session_offsets_matrix : np.ndarray The num_sessions x num_sessions symmetric matrix of displacements between all sessions, generated by `_compute_session_alignment()`. @@ -957,7 +957,7 @@ def _check_align_sesssions_inpus( accepted_hist_methods = ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"] method = estimate_histogram_kwargs["method"] - if not method in ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"]: + if method not in ["entire_session", "chunked_mean", "chunked_median", "chunked_supremum", "chunked_poisson"]: raise ValueError(f"`method` option must be one of: {accepted_hist_methods}") if alignment_order != "to_middle":