Skip to content

Commit

Permalink
Merge pull request #1 from weigertlab/save-tracks
Browse files Browse the repository at this point in the history
Add saving tracks to CTC format
  • Loading branch information
bentaculum authored Jun 5, 2024
2 parents bedfa08 + a6631a9 commit 0889c3c
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 24 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
demo.gif
demo.mp4

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
include LICENSE
include README.md
include src/napari_trackastra/resources/trackastra_logo_small.png
include src/napari_trackastra/resources/icon.png

recursive-exclude * __pycache__
recursive-exclude * *.py[co]
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

A napari plugin for cell tracking with [`trackastra`](https://github.com/weigertlab/trackastra).

https://github.com/weigertlab/napari-trackastra/assets/11042162/2751e81a-3992-4c60-bceb-c3f340732435
![demo](https://github.com/weigertlab/napari-trackastra/assets/8866751/097eb82d-0fef-423e-9275-3fb528c20f7d)


## Installation
Expand All @@ -28,10 +28,10 @@ Notes:

## Usage

`trackastra` expects a timeseries of raw images and corresponding segmentations masks as input. We provide some demo data at
```
File > Open Sample > trackastra
```
- `trackastra` expects a timeseries of raw images and corresponding segmentations masks as input.
- We provide some demo data at `File > Open Sample > trackastra`.
- Tracked cells can be directly saved to [Cell Tracking Challenge format](https://celltrackingchallenge.net/datasets/).
- Results can be drag-and-dropped back into napari for inspection.

[napari]: https://github.com/napari/napari
[tox]: https://tox.readthedocs.io/en/latest/
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ dependencies = [
"qtpy",
"scikit-image",
"trackastra",
"napari-open-ctc",
"napari-ctc-io",
]

[project.optional-dependencies]
Expand Down
5 changes: 4 additions & 1 deletion src/napari_trackastra/_tests/test_demo_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def test_demo_widget():
viewer.add_image(img)
viewer.add_labels(mask)

viewer.window.add_dock_widget(Tracker(viewer))
tracker = Tracker(viewer)
viewer.window.add_dock_widget(tracker)
tracker._run()
tracker._save()


if __name__ == "__main__":
Expand Down
88 changes: 72 additions & 16 deletions src/napari_trackastra/_widget.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from pathlib import Path

import napari
import npe2
import numpy as np
import torch
import trackastra
Expand All @@ -10,20 +13,21 @@
RadioButtons,
create_widget,
)
from pathlib import Path
from napari import save_layers
from napari.utils import progress
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
from trackastra.utils import normalize
from trackastra.tracking import (
ctc_to_napari_tracks,
graph_to_ctc,
)

device = "cuda" if torch.cuda.is_available() else "cpu"


logo_path = Path(__file__).parent/"resources"/"trackastra_logo_small.png"
logo_path = Path(__file__).parent / "resources" / "trackastra_logo_small.png"


def _track_function(model, imgs, masks, mode="greedy", **kwargs):
print("Normalizing...")
imgs = np.stack([normalize(x) for x in imgs])
print(f"Tracking with mode {mode}...")
track_graph = model.track(
imgs,
Expand All @@ -33,19 +37,26 @@ def _track_function(model, imgs, masks, mode="greedy", **kwargs):
progbar_class=progress,
**kwargs,
) # or mode="ilp"

# Visualise in napari
df, masks_tracked = graph_to_ctc(track_graph, masks, outdir=None)
napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
return track_graph, masks_tracked, napari_tracks
napari_tracks, napari_tracks_graph = ctc_to_napari_tracks(
segmentation=masks_tracked, man_track=df
)

return track_graph, masks_tracked, napari_tracks, napari_tracks_graph


class Tracker(Container):
def __init__(self, viewer: "napari.viewer.Viewer"):
super().__init__()
self._viewer = viewer
self._label = create_widget(widget_type="Label",
label=f'<img src="{logo_path}"></img>')
self._image_layer = create_widget(label="Images", annotation="napari.layers.Image")
self._label = create_widget(
widget_type="Label", label=f'<img src="{logo_path}"></img>'
)
self._image_layer = create_widget(
label="Images", annotation="napari.layers.Image"
)

self._mask_layer = create_widget(
label="Masks", annotation="napari.layers.Labels"
Expand All @@ -63,7 +74,18 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
)
self._model_path = FileEdit(label="Model Path", mode="d")
self._model_path.hide()
self._run_button = PushButton(label="Track")
self._run_button = PushButton(label="TRACK")

self._save_button = PushButton(
label="SAVE\n(CTC format from masks-tracked + tracks)",
visible=False,
)
self._save_path = FileEdit(
label="Save tracks to",
mode="d",
value="~/Desktop/TRA",
visible=False,
)

self._linking_mode = ComboBox(
label="Linking",
Expand All @@ -78,6 +100,9 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
self._model_path.changed.connect(self._update_model)
self._run_button.changed.connect(self._run)

self._save_path.changed.connect(self._save)
self._save_button.changed.connect(self._save)

# append into/extend the container with your widgets
self.extend(
[
Expand All @@ -89,6 +114,8 @@ def __init__(self, viewer: "napari.viewer.Viewer"):
self._model_path,
self._linking_mode,
self._run_button,
self._save_path,
self._save_button,
]
)

Expand Down Expand Up @@ -124,8 +151,10 @@ def _run(self, event=None):
masks = np.asarray(self._mask_layer.value.data)

self._show_activity_dock(True)
track_graph, masks_tracked, napari_tracks = _track_function(
self.model, imgs, masks, mode=self._linking_mode.value
track_graph, masks_tracked, napari_tracks, napari_tracks_graph = (
_track_function(
self.model, imgs, masks, mode=self._linking_mode.value
)
)

self._mask_layer.value.visible = False
Expand All @@ -145,5 +174,32 @@ def _run(self, event=None):
if len(lays) > 0:
lays[0].data = napari_tracks
else:
self._viewer.add_tracks(napari_tracks, name="tracks", tail_length=5)

self._viewer.add_tracks(
napari_tracks,
graph=napari_tracks_graph,
name="tracks",
tail_length=5,
)

self._save_path.show()
self._save_button.show()

def _save(self, event=None):
pm = npe2.PluginManager.instance()

outdir = self._save_path.value
writer_contrib = pm.get_writer(
outdir,
["labels", "tracks"],
"napari-ctc-io",
)[0]

save_layers(
path=outdir,
layers=[
self._viewer.layers["masks_tracked"],
self._viewer.layers["tracks"],
],
plugin="napari-ctc-io",
_writer=writer_contrib,
)
3 changes: 2 additions & 1 deletion src/napari_trackastra/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ name: napari-trackastra
display_name: trackastra
# use 'hidden' to remove plugin from napari hub search results
visibility: public
categories: ["Segmentation"]
categories: ["Image Processing"]
icon: resources/icon.png
contributions:
commands:
- id: napari-trackastra.example_data_bacteria
Expand Down
Binary file added src/napari_trackastra/resources/icon.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 0889c3c

Please sign in to comment.