Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add correct_inter_session_displacementfunction #3126

Draft
wants to merge 102 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
102 commits
Select commit Hold shift + click to select a range
6abe4fc
Add debugging.
JoeZiminski Jul 18, 2024
8b30aee
Factor out get probe from generate drifting recording.
JoeZiminski Jul 18, 2024
f002380
Add start of session_displacement_generator, generate probe.
JoeZiminski Jul 18, 2024
b918b03
Factor out fixing of 'generate_templates_kwargs'.
JoeZiminski Jul 18, 2024
66a5490
Factor out unit_factor calculation, fix wrong unit_locations dim in d…
JoeZiminski Jul 18, 2024
41feacf
Factor out setup for InjectTemplatesRecording.
JoeZiminski Jul 18, 2024
5632a6a
Add first draft session_displacement_generator.py
JoeZiminski Jul 18, 2024
6c26bef
Start adding tests.
JoeZiminski Jul 18, 2024
7311ecd
Add annotations to hopefully fix tests.
JoeZiminski Jul 18, 2024
d1f14a5
Add return extra outputs.
JoeZiminski Jul 19, 2024
8bda497
Add tests.
JoeZiminski Jul 19, 2024
2ac1f1b
Improve tests and test documentation.
JoeZiminski Jul 22, 2024
576ba9d
Add optional return to 'generate_sorting' to get extra outputs, firin…
JoeZiminski Jul 23, 2024
8dd0177
Start adding scaling of template amplitudes across recordings.
JoeZiminski Jul 23, 2024
3bcf1a1
Finalising amplitude scalings.
JoeZiminski Jul 24, 2024
f477dcd
Add input checks.
JoeZiminski Jul 24, 2024
e0fae74
Add documentation.
JoeZiminski Jul 24, 2024
8ed2ced
Start finalising tests.
JoeZiminski Jul 24, 2024
f845fba
Finalise and tidy up tests.
JoeZiminski Jul 25, 2024
322240f
Temporarily fix the mutable defaults so tests path.
JoeZiminski Jul 29, 2024
c134c47
Fix string formatting, remove breakpoint.
JoeZiminski Aug 28, 2024
cd8c281
Add `shift_units_outside_probe`.
JoeZiminski Aug 28, 2024
dc5a8a8
Begin adding rough implementation.
JoeZiminski Jul 2, 2024
9a100d3
Doing some basic alignment.
JoeZiminski Jul 9, 2024
fb7ce0c
Finish naive alignment to first session.
JoeZiminski Jul 9, 2024
cbf68ea
Play around with slope drift for generate_drifting_recording.
JoeZiminski Jul 15, 2024
047ec7a
Very minimal first working version.
JoeZiminski Jul 16, 2024
ae5d51f
Start playing around with histogram estimation.
JoeZiminski Jul 30, 2024
45ab8b3
Continue playing with estimation.
JoeZiminski Jul 30, 2024
e5f229a
Continue playing around with estimation.
JoeZiminski Jul 31, 2024
4982d90
Add rough time estimatino.
JoeZiminski Aug 1, 2024
73f2fe1
To some refactoring.
JoeZiminski Aug 12, 2024
14d3897
Add in alignment and interpolation functions.
JoeZiminski Aug 13, 2024
b6e816a
Fix scaling on histograms and bin_s estimation, add benchmarking init…
JoeZiminski Aug 14, 2024
e5dc710
Adding some mp benchmarking.
JoeZiminski Aug 15, 2024
9e59862
Small fixes for hpc.
JoeZiminski Aug 15, 2024
5bc3996
Add arg input for SLURM.
JoeZiminski Aug 15, 2024
411e7ff
sbatch file.
JoeZiminski Aug 15, 2024
91bcaa5
small fixes to other methods.
JoeZiminski Aug 15, 2024
ec61cff
Start adding widgets.
JoeZiminski Aug 21, 2024
a0db02d
Add first rough nonrigid.
JoeZiminski Aug 23, 2024
76bbbc7
Try a recursive nonrigid alignment, doesn't really work.
JoeZiminski Aug 27, 2024
33a4f62
Revert "Try a recursive nonrigid alignment, doesn't really work."
JoeZiminski Aug 27, 2024
a9d355d
Adding a few more options.
JoeZiminski Aug 27, 2024
5801e78
small changes prior to refactoring.
JoeZiminski Aug 28, 2024
9477020
Small changes.
JoeZiminski Aug 28, 2024
4225f10
Remove old and benchmarking scripts.
JoeZiminski Aug 28, 2024
65afefc
Begin tidying up.
JoeZiminski Aug 28, 2024
d662f28
Major refactor, add smoothing, interpolation, better spatial binning.
JoeZiminski Aug 30, 2024
9ade367
Update some notes.
JoeZiminski Aug 30, 2024
4335be4
Tidy up, expose aling to session X.
JoeZiminski Sep 2, 2024
f0feac6
Thinking about trimmed versions and robust xcorr, leave until later.
JoeZiminski Sep 2, 2024
e0e389e
PLaying with large shifts.
JoeZiminski Sep 3, 2024
7b82f65
Playing with alignment algorithm, completely fails when rigid shift l…
JoeZiminski Sep 3, 2024
a980441
some small tidying up
JoeZiminski Sep 4, 2024
bc687b6
Lots of tidying up, need to look into nonrigid more something has reg…
JoeZiminski Sep 5, 2024
cfdb86c
Look further into rigid, it was working well, but the parameters chos…
JoeZiminski Sep 6, 2024
43d2041
Fix corrected histogram to use same method as compute.
JoeZiminski Sep 6, 2024
ba703a1
Big refactor, getting close to first draft.
JoeZiminski Sep 6, 2024
c790c0b
Begin typing and documentation.
JoeZiminski Sep 12, 2024
8360407
Fix from motion correction function, adding typing and docstrings.
JoeZiminski Sep 13, 2024
88fab86
Docstrings and typing for session_alignment.
JoeZiminski Sep 17, 2024
6d990ad
Continue adding docstring / types.
JoeZiminski Sep 27, 2024
da6693f
Continue with docstrings.
JoeZiminski Sep 30, 2024
de2c798
Small fixes
JoeZiminski Nov 21, 2024
29de949
Remove a TODO.
JoeZiminski Nov 21, 2024
604ed43
Add recording shifts argument.
JoeZiminski Nov 21, 2024
21534ae
Update time estimate method.
JoeZiminski Nov 22, 2024
27fddbf
Fix time estimate, other tidy ups.
JoeZiminski Nov 27, 2024
ef0bbe3
gaussian process first draft.
JoeZiminski Nov 29, 2024
af6c917
remove some debug stuff from gp first draft.
JoeZiminski Nov 29, 2024
4252ddc
Very rough introduce 2D versions.
JoeZiminski Nov 29, 2024
e466cd1
Add scaling to GP, beter convergence.
JoeZiminski Dec 6, 2024
8166e38
Tidy up.
JoeZiminski Dec 16, 2024
17e82b3
Handle the 1-bin case in Motion objects __repre__
JoeZiminski Dec 16, 2024
f194194
Tidy up and fixes for 'session_alignment.py'
JoeZiminski Dec 16, 2024
d4e9902
Tidying up, begin fixing alignment alg.
JoeZiminski Dec 16, 2024
f1e61b7
Remove pickle files.
JoeZiminski Dec 17, 2024
28115bd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 17, 2024
8a4b89e
Reformatting alignment methods and add 2D, need to tidy up.
JoeZiminski Dec 17, 2024
5a5d509
delete big files.
JoeZiminski Dec 17, 2024
ee6987e
Trying different alg
JoeZiminski Dec 19, 2024
a0f9b6e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 19, 2024
856adbd
Playing with alg rigid shift but scales for better rigid shift, okay …
JoeZiminski Dec 19, 2024
1190093
Testing on the session_alignment side.
JoeZiminski Dec 19, 2024
3fd26e4
playing with nonrigid.
JoeZiminski Dec 20, 2024
005bab6
Save.
JoeZiminski Dec 24, 2024
c8c5610
With additional windowing.
JoeZiminski Dec 24, 2024
867be0e
Doing some tidying, use amplitudes much more TODO!
JoeZiminski Dec 25, 2024
4555cef
Remove additional improvements on alignment alg to work on later.
JoeZiminski Jan 13, 2025
76cae6b
First draft add plotting 2D histograms.
JoeZiminski Jan 13, 2025
cb13d0c
Tidying up, checking alignment function.
JoeZiminski Jan 13, 2025
ce1d213
Add 2d histogram plot, move plots to SI widgets.
JoeZiminski Jan 13, 2025
f97ad78
Update playing.py
JoeZiminski Jan 13, 2025
119813d
Start adding tests.
JoeZiminski Jan 14, 2025
6210475
Continue working on tests.
JoeZiminski Jan 15, 2025
8151990
Continue adding tests.
JoeZiminski Jan 16, 2025
b75e769
Continue! working on tests.
JoeZiminski Jan 17, 2025
93b89f6
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 17, 2025
abe27c1
Tidy up tests and (slightly) improve num_iter handling.
JoeZiminski Jan 20, 2025
001ed03
Add how-to doc.
JoeZiminski Jan 21, 2025
34ba6bc
Small changes to docs.
JoeZiminski Jan 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
268 changes: 268 additions & 0 deletions debugging/DELinter_session_displacement.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
from __future__ import annotations

import copy

import numpy as np
import json
from pathlib import Path
import time

from spikeinterface.core.baserecording import BaseRecording
from spikeinterface.core import get_noise_levels, fix_job_kwargs, get_random_data_chunks
from spikeinterface.core.job_tools import _shared_job_kwargs_doc
from spikeinterface.core.core_tools import SIJsonEncoder
from spikeinterface.core.job_tools import _shared_job_kwargs_doc

# TODO: update motion docstrings around the 'select' step.


# TODO:
# 1) detect peaks and peak locations if not already provided.
# - could use only a subset of data, for ease now just estimate
# everything on the entire dataset
# 2) Calcualte the activity histogram across the entire session
# - will be better ways to estimate this, i.e. from the end
# of the session, from periods of stability, etc.
# taking a weighted average of histograms
# 3) Optimise for drift correction for each session across
# all histograms, minimising lost data at edges and keeping
# shift similar for all sessions. Could alternatively shift
# to the average histogram but this seems like a bad idea.
# 4) Store the motion vectors, ether adding to existing (of motion
# objects passed) otherwise.


def correct_inter_session_displacement(
recordings_list: list[BaseRecording],
existing_motion_info: Optional[list[Dict]] = None,
keep_channels_constant=False,
detect_kwargs={}, # TODO: make non-mutable (same for motion.py)
select_kwargs={},
localize_peaks_kwargs={},
job_kwargs={},
):
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods
from spikeinterface.sortingcomponents.peak_detection import detect_peaks, detect_peak_methods
from spikeinterface.sortingcomponents.peak_selection import select_peaks
from spikeinterface.sortingcomponents.peak_localization import localize_peaks, localize_peak_methods
from spikeinterface.sortingcomponents.motion.motion_estimation import estimate_motion
from spikeinterface.sortingcomponents.motion.motion_interpolation import InterpolateMotionRecording
from spikeinterface.core.node_pipeline import ExtractDenseWaveforms, run_node_pipeline
from spikeinterface.sortingcomponents.motion.motion_utils import Motion, get_spatial_windows

# TODO: do not accept multi-segment recordings.
# TODO: check all recordings have the same probe dimensions!
# Check if exsting_motion_info is passed then the recordings have the motion vector (I guess this is stored somewhere? maybe it is on the motion object)
if existing_motion_info is not None:
if not isinstance(existing_motion_info, list) and len(recordings_list) != len(existing_motion_info):
raise ValueError(
"`estimate_motion_info` if provided, must be"
"a list of `motion_info` with each associated with"
"the corresponding recording in `recordings_list`."
)

# TODO: do not handle select peaks option yet as probably better to chunk
# rather than select peaks? no sure can discuss.
if existing_motion_info is None:

peaks_list = []
peak_locations_list = []

for recording in recordings_list:
# TODO: this is a direct copy from motion.detect_motion().
# Factor into own function in motion.py
gather_mode = "memory"
# node detect
method = detect_kwargs.pop("method", "locally_exclusive")
method_class = detect_peak_methods[method]
node0 = method_class(recording, **detect_kwargs)

node1 = ExtractDenseWaveforms(recording, parents=[node0], ms_before=0.1, ms_after=0.3)

# node detect + localize
method = localize_peaks_kwargs.pop("method", "center_of_mass")
method_class = localize_peak_methods[method]
node2 = method_class(recording, parents=[node0, node1], return_output=True, **localize_peaks_kwargs)
pipeline_nodes = [node0, node1, node2]

peaks, peak_locations = run_node_pipeline(
recording,
pipeline_nodes,
job_kwargs,
job_name="detect and localize",
gather_mode=gather_mode,
gather_kwargs=None,
squeeze_output=False,
folder=None,
names=None,
)
peaks_list.append(peaks)
peak_locations_list.append(peak_locations)
else:
peaks_list = [info["peaks"] for info in existing_motion_info]
peak_locations_list = [info["peak_locations"] for info in existing_motion_info]

from spikeinterface.sortingcomponents.motion.motion_utils import make_2d_motion_histogram, make_3d_motion_histograms

# make motion histogram
motion_histogram_dim = "2D" # "2D" or "3D", for now only handle 2D case

motion_histogram_list = []
all_temporal_bin_edges = [] # TODO: fix naming

bin_um = 2 # TODO: critial paraneter. easier to take no binning and gaus smooth?

# TODO: own function
for recording, peaks, peak_locations in zip(
recordings_list,
peaks_list,
peak_locations_list, # TODO: this is overwriting above variable names. Own function!
): # TODO: do a lot of checks to make sure these bin sizes make sesnese
# Do some checks on temporal and spatial bin edges that they are all the same?

if motion_histogram_dim == "2D":
motion_histogram = make_2d_motion_histogram(
recording,
peaks,
peak_locations,
weight_with_amplitude=False,
direction="y",
bin_s=recording.get_duration(segment_index=0), # 1.0,
bin_um=bin_um,
hist_margin_um=50,
spatial_bin_edges=None,
)
else:
assert NotImplementedError # TODO: might be old API pre-dredge
motion_histogram = make_3d_motion_histograms(
recording,
peaks,
peak_locations,
direction="y",
bin_duration_s=recording.get_duration(segment_index=0), # 1.0,
bin_um=bin_um,
margin_um=50,
num_amp_bins=20,
log_transform=True,
spatial_bin_edges=None,
)
motion_histogram_list.append(motion_histogram[0].squeeze())
# store bin edges
all_temporal_bin_edges.append(motion_histogram[1])
spatial_bin_edges_um = motion_histogram[2] # should be same across all recordings

# Do some checks on temporal and spatial bin edges that they are all the same?
# TODO: do some smoothing? Try some other methds (e.g. NMI, KL divergence)
# Let's do a very basic optimisation to find the best midpoint, just
# align everything to the first session. This isn't great because
# introduces some bias. Maybe align to all sessions and then take some
# average. Certainly cannot optimise brute force over the whole space
# which is (2P-1)^N where P is length of motion histogram and N is number of recordings.
# TODO: double-check what is done in kilosort-like / DREDGE
# put histograms into X and do X^T X then mean(U), det or eigs of covar mat
# can try iterative template. Not sure it will work so well taking the mean
# over only a few histograms that could be wildy different.
# Displacemene
num_recordings = len(recordings_list)

shifts = np.zeros(num_recordings)

# TODO: not checked any of the below properly
first_hist = motion_histogram_list[0] / motion_histogram_list[0].sum()
# first_hist -= np.mean(first_hist) # TODO: pretty sure not necessary

for i in range(1, num_recordings):

hist = motion_histogram_list[i] / motion_histogram_list[i].sum()
# hist -= np.mean(hist) # TODO: pretty sure not necessary
conv = np.correlate(first_hist, hist, mode="full")

if conv.size % 2 == 0:
midpoint = conv.size / 2
else:
midpoint = (conv.size - 1) / 2 # TODO: carefully double check!

# TODO: think will need to make this negative
shifts[i] = (midpoint - np.argmax(conv)) * bin_um # # TODO: the bin spacing is super important for resoltuion

# half
# TODO: need to figure out interpolation to the center point, weird;y
# the below does not work
# shifts[0] = (shifts[1] / 2)
# shifts[1] = (shifts[1] / 2) * -1
# print("SHIFTS", shifts)
# TODO: handle only the 2D case for now
# TODO: do multi-session optimisation

# Handle drift
interpolate_motion_kwargs = {}

# TODO: add motion to motion if exists otherwise create InterpolateMotionRecording object!
# Will need the y-axis bins for this
all_recording_corrected = []
all_motion_info = []
for i, recording in enumerate(recordings_list):

# TODO: direct copy, use 'get_window' from motion machinery
if False:
bin_centers = spatial_bin_edges[:-1] + bin_um / 2.0
n = bin_centers.size
non_rigid_windows = [np.ones(n, dtype="float64")]
middle = (spatial_bin_edges[0] + spatial_bin_edges[-1]) / 2.0
non_rigid_window_centers = np.array([middle])

dim = 1 # ["x", "y", "z"].index(direction)
contact_depths = recording.get_channel_locations()[:, dim]
spatial_bin_centers = 0.5 * (spatial_bin_edges_um[1:] + spatial_bin_edges_um[:-1])

_, window_centers = get_spatial_windows(
contact_depths, spatial_bin_centers, rigid=True # TODO: handle non-rigid case
)
# win_shape=win_shape, TODO: handle defaults better
# win_step_um=win_step_um,
# win_scale_um=win_scale_um,
# win_margin_um=win_margin_um,
# zero_threshold=1e-5,

# if shifts[i] == 0:
## all_recording_corrected.append(recording) # TODO
# continue
temporal_bin_edges = all_temporal_bin_edges[i]
temporal_bins = 0.5 * (temporal_bin_edges[1:] + temporal_bin_edges[:-1])

motion_array = np.zeros((temporal_bins.size, window_centers.size)) # TODO: check this is the expected shape
motion_array[:, :] = shifts[i] # TODO: this is the rigid case!

motion = Motion(
[motion_array], [temporal_bins], window_centers, direction="y"
) # will be same for all except for shifts
all_motion_info.append(motion) # not certain on this

if isinstance(recording, InterpolateMotionRecording):
raise NotImplementedError
recording_corrected = copy.deepcopy(recording)
# TODO: add interpolation to the existing one.
# Not if inter-session motion correction already exists, but further
# up the preprocessing chain, it will NOT be added and interpolation
# will occur twice. Throw a warning here!
else:
recording_corrected = InterpolateMotionRecording(recording, motion, **interpolate_motion_kwargs)
all_recording_corrected.append(recording_corrected)

displacement_info = {
"all_motion_info": all_motion_info,
"all_motion_histograms": motion_histogram_list, # TODO: naming
"all_shifts": shifts,
}

if keep_channels_constant:
# TODO: use set
import functools

common_channels = functools.reduce(
np.intersect1d, [recording.channel_ids for recording in all_recording_corrected]
)

all_recording_corrected = [recording.channel_slice(common_channels) for recording in all_recording_corrected]

return all_recording_corrected, displacement_info # TODO: output more stuff later e.g. the Motion object
Empty file added debugging/__init__.py
Empty file.
Loading
Loading