Skip to content

Commit

Permalink
Merge branch 'main' into merge-qc
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 authored Jan 8, 2025
2 parents 807e771 + 0b1bf67 commit 77ed6ab
Show file tree
Hide file tree
Showing 53 changed files with 409 additions and 282 deletions.
35 changes: 12 additions & 23 deletions .github/actions/build-test-environment/action.yml
Original file line number Diff line number Diff line change
@@ -1,41 +1,20 @@
name: Install packages
description: This action installs the package and its dependencies for testing

inputs:
python-version:
description: 'Python version to set up'
required: false
os:
description: 'Operating system to set up'
required: false

runs:
using: "composite"
steps:
- name: Install dependencies
run: |
sudo apt install git
git config --global user.email "[email protected]"
git config --global user.name "CI Almighty"
python -m venv ${{ github.workspace }}/test_env # Environment used in the caching step
python -m pip install -U pip # Official recommended way
source ${{ github.workspace }}/test_env/bin/activate
pip install tabulate # This produces summaries at the end
pip install -e .[test,extractors,streaming_extractors,test_extractors,full]
shell: bash
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
source ${{ github.workspace }}/test_env/bin/activate
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
- name: Install git-annex
shell: bash
- name: git-annex install
run: |
pip install datalad-installer
wget https://downloads.kitenet.net/git-annex/linux/current/git-annex-standalone-amd64.tar.gz
mkdir /home/runner/work/installation
mv git-annex-standalone-amd64.tar.gz /home/runner/work/installation/
Expand All @@ -44,4 +23,14 @@ runs:
tar xvzf git-annex-standalone-amd64.tar.gz
echo "$(pwd)/git-annex.linux" >> $GITHUB_PATH
cd $workdir
git config --global filter.annex.process "git-annex filter-process" # recommended for efficiency
- name: Force installation of latest dev from key-packages when running dev (not release)
run: |
spikeinterface_is_dev_version=$(python -c "import spikeinterface; print(spikeinterface.DEV_MODE)")
if [ $spikeinterface_is_dev_version = "True" ]; then
echo "Running spikeinterface dev version"
pip install --no-cache-dir git+https://github.com/NeuralEnsemble/python-neo
pip install --no-cache-dir git+https://github.com/SpikeInterface/probeinterface
fi
echo "Running tests for release, using pyproject.toml versions of neo and probeinterface"
shell: bash
2 changes: 1 addition & 1 deletion .github/workflows/all-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
echo "$file was changed"
done
- name: Set testing environment # This decides which tests are run and whether to install especial dependencies
- name: Set testing environment # This decides which tests are run and whether to install special dependencies
shell: bash
run: |
changed_files="${{ steps.changed-files.outputs.all_changed_files }}"
Expand Down
1 change: 0 additions & 1 deletion .github/workflows/full-test-with-codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ jobs:
env:
HDF5_PLUGIN_PATH: ${{ github.workspace }}/hdf5_plugin_path_maxwell
run: |
source ${{ github.workspace }}/test_env/bin/activate
pytest -m "not sorters_external" --cov=./ --cov-report xml:./coverage.xml -vv -ra --durations=0 | tee report_full.txt; test ${PIPESTATUS[0]} -eq 0 || exit 1
echo "# Timing profile of full tests" >> $GITHUB_STEP_SUMMARY
python ./.github/scripts/build_job_summary.py report_full.txt >> $GITHUB_STEP_SUMMARY
Expand Down
2 changes: 1 addition & 1 deletion doc/get_started/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ compute quality metrics (some quality metrics require certain extensions
'min_spikes': 0,
'window_size_s': 1},
'snr': {'peak_mode': 'extremum', 'peak_sign': 'neg'},
'synchrony': {'synchrony_sizes': (2, 4, 8)}}
'synchrony': {}
Since the recording is very short, let’s change some parameters to
Expand Down
4 changes: 2 additions & 2 deletions doc/modules/qualitymetrics/synchrony.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ trains. This way synchronous events can be found both in multi-unit and single-u
Complexity is calculated by counting the number of spikes (i.e. non-empty bins) that occur at the same sample index,
within and across spike trains.

Synchrony metrics can be computed for different synchrony sizes (>1), defining the number of simultaneous spikes to count.
Synchrony metrics are computed for 2, 4 and 8 synchronous spikes.



Expand All @@ -29,7 +29,7 @@ Example code
import spikeinterface.qualitymetrics as sqm
# Combine a sorting and recording into a sorting_analyzer
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer synchrony_sizes=(2, 4, 8))
synchrony = sqm.compute_synchrony_metrics(sorting_analyzer=sorting_analyzer)
# synchrony is a tuple of dicts with the synchrony metrics for each unit
Expand Down
4 changes: 2 additions & 2 deletions examples/tutorials/widgets/plot_2_sort_gallery.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
# plot_autocorrelograms()
# ~~~~~~~~~~~~~~~~~~~~~~~~

w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5])
w_ach = sw.plot_autocorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5'])

##############################################################################
# plot_crosscorrelograms()
# ~~~~~~~~~~~~~~~~~~~~~~~~


w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=[1, 2, 5])
w_cch = sw.plot_crosscorrelograms(sorting, window_ms=150.0, bin_ms=5.0, unit_ids=['1', '2', '5'])

plt.show()

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ extractors = [
]

streaming_extractors = [
"ONE-api>=2.7.0", # alf sorter and streaming IBL
"ONE-api>=2.7.0,<2.10.0", # alf sorter and streaming IBL
"ibllib>=2.36.0", # streaming IBL
# Following dependencies are for streaming with nwb files
"pynwb>=2.6.0",
Expand Down
3 changes: 2 additions & 1 deletion src/spikeinterface/benchmark/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,11 @@ def run(self, case_keys=None, keep=True, verbose=False, **job_kwargs):
for key in case_keys:

result_folder = self.folder / "results" / self.key_to_str(key)
sorter_folder = self.folder / "sorters" / self.key_to_str(key)

if keep and result_folder.exists():
continue
elif not keep and result_folder.exists():
elif not keep and (result_folder.exists() or sorter_folder.exists()):
self.remove_benchmark(key)
job_keys.append(key)

Expand Down
15 changes: 14 additions & 1 deletion src/spikeinterface/benchmark/benchmark_motion_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ def run(self, **job_kwargs):
estimate_motion=t4 - t3,
)

self.result["peaks"] = peaks
self.result["peak_locations"] = peak_locations
self.result["step_run_times"] = step_run_times
self.result["raw_motion"] = motion

Expand All @@ -131,6 +133,8 @@ def compute_result(self, **result_params):
self.result["motion"] = motion

_run_key_saved = [
("peaks", "npy"),
("peak_locations", "npy"),
("raw_motion", "Motion"),
("step_run_times", "pickle"),
]
Expand Down Expand Up @@ -161,7 +165,9 @@ def create_benchmark(self, key):
def plot_true_drift(self, case_keys=None, scaling_probe=1.5, figsize=(8, 6)):
self.plot_drift(case_keys=case_keys, tested_drift=False, scaling_probe=scaling_probe, figsize=figsize)

def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_probe=1.0, figsize=(8, 6)):
def plot_drift(
self, case_keys=None, gt_drift=True, tested_drift=True, raster=False, scaling_probe=1.0, figsize=(8, 6)
):
import matplotlib.pyplot as plt

if case_keys is None:
Expand Down Expand Up @@ -195,6 +201,13 @@ def plot_drift(self, case_keys=None, gt_drift=True, tested_drift=True, scaling_p

# for i in range(self.gt_unit_positions.shape[1]):
# ax.plot(temporal_bins_s, self.gt_unit_positions[:, i], alpha=0.5, ls="--", c="0.5")
if raster:
peaks = bench.result["peaks"]
peak_locations = bench.result["peak_locations"]
rec = bench.recording
x = peaks["sample_index"] / rec.sampling_frequency
y = peak_locations[bench.direction]
ax.scatter(x, y, alpha=0.2, s=2, c=np.abs(peaks["amplitude"]), cmap="inferno")

for i in range(gt_motion.displacement[0].shape[1]):
depth = motion.spatial_bins_um[i]
Expand Down
9 changes: 9 additions & 0 deletions src/spikeinterface/benchmark/benchmark_sorter.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ def create_benchmark(self, key):
benchmark = SorterBenchmark(recording, gt_sorting, params, sorter_folder)
return benchmark

def remove_benchmark(self, key):
BenchmarkStudy.remove_benchmark(self, key)

sorter_folder = self.folder / "sorters" / self.key_to_str(key)
import shutil

if sorter_folder.exists():
shutil.rmtree(sorter_folder)

def get_performance_by_unit(self, case_keys=None):
import pandas as pd

Expand Down
2 changes: 1 addition & 1 deletion src/spikeinterface/core/basesorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_total_duration(self) -> float:

def get_unit_spike_train(
self,
unit_id,
unit_id: str | int,
segment_index: Union[int, None] = None,
start_frame: Union[int, None] = None,
end_frame: Union[int, None] = None,
Expand Down
13 changes: 8 additions & 5 deletions src/spikeinterface/core/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
import warnings
import numpy as np
from typing import Literal
from typing import Literal, Optional
from math import ceil

from .basesorting import SpikeVectorSortingSegment
Expand Down Expand Up @@ -134,7 +134,7 @@ def generate_sorting(
seed = _ensure_seed(seed)
rng = np.random.default_rng(seed)
num_segments = len(durations)
unit_ids = np.arange(num_units)
unit_ids = [str(idx) for idx in np.arange(num_units)]

spikes = []
for segment_index in range(num_segments):
Expand Down Expand Up @@ -1111,7 +1111,7 @@ def __init__(
"""

unit_ids = np.arange(num_units)
unit_ids = [str(idx) for idx in np.arange(num_units)]
super().__init__(sampling_frequency, unit_ids)

self.num_units = num_units
Expand All @@ -1138,6 +1138,7 @@ def __init__(
firing_rates=firing_rates,
refractory_period_seconds=self.refractory_period_seconds,
seed=segment_seed,
unit_ids=unit_ids,
t_start=None,
)
self.add_sorting_segment(segment)
Expand All @@ -1161,6 +1162,7 @@ def __init__(
firing_rates: float | np.ndarray,
refractory_period_seconds: float | np.ndarray,
seed: int,
unit_ids: list[str],
t_start: Optional[float] = None,
):
self.num_units = num_units
Expand All @@ -1177,7 +1179,8 @@ def __init__(
self.refractory_period_seconds = np.full(num_units, self.refractory_period_seconds, dtype="float64")

self.segment_seed = seed
self.units_seed = {unit_id: self.segment_seed + hash(unit_id) for unit_id in range(num_units)}
self.units_seed = {unit_id: abs(self.segment_seed + hash(unit_id)) for unit_id in unit_ids}

self.num_samples = math.ceil(sampling_frequency * duration)
super().__init__(t_start)

Expand Down Expand Up @@ -1280,7 +1283,7 @@ def __init__(
noise_block_size: int = 30000,
):

channel_ids = np.arange(num_channels)
channel_ids = [str(idx) for idx in np.arange(num_channels)]
dtype = np.dtype(dtype).name # Cast to string for serialization
if dtype not in ("float32", "float64"):
raise ValueError(f"'dtype' must be 'float32' or 'float64' but is {dtype}")
Expand Down
10 changes: 5 additions & 5 deletions src/spikeinterface/core/tests/test_basesnippets.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ def test_BaseSnippets(create_cache_folder):
assert snippets.get_num_segments() == len(duration)
assert snippets.get_num_channels() == num_channels

assert np.all(snippets.ids_to_indices([0, 1, 2]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices([0, 1, 2], prefer_slice=True) == slice(0, 3, None))
assert np.all(snippets.ids_to_indices(["0", "1", "2"]) == [0, 1, 2])
assert np.all(snippets.ids_to_indices(["0", "1", "2"], prefer_slice=True) == slice(0, 3, None))

# annotations / properties
snippets.annotate(gre="ta")
Expand All @@ -60,7 +60,7 @@ def test_BaseSnippets(create_cache_folder):
)

# missing property
snippets.set_property("string_property", ["ciao", "bello"], ids=[0, 1])
snippets.set_property("string_property", ["ciao", "bello"], ids=["0", "1"])
values = snippets.get_property("string_property")
assert values[2] == ""

Expand All @@ -70,14 +70,14 @@ def test_BaseSnippets(create_cache_folder):
snippets.set_property,
key="string_property_nan",
values=["hola", "chabon"],
ids=[0, 1],
ids=["0", "1"],
missing_value=np.nan,
)

# int properties without missing values raise an error
assert_raises(Exception, snippets.set_property, key="int_property", values=[5, 6], ids=[1, 2])

snippets.set_property("int_property", [5, 6], ids=[1, 2], missing_value=200)
snippets.set_property("int_property", [5, 6], ids=["1", "2"], missing_value=200)
values = snippets.get_property("int_property")
assert values.dtype.kind == "i"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ def test_channelsaggregationrecording():

assert np.allclose(traces1_1, recording_agg.get_traces(channel_ids=[str(channel_ids[1])], segment_index=seg))
assert np.allclose(
traces2_0, recording_agg.get_traces(channel_ids=[str(num_channels + channel_ids[0])], segment_index=seg)
traces2_0,
recording_agg.get_traces(channel_ids=[str(num_channels + int(channel_ids[0]))], segment_index=seg),
)
assert np.allclose(
traces3_2, recording_agg.get_traces(channel_ids=[str(2 * num_channels + channel_ids[2])], segment_index=seg)
traces3_2,
recording_agg.get_traces(channel_ids=[str(2 * num_channels + int(channel_ids[2]))], segment_index=seg),
)
# all traces
traces1 = recording1.get_traces(segment_index=seg)
Expand Down
8 changes: 8 additions & 0 deletions src/spikeinterface/core/tests/test_sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,14 @@ def get_dataset():
noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"),
seed=2205,
)

# TODO: the tests or the sorting analyzer make assumptions about the ids being integers
# So keeping this the way it was
integer_channel_ids = [int(id) for id in recording.get_channel_ids()]
integer_unit_ids = [int(id) for id in sorting.get_unit_ids()]

recording = recording.rename_channels(new_channel_ids=integer_channel_ids)
sorting = sorting.rename_units(new_unit_ids=integer_unit_ids)
return recording, sorting


Expand Down
22 changes: 13 additions & 9 deletions src/spikeinterface/core/tests/test_unitsselectionsorting.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,43 @@
def test_basic_functions():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2])
assert np.array_equal(sorting2.unit_ids, [0, 2])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"])
assert np.array_equal(sorting2.unit_ids, ["0", "2"])
assert sorting2.get_parent() == sorting

sorting3 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "b"])
sorting3 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "b"])
assert np.array_equal(sorting3.unit_ids, ["a", "b"])

assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting2.get_unit_spike_train(0, segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting2.get_unit_spike_train(unit_id="0", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(0, segment_index=0), sorting3.get_unit_spike_train("a", segment_index=0)
sorting.get_unit_spike_train(unit_id="0", segment_index=0),
sorting3.get_unit_spike_train(unit_id="a", segment_index=0),
)

assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting2.get_unit_spike_train(2, segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting2.get_unit_spike_train(unit_id="2", segment_index=0),
)
assert np.array_equal(
sorting.get_unit_spike_train(2, segment_index=0), sorting3.get_unit_spike_train("b", segment_index=0)
sorting.get_unit_spike_train(unit_id="2", segment_index=0),
sorting3.get_unit_spike_train(unit_id="b", segment_index=0),
)


def test_failure_with_non_unique_unit_ids():
seed = 10
sorting = generate_sorting(num_units=3, durations=[0.100], sampling_frequency=30000.0, seed=seed)
with pytest.raises(AssertionError):
sorting2 = UnitsSelectionSorting(sorting, unit_ids=[0, 2], renamed_unit_ids=["a", "a"])
sorting2 = UnitsSelectionSorting(sorting, unit_ids=["0", "2"], renamed_unit_ids=["a", "a"])


def test_custom_cache_spike_vector():
sorting = generate_sorting(num_units=3, durations=[0.100, 0.100], sampling_frequency=30000.0)

sub_sorting = UnitsSelectionSorting(sorting, unit_ids=[2, 0], renamed_unit_ids=["b", "a"])
sub_sorting = UnitsSelectionSorting(sorting, unit_ids=["2", "0"], renamed_unit_ids=["b", "a"])
cached_spike_vector = sub_sorting.to_spike_vector(use_cache=True)
computed_spike_vector = sub_sorting.to_spike_vector(use_cache=False)
assert np.all(cached_spike_vector == computed_spike_vector)
Expand Down
Loading

0 comments on commit 77ed6ab

Please sign in to comment.