Skip to content

Commit

Permalink
Continue adding docstring / types.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoeZiminski committed Sep 27, 2024
1 parent 17078e8 commit 1cf68ea
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 36 deletions.
103 changes: 77 additions & 26 deletions debugging/alignment_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from spikeinterface import BaseRecording

import numpy as np
import matplotlib.pyplot as plt
import spikeinterface.full as si
Expand All @@ -14,16 +19,50 @@
from scipy.ndimage import gaussian_filter
from spikeinterface.sortingcomponents.motion.iterative_template import kriging_kernel

# -----------------------------------------------------------------------------
# #############################################################################
# Get Histograms
# -----------------------------------------------------------------------------
# #############################################################################


def get_activity_histogram(
recording, peaks, peak_locations, spatial_bin_edges, log_scale, bin_s, depth_smooth_um
recording: BaseRecording,
peaks: np.ndarray,
peak_locations: np.ndarray,
spatial_bin_edges: np.ndarray,
log_scale: bool,
bin_s: float| None,
depth_smooth_um: float | None,
):
"""
TODO: assumes 1-segment recording
Generate a 2D activity histogram for the session. Wraps the underlying
spikeinterface function with some adjustments for scaling to time and
log transform.
Parameters
----------
recording: BaseRecording,
A SpikeInterface recording object.
peaks: np.ndarray,
A SpikeInterface `peaks` array.
peak_locations: np.ndarray,
A SpikeInterface `peak_locations` array.
spatial_bin_edges: np.ndarray,
A (1 x n_bins + 1) array of spatial (probe y dimension) bin edges.
log_scale: bool,
If `True`, histogram is log scaled.
bin_s | None: float,
If `None`, a single histogram will be generated from all session
peaks. Otherwise, multiple histograms will be generated, one for
each time bin.
depth_smooth_um: float | None
If not `None`, smooth the histogram across the spatial
axis. see `make_2d_motion_histogram()` for details.
TODO
----
- assumes 1-segment recording
- ask Sam whether it makes sense to integrate this function with `make_2d_motion_histogram`.
"""
activity_histogram, temporal_bin_edges, generated_spatial_bin_edges = make_2d_motion_histogram(
recording,
Expand Down Expand Up @@ -55,11 +94,6 @@ def get_activity_histogram(
return activity_histogram, temporal_bin_centers, spatial_bin_centers


# -----------------------------------------------------------------------------
# Utils
# -----------------------------------------------------------------------------


def get_bin_centers(bin_edges):
return (bin_edges[1:] + bin_edges[:-1]) / 2

Expand All @@ -71,6 +105,10 @@ def estimate_chunk_size(scaled_activity_histogram):
estimated within 10% 99% of the time,
corrected based on assumption
of Poisson firing (based on CLT).
TODO
----
- make the details available.
"""
firing_rate = np.percentile(scaled_activity_histogram, 98)

Expand All @@ -84,32 +122,39 @@ def estimate_chunk_size(scaled_activity_histogram):
return t, lambda_


# -----------------------------------------------------------------------------
# #############################################################################
# Chunked Histogram estimation methods
# -----------------------------------------------------------------------------

# #############################################################################
# Given a set off chunked_session_histograms (num time chunks x num spatial bins)
# take the summary statistic over the time axis.

def get_chunked_hist_mean(chunked_session_histograms):
""" """

mean_hist = np.mean(chunked_session_histograms, axis=0)
return mean_hist


def get_chunked_hist_median(chunked_session_histograms):
""" """

median_hist = np.median(chunked_session_histograms, axis=0)
return median_hist


def get_chunked_hist_supremum(chunked_session_histograms):
""" """

max_hist = np.max(chunked_session_histograms, axis=0)
return max_hist


def get_chunked_hist_poisson_estimate(chunked_session_histograms):
""" """
"""
Make a MLE estimate of the most likely value for each bin
given the assumption of Poisson firing. Turns out this is
basically identical to the mean :'D.
Keeping for now as opportunity to add prior or do some outlier
removal per bin. But if not useful, deprecate in future.
"""
def obj_fun(lambda_, m, sum_k):
return -(sum_k * np.log(lambda_) - m * lambda_)

Expand All @@ -123,17 +168,19 @@ def obj_fun(lambda_, m, sum_k):
m = ks.shape
sum_k = np.sum(ks)

# lol, this is painfully close to the mean, no meaningful
# prior comes to mind to extend the method with.
poisson_estimate[i] = minimize(obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)).x
poisson_estimate[i] = minimize(
obj_fun, 0.5, (m, sum_k), bounds=((1e-10, np.inf),)
).x

return poisson_estimate


# TODO: currently deprecated due to scaling issues between
# sessions. A much better (?) way will to make PCA from all
# sessions, then align based on projection
def get_chunked_hist_eigenvector(chunked_session_histograms):
""" """
def DEPRECATE_get_chunked_hist_eigenvector(chunked_session_histograms):
"""
TODO: currently deprecated due to scaling issues between
sessions. A much better (?) way will to make PCA from all
sessions, then align based on projection
"""
if chunked_session_histograms.shape[0] == 1: # TODO: handle elsewhere
return chunked_session_histograms.squeeze()

Expand All @@ -150,9 +197,9 @@ def get_chunked_hist_eigenvector(chunked_session_histograms):
return first_eigenvector


# -----------------------------------------------------------------------------
# #############################################################################
# TODO: MOVE creating recordings
# -----------------------------------------------------------------------------
# #############################################################################


def compute_histogram_crosscorrelation(
Expand All @@ -168,6 +215,10 @@ def compute_histogram_crosscorrelation(
smoothing_sigma_window,
):
"""
# TODO: what happens when this bigger than thing. Also rename about shifts
# TODO: this is kind of wasteful, no optimisations made against redundant
# session computation, but these in generate very fast.
Expand Down
23 changes: 13 additions & 10 deletions debugging/session_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def align_sessions(
a list of corrected `peak_locations` and activity
histogram generated after correction.
"""
non_rigid_window_kwargs = copy.deepcopy(non_rigid_window_kwargs)
estimate_histogram_kwargs = copy.deepcopy(estimate_histogram_kwargs)
compute_alignment_kwargs = copy.deepcopy(compute_alignment_kwargs)
interpolate_motion_kwargs = copy.deepcopy(interpolate_motion_kwargs)
Expand Down Expand Up @@ -457,7 +458,7 @@ def _get_single_session_activity_histogram(
Returns
-------
session_histogram : np.ndarray
Summary acitivity histogram for the session.
Summary activity histogram for the session.
temporal_bin_centers : np.ndarray
Temporal bin center (session mid-point as we only have
one time point) for the session.
Expand All @@ -469,7 +470,7 @@ def _get_single_session_activity_histogram(
for the chunked histograms, with length num_chunks.
"session_std" : The mean across bin-wise standard deviation
of the chunked histograms.
"chunked_bin_size_s" : time of each chunk used to TODO
"chunked_bin_size_s" : time of each chunk used to
calculate the chunked histogram.
"""
times = recording.get_times()
Expand Down Expand Up @@ -543,8 +544,8 @@ def _create_motion_recordings(
Returns
-------
corrected_recordings_list : list[BaseRecording]
A list of InterpolateMotionRecording recordings of shift-
corrected recordings coressponding to `recordings_list`.
A list of InterpolateMotionRecording recordings of shift-corrected
recordings corresponding to `recordings_list`.
motion_objects_list : list[Motion]
A list of Motion objects. If the recording in `recordings_list`
Expand Down Expand Up @@ -613,9 +614,9 @@ def _add_displacement_to_interpolate_recording(
TODO
----
Check + ask Sam if any other fields need to be chagned. This is a little
# hairy (4 possible combinations of new and
# old displacement shapes, rigid or nonrigid, so test thoroughly.
Check + ask Sam if any other fields need to be changed. This is a little
hairy (4 possible combinations of new and old displacement shapes,
rigid or nonrigid, so test thoroughly.
"""
# Everything is done in place, so keep a short variable
# name reference to the new recordings `motion` object
Expand Down Expand Up @@ -665,11 +666,13 @@ def _correct_session_displacement(
):
"""
Internal function to apply the correction from `align_sessions`
to build a corrected histogram for comparison. First,
to build a corrected histogram for comparison. First, create
new shifted peak locations. Then, create a new 'corrected'
activity histogram from the new peak locations.
Parameters
----------
see `align_sessions()` for parameters. TODO: can add motion shifts if we need.
see `align_sessions()` for parameters.
Returns
-------
Expand Down Expand Up @@ -796,7 +799,6 @@ def _compute_session_alignment(
else:
shifts = rigid_shifts + non_rigid_shifts

breakpoint()
return shifts, non_rigid_windows, non_rigid_window_centers


Expand Down Expand Up @@ -865,6 +867,7 @@ def _akima_interpolate_nonrigid_shifts(
An array (length num_spatial_bins) of shifts
interpolated from the non-rigid shifts.
TODO
----
requires scipy 14
"""
Expand Down

0 comments on commit 1cf68ea

Please sign in to comment.