Skip to content

Commit

Permalink
Merge pull request #2317 from samuelgarcia/add_probe_to_template
Browse files Browse the repository at this point in the history
Add probe field in class Templates
  • Loading branch information
alejoe91 authored Dec 20, 2023
2 parents 3b17406 + 94eff00 commit 0b821f3
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
10 changes: 10 additions & 0 deletions src/spikeinterface/core/template.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import json
from dataclasses import dataclass, field, astuple
from probeinterface import Probe
from .sparsity import ChannelSparsity


Expand All @@ -24,6 +25,8 @@ class Templates:
Array of channel IDs. If `None`, defaults to an array of increasing integers.
unit_ids : np.ndarray, optional default: None
Array of unit IDs. If `None`, defaults to an array of increasing integers.
probe: Probe, default: None
A `probeinterface.Probe` object
check_for_consistent_sparsity : bool, optional default: None
When passing a sparsity_mask, this checks that the templates array is also sparse and that it matches the
structure fo the sparsity_masl.
Expand Down Expand Up @@ -57,6 +60,8 @@ class Templates:
channel_ids: np.ndarray = None
unit_ids: np.ndarray = None

probe: Probe = None

check_for_consistent_sparsity: bool = True

num_units: int = field(init=False)
Expand Down Expand Up @@ -135,6 +140,7 @@ def to_dict(self):
"unit_ids": self.unit_ids,
"sampling_frequency": self.sampling_frequency,
"nbefore": self.nbefore,
"probe": self.probe.to_dict() if self.probe is not None else None,
}

@classmethod
Expand All @@ -146,6 +152,7 @@ def from_dict(cls, data):
unit_ids=np.asarray(data["unit_ids"]),
sampling_frequency=data["sampling_frequency"],
nbefore=data["nbefore"],
probe=data["probe"] if data["probe"] is None else Probe.from_dict(data["probe"]),
)

def to_json(self):
Expand Down Expand Up @@ -189,6 +196,9 @@ def __eq__(self, other):
return False
if not np.array_equal(s_field.channel_ids, o_field.channel_ids):
return False
elif isinstance(s_field, Probe):
# TODO implement __eq__ in probeinterface...
pass
else:
if s_field != o_field:
return False
Expand Down
11 changes: 11 additions & 0 deletions src/spikeinterface/core/tests/test_template_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from spikeinterface.core.template import Templates
from spikeinterface.core.sparsity import ChannelSparsity

from probeinterface import generate_multi_columns_probe


def generate_test_template(template_type):
num_units = 2
Expand All @@ -15,6 +17,8 @@ def generate_test_template(template_type):
sampling_frequency = 30_000
nbefore = 2

probe = generate_multi_columns_probe(num_columns=1, num_contact_per_column=[3])

if template_type == "dense":
return Templates(templates_array=templates_array, sampling_frequency=sampling_frequency, nbefore=nbefore)
elif template_type == "sparse": # sparse with sparse templates
Expand All @@ -35,6 +39,7 @@ def generate_test_template(template_type):
sparsity_mask=sparsity_mask,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
)

elif template_type == "sparse_with_dense_templates": # sparse with dense templates
Expand All @@ -45,6 +50,7 @@ def generate_test_template(template_type):
sparsity_mask=sparsity_mask,
sampling_frequency=sampling_frequency,
nbefore=nbefore,
probe=probe,
)


Expand Down Expand Up @@ -84,3 +90,8 @@ def test_get_dense_templates(template_type):
def test_initialization_fail_with_dense_templates():
with pytest.raises(ValueError, match="Sparsity mask passed but the templates are not sparse"):
template = generate_test_template(template_type="sparse_with_dense_templates")


if __name__ == "__main__":
# test_json_serialization("sparse")
test_json_serialization("dense")

0 comments on commit 0b821f3

Please sign in to comment.