Skip to content

Commit

Permalink
Continue with docstrings.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Sep 30, 2024
1 parent 1cf68ea commit db53c14
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 114 deletions.
43 changes: 21 additions & 22 deletions debugging/_test_session_alignment.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,16 @@
from __future__ import annotations

import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
import numpy as np
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms
from scipy.optimize import minimize
from pathlib import Path
import alignment_utils # TODO
import pickle
import session_alignment # TODO
from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks
from spikeinterface.widgets.motion import DriftRasterMapWidget
from spikeinterface.widgets.base import BaseWidget
import plotting


import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
import matplotlib.pyplot as plt
import numpy as np
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms
from scipy.optimize import minimize
from pathlib import Path
import alignment_utils # TODO
import pickle
import session_alignment # TODO
from spikeinterface.sortingcomponents.motion import correct_motion_on_peaks
import spikeinterface.full as si


# TODO: all of the nonrigid methods (and even rigid) could be having some strange affects on AP
# waveforms. definately needs looking into!
Expand Down Expand Up @@ -132,6 +112,25 @@
# - xcorr is not the best for large shifts due to lower num overlapping samples
# -

def _prep_recording(recording, plot=False):
"""
:param recording:
:return:
"""
peaks = detect_peaks(recording, method="locally_exclusive")

peak_locations = localize_peaks(recording, peaks, method="grid_convolution")

if plot:
si.plot_drift_raster_map(
peaks=peaks,
peak_locations=peak_locations,
recording=recording,
clim=(-300, 0), # fix clim for comparability across plots
)
plt.show()

return peaks, peak_locations

MOTION = True # True
SAVE = False
Expand Down
179 changes: 115 additions & 64 deletions debugging/alignment_utils.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,7 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from spikeinterface import BaseRecording

from spikeinterface import BaseRecording
import numpy as np
import matplotlib.pyplot as plt
import spikeinterface.full as si
from spikeinterface.generation.session_displacement_generator import generate_session_displacement_recordings
from spikeinterface.sortingcomponents.peak_detection import detect_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms
from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram
from scipy.optimize import minimize
from pathlib import Path
from spikeinterface.sortingcomponents.motion import InterpolateMotionRecording
from spikeinterface.sortingcomponents.motion.motion_utils import get_spatial_windows, Motion
from spikeinterface.sortingcomponents.motion.iterative_template import iterative_template_registration
from spikeinterface.sortingcomponents.motion.motion_interpolation import correct_motion_on_peaks
from scipy.ndimage import gaussian_filter
from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel

Expand Down Expand Up @@ -203,35 +189,95 @@ def DEPRECATE_get_chunked_hist_eigenvector(chunked_session_histograms):


def compute_histogram_crosscorrelation(
session_histogram_list,
non_rigid_windows,
num_shifts_block,
interpolate,
interp_factor,
kriging_sigma,
kriging_p,
kriging_d,
smoothing_sigma_bin,
smoothing_sigma_window,
session_histogram_list: list[np.ndarray],
non_rigid_windows: np.ndarray,
num_shifts_block: int,
interpolate: bool,
interp_factor: int,
kriging_sigma: float,
kriging_p: float,
kriging_d: float,
smoothing_sigma_bin: float,
smoothing_sigma_window: float,
):
"""
Given a list of session activity histograms, cross-correlate
all histograms returning the peak correlation shift (in indices)
in a symmetric (num_session x num_session) matrix.
Supports non-rigid estimation by windowing the activity histogram
and performing separate cross-correlations on each window separately.
Parameters
----------
# TODO: what happens when this bigger than thing. Also rename about shifts
# TODO: this is kind of wasteful, no optimisations made against redundant
# session computation, but these in generate very fast.
# The problem is this stratergy completely fails when thexcorr is very bad.
# The smoothing and interpolation make it much worse, because bad xcorr are
# merged together. The xcorr can be bad when the recording is shifted a lot
# and so there are empty regions that are correlated with non-empty regions
# in the nonrigid approach. A different approach will need to be taken in
# this case.
# Note that due to interpolation, ij vs ji are not exact duplicates.
# dont improve for now.
session_histogram_list : list[np.ndarray]
non_rigid_windows : np.ndarray
A (num windows x num_bins) binary of weights by which to window
the activity histogram for non-rigid-registration. For example, if
2 rectangular masks were used, there would be a two row binary mask
the first row with mask of the first half of the probe and the second
row a mask for the second half of the probe.
num_shifts_block : int
Number of indices by which to shift the histogram to find the maximum
of the cross correlation. If `None`, the entire activity histograms
are cross-correlated.
interpolate : bool
If `True`, the cross-correlation is interpolated before maximum is taken.
interp_factor:
Factor by which to interpolate the cross-correlation.
kriging_sigma : float
sigma parameter for kriging_kernel function. See `kriging_kernel`.
kriging_p : float
p parameter for kriging_kernel function. See `kriging_kernel`.
kriging_d : float
d parameter for kriging_kernel function. See `kriging_kernel`.
smoothing_sigma_bin : float
sigma parameter for the gaussian smoothing kernel over the
spatial bins.
smoothing_sigma_window : float
sigma parameter for the gaussian smoothing kernel over the
non-rigid windows.
Returns
-------
shift_matrix : ndarray
A (num_session x num_session) symmetric matrix of shifts
(indices) between pairs of session activity histograms.
Notes
-----
- This function is very similar to the IterativeTemplateRegistration
function used in motion correct, though slightly difference in scope.
It was not convenient to merge them at this time, but worth looking
into in future.
- Some obvious performances boosts, not done so because already fast
1) the cross correlations for each session comparison are performed
twice. They are slightly different due to interpolation, but
still probably better to calculate once and flip.
2) `num_shifts_block` is implemented by simply making the full
cross correlation. Would probably be nicer to explicitly calculate
only where needed. However, in general these cross correlations are
only a few thousand datapoints and so are already extremely
fast to cross correlate.
Notes
-----
- The original kilosort method does not work in the inter-session
context because it averages over time bins to form a template to
align too. In this case, averaging over a small number of possibly
quite different session histograms does not work well.
- In the nonrigid case, this strategy can completely fail when the xcorr
is very bad for a certain window. The smoothing and interpolation
make it much worse, because bad xcorr are merged together. The x-corr
can be bad when the recording is shifted a lot and so there are empty
regions that are correlated with non-empty regions in the nonrigid
approach. A different approach will need to be taken in this case.
Note that kilosort method does not work because creating a
mean does not make sense over sessions.
Expand Down Expand Up @@ -259,7 +305,10 @@ def compute_histogram_crosscorrelation(
xcorr = np.correlate(windowed_histogram_i, windowed_histogram_j, mode="full")

if num_shifts_block:
window_indices = np.arange(center_bin - num_shifts_block, center_bin + num_shifts_block)
window_indices = np.arange(
center_bin - num_shifts_block,
center_bin + num_shifts_block
)
mask = np.zeros_like(xcorr)
mask[window_indices] = 1
xcorr *= mask
Expand Down Expand Up @@ -298,31 +347,33 @@ def compute_histogram_crosscorrelation(
return shift_matrix


def shift_array_fill_zeros(array, shift):
abs_shift = np.abs(shift)
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0)
padded_hist = np.pad(array, pad_tuple, mode="constant")
cut_padded_hist = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]
return cut_padded_hist
def shift_array_fill_zeros(array: np.ndarray, shift: int) -> np.ndarray:
"""
Shift an array by `shift` indices, padding with zero.
Samples going out of bounds are dropped i,e, the array is not
extended and samples are not wrapped around to the start of the array.
Parameters
----------
# TODO: deprecate
def prep_recording(recording, plot=False):
"""
:param recording:
:return:
"""
peaks = detect_peaks(recording, method="locally_exclusive")
array : np.ndarray
The array to pad.
shift : int
Number of indices why which to shift the array. If positive, the
zeros are added from the end of the array. If negative, the zeros
are added from the start of the array.
peak_locations = localize_peaks(recording, peaks, method="grid_convolution")
Returns
-------
if plot:
si.plot_drift_raster_map(
peaks=peaks,
peak_locations=peak_locations,
recording=recording,
clim=(-300, 0), # fix clim for comparability across plots
)
plt.show()
cut_padded_array : np.ndarray
The `array` padded with zeros and cut down (i.e. out of bounds
samples dropped).
"""
abs_shift = np.abs(shift)
pad_tuple = (0, abs_shift) if shift > 0 else (abs_shift, 0)
padded_hist = np.pad(array, pad_tuple, mode="constant")
cut_padded_array = padded_hist[abs_shift:] if shift >= 0 else padded_hist[:-abs_shift]

return peaks, peak_locations
return cut_padded_array
Loading

0 comments on commit db53c14

Please sign in to comment.