Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed May 30, 2024
1 parent ca98874 commit b11fd1f
Show file tree
Hide file tree
Showing 6 changed files with 100 additions and 73 deletions.
10 changes: 5 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,21 @@ description = "trackastra tracking"
readme = "README.md"
license = {file = "LICENSE"}
authors = [
{name = "Martin Weigert"},
{email = "[email protected]"},
{name = "Martin Weigert, Benjamin Gallusser"},
{email = "[email protected], [email protected]"},
]
classifiers = [
"Development Status :: 2 - Pre-Alpha",
"Development Status :: 3 - Alpha",
"Framework :: napari",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Image Processing",
]
requires-python = ">=3.9"
requires-python = ">=3.10"
dependencies = [
"napari",
"numpy",
Expand Down Expand Up @@ -87,6 +86,7 @@ lint.select = [
"SIM", # flake8-simplify
]
lint.ignore = [
"F401",
"E501", # line too long. let black handle this
"UP006", "UP007", # type annotation. As using magicgui require runtime type annotation then we disable this.
"SIM117", # flake8-simplify - some of merged with statements are not looking great with black, reanble after drop python 3.9
Expand Down
2 changes: 0 additions & 2 deletions src/napari_trackastra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
__version__ = "0.0.1"
from ._widget import Tracker


21 changes: 13 additions & 8 deletions src/napari_trackastra/_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,24 @@
Replace code below according to your needs.
"""

from __future__ import annotations

from pathlib import Path
import numpy
import tifffile
from trackastra import data


def test_data_bacteria() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.example_data_bacteria()
return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')]
def example_data_bacteria() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.example_data_bacteria()
return [
(imgs, {"name": "img"}, "image"),
(masks, {"name": "mask"}, "labels"),
]


def test_data_hela() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.example_data_hela()
return [(imgs, dict(name='img'), 'image'), (masks, dict(name='mask'), 'labels')]
def example_data_hela() -> list[tuple[numpy.ndarray, dict, str]]:
imgs, masks = data.example_data_hela()
return [
(imgs, {"name": "img"}, "image"),
(masks, {"name": "mask"}, "labels"),
]
9 changes: 4 additions & 5 deletions src/napari_trackastra/_tests/test_widget.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import numpy as np
import napari
from trackastra.data import example_data_bacteria

from napari_trackastra._widget import Tracker
from trackastra.data import test_data_bacteria


def test_widget():
viewer = napari.Viewer()
img, mask = test_data_bacteria()
img, mask = example_data_bacteria()
viewer.add_image(img)
viewer.add_labels(mask)

viewer.window.add_dock_widget(Tracker(viewer))


if __name__ == "__main__":
test_widget()

napari.run()

119 changes: 72 additions & 47 deletions src/napari_trackastra/_widget.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,85 @@

import torch
import numpy as np


import napari
from magicgui import magic_factory, magicgui
from magicgui.widgets import CheckBox, Container, create_widget, PushButton, FileEdit, ComboBox, RadioButtons
from pathlib import Path
from typing import List, OrderedDict
from napari.utils import progress
import numpy as np
import torch
import trackastra
from trackastra.utils import normalize
from magicgui.widgets import (
ComboBox,
Container,
FileEdit,
PushButton,
RadioButtons,
create_widget,
)
from napari.utils import progress
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks

device = "cuda" if torch.cuda.is_available() else "cpu"
from trackastra.utils import normalize

device = "cuda" if torch.cuda.is_available() else "cpu"


# logo = Path(__file__).parent/"resources"/"trackastra_logo_small.png"
# logo_html = f"""<div style="display: flex;
# logo_html = f"""<div style="display: flex;
# align-items: center;">
# <img src="{logo}" alt="Logo" style="margin-right: 50px; width: 30px; height: 30px;">
# <img src="{logo}" alt="Logo" style="margin-right: 50px; width: 30px; height: 30px;">
# <span style="line-height: 50px;">
# Trackastra
# </div>"""



def _track_function(model, imgs, masks, mode="greedy", **kwargs):
def _track_function(model, imgs, masks, mode="greedy", **kwargs):
print("Normalizing...")
imgs = np.stack([normalize(x) for x in imgs])
print(f"Tracking with mode {mode}...")
track_graph = model.track(imgs, masks, mode=mode,
max_distance=128,
progbar_class=progress,
**kwargs) # or mode="ilp"
track_graph = model.track(
imgs,
masks,
mode=mode,
max_distance=128,
progbar_class=progress,
**kwargs,
) # or mode="ilp"
# Visualise in napari
df, masks_tracked = graph_to_ctc(track_graph,masks,outdir=None)
df, masks_tracked = graph_to_ctc(track_graph, masks, outdir=None)
napari_tracks, napari_tracks_graph, _ = graph_to_napari_tracks(track_graph)
return track_graph, masks_tracked, napari_tracks



class Tracker(Container):
def __init__(self, viewer: "napari.viewer.Viewer"):
super().__init__()
self._viewer = viewer
self._label = create_widget(widget_type="Label", label="<h2>Trackastra</h2>")
self._image_layer = create_widget(label="Images", annotation="napari.layers.Image")


self._mask_layer = create_widget(label="Masks", annotation="napari.layers.Labels")
self._model_type = RadioButtons(label="Model Type", choices=["Pretrained", "Custom"], orientation="horizontal", value="Pretrained")
self._model_pretrained = ComboBox(label="Pretrained Model",
choices=tuple(trackastra.model.pretrained._MODELS.keys()), value="general_2d")
self._label = create_widget(
widget_type="Label", label="<h2>Trackastra</h2>"
)
self._image_layer = create_widget(
label="Images", annotation="napari.layers.Image"
)

self._mask_layer = create_widget(
label="Masks", annotation="napari.layers.Labels"
)
self._model_type = RadioButtons(
label="Model Type",
choices=["Pretrained", "Custom"],
orientation="horizontal",
value="Pretrained",
)
self._model_pretrained = ComboBox(
label="Pretrained Model",
choices=tuple(trackastra.model.pretrained._MODELS.keys()),
value="general_2d",
)
self._model_path = FileEdit(label="Model Path", mode="d")
self._model_path.hide()
self._run_button = PushButton(label="Track")

self._linking_mode = ComboBox(label="Linking",
choices=("greedy_nodiv","greedy", "ilp"), value="greedy")


self._linking_mode = ComboBox(
label="Linking",
choices=("greedy_nodiv", "greedy", "ilp"),
value="greedy",
)

self._out_mask, self._out_tracks = None, None

self._model_type.changed.connect(self._model_type_changed)
Expand Down Expand Up @@ -93,40 +111,47 @@ def _model_type_changed(self, event):

def _update_model(self, event=None):
if self._model_type.value == "Pretrained":
self.model = Trackastra.from_pretrained(self._model_pretrained.value, device=device)
self.model = Trackastra.from_pretrained(
self._model_pretrained.value, device=device
)
else:
self.model = Trackastra.from_folder(self._model_path.value, device=device)

self.model = Trackastra.from_folder(
self._model_path.value, device=device
)

def _show_activity_dock(self, state=True):
# show/hide activity dock if there is actual progress to see
self._viewer.window._status_bar._toggle_activity_dock(state)


def _run(self, event=None):
self._update_model()

if self.model is None:
raise ValueError("Model not loaded")

imgs = np.asarray(self._image_layer.value.data)
masks = np.asarray(self._mask_layer.value.data)

self._show_activity_dock(True)
track_graph, masks_tracked, napari_tracks = _track_function(self.model, imgs, masks, mode=self._linking_mode.value)

track_graph, masks_tracked, napari_tracks = _track_function(
self.model, imgs, masks, mode=self._linking_mode.value
)

self._mask_layer.value.visible = False
self._show_activity_dock(False)

lays = tuple(lay for lay in self._viewer.layers if lay.name=="masks_tracked")
lays = tuple(
lay for lay in self._viewer.layers if lay.name == "masks_tracked"
)
if len(lays) > 0:
lays[0].data = masks_tracked
else:
self._viewer.add_labels(masks_tracked, name="masks_tracked")

lays = tuple(lay for lay in self._viewer.layers if lay.name=="tracks")

lays = tuple(
lay for lay in self._viewer.layers if lay.name == "tracks"
)
if len(lays) > 0:
lays[0].data = napari_tracks
else:
self._viewer.add_tracks(napari_tracks, name="tracks")

12 changes: 6 additions & 6 deletions src/napari_trackastra/napari.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ visibility: public
categories: ["Segmentation"]
contributions:
commands:
- id: napari-trackastra.test_data_bacteria
python_name: napari_trackastra._sample_data:test_data_bacteria
- id: napari-trackastra.example_data_bacteria
python_name: napari_trackastra._sample_data:example_data_bacteria
title: Sample bacteria images and masks
- id: napari-trackastra.test_data_hela
python_name: napari_trackastra._sample_data:test_data_hela
- id: napari-trackastra.example_data_hela
python_name: napari_trackastra._sample_data:example_data_hela
title: Sample bacteria images and masks
- id: napari-trackastra.track
python_name: napari_trackastra:Tracker
title: Create Plugin
sample_data:
- command: napari-trackastra.test_data_bacteria
- command: napari-trackastra.example_data_bacteria
display_name: bacteria
key: unique_id.1
- command: napari-trackastra.test_data_hela
- command: napari-trackastra.example_data_hela
display_name: hela
key: unique_id.2
widgets:
Expand Down

0 comments on commit b11fd1f

Please sign in to comment.