Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Only load template_metrics extension on compute if keeping some metrics #3478

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2033,19 +2033,18 @@ def load(cls, sorting_analyzer):
return None

def load_run_info(self):
run_info = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
run_info_file = extension_folder / "run_info.json"
if run_info_file.is_file():
with open(str(run_info_file), "r") as f:
run_info = json.load(f)
else:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
Comment on lines -2042 to -2043
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning was here to be aware that the analyzer was computed before this run info.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take a look at the full current implementation here:


If the first warning is triggered, the second is always triggered (I think!). So the first warning isn’t needed, and this is deleted here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the second if block was meant to be indented under the zarr case, but I agree it works just as well to assign it to None in the first place and then do the warning at the end if it's still None.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did wonder if it was a missing indentation. This solution's a tiny bit neater as we only need one copy of the warning message. I like the idea of the run_info in general: great suggestion @jonahpearl :)

run_info = None

elif self.format == "zarr":
extension_group = self._get_zarr_extension_group(mode="r")
run_info = extension_group.attrs.get("run_info", None)

if run_info is None:
warnings.warn(f"Found no run_info file for {self.extension_name}, extension should be re-computed.")
self.run_info = run_info
Expand Down
19 changes: 12 additions & 7 deletions src/spikeinterface/postprocessing/template_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,13 +335,18 @@ def _run(self, verbose=False):
)

existing_metrics = []
tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# Check if we need to propogate any old metrics. If so, we'll do that.
# Otherwise, we'll avoid attempting to load an empty template_metrics.
if set(self.params["metrics_to_compute"]) != set(self.params["metric_names"]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a cool idea to make the comparison :)


tm_extension = self.sorting_analyzer.get_extension("template_metrics")
if (
delete_existing_metrics is False
and tm_extension is not None
and tm_extension.data.get("metrics") is not None
):
existing_metrics = tm_extension.params["metric_names"]

# append the metrics which were previously computed
for metric_name in set(existing_metrics).difference(metrics_to_compute):
Expand Down
Loading