Skip to content

Commit

Permalink
Merge branch 'main' into cli
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed Jun 14, 2024
2 parents 0130df6 + 9125010 commit 647d326
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 55 deletions.
25 changes: 14 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

![Overview](overview.png)

If you are using this code in your research, please cite our [preprint](https://arxiv.org/abs/2405.15700)
If you are using this code in your research, please cite our [preprint](https://arxiv.org/abs/2405.15700):
> Benjamin Gallusser and Martin Weigert<br>*Trackastra - Transformer-based cell tracking for live-cell microscopy*<br> arXiv, 2024
## Examples
Expand All @@ -30,6 +30,7 @@ pip install trackastra
For tracking with an integer linear program (ILP, which is optional)
```bash
conda create --name trackastra python=3.10 --no-default-packages
conda activate trackastra
conda install -c conda-forge -c gurobi -c funkelab ilpy
pip install trackastra[ilp]
```
Expand All @@ -41,12 +42,23 @@ Notes:

2. The [SCIP Optimizer](https://www.scipopt.org/), a free and open source solver. If `motile` does not find a valid Gurobi license, it will fall back to using SCIP.
- On MacOS, installing packages into the conda environment before installing `ilpy` can cause problems.
- 2024-06-07: On Apple M3 chips, you might have to use the nightly build of `torch` and `torchvision`, or worst case build them yourself.

## Usage

The input to *Trackastra* is a sequence of images and their corresponding cell (instance) segmentations.
The input to Trackastra is a sequence of images and their corresponding cell (instance) segmentations.

### Napari plugin

For a quick try of Trackastra on your data, please use our [napari plugin](https://github.com/weigertlab/napari-trackastra/), which already comes with pretrained models included.

![demo](https://github.com/weigertlab/napari-trackastra/assets/8866751/097eb82d-0fef-423e-9275-3fb528c20f7d)


### Tracking with a pretrained model

> The available pretrained models are described in detail [here](trackastra/model/pretrained.json).
Consider the following python example script for tracking already segmented cells. All you need are the following two numpy arrays:
- `imgs`: a microscopy time lapse of shape `time,(z),y,x`.
- `masks`: corresponding instance segmentation of shape `time,(z),y,x`.
Expand All @@ -61,8 +73,6 @@ Otherwise, no hyperparameters to choose :)

```python
import torch
import numpy as np
from trackastra.utils import normalize
from trackastra.model import Trackastra
from trackastra.tracking import graph_to_ctc, graph_to_napari_tracks
from trackastra.data import example_data_bacteria
Expand All @@ -72,9 +82,6 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
# load some test data images and masks
imgs, masks = example_data_bacteria()

# Normalize your images
imgs = np.stack([normalize(x) for x in imgs])

# Load a pretrained model
model = Trackastra.from_pretrained("general_2d", device=device)

Expand Down Expand Up @@ -106,10 +113,6 @@ v.add_labels(masks_tracked)
v.add_tracks(data=napari_tracks, graph=napari_tracks_graph)
```

### Napari plugin

We additionally provide a [napari plugin](https://github.com/weigertlab/napari-trackastra/) which allows one to quickly apply pretrained and custom models on custom timeseries.

### Training a model on your own data

To run an example
Expand Down
2 changes: 0 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@ install_requires =
wandb
edt
joblib
pydantic >= 2.0
pydantic_numpy
python_requires = >=3.10
include_package_data = True

Expand Down
5 changes: 3 additions & 2 deletions trackastra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# from . import cli
import os

from ._version import __version__, __version_tuple__

# from .cli import cli
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
2 changes: 0 additions & 2 deletions trackastra/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,11 @@

# from .data import CTCData
import tifffile
from pydantic import validate_call
from tqdm import tqdm

logger = logging.getLogger(__name__)


@validate_call
def load_tiff_timeseries(
dir: Path,
dtype: str | type | None = None,
Expand Down
37 changes: 22 additions & 15 deletions trackastra/data/wrfeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
import numpy as np
import pandas as pd
from edt import edt
from pydantic import validate_call
from pydantic_numpy import NpNDArray
from skimage.measure import regionprops, regionprops_table
from tqdm import tqdm

Expand Down Expand Up @@ -89,7 +87,10 @@ def __init__(
self.timepoints = timepoints

def __repr__(self):
s = f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}, ntimepoints={len(np.unique(self.timepoints))})\n\n"
s = (
f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)},"
f" ntimepoints={len(np.unique(self.timepoints))})\n\n"
)
for k, v in self.features.items():
s += f"{k:>20} -> {v.shape}\n"
return s
Expand Down Expand Up @@ -445,10 +446,9 @@ def __call__(self, feats: WRFeatures):
return feats


@validate_call
def get_features(
detections: NpNDArray,
imgs: NpNDArray | None = None,
detections: np.ndarray,
imgs: np.ndarray | None = None,
features: Literal["none", "wrfeat"] = "wrfeat",
ndim: int = 2,
n_workers=0,
Expand All @@ -458,25 +458,34 @@ def get_features(
imgs = _check_dimensions(imgs, ndim)
logger.info(f"Extracting features from {len(detections)} detections")
if n_workers > 0:
features = joblib.Parallel(n_jobs=n_workers, backend='multiprocessing')(
features = joblib.Parallel(n_jobs=n_workers, backend="multiprocessing")(
joblib.delayed(WRFeatures.from_mask_img)(
# New axis for time component
mask=mask[np.newaxis, ...],
img=img[np.newaxis, ...],
t_start=t,
)
for t, (mask, img) in progbar_class(enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features")
for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)
else:
logger.info("Using single process for feature extraction")
features = tuple(WRFeatures.from_mask_img(
features = tuple(
WRFeatures.from_mask_img(
mask=mask[np.newaxis, ...],
img=img[np.newaxis, ...],
t_start=t,
)
for t, (mask, img) in progbar_class(enumerate(zip(detections, imgs)), total=len(imgs), desc="Extracting features")
)

for t, (mask, img) in progbar_class(
enumerate(zip(detections, imgs)),
total=len(imgs),
desc="Extracting features",
)
)

return features


Expand All @@ -495,9 +504,7 @@ def _check_dimensions(x: np.ndarray, ndim: int):


def build_windows(
features: list[WRFeatures],
window_size: int,
progbar_class=tqdm
features: list[WRFeatures], window_size: int, progbar_class=tqdm
) -> list[dict]:
windows = []
for t1, t2 in progbar_class(
Expand Down
45 changes: 33 additions & 12 deletions trackastra/model/model_api.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import logging
import os
from pathlib import Path
from typing import Literal

import numpy as np
import torch
import yaml
from pydantic import validate_call
from tqdm import tqdm

from ..data import build_windows, get_features, load_tiff_timeseries
Expand All @@ -18,26 +19,46 @@


class Trackastra:
def __init__(self, transformer, train_args, device="cpu"):
if device == "mps":
raise NotImplementedError("Trackastra on mps not supported.")
# Hack: to(device) for some more submodules that map_location does cover
self.transformer = transformer.to(device)
def __init__(
self,
transformer: TrackingTransformer,
train_args: dict,
device: str | None = None,
):
if device is None:
should_use_mps = (
torch.backends.mps.is_available()
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None
and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0"
)
self.device = (
"cuda"
if torch.cuda.is_available()
else (
"mps"
if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK")
else "cpu"
)
)
else:
self.device = device

print(f"Using device {self.device}")

self.transformer = transformer.to(self.device)
self.train_args = train_args
self.device = device

@classmethod
@validate_call
def from_folder(cls, dir: Path, device: str = "cpu"):
transformer = TrackingTransformer.from_folder(dir, map_location=device)
def from_folder(cls, dir: Path, device: str | None = None):
# Always load to cpu first
transformer = TrackingTransformer.from_folder(dir, map_location="cpu")
train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader)
return cls(transformer=transformer, train_args=train_args, device=device)

# TODO make safer
@classmethod
@validate_call
def from_pretrained(
cls, name: str, device: str = "cpu", download_dir: Path | None = None
cls, name: str, device: str | None = None, download_dir: Path | None = None
):
folder = download_pretrained(name, download_dir)
# download zip from github to location/name, then unzip
Expand Down
54 changes: 54 additions & 0 deletions trackastra/model/pretrained.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
{
"general_2d": {
"tags": ["cells, nuclei, bacteria, epithelial"],
"dimensionality": [2],
"description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane.",
"url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip",
"datasets": {
"Subset of Cell Tracking Challenge 2d datasets": {
"url": "https://celltrackingchallenge.net/2d-datasets/",
"reference": "Maška M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, Nečasová T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Löffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
},
"Bacteria van Vliet": {
"url": "https://zenodo.org/records/268921",
"reference": "van Vliet S, Winkler AR, Spriewald S, Stecher B, Ackermann M. Spatially correlated gene expression in bacterial groups: the role of lineage history, spatial gradients, and cell-cell interactions. Cell systems. 2018 Apr 25;6(4):496-507."
},
"Bacteria ObiWan-Microbi": {
"url": "https://zenodo.org/records/7260137",
"reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, Nöh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638."
},
"DeepCell": {
"url": "https://datasets.deepcell.org/data",
"reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205."
},
"Ker phase contrast": {
"url": "https://osf.io/ysaq2/",
"reference": "Ker DF, Eom S, Sanami S, Bise R, Pascale C, Yin Z, Huh SI, Osuna-Highley E, Junkers SN, Helfrich CJ, Liang PY. Phase contrast time-lapse microscopy datasets with automated and manual cell tracking annotations. Scientific data. 2018 Nov 13;5(1):1-2."
},
"Epithelia benchmark": {
"reference": "Funke J, Mais L, Champion A, Dye N, Kainmueller D. A benchmark for epithelial cell tracking. InProceedings of The European Conference on Computer Vision (ECCV) Workshops 2018 (pp. 0-0)."
},
"T Cells": {
"url": "https://zenodo.org/records/5206119"
},
"Neisseria meningitidis bacterial growth": {
"url": "https://zenodo.org/records/5419619"
},
"Synthetic nuclei": {
"reference": "Weigert group live cell simulator."
}
}
},
"ctc": {
"tags": ["ctc", "isbi2024"],
"dimensionality": [2, 3],
"description": "For tracking Cell Tracking Challenge datasets. Winner of the ISBI 2024 CTC generalizable linking challenge.",
"url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip",
"datasets": {
"All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": {
"url": "https://celltrackingchallenge.net/3d-datasets/",
"reference": "Maška M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, Nečasová T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Löffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20."
}
}
}
}
9 changes: 5 additions & 4 deletions trackastra/model/pretrained.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from pathlib import Path

import requests
from pydantic import validate_call
from tqdm import tqdm

logger = logging.getLogger(__name__)

_MODELS = {
"ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip",
"ctc": (
"https://github.com/weigertlab/trackastra-models/releases/download/v0.1/ctc.zip"
),
"general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.1.1/general_2d.zip",
}

Expand Down Expand Up @@ -53,7 +54,6 @@ def download(url: str, fname: Path):
bar.update(size)


@validate_call
def download_pretrained(name: str, download_dir: Path | None = None):
# TODO make safe, introduce versioning
if download_dir is None:
Expand All @@ -66,7 +66,8 @@ def download_pretrained(name: str, download_dir: Path | None = None):
url = _MODELS[name]
except KeyError:
raise ValueError(
f"Pretrained model `name` is not available. Choose from {list(_MODELS.keys())}"
"Pretrained model `name` is not available. Choose from"
f" {list(_MODELS.keys())}"
)
folder = download_dir / name
download_and_unzip(url=url, dst=folder)
Expand Down
1 change: 1 addition & 0 deletions trackastra/tracking/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
track_greedy,
)
from .utils import (
ctc_to_napari_tracks,
graph_to_ctc,
graph_to_napari_tracks,
linear_chains,
Expand Down
Loading

0 comments on commit 647d326

Please sign in to comment.