Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add inference API and task for all BioImage.IO Model Zoo models #338

Merged
merged 35 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
a2d95b6
refactor(zoo): show `nickname` and `name` for bioimage.io models
qin-yu Sep 26, 2024
33b6c39
refactor(zoo): improve code
qin-yu Sep 26, 2024
be666b2
gui(zoo): set model choices font
qin-yu Sep 26, 2024
ea484c7
gui: set fixed width for widget dock
qin-yu Sep 26, 2024
28eea70
Merge remote-tracking branch 'origin' into qy/unify-prediction
qin-yu Oct 9, 2024
4220f61
Merge remote-tracking branch 'origin/master' into qy/unify-prediction
qin-yu Nov 26, 2024
e66fb1f
chore(macOS): development environment for Apple Sillicon
qin-yu Nov 26, 2024
38f993a
fix(gui): always filters under mode
qin-yu Nov 26, 2024
841ea4d
refactor(zoo): use `ModelZoo.models_bioimageio`
qin-yu Dec 3, 2024
46736a2
test(zoo): `ModelZoo.models_bioimageio` is manually initialised
qin-yu Dec 3, 2024
1e119bf
refactor(zoo): use `.models_bioimageio` DataFrame instead of dict
qin-yu Dec 3, 2024
d204ee7
refactor(zoo): use `pydantic.HttpUrl` for URLs
qin-yu Dec 3, 2024
2cd47a4
fix(zoo): fix bioimageio model zoo df initialisation
qin-yu Dec 3, 2024
3656496
fix(zoo): update bioimageio model zoo collection file link
qin-yu Dec 5, 2024
c0ceb3a
feat(pred)!: use `bioimageio.core` for BioImage.IO Model Zoo model in…
qin-yu Dec 14, 2024
3585207
docs: explain functionals/tasks/widgets due to #371
qin-yu Dec 14, 2024
97e7c04
fix: `plantseg: command not found` for dev env
qin-yu Dec 16, 2024
f14afc1
fix: bioimageio `Tensor` needs `AxisId` in some versions
qin-yu Dec 16, 2024
7a16e04
feat: use `predict_sample_with_blocking`
qin-yu Dec 16, 2024
f43a529
feat: support both `blocksize_parameter` and `input_block_shape`
qin-yu Dec 16, 2024
cf0e69b
feat: support arbitrary network output suck as Cellpose
qin-yu Dec 16, 2024
fca44de
fix(ci): fix conda build CI `setuptools` missing etc.
qin-yu Dec 16, 2024
74d644e
fix(docs): missing index page for API
qin-yu Dec 16, 2024
c6dd2a1
ci: inspect why action no.95 and no.96 can't finish
qin-yu Dec 17, 2024
f300890
fix: temp fix for ci hanging on napari GUI
qin-yu Dec 17, 2024
f45d5c0
Merge branch 'master' into qy/unify-prediction
qin-yu Dec 17, 2024
6deb8c1
feat: bioimageio.core prediction functional and task
qin-yu Dec 18, 2024
f4efde4
refactor: improve naming of bioimageio.core output
qin-yu Dec 18, 2024
bff023c
fix: prediction widget is not updated by import
qin-yu Dec 18, 2024
d8eef33
refactor: support more formats of bioimageio axes specs
qin-yu Dec 18, 2024
a01e65c
test: bioimage.io core prediction functional
qin-yu Dec 18, 2024
cfb9c14
fix: ZCYX doesn't throw error but corrected to CZYX
qin-yu Dec 19, 2024
b539c62
test: bioimage.io core prediction task
qin-yu Dec 19, 2024
3873575
fix: fix parts of #373
qin-yu Dec 19, 2024
fabae0a
feat: show both model ID and name for BioImage.IO models
qin-yu Dec 19, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/build-and-publish-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches:
- master
- qy/mask-pred
- qy/unify-prediction

permissions:
contents: write
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
shell: bash -l {0}
run: |
conda activate plant-seg
pytest --cov --cov-report=xml
pytest -s --cov --cov-report=xml
conda deactivate

# Upload Codecov report
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ docs/_build/

# Codecov
.coverage

# macOS
.DS_Store
1 change: 1 addition & 0 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ requirements:
build:
- python
- pip
- setuptools

run:
- python >=3.9
Expand Down
4 changes: 4 additions & 0 deletions docs/chapters/getting_started/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ To install PlantSeg in development mode, run:
pip install -e . --no-deps
```

## Hierarchical Design of PlantSeg

Please refer to [Python API](../python_api/index.md).

## Coding Style

PlantSeg uses _Ruff_ for linting and formatting. _Ruff_ is compatible with _Black_ for formatting. Ensure you have _Black_ set as the formatter with a line length of 120.
Expand Down
7 changes: 7 additions & 0 deletions docs/chapters/python_api/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Hierarchical Design of PlantSeg

PlantSeg is organized into three layers:

1. Functionals (Python API): The foundational layer of PlantSeg, providing its core functionality. This layer can be accessed directly in Python scripts or Jupyter notebooks.
2. Tasks: The intermediate layer of PlantSeg, which encapsulates the functionals to handle resource management and support distributed computing.
3. Napari Widgets: The top layer of PlantSeg, which integrates tasks into user-friendly widgets for easy interaction within graphical interfaces.
46 changes: 46 additions & 0 deletions environment-dev-apple.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: plant-seg-dev
channels:
- pytorch
- conda-forge
# `defaults` is optional, unless e.g. `conda-forge` has no cudnn 9.* when `defaults` has.
# `defaults` of Anaconda is not accessible for many non-profit institutes such as EMBL.
# - defaults
dependencies:
- python
# Neural Network and GPU
- pytorch::pytorch
- torchvision
# Bioimage and CV
- tifffile
- h5py
- zarr
- vigra
- python-elf
- python-graphviz
- scikit-image
- bioimageio.core>=0.6.5
# GUI
- pyqt
- napari
# Other
- requests
- pyyaml
- pydantic>2,<2.10 # 2.10 cause problem spec-bioimage-io/issues/663
# Test
- pytest
- pytest-qt
- pytest-mock
- requests-mock
# CI/CD
- pre-commit
- bump-my-version
# Docs
- mkdocs-material
- mkdocs-autorefs
- mkdocs-git-revision-date-localized-plugin
- mkdocs-git-committers-plugin-2
- mkdocstrings-python
- pip:
- markdown-exec
# PlantSeg
- -e .
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ site_description: Cell instance aware segmentation in densely packed 3D volumetr
repo_name: kreshuklab/plant-seg
repo_url: https://github.com/kreshuklab/plant-seg
edit_uri: edit/main/docs/
copyright: Copyright &copy; 2019 - 2024 Lorenzo Cerrone, Adrian Wolny, Qin Yu
copyright: Copyright &copy; 2019 - 2025 Lorenzo Cerrone, Adrian Wolny, Qin Yu

theme:
name: material
Expand Down Expand Up @@ -116,6 +116,7 @@ nav:
- Training: chapters/plantseg_models/training.md

- API:
- chapters/python_api/index.md
- tasks:
- plantseg.tasks.io_tasks: chapters/python_api/tasks/io_tasks.md
- plantseg.tasks.dataprocessing_tasks: chapters/python_api/tasks/dataprocessing_tasks.md
Expand Down
5 changes: 4 additions & 1 deletion plantseg/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> tuple[n
return data[:, 0], properties

elif self.image_layout == ImageLayout.ZCYX:
raise ValueError(f"Image layout {self.image_layout} not supported, should have been converted to CZYX")
logger.warning("Image layout is ZCYX but should have been converted to CZYX. PlantSeg is doing this now.")
properties.image_layout = ImageLayout.CZYX
data = np.moveaxis(data, 0, 1)
return self._check_shape(data, properties)

return data, properties

Expand Down
110 changes: 78 additions & 32 deletions plantseg/core/zoo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Model Zoo Singleton"""

# pylint: disable=C0116,C0103

import json
import logging
from enum import Enum
Expand All @@ -14,7 +16,7 @@
from bioimageio.spec.model.v0_5 import ModelDescr as ModelDescr_v0_5
from bioimageio.spec.utils import download
from pandas import DataFrame, concat
from pydantic import AliasChoices, BaseModel, Field, model_validator
from pydantic import AliasChoices, BaseModel, Field, HttpUrl, model_validator
from torch.nn import Conv2d, Conv3d, MaxPool2d, MaxPool3d, Module

from plantseg import (
Expand All @@ -37,18 +39,15 @@
USER = 'user'


BIOIMAGE_IO_COLLECTION_URL = (
"https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json"
)
BIOIMAGE_IO_COLLECTION_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/collection.json"


class ModelZooRecord(BaseModel):
"""Model Zoo Record"""

name: str
url: Optional[str] = Field(None, validation_alias=AliasChoices('model_url', 'url'))
url: Optional[HttpUrl] = Field(None, validation_alias=AliasChoices('model_url', 'url'))
path: Optional[str] = None
id: Optional[str] = None
description: Optional[str] = None
resolution: Optional[tuple[float, float, float]] = None
dimensionality: Optional[str] = None
Expand All @@ -58,11 +57,25 @@
doi: Optional[str] = None
added_by: Optional[str] = None

# BioImage.IO models specific fields. TODO: unify.
id: Optional[str] = None
name_display: Optional[str] = None
rdf_source: Optional[HttpUrl] = None
supported: Optional[bool] = None

@model_validator(mode='after')
def check_one_id_present(self) -> Self:
"""Check that one of url (zenodo), path (custom/local) or id (bioimage.io) is present"""
if self.url is None and self.path is None and self.id is None:
raise ValueError(f'One of url, path or id must be present: {self}')
raise ValueError(f'One of `url`, `path` or `id` must be present: {self}')

Check warning on line 70 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L70

Added line #L70 was not covered by tests
return self

@model_validator(mode='after')
def check_id_fields_present(self) -> Self:
if self.id is not None and (self.name_display is None or self.rdf_source is None or self.supported is None):
raise ValueError(

Check warning on line 76 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L76

Added line #L76 was not covered by tests
f'If `id` exists, then `name_display`, `rdf_source` and `supported` must be present: {self}'
)
return self


Expand All @@ -82,7 +95,6 @@
_zoo_custom_dict: dict = {}
_bioimageio_zoo_collection: dict = {}
_bioimageio_zoo_all_model_url_dict: dict = {}
_bioimageio_zoo_plantseg_model_url_dict: dict = {}

path_zoo: Path = PATH_MODEL_ZOO
path_zoo_custom: Path = PATH_MODEL_ZOO_CUSTOM
Expand Down Expand Up @@ -328,13 +340,13 @@
https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.8401064/8429203/rdf.yaml
"""

if not self._bioimageio_zoo_all_model_url_dict:
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()

if model_id not in self._bioimageio_zoo_all_model_url_dict:
if model_id not in self.models_bioimageio.index:
raise ValueError(f"Model ID {model_id} not found in BioImage.IO Model Zoo")

rdf_url = self._bioimageio_zoo_all_model_url_dict[model_id]
rdf_url = self.models_bioimageio.at[model_id, 'rdf_source']
model_description = load_description(rdf_url)

# Check if description is `ResourceDescr`
Expand All @@ -360,6 +372,8 @@
elif isinstance(model_description, ModelDescr_v0_5): # then it is `ArchitectureDescr` with `callable`
architecture_callable = model_description.weights.pytorch_state_dict.architecture.callable
architecture_kwargs = model_description.weights.pytorch_state_dict.architecture.kwargs
else:
raise ValueError(f"Unsupported model description format: {type(model_description).__name__}")

Check warning on line 376 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L376

Added line #L376 was not covered by tests
logger_zoo.info(f"Got {architecture_callable} model with kwargs {architecture_kwargs}.")

# Create model from architecture and kwargs
Expand All @@ -382,25 +396,54 @@
logger_zoo.info(f"Loaded model from BioImage.IO Model Zoo: {model_id}")
return model, model_config, model_weights_path

def _init_bioimageio_zoo_df(self) -> None:
records = []
for _, model in self._bioimageio_zoo_all_model_url_dict.items():
records.append(ModelZooRecord(**model, added_by=Author.BIOIMAGEIO).model_dump())

self.models_bioimageio = DataFrame(
records,
columns=list(ModelZooRecord.model_fields.keys()),
).set_index('id')

def refresh_bioimageio_zoo_urls(self):
"""Initialize the BioImage.IO Model Zoo collection and URL dictionaries.

The BioImage.IO Model Zoo collection is not downloaded during ModelZoo initialization to avoid unnecessary
network requests. This method downloads the collection and extracts the model URLs for all models.

Note that `models_bioimageio` doesn't exist until this method is called.
"""
logger_zoo.info(f"Fetching BioImage.IO Model Zoo collection from {BIOIMAGE_IO_COLLECTION_URL}")
collection_path = Path(pooch.retrieve(BIOIMAGE_IO_COLLECTION_URL, known_hash=None))
with collection_path.open(encoding='utf-8') as f:
collection = json.load(f)

def get_id(entry):
return entry["id"] if "nickname" not in entry else entry["nickname"]

models = [entry for entry in collection["collection"] if entry["type"] == "model"]
max_nickname_length = max(len(get_id(entry)) for entry in models)

def truncate_name(name, length=100):
return name[:length] + '...' if len(name) > length else name

def build_model_url_dict(filter_func=None):
filtered_models = filter(filter_func, models) if filter_func else models
return {
entry['name']: {
"id": get_id(entry),
"name": entry["name"],
"name_display": f"{get_id(entry):<{max_nickname_length}}: {truncate_name(entry['name'])}",
"rdf_source": entry["rdf_source"],
"supported": self._is_plantseg_model(entry),
}
for entry in filtered_models
}

self._bioimageio_zoo_collection = collection
self._bioimageio_zoo_all_model_url_dict = {
entry["nickname"]: entry["rdf_source"] for entry in collection["collection"] if entry["type"] == "model"
}
self._bioimageio_zoo_plantseg_model_url_dict = {
entry["nickname"]: entry["rdf_source"]
for entry in collection["collection"]
if entry["type"] == "model" and self._is_plantseg_model(entry)
}
self._bioimageio_zoo_all_model_url_dict = build_model_url_dict()
self._init_bioimageio_zoo_df()

def _is_plantseg_model(self, collection_entry: dict) -> bool:
"""Determines if the 'tags' field in a collection entry contains the keyword 'plantseg'."""
Expand All @@ -414,23 +457,26 @@
normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags]
return 'plantseg' in normalized_tags

def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'."""
if not self._bioimageio_zoo_plantseg_model_url_dict:
def get_bioimageio_zoo_all_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo."""
if not hasattr(self, 'models_bioimageio'):

Check warning on line 462 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L462

Added line #L462 was not covered by tests
self.refresh_bioimageio_zoo_urls()
return sorted(list(self._bioimageio_zoo_plantseg_model_url_dict.keys()))
id_name = self.models_bioimageio[['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

Check warning on line 465 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L464-L465

Added lines #L464 - L465 were not covered by tests

def get_bioimageio_zoo_all_model_names(self) -> list[str]:
"""Return a list of all model names in the BioImage.IO Model Zoo."""
if not self._bioimageio_zoo_all_model_url_dict:
def get_bioimageio_zoo_plantseg_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
return sorted(list(self._bioimageio_zoo_all_model_url_dict.keys()))
id_name = self.models_bioimageio[self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

def get_bioimageio_zoo_other_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
return sorted(
list(set(self.get_bioimageio_zoo_all_model_names()) - set(self.get_bioimageio_zoo_plantseg_model_names()))
)
def get_bioimageio_zoo_other_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
id_name = self.models_bioimageio[~self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

Check warning on line 479 in plantseg/core/zoo.py

View check run for this annotation

Codecov / codecov/patch

plantseg/core/zoo.py#L476-L479

Added lines #L476 - L479 were not covered by tests

def _flatten_module(self, module: Module) -> list[Module]:
"""Recursively flatten a PyTorch nn.Module into a list of its elemental layers."""
Expand Down
3 changes: 2 additions & 1 deletion plantseg/functionals/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from plantseg.functionals.prediction.prediction import unet_prediction
from plantseg.functionals.prediction.prediction import biio_prediction, unet_prediction

# Use __all__ to let type checkers know what is part of the public API.
__all__ = [
"unet_prediction",
"biio_prediction",
]
Loading
Loading