Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a fast correlogram merge #3607

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

chrishalcrow
Copy link
Collaborator

Have added a fast method to re-compute correlograms when units are merged.

When two units are merged, their correlograms with other units become the sum of the previous correlograms. Let $C_{i,j}$ be the cross-correlogram of units $i$ and $j$. If we merge units 2, 4 and 6 into a new unit 10, the cross-correlograms of unit 10 with unit 1 would be:

$$C_{10,1} = C_{2,1} + C_{4,1} + C_{6,1}$$ $$C_{1,10} = C_{1,2} + C_{1,4} + C_{1,6}$$

Similarly, the auto-correlogram of the new unit 10 is

$$C_{10,10} = (C_{2,2} + C_{2,4} + C_{2,6}) + (C_{4,2} + C_{4,4} + C_{4,6}) + (C_{6,2} + C_{6,4} + C_{6,6})$$

You can implement this with two matrix operations (sum all merged-unit columns together, then sum all merged-unit rows together), which I've implemented.

The most annoying thing is tracking the new_unit_index, and there's more code which deals with this than the actual new algorithm. It's much easier to deal with the take_first new_id_strategy than the append one. Any advice here is welcome.

If you don't believe my maths, I've added tests to show that the results are the same if you use the fast merging method or if we merge, then re-compute the correlograms. If you can think of more painful tests than I've written, let me know!

This is a big speedup for my kilosort'd NP2.0 sortings (on my laptop, 25s -> 0.2s for a single pairwise merge). And yup, I do have numba installed.

It only becomes a lot faster for generated recordings when the correlograms become a bit more full. This requires a high firing rate and small refractory period. Here's some benchmarking code:

from time import perf_counter

nums_units = [40,80,160,320,640]
times = []
for num_units in nums_units:
    
    rec, sort = si.generate_ground_truth_recording(seed=1205,num_units=num_units, generate_sorting_kwargs = {'firing_rates': 50, 'refractory_period_ms': 1.0})
    sa = si.create_sorting_analyzer(recording=rec, sorting=sort)
    sa.compute('correlograms')
    
    t1 = perf_counter()
    sa.merge_units(merge_unit_groups=[['0','4']], sparsity_overlap=0.5)
    t2 = perf_counter()
    times.append(t2-t1)

old_times = dict(zip(nums_units, times))

import matplotlib.pyplot as plt
plt.plot(nums_units, new_times.values())
plt.plot(nums_units, old_times.values())
plt.legend(['new method', 'old method'])
plt.xlabel('num units')
plt.ylabel('time (seconds)')

and I get:

print(old_times)
{40: 0.007854124996811152,
 80: 0.02739199995994568,
 160: 0.11776154092513025,
 320: 1.210061291931197,
 640: 6.01522733294405}

print(new_times)
{40: 0.0016532500740140676,
 80: 0.003977375105023384,
 160: 0.013842040905728936,
 320: 0.05750433378852904,
 640: 0.20545295905321836}

giving the following plot

Screenshot 2025-01-10 at 15 13 28

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant