Skip to content

Commit

Permalink
Merge pull request #338 from kreshuklab/qy/unify-prediction
Browse files Browse the repository at this point in the history
Add `bioimageio.core` inference for BioImage.IO Model Zoo models
  • Loading branch information
qin-yu authored Dec 20, 2024
2 parents f7e8e4b + fabae0a commit f22e0c1
Show file tree
Hide file tree
Showing 26 changed files with 470 additions and 81 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build-and-publish-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ on:
push:
branches:
- master
- qy/mask-pred
- qy/unify-prediction

permissions:
contents: write
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/build-and-test-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
shell: bash -l {0}
run: |
conda activate plant-seg
pytest --cov --cov-report=xml
pytest -s --cov --cov-report=xml
conda deactivate
# Upload Codecov report
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ docs/_build/

# Codecov
.coverage

# macOS
.DS_Store
1 change: 1 addition & 0 deletions conda-recipe/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ requirements:
build:
- python
- pip
- setuptools

run:
- python >=3.9
Expand Down
4 changes: 4 additions & 0 deletions docs/chapters/getting_started/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ To install PlantSeg in development mode, run:
pip install -e . --no-deps
```

## Hierarchical Design of PlantSeg

Please refer to [Python API](../python_api/index.md).

## Coding Style

PlantSeg uses _Ruff_ for linting and formatting. _Ruff_ is compatible with _Black_ for formatting. Ensure you have _Black_ set as the formatter with a line length of 120.
Expand Down
7 changes: 7 additions & 0 deletions docs/chapters/python_api/index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Hierarchical Design of PlantSeg

PlantSeg is organized into three layers:

1. Functionals (Python API): The foundational layer of PlantSeg, providing its core functionality. This layer can be accessed directly in Python scripts or Jupyter notebooks.
2. Tasks: The intermediate layer of PlantSeg, which encapsulates the functionals to handle resource management and support distributed computing.
3. Napari Widgets: The top layer of PlantSeg, which integrates tasks into user-friendly widgets for easy interaction within graphical interfaces.
46 changes: 46 additions & 0 deletions environment-dev-apple.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
name: plant-seg-dev
channels:
- pytorch
- conda-forge
# `defaults` is optional, unless e.g. `conda-forge` has no cudnn 9.* when `defaults` has.
# `defaults` of Anaconda is not accessible for many non-profit institutes such as EMBL.
# - defaults
dependencies:
- python
# Neural Network and GPU
- pytorch::pytorch
- torchvision
# Bioimage and CV
- tifffile
- h5py
- zarr
- vigra
- python-elf
- python-graphviz
- scikit-image
- bioimageio.core>=0.6.5
# GUI
- pyqt
- napari
# Other
- requests
- pyyaml
- pydantic>2,<2.10 # 2.10 cause problem spec-bioimage-io/issues/663
# Test
- pytest
- pytest-qt
- pytest-mock
- requests-mock
# CI/CD
- pre-commit
- bump-my-version
# Docs
- mkdocs-material
- mkdocs-autorefs
- mkdocs-git-revision-date-localized-plugin
- mkdocs-git-committers-plugin-2
- mkdocstrings-python
- pip:
- markdown-exec
# PlantSeg
- -e .
3 changes: 2 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ site_description: Cell instance aware segmentation in densely packed 3D volumetr
repo_name: kreshuklab/plant-seg
repo_url: https://github.com/kreshuklab/plant-seg
edit_uri: edit/main/docs/
copyright: Copyright &copy; 2019 - 2024 Lorenzo Cerrone, Adrian Wolny, Qin Yu
copyright: Copyright &copy; 2019 - 2025 Lorenzo Cerrone, Adrian Wolny, Qin Yu

theme:
name: material
Expand Down Expand Up @@ -116,6 +116,7 @@ nav:
- Training: chapters/plantseg_models/training.md

- API:
- chapters/python_api/index.md
- tasks:
- plantseg.tasks.io_tasks: chapters/python_api/tasks/io_tasks.md
- plantseg.tasks.dataprocessing_tasks: chapters/python_api/tasks/dataprocessing_tasks.md
Expand Down
5 changes: 4 additions & 1 deletion plantseg/core/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,10 @@ def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> tuple[n
return data[:, 0], properties

elif self.image_layout == ImageLayout.ZCYX:
raise ValueError(f"Image layout {self.image_layout} not supported, should have been converted to CZYX")
logger.warning("Image layout is ZCYX but should have been converted to CZYX. PlantSeg is doing this now.")
properties.image_layout = ImageLayout.CZYX
data = np.moveaxis(data, 0, 1)
return self._check_shape(data, properties)

return data, properties

Expand Down
110 changes: 78 additions & 32 deletions plantseg/core/zoo.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Model Zoo Singleton"""

# pylint: disable=C0116,C0103

import json
import logging
from enum import Enum
Expand All @@ -14,7 +16,7 @@
from bioimageio.spec.model.v0_5 import ModelDescr as ModelDescr_v0_5
from bioimageio.spec.utils import download
from pandas import DataFrame, concat
from pydantic import AliasChoices, BaseModel, Field, model_validator
from pydantic import AliasChoices, BaseModel, Field, HttpUrl, model_validator
from torch.nn import Conv2d, Conv3d, MaxPool2d, MaxPool3d, Module

from plantseg import (
Expand All @@ -37,18 +39,15 @@ class Author(str, Enum):
USER = 'user'


BIOIMAGE_IO_COLLECTION_URL = (
"https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json"
)
BIOIMAGE_IO_COLLECTION_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/collection.json"


class ModelZooRecord(BaseModel):
"""Model Zoo Record"""

name: str
url: Optional[str] = Field(None, validation_alias=AliasChoices('model_url', 'url'))
url: Optional[HttpUrl] = Field(None, validation_alias=AliasChoices('model_url', 'url'))
path: Optional[str] = None
id: Optional[str] = None
description: Optional[str] = None
resolution: Optional[tuple[float, float, float]] = None
dimensionality: Optional[str] = None
Expand All @@ -58,11 +57,25 @@ class ModelZooRecord(BaseModel):
doi: Optional[str] = None
added_by: Optional[str] = None

# BioImage.IO models specific fields. TODO: unify.
id: Optional[str] = None
name_display: Optional[str] = None
rdf_source: Optional[HttpUrl] = None
supported: Optional[bool] = None

@model_validator(mode='after')
def check_one_id_present(self) -> Self:
"""Check that one of url (zenodo), path (custom/local) or id (bioimage.io) is present"""
if self.url is None and self.path is None and self.id is None:
raise ValueError(f'One of url, path or id must be present: {self}')
raise ValueError(f'One of `url`, `path` or `id` must be present: {self}')
return self

@model_validator(mode='after')
def check_id_fields_present(self) -> Self:
if self.id is not None and (self.name_display is None or self.rdf_source is None or self.supported is None):
raise ValueError(
f'If `id` exists, then `name_display`, `rdf_source` and `supported` must be present: {self}'
)
return self


Expand All @@ -82,7 +95,6 @@ class ModelZoo:
_zoo_custom_dict: dict = {}
_bioimageio_zoo_collection: dict = {}
_bioimageio_zoo_all_model_url_dict: dict = {}
_bioimageio_zoo_plantseg_model_url_dict: dict = {}

path_zoo: Path = PATH_MODEL_ZOO
path_zoo_custom: Path = PATH_MODEL_ZOO_CUSTOM
Expand Down Expand Up @@ -328,13 +340,13 @@ def get_model_by_id(self, model_id: str):
https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.8401064/8429203/rdf.yaml
"""

if not self._bioimageio_zoo_all_model_url_dict:
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()

if model_id not in self._bioimageio_zoo_all_model_url_dict:
if model_id not in self.models_bioimageio.index:
raise ValueError(f"Model ID {model_id} not found in BioImage.IO Model Zoo")

rdf_url = self._bioimageio_zoo_all_model_url_dict[model_id]
rdf_url = self.models_bioimageio.at[model_id, 'rdf_source']
model_description = load_description(rdf_url)

# Check if description is `ResourceDescr`
Expand All @@ -360,6 +372,8 @@ def get_model_by_id(self, model_id: str):
elif isinstance(model_description, ModelDescr_v0_5): # then it is `ArchitectureDescr` with `callable`
architecture_callable = model_description.weights.pytorch_state_dict.architecture.callable
architecture_kwargs = model_description.weights.pytorch_state_dict.architecture.kwargs
else:
raise ValueError(f"Unsupported model description format: {type(model_description).__name__}")
logger_zoo.info(f"Got {architecture_callable} model with kwargs {architecture_kwargs}.")

# Create model from architecture and kwargs
Expand All @@ -382,25 +396,54 @@ def get_model_by_id(self, model_id: str):
logger_zoo.info(f"Loaded model from BioImage.IO Model Zoo: {model_id}")
return model, model_config, model_weights_path

def _init_bioimageio_zoo_df(self) -> None:
records = []
for _, model in self._bioimageio_zoo_all_model_url_dict.items():
records.append(ModelZooRecord(**model, added_by=Author.BIOIMAGEIO).model_dump())

self.models_bioimageio = DataFrame(
records,
columns=list(ModelZooRecord.model_fields.keys()),
).set_index('id')

def refresh_bioimageio_zoo_urls(self):
"""Initialize the BioImage.IO Model Zoo collection and URL dictionaries.
The BioImage.IO Model Zoo collection is not downloaded during ModelZoo initialization to avoid unnecessary
network requests. This method downloads the collection and extracts the model URLs for all models.
Note that `models_bioimageio` doesn't exist until this method is called.
"""
logger_zoo.info(f"Fetching BioImage.IO Model Zoo collection from {BIOIMAGE_IO_COLLECTION_URL}")
collection_path = Path(pooch.retrieve(BIOIMAGE_IO_COLLECTION_URL, known_hash=None))
with collection_path.open(encoding='utf-8') as f:
collection = json.load(f)

def get_id(entry):
return entry["id"] if "nickname" not in entry else entry["nickname"]

models = [entry for entry in collection["collection"] if entry["type"] == "model"]
max_nickname_length = max(len(get_id(entry)) for entry in models)

def truncate_name(name, length=100):
return name[:length] + '...' if len(name) > length else name

def build_model_url_dict(filter_func=None):
filtered_models = filter(filter_func, models) if filter_func else models
return {
entry['name']: {
"id": get_id(entry),
"name": entry["name"],
"name_display": f"{get_id(entry):<{max_nickname_length}}: {truncate_name(entry['name'])}",
"rdf_source": entry["rdf_source"],
"supported": self._is_plantseg_model(entry),
}
for entry in filtered_models
}

self._bioimageio_zoo_collection = collection
self._bioimageio_zoo_all_model_url_dict = {
entry["nickname"]: entry["rdf_source"] for entry in collection["collection"] if entry["type"] == "model"
}
self._bioimageio_zoo_plantseg_model_url_dict = {
entry["nickname"]: entry["rdf_source"]
for entry in collection["collection"]
if entry["type"] == "model" and self._is_plantseg_model(entry)
}
self._bioimageio_zoo_all_model_url_dict = build_model_url_dict()
self._init_bioimageio_zoo_df()

def _is_plantseg_model(self, collection_entry: dict) -> bool:
"""Determines if the 'tags' field in a collection entry contains the keyword 'plantseg'."""
Expand All @@ -414,23 +457,26 @@ def _is_plantseg_model(self, collection_entry: dict) -> bool:
normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags]
return 'plantseg' in normalized_tags

def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'."""
if not self._bioimageio_zoo_plantseg_model_url_dict:
def get_bioimageio_zoo_all_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
return sorted(list(self._bioimageio_zoo_plantseg_model_url_dict.keys()))
id_name = self.models_bioimageio[['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

def get_bioimageio_zoo_all_model_names(self) -> list[str]:
"""Return a list of all model names in the BioImage.IO Model Zoo."""
if not self._bioimageio_zoo_all_model_url_dict:
def get_bioimageio_zoo_plantseg_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
return sorted(list(self._bioimageio_zoo_all_model_url_dict.keys()))
id_name = self.models_bioimageio[self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

def get_bioimageio_zoo_other_model_names(self) -> list[str]:
"""Return a list of model names in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
return sorted(
list(set(self.get_bioimageio_zoo_all_model_names()) - set(self.get_bioimageio_zoo_plantseg_model_names()))
)
def get_bioimageio_zoo_other_model_names(self) -> list[tuple[str, str]]:
"""Return a list of (model id, model display name) in the BioImage.IO Model Zoo not tagged with 'plantseg'."""
if not hasattr(self, 'models_bioimageio'):
self.refresh_bioimageio_zoo_urls()
id_name = self.models_bioimageio[~self.models_bioimageio["supported"]][['name_display']]
return sorted([(name, id) for id, name in id_name.itertuples()])

def _flatten_module(self, module: Module) -> list[Module]:
"""Recursively flatten a PyTorch nn.Module into a list of its elemental layers."""
Expand Down
3 changes: 2 additions & 1 deletion plantseg/functionals/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from plantseg.functionals.prediction.prediction import unet_prediction
from plantseg.functionals.prediction.prediction import biio_prediction, unet_prediction

# Use __all__ to let type checkers know what is part of the public API.
__all__ = [
"unet_prediction",
"biio_prediction",
]
Loading

0 comments on commit f22e0c1

Please sign in to comment.