diff --git a/.github/workflows/build-and-publish-docs.yml b/.github/workflows/build-and-publish-docs.yml index 1773007a..d4a33122 100644 --- a/.github/workflows/build-and-publish-docs.yml +++ b/.github/workflows/build-and-publish-docs.yml @@ -3,7 +3,7 @@ on: push: branches: - master - - qy/mask-pred + - qy/unify-prediction permissions: contents: write diff --git a/.github/workflows/build-and-test-package.yml b/.github/workflows/build-and-test-package.yml index bd9c8b14..43a0bd59 100644 --- a/.github/workflows/build-and-test-package.yml +++ b/.github/workflows/build-and-test-package.yml @@ -62,7 +62,7 @@ jobs: shell: bash -l {0} run: | conda activate plant-seg - pytest --cov --cov-report=xml + pytest -s --cov --cov-report=xml conda deactivate # Upload Codecov report diff --git a/.gitignore b/.gitignore index f5f1aa4d..0c3c4204 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,6 @@ docs/_build/ # Codecov .coverage + +# macOS +.DS_Store diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 8bf014fe..dfc1fca6 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -17,6 +17,7 @@ requirements: build: - python - pip + - setuptools run: - python >=3.9 diff --git a/docs/chapters/getting_started/contributing.md b/docs/chapters/getting_started/contributing.md index 7a31fa46..02e00378 100644 --- a/docs/chapters/getting_started/contributing.md +++ b/docs/chapters/getting_started/contributing.md @@ -17,6 +17,10 @@ To install PlantSeg in development mode, run: pip install -e . --no-deps ``` +## Hierarchical Design of PlantSeg + +Please refer to [Python API](../python_api/index.md). + ## Coding Style PlantSeg uses _Ruff_ for linting and formatting. _Ruff_ is compatible with _Black_ for formatting. Ensure you have _Black_ set as the formatter with a line length of 120. diff --git a/docs/chapters/python_api/index.md b/docs/chapters/python_api/index.md new file mode 100644 index 00000000..d65dabbe --- /dev/null +++ b/docs/chapters/python_api/index.md @@ -0,0 +1,7 @@ +# Hierarchical Design of PlantSeg + +PlantSeg is organized into three layers: + + 1. Functionals (Python API): The foundational layer of PlantSeg, providing its core functionality. This layer can be accessed directly in Python scripts or Jupyter notebooks. + 2. Tasks: The intermediate layer of PlantSeg, which encapsulates the functionals to handle resource management and support distributed computing. + 3. Napari Widgets: The top layer of PlantSeg, which integrates tasks into user-friendly widgets for easy interaction within graphical interfaces. diff --git a/environment-dev-apple.yaml b/environment-dev-apple.yaml new file mode 100755 index 00000000..4e99a587 --- /dev/null +++ b/environment-dev-apple.yaml @@ -0,0 +1,46 @@ +name: plant-seg-dev +channels: + - pytorch + - conda-forge + # `defaults` is optional, unless e.g. `conda-forge` has no cudnn 9.* when `defaults` has. + # `defaults` of Anaconda is not accessible for many non-profit institutes such as EMBL. + # - defaults +dependencies: + - python + # Neural Network and GPU + - pytorch::pytorch + - torchvision + # Bioimage and CV + - tifffile + - h5py + - zarr + - vigra + - python-elf + - python-graphviz + - scikit-image + - bioimageio.core>=0.6.5 + # GUI + - pyqt + - napari + # Other + - requests + - pyyaml + - pydantic>2,<2.10 # 2.10 cause problem spec-bioimage-io/issues/663 + # Test + - pytest + - pytest-qt + - pytest-mock + - requests-mock + # CI/CD + - pre-commit + - bump-my-version + # Docs + - mkdocs-material + - mkdocs-autorefs + - mkdocs-git-revision-date-localized-plugin + - mkdocs-git-committers-plugin-2 + - mkdocstrings-python + - pip: + - markdown-exec + # PlantSeg + - -e . diff --git a/mkdocs.yml b/mkdocs.yml index c9ae3d30..f2438fd2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -4,7 +4,7 @@ site_description: Cell instance aware segmentation in densely packed 3D volumetr repo_name: kreshuklab/plant-seg repo_url: https://github.com/kreshuklab/plant-seg edit_uri: edit/main/docs/ -copyright: Copyright © 2019 - 2024 Lorenzo Cerrone, Adrian Wolny, Qin Yu +copyright: Copyright © 2019 - 2025 Lorenzo Cerrone, Adrian Wolny, Qin Yu theme: name: material @@ -116,6 +116,7 @@ nav: - Training: chapters/plantseg_models/training.md - API: + - chapters/python_api/index.md - tasks: - plantseg.tasks.io_tasks: chapters/python_api/tasks/io_tasks.md - plantseg.tasks.dataprocessing_tasks: chapters/python_api/tasks/dataprocessing_tasks.md diff --git a/plantseg/core/image.py b/plantseg/core/image.py index 4c60eab5..b93cd13d 100644 --- a/plantseg/core/image.py +++ b/plantseg/core/image.py @@ -417,7 +417,10 @@ def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> tuple[n return data[:, 0], properties elif self.image_layout == ImageLayout.ZCYX: - raise ValueError(f"Image layout {self.image_layout} not supported, should have been converted to CZYX") + logger.warning("Image layout is ZCYX but should have been converted to CZYX. PlantSeg is doing this now.") + properties.image_layout = ImageLayout.CZYX + data = np.moveaxis(data, 0, 1) + return self._check_shape(data, properties) return data, properties diff --git a/plantseg/core/zoo.py b/plantseg/core/zoo.py index fb2d1fc1..554c3588 100644 --- a/plantseg/core/zoo.py +++ b/plantseg/core/zoo.py @@ -1,5 +1,7 @@ """Model Zoo Singleton""" +# pylint: disable=C0116,C0103 + import json import logging from enum import Enum @@ -14,7 +16,7 @@ from bioimageio.spec.model.v0_5 import ModelDescr as ModelDescr_v0_5 from bioimageio.spec.utils import download from pandas import DataFrame, concat -from pydantic import AliasChoices, BaseModel, Field, model_validator +from pydantic import AliasChoices, BaseModel, Field, HttpUrl, model_validator from torch.nn import Conv2d, Conv3d, MaxPool2d, MaxPool3d, Module from plantseg import ( @@ -37,18 +39,15 @@ class Author(str, Enum): USER = 'user' -BIOIMAGE_IO_COLLECTION_URL = ( - "https://raw.githubusercontent.com/bioimage-io/collection-bioimage-io/gh-pages/collection.json" -) +BIOIMAGE_IO_COLLECTION_URL = "https://uk1s3.embassy.ebi.ac.uk/public-datasets/bioimage.io/collection.json" class ModelZooRecord(BaseModel): """Model Zoo Record""" name: str - url: Optional[str] = Field(None, validation_alias=AliasChoices('model_url', 'url')) + url: Optional[HttpUrl] = Field(None, validation_alias=AliasChoices('model_url', 'url')) path: Optional[str] = None - id: Optional[str] = None description: Optional[str] = None resolution: Optional[tuple[float, float, float]] = None dimensionality: Optional[str] = None @@ -58,11 +57,25 @@ class ModelZooRecord(BaseModel): doi: Optional[str] = None added_by: Optional[str] = None + # BioImage.IO models specific fields. TODO: unify. + id: Optional[str] = None + name_display: Optional[str] = None + rdf_source: Optional[HttpUrl] = None + supported: Optional[bool] = None + @model_validator(mode='after') def check_one_id_present(self) -> Self: """Check that one of url (zenodo), path (custom/local) or id (bioimage.io) is present""" if self.url is None and self.path is None and self.id is None: - raise ValueError(f'One of url, path or id must be present: {self}') + raise ValueError(f'One of `url`, `path` or `id` must be present: {self}') + return self + + @model_validator(mode='after') + def check_id_fields_present(self) -> Self: + if self.id is not None and (self.name_display is None or self.rdf_source is None or self.supported is None): + raise ValueError( + f'If `id` exists, then `name_display`, `rdf_source` and `supported` must be present: {self}' + ) return self @@ -82,7 +95,6 @@ class ModelZoo: _zoo_custom_dict: dict = {} _bioimageio_zoo_collection: dict = {} _bioimageio_zoo_all_model_url_dict: dict = {} - _bioimageio_zoo_plantseg_model_url_dict: dict = {} path_zoo: Path = PATH_MODEL_ZOO path_zoo_custom: Path = PATH_MODEL_ZOO_CUSTOM @@ -328,13 +340,13 @@ def get_model_by_id(self, model_id: str): https://bioimage-io.github.io/collection-bioimage-io/rdfs/10.5281/zenodo.8401064/8429203/rdf.yaml """ - if not self._bioimageio_zoo_all_model_url_dict: + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() - if model_id not in self._bioimageio_zoo_all_model_url_dict: + if model_id not in self.models_bioimageio.index: raise ValueError(f"Model ID {model_id} not found in BioImage.IO Model Zoo") - rdf_url = self._bioimageio_zoo_all_model_url_dict[model_id] + rdf_url = self.models_bioimageio.at[model_id, 'rdf_source'] model_description = load_description(rdf_url) # Check if description is `ResourceDescr` @@ -360,6 +372,8 @@ def get_model_by_id(self, model_id: str): elif isinstance(model_description, ModelDescr_v0_5): # then it is `ArchitectureDescr` with `callable` architecture_callable = model_description.weights.pytorch_state_dict.architecture.callable architecture_kwargs = model_description.weights.pytorch_state_dict.architecture.kwargs + else: + raise ValueError(f"Unsupported model description format: {type(model_description).__name__}") logger_zoo.info(f"Got {architecture_callable} model with kwargs {architecture_kwargs}.") # Create model from architecture and kwargs @@ -382,25 +396,54 @@ def get_model_by_id(self, model_id: str): logger_zoo.info(f"Loaded model from BioImage.IO Model Zoo: {model_id}") return model, model_config, model_weights_path + def _init_bioimageio_zoo_df(self) -> None: + records = [] + for _, model in self._bioimageio_zoo_all_model_url_dict.items(): + records.append(ModelZooRecord(**model, added_by=Author.BIOIMAGEIO).model_dump()) + + self.models_bioimageio = DataFrame( + records, + columns=list(ModelZooRecord.model_fields.keys()), + ).set_index('id') + def refresh_bioimageio_zoo_urls(self): """Initialize the BioImage.IO Model Zoo collection and URL dictionaries. The BioImage.IO Model Zoo collection is not downloaded during ModelZoo initialization to avoid unnecessary network requests. This method downloads the collection and extracts the model URLs for all models. + + Note that `models_bioimageio` doesn't exist until this method is called. """ logger_zoo.info(f"Fetching BioImage.IO Model Zoo collection from {BIOIMAGE_IO_COLLECTION_URL}") collection_path = Path(pooch.retrieve(BIOIMAGE_IO_COLLECTION_URL, known_hash=None)) with collection_path.open(encoding='utf-8') as f: collection = json.load(f) + + def get_id(entry): + return entry["id"] if "nickname" not in entry else entry["nickname"] + + models = [entry for entry in collection["collection"] if entry["type"] == "model"] + max_nickname_length = max(len(get_id(entry)) for entry in models) + + def truncate_name(name, length=100): + return name[:length] + '...' if len(name) > length else name + + def build_model_url_dict(filter_func=None): + filtered_models = filter(filter_func, models) if filter_func else models + return { + entry['name']: { + "id": get_id(entry), + "name": entry["name"], + "name_display": f"{get_id(entry):<{max_nickname_length}}: {truncate_name(entry['name'])}", + "rdf_source": entry["rdf_source"], + "supported": self._is_plantseg_model(entry), + } + for entry in filtered_models + } + self._bioimageio_zoo_collection = collection - self._bioimageio_zoo_all_model_url_dict = { - entry["nickname"]: entry["rdf_source"] for entry in collection["collection"] if entry["type"] == "model" - } - self._bioimageio_zoo_plantseg_model_url_dict = { - entry["nickname"]: entry["rdf_source"] - for entry in collection["collection"] - if entry["type"] == "model" and self._is_plantseg_model(entry) - } + self._bioimageio_zoo_all_model_url_dict = build_model_url_dict() + self._init_bioimageio_zoo_df() def _is_plantseg_model(self, collection_entry: dict) -> bool: """Determines if the 'tags' field in a collection entry contains the keyword 'plantseg'.""" @@ -414,23 +457,26 @@ def _is_plantseg_model(self, collection_entry: dict) -> bool: normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags] return 'plantseg' in normalized_tags - def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]: - """Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'.""" - if not self._bioimageio_zoo_plantseg_model_url_dict: + def get_bioimageio_zoo_all_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo.""" + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() - return sorted(list(self._bioimageio_zoo_plantseg_model_url_dict.keys())) + id_name = self.models_bioimageio[['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) - def get_bioimageio_zoo_all_model_names(self) -> list[str]: - """Return a list of all model names in the BioImage.IO Model Zoo.""" - if not self._bioimageio_zoo_all_model_url_dict: + def get_bioimageio_zoo_plantseg_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo tagged with 'plantseg'.""" + if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() - return sorted(list(self._bioimageio_zoo_all_model_url_dict.keys())) + id_name = self.models_bioimageio[self.models_bioimageio["supported"]][['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) - def get_bioimageio_zoo_other_model_names(self) -> list[str]: - """Return a list of model names in the BioImage.IO Model Zoo not tagged with 'plantseg'.""" - return sorted( - list(set(self.get_bioimageio_zoo_all_model_names()) - set(self.get_bioimageio_zoo_plantseg_model_names())) - ) + def get_bioimageio_zoo_other_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo not tagged with 'plantseg'.""" + if not hasattr(self, 'models_bioimageio'): + self.refresh_bioimageio_zoo_urls() + id_name = self.models_bioimageio[~self.models_bioimageio["supported"]][['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) def _flatten_module(self, module: Module) -> list[Module]: """Recursively flatten a PyTorch nn.Module into a list of its elemental layers.""" diff --git a/plantseg/functionals/prediction/__init__.py b/plantseg/functionals/prediction/__init__.py index 8ce10d6b..956bd936 100644 --- a/plantseg/functionals/prediction/__init__.py +++ b/plantseg/functionals/prediction/__init__.py @@ -1,6 +1,7 @@ -from plantseg.functionals.prediction.prediction import unet_prediction +from plantseg.functionals.prediction.prediction import biio_prediction, unet_prediction # Use __all__ to let type checkers know what is part of the public API. __all__ = [ "unet_prediction", + "biio_prediction", ] diff --git a/plantseg/functionals/prediction/prediction.py b/plantseg/functionals/prediction/prediction.py index c8ae5e81..afa5d6d1 100644 --- a/plantseg/functionals/prediction/prediction.py +++ b/plantseg/functionals/prediction/prediction.py @@ -1,8 +1,16 @@ import logging from pathlib import Path +from typing import assert_never import numpy as np import torch +from bioimageio.core.axis import AxisId +from bioimageio.core.prediction import predict +from bioimageio.core.sample import Sample +from bioimageio.core.tensor import Tensor +from bioimageio.spec import load_model_description +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId from plantseg.core.zoo import model_zoo from plantseg.functionals.dataprocessing.dataprocessing import ImageLayout, fix_layout_to_CZYX, fix_layout_to_ZYX @@ -16,6 +24,108 @@ logger = logging.getLogger(__name__) +def biio_prediction( + raw: np.ndarray, + input_layout: ImageLayout, + model_id: str, +) -> dict[str, np.ndarray]: + assert isinstance(input_layout, str) + + model = load_model_description(model_id) + if isinstance(model, v0_4.ModelDescr): + input_ids = [input_tensor.name for input_tensor in model.inputs] + elif isinstance(model, v0_5.ModelDescr): + input_ids = [input_tensor.id for input_tensor in model.inputs] + else: + assert_never(model) + + logger.info(f"Model expects these inputs: {input_ids}.") + if len(input_ids) < 1: + logger.error("Model needs no input tensor. PlantSeg does not support this yet.") + if len(input_ids) > 1: + logger.error("Model needs more than one input tensor. PlantSeg does not support this yet.") + + tensor_id = input_ids[0] + axes = model.inputs[0].axes # PlantSeg only supports one input tensor for now + dims = tuple( + AxisId('channel') if item.lower() == 'c' else AxisId(item.lower()) for item in input_layout + ) # `AxisId` has to be "channel" not "c" + + if isinstance(axes[0], str): # then it's a <=0.4.10 model, `predict_sample_block` is not implemented + logger.warning( + "Model is older than 0.5.0. PlantSeg will try to run BioImage.IO core inference, but it is not supported by BioImage.IO core." + ) + axis_mapping = {'b': 'batch', 'c': 'channel'} + axes = [AxisId(axis_mapping.get(a, a)) for a in list(axes)] + members = {TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose([AxisId(a) for a in axes])} + sample = Sample(members=members, stat={}, id="raw") + sample_out = predict(model=model, inputs=sample) + + # If inference is supported by BioImage.IO core, this is how it should be done in PlantSeg: + # + # shape = model.inputs[0].shape + # input_block_shape = {TensorId(tensor_id): {AxisId(a): s for a, s in zip(axes, shape)}} + # sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape) + else: + members = { + TensorId(tensor_id): Tensor(array=raw, dims=dims).transpose( + [AxisId(a) if isinstance(a, str) else a.id for a in axes] + ) + } + sample = Sample(members=members, stat={}, id="raw") + sizes_in_rdf = {a.id: a.size for a in axes} + assert 'x' in sizes_in_rdf, "Model does not have 'x' axis in input tensor." + size_to_check = sizes_in_rdf[AxisId('x')] + if isinstance(size_to_check, int): # e.g. 'emotional-cricket' + # 'emotional-cricket' has {'batch': None, 'channel': 1, 'z': 100, 'y': 128, 'x': 128} + input_block_shape = { + TensorId(tensor_id): { + a.id: a.size if isinstance(a.size, int) else 1 + for a in axes + if not isinstance(a, str) # for a.size/a.id type checking only + } + } + sample_out = predict(model=model, inputs=sample, input_block_shape=input_block_shape) + elif isinstance(size_to_check, v0_5.ParameterizedSize): # e.g. 'philosophical-panda' + # 'philosophical-panda' has: + # {'z': ParameterizedSize(min=1, step=1), + # 'channel': 2, + # 'y': ParameterizedSize(min=16, step=16), + # 'x': ParameterizedSize(min=16, step=16)} + blocksize_parameter = { + (TensorId(tensor_id), a.id): ( + (96 - a.size.min) // a.size.step if isinstance(a.size, v0_5.ParameterizedSize) else 1 + ) + for a in axes + if not isinstance(a, str) # for a.size/a.id type checking only + } + sample_out = predict(model=model, inputs=sample, blocksize_parameter=blocksize_parameter) + else: + assert_never(size_to_check) + + assert isinstance(sample_out, Sample) + if len(sample_out.members) != 1: + logger.warning("Model has more than one output tensor. PlantSeg does not support this yet.") + + desired_axes_short = [AxisId(a) for a in ['b', 'c', 'z', 'y', 'x']] + desired_axes = [AxisId(a) for a in ['batch', 'channel', 'z', 'y', 'x']] + t = { + i: o.transpose(desired_axes_short) if 'b' in o.dims or 'c' in o.dims else o.transpose(desired_axes) + for i, o in sample_out.members.items() + } + + named_pmaps = {} + for key, tensor_bczyx in t.items(): + bczyx = tensor_bczyx.data.to_numpy() + assert bczyx.ndim == 5, f"Expected 5D BCZYX-transposed prediction from `bioimageio.core`, got {bczyx.ndim}D" + if bczyx.shape[0] == 1: + named_pmaps[f'{key}'] = bczyx[0] + else: + for b, czyx in enumerate(bczyx): + named_pmaps[f'{key}_{b}'] = czyx + return named_pmaps # list of CZYX arrays + + def unet_prediction( raw: np.ndarray, input_layout: ImageLayout, @@ -35,6 +145,10 @@ def unet_prediction( This function handles both single and multi-channel outputs from the model, returning appropriately shaped arrays based on the output channel configuration. + For Bioimage.IO Model Zoo models, weights are downloaded and loaded into `UNet3D` or `UNet2D` + in `plantseg.training.model`, i.e. `bioimageio.core` is not used. `biio_prediction()` uses + `bioimageio.core` for loading and running models. + Args: raw (np.ndarray): Raw input data. Input_layout (ImageLayout): The layout of the input data. diff --git a/plantseg/tasks/prediction_tasks.py b/plantseg/tasks/prediction_tasks.py index 4f185615..e84a7744 100644 --- a/plantseg/tasks/prediction_tasks.py +++ b/plantseg/tasks/prediction_tasks.py @@ -2,7 +2,7 @@ from plantseg.core.image import ImageLayout, PlantSegImage, SemanticType from plantseg.functionals.dataprocessing import fix_layout -from plantseg.functionals.prediction import unet_prediction +from plantseg.functionals.prediction import biio_prediction, unet_prediction from plantseg.tasks import task_tracker @@ -68,3 +68,32 @@ def unet_prediction_task( ) return new_images + + +@task_tracker +def biio_prediction_task( + image: PlantSegImage, + model_id: str, + suffix: str = "_prediction", +) -> list[PlantSegImage]: + data = image.get_data() + input_layout = image.image_layout.value + + named_pmaps = biio_prediction( + raw=data, + input_layout=input_layout, + model_id=model_id, + ) + + new_images = [] + for name, pmap in named_pmaps.items(): + # Input layout is always CZYX this loop + new_images.append( + image.derive_new( + pmap, + name=f"{image.name}_{suffix}_{name}", + semantic_type=SemanticType.PREDICTION, + image_layout='CZYX', + ) + ) + return new_images diff --git a/plantseg/viewer_napari/containers.py b/plantseg/viewer_napari/containers.py index bc342236..474b7cd7 100644 --- a/plantseg/viewer_napari/containers.py +++ b/plantseg/viewer_napari/containers.py @@ -1,4 +1,5 @@ from magicgui.widgets import Container +from qtpy.QtGui import QFont from plantseg.viewer_napari.widgets import ( widget_add_custom_model, @@ -31,6 +32,7 @@ ) STYLE_SLIDER = "font-size: 9pt;" +MONOSPACE_FONT = QFont("Courier New", 9) # "Courier New" is a common monospaced font def get_data_io_tab(): @@ -64,6 +66,7 @@ def get_preprocessing_tab(): def get_segmentation_tab(): + widget_unet_prediction.model_id.native.setFont(MONOSPACE_FONT) container = Container( widgets=[ widget_unet_prediction, diff --git a/plantseg/viewer_napari/viewer.py b/plantseg/viewer_napari/viewer.py index 0e9bc8ef..97c35629 100644 --- a/plantseg/viewer_napari/viewer.py +++ b/plantseg/viewer_napari/viewer.py @@ -23,7 +23,8 @@ def run_viewer(): (get_postprocessing_tab(), 'Postprocessing'), (get_proofreading_tab(), 'Proofreading'), ]: - viewer.window.add_dock_widget(_containers, name=name, tabify=True) + this_widget = viewer.window.add_dock_widget(_containers, name=name, tabify=True) + this_widget.setFixedWidth(666) # Show data tab by default viewer.window._dock_widgets['Input/Output'].show() diff --git a/plantseg/viewer_napari/widgets/dataprocessing.py b/plantseg/viewer_napari/widgets/dataprocessing.py index 473c1403..129762a8 100644 --- a/plantseg/viewer_napari/widgets/dataprocessing.py +++ b/plantseg/viewer_napari/widgets/dataprocessing.py @@ -129,8 +129,15 @@ def widget_cropping( ) +initialised_widget_cropping: bool = ( + False # Avoid throwing an error when the first image is loaded but its layout is not supported +) + + @widget_cropping.image.changed.connect def _on_cropping_image_changed(image: Layer): + global initialised_widget_cropping + if image is None: widget_cropping.crop_z.hide() return None @@ -145,7 +152,10 @@ def _on_cropping_image_changed(image: Layer): return None if ps_image.is_multichannel: - raise ValueError("Multichannel images are not supported for cropping.") + if initialised_widget_cropping: + raise ValueError("Multichannel images are not supported for cropping.") + else: + initialised_widget_cropping = True widget_cropping.crop_z.show() image_shape_z = ps_image.shape[0] diff --git a/plantseg/viewer_napari/widgets/io.py b/plantseg/viewer_napari/widgets/io.py index 7a3d689d..f9bedfa9 100644 --- a/plantseg/viewer_napari/widgets/io.py +++ b/plantseg/viewer_napari/widgets/io.py @@ -14,6 +14,7 @@ from plantseg.tasks.io_tasks import export_image_task, import_image_task from plantseg.tasks.workflow_handler import workflow_handler from plantseg.viewer_napari import log +from plantseg.viewer_napari.widgets.prediction import widget_unet_prediction from plantseg.viewer_napari.widgets.utils import _return_value_if_widget, schedule_task current_dataset_keys: list[str] | None = None @@ -101,7 +102,10 @@ def widget_open_file( elif layer_type == ImageType.LABEL.value: semantic_type = SemanticType.SEGMENTATION - widgets_to_update = [widget_set_voxel_size.layer] + widgets_to_update = [ + widget_set_voxel_size.layer, + widget_unet_prediction.image, + ] return schedule_task( import_image_task, diff --git a/plantseg/viewer_napari/widgets/prediction.py b/plantseg/viewer_napari/widgets/prediction.py index e68ad581..311dd5fa 100644 --- a/plantseg/viewer_napari/widgets/prediction.py +++ b/plantseg/viewer_napari/widgets/prediction.py @@ -12,7 +12,7 @@ from plantseg.core.image import PlantSegImage from plantseg.core.zoo import model_zoo -from plantseg.tasks.prediction_tasks import unet_prediction_task +from plantseg.tasks.prediction_tasks import biio_prediction_task, unet_prediction_task from plantseg.viewer_napari import log from plantseg.viewer_napari.widgets.proofreading import widget_split_and_merge_from_scribbles from plantseg.viewer_napari.widgets.segmentation import widget_agglomeration, widget_dt_ws @@ -106,6 +106,7 @@ def to_choices(cls): 'label': 'BioImage.IO model', 'tooltip': 'Select a model from BioImage.IO model zoo.', 'choices': model_zoo.get_bioimageio_zoo_plantseg_model_names(), + 'value': model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1], }, advanced={ 'label': 'Show advanced parameters', @@ -128,45 +129,59 @@ def to_choices(cls): ) def widget_unet_prediction( image: Image, - plantseg_filter: bool = True, mode: UNetPredictionMode = UNetPredictionMode.PLANTSEG, + plantseg_filter: bool = True, model_name: Optional[str] = None, - model_id: Optional[str] = None, + model_id: Optional[str] = model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1], device: str = ALL_DEVICES[0], advanced: bool = False, patch_size: tuple[int, int, int] = (128, 128, 128), patch_halo: tuple[int, int, int] = (0, 0, 0), single_patch: bool = False, ) -> None: + ps_image = PlantSegImage.from_napari_layer(image) + if mode is UNetPredictionMode.PLANTSEG: suffix = model_name model_id = None + widgets_to_update = [ + widget_dt_ws.image, + widget_agglomeration.image, + widget_split_and_merge_from_scribbles.image, + ] + return schedule_task( + unet_prediction_task, + task_kwargs={ + "image": ps_image, + "model_name": model_name, + "model_id": model_id, + "suffix": suffix, + "patch": patch_size if advanced else None, + "patch_halo": patch_halo if advanced else None, + "single_batch_mode": single_patch if advanced else False, + "device": device, + }, + widgets_to_update=widgets_to_update, + ) elif mode is UNetPredictionMode.BIOIMAGEIO: suffix = model_id model_name = None + widgets_to_update = [ + # BioImage.IO models may output multi-channel 3D image or even multi-channel scalar in CZYX format. + # So PlantSeg widgets, which all take ZYX or YX, are better not to be updated. + ] + return schedule_task( + biio_prediction_task, + task_kwargs={ + "image": ps_image, + "model_id": model_id, + "suffix": suffix, + }, + widgets_to_update=widgets_to_update, + ) else: raise NotImplementedError(f'Mode {mode} not implemented yet.') - ps_image = PlantSegImage.from_napari_layer(image) - return schedule_task( - unet_prediction_task, - task_kwargs={ - "image": ps_image, - "model_name": model_name, - "model_id": model_id, - "suffix": suffix, - "patch": patch_size if advanced else None, - "patch_halo": patch_halo if advanced else None, - "single_batch_mode": single_patch if advanced else False, - "device": device, - }, - widgets_to_update=[ - widget_dt_ws.image, - widget_agglomeration.image, - widget_split_and_merge_from_scribbles.image, - ], - ) - widget_unet_prediction.insert(3, model_filters) @@ -201,17 +216,11 @@ def update_halo(): widget_unet_prediction.patch_size[0].enabled = True widget_unet_prediction.patch_halo[0].enabled = True elif widget_unet_prediction.mode.value is UNetPredictionMode.BIOIMAGEIO: - widget_unet_prediction.patch_halo.value = model_zoo.compute_3D_halo_for_bioimageio_models( - widget_unet_prediction.model_id.value + log( + 'Automatic halo not implemented for BioImage.IO models yet because they are handled by BioImage.IO Core.', + thread='BioImage.IO Core prediction', + level='info', ) - if model_zoo.is_2D_bioimageio_model(widget_unet_prediction.model_id.value): - widget_unet_prediction.patch_size[0].value = 0 - widget_unet_prediction.patch_size[0].enabled = False - widget_unet_prediction.patch_halo[0].enabled = False - else: - widget_unet_prediction.patch_size[0].value = widget_unet_prediction.patch_size[1].value - widget_unet_prediction.patch_size[0].enabled = True - widget_unet_prediction.patch_halo[0].enabled = True else: raise NotImplementedError(f'Automatic halo not implemented for {widget_unet_prediction.mode.value} mode.') @@ -261,7 +270,7 @@ def _on_widget_unet_prediction_plantseg_filter_change(plantseg_filter: bool): else: widget_unet_prediction.model_id.choices = ( model_zoo.get_bioimageio_zoo_plantseg_model_names() - + [Separator] + + [('', Separator)] # `[('', Separator)]` for list[tuple[str, str]], [Separator] for list[str] + model_zoo.get_bioimageio_zoo_other_model_names() ) diff --git a/plantseg/viewer_napari/widgets/segmentation.py b/plantseg/viewer_napari/widgets/segmentation.py index a4e5ccef..5e2868ae 100644 --- a/plantseg/viewer_napari/widgets/segmentation.py +++ b/plantseg/viewer_napari/widgets/segmentation.py @@ -219,8 +219,15 @@ def _on_show_advanced_changed(state: bool): widget.hide() +initialised_widget_dt_ws: bool = ( + False # Avoid throwing an error when the first image is loaded but its layout is not supported +) + + @widget_dt_ws.image.changed.connect def _on_image_changed(image: Image): + global initialised_widget_dt_ws + ps_image = PlantSegImage.from_napari_layer(image) if ps_image.image_layout == ImageLayout.ZYX: @@ -229,4 +236,7 @@ def _on_image_changed(image: Image): widget_dt_ws.stacked.hide() widget_dt_ws.stacked.value = False if ps_image.image_layout != ImageLayout.YX: - log(f"Unsupported image layout: {ps_image.image_layout}", thread="DT Watershed", level="error") + if initialised_widget_dt_ws: + log(f"Unsupported image layout: {ps_image.image_layout}", thread="DT Watershed", level="error") + else: + initialised_widget_dt_ws = True diff --git a/setup.py b/setup.py index 9fde15cc..7235b4dc 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,12 @@ 'plantseg': ['resources/logo_white.png'], }, description='PlantSeg is a tool for cell instance aware segmentation in densely packed 3D volumetric images.', - author='Lorenzo Cerrone, Adrian Wolny', + author='Lorenzo Cerrone, Adrian Wolny, Qin Yu', url='https://github.com/kreshuklab/plant-seg', - author_email='lorenzo.cerrone@iwr.uni-heidelberg.de', + author_email='lorenzo.cerrone@uzh.ch, qin.yu@embl.de', + entry_points={ + 'console_scripts': [ + 'plantseg=plantseg.run_plantseg:main', + ], + }, ) diff --git a/tests/conftest.py b/tests/conftest.py index d5076888..6567c858 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,15 +3,42 @@ import shutil from pathlib import Path +import numpy as np +import pooch import pytest +import skimage.transform as skt import torch import yaml +from plantseg.io.io import smart_load + TEST_FILES = Path(__file__).resolve().parent / "resources" VOXEL_SIZE = (0.235, 0.15, 0.15) KEY_ZARR = "volumes/new" IS_CUDA_AVAILABLE = torch.cuda.is_available() +CELLPOSE_TEST_IMAGE_RGB_3D = 'http://www.cellpose.org/static/data/rgb_3D.tif' + + +@pytest.fixture +def raw_zcyx_75x2x75x75(tmpdir) -> np.ndarray: + path_rgb_3d_75x2x75x75 = Path(pooch.retrieve(CELLPOSE_TEST_IMAGE_RGB_3D, path=tmpdir, known_hash=None)) + return smart_load(path_rgb_3d_75x2x75x75) + + +@pytest.fixture +def raw_zcyx_96x2x96x96(raw_zcyx_75x2x75x75): + return skt.resize(raw_zcyx_75x2x75x75, (96, 2, 96, 96), order=1) + + +@pytest.fixture +def raw_cell_3d_100x128x128(raw_zcyx_75x2x75x75): + return skt.resize(raw_zcyx_75x2x75x75[:, 1], (100, 128, 128), order=1) + + +@pytest.fixture +def raw_cell_2d_96x96(raw_cell_3d_100x128x128): + return raw_cell_3d_100x128x128[48] @pytest.fixture diff --git a/tests/core/test_zoo.py b/tests/core/test_zoo.py index 1f5dc757..8a1b92f5 100644 --- a/tests/core/test_zoo.py +++ b/tests/core/test_zoo.py @@ -47,6 +47,8 @@ def test_model_output_normalisation(self, model_name): class TestBioImageIOModelZoo: """Test the BioImage.IO model zoo""" + model_zoo.refresh_bioimageio_zoo_urls() + @pytest.mark.parametrize("model_id", MODEL_IDS) def test_get_model_by_id(self, model_id): """Try to load a model from the BioImage.IO model zoo by ID.""" diff --git a/tests/functionals/prediction/test_prediction_biio.py b/tests/functionals/prediction/test_prediction_biio.py new file mode 100644 index 00000000..7334e280 --- /dev/null +++ b/tests/functionals/prediction/test_prediction_biio.py @@ -0,0 +1,18 @@ +import pytest + +from plantseg.functionals.prediction.prediction import biio_prediction + + +@pytest.mark.parametrize( + "raw_fixture_name, input_layout, model_id", + ( + ('raw_zcyx_96x2x96x96', 'ZCYX', 'philosophical-panda'), + ('raw_cell_3d_100x128x128', 'ZYX', 'emotional-cricket'), + ('raw_cell_2d_96x96', 'YX', 'pioneering-rhino'), + ), +) +def test_biio_prediction(raw_fixture_name, input_layout, model_id, request): + named_pmaps = biio_prediction(request.getfixturevalue(raw_fixture_name), input_layout, model_id) + for key, pmap in named_pmaps.items(): + assert pmap is not None, f"Prediction map for {key} is None" + assert pmap.ndim == 4, f"Prediction map for {key} has {pmap.ndim} dimensions" diff --git a/tests/tasks/test_prediction_tasks.py b/tests/tasks/test_prediction_tasks.py index b3cad4ea..b4d8d2eb 100644 --- a/tests/tasks/test_prediction_tasks.py +++ b/tests/tasks/test_prediction_tasks.py @@ -3,7 +3,7 @@ from plantseg.core.image import ImageLayout, ImageProperties, PlantSegImage, SemanticType from plantseg.io.voxelsize import VoxelSize -from plantseg.tasks.prediction_tasks import unet_prediction_task +from plantseg.tasks.prediction_tasks import biio_prediction_task, unet_prediction_task @pytest.mark.parametrize( @@ -13,7 +13,7 @@ ((64, 64), ImageLayout.YX, 'confocal_2D_unet_ovules_ds2x'), ], ) -def test_unet_prediction(shape, layout, model_name): +def test_unet_prediction_task(shape, layout, model_name): mock_data = np.random.rand(*shape).astype('float32') property = ImageProperties( @@ -25,7 +25,12 @@ def test_unet_prediction(shape, layout, model_name): ) image = PlantSegImage(data=mock_data, properties=property) - result = unet_prediction_task(image=image, model_name=model_name, model_id=None, device='cpu') + result = unet_prediction_task( + image=image, + model_name=model_name, + model_id=None, + device='cpu', + ) assert len(result) == 1 result = result[0] @@ -34,3 +39,32 @@ def test_unet_prediction(shape, layout, model_name): assert result.image_layout == property.image_layout assert result.voxel_size == property.voxel_size assert result.shape == mock_data.shape + + +@pytest.mark.parametrize( + "raw_fixture_name, input_layout, model_id", + ( + ('raw_zcyx_96x2x96x96', 'ZCYX', 'philosophical-panda'), + ('raw_cell_3d_100x128x128', 'ZYX', 'emotional-cricket'), + ('raw_cell_2d_96x96', 'YX', 'pioneering-rhino'), + ), +) +def test_biio_prediction_task(raw_fixture_name, input_layout, model_id, request): + image = PlantSegImage( + data=request.getfixturevalue(raw_fixture_name), + properties=ImageProperties( + name='test', + voxel_size=VoxelSize(voxels_size=(1.0, 1.0, 1.0), unit='um'), + semantic_type=SemanticType.RAW, + image_layout=input_layout, + original_voxel_size=VoxelSize(voxels_size=(1.0, 1.0, 1.0), unit='um'), + ), + ) + result = biio_prediction_task( + image=image, + model_id=model_id, + suffix="_biio_prediction", + ) + for new_image in result: + assert new_image.semantic_type == SemanticType.PREDICTION + assert '_biio_prediction' in new_image.name diff --git a/tests/widgets/test_widget_open_file.py b/tests/widgets/test_widget_open_file.py index dc5a48ca..e0300986 100644 --- a/tests/widgets/test_widget_open_file.py +++ b/tests/widgets/test_widget_open_file.py @@ -1,11 +1,17 @@ +import os + import napari import numpy as np +import pytest from plantseg.io.h5 import create_h5 from plantseg.io.voxelsize import VoxelSize from plantseg.viewer_napari.widgets.io import PathMode, widget_open_file +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" # set to true in GitHub Actions by default to skip CUDA tests + +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="GUI tests hangs in GitHub Actions.") def test_widget_open_file(make_napari_viewer_proxy, path_h5): viewer = make_napari_viewer_proxy() shape = (10, 10, 10) diff --git a/tests/widgets/test_widget_preprocessing.py b/tests/widgets/test_widget_preprocessing.py index 6b1f8411..0173563e 100644 --- a/tests/widgets/test_widget_preprocessing.py +++ b/tests/widgets/test_widget_preprocessing.py @@ -1,3 +1,5 @@ +import os + import napari import numpy as np import pytest @@ -8,6 +10,8 @@ from plantseg.io.voxelsize import VoxelSize from plantseg.viewer_napari.widgets.dataprocessing import RescaleModes, widget_rescaling +IN_GITHUB_ACTIONS = os.getenv("GITHUB_ACTIONS") == "true" # set to true in GitHub Actions by default to skip CUDA tests + def create_layer_name(name: str, suffix: str): return f"{name}_{suffix}" @@ -47,6 +51,7 @@ def widget_add_image(image: PlantSegImage) -> LayerDataTuple: return image.to_napari_layer_tuple() +@pytest.mark.skipif(IN_GITHUB_ACTIONS, reason="GUI tests hangs in GitHub Actions.") class TestWidgetRescaling: def test_rescaling_from_factor(self, make_napari_viewer_proxy, sample_image): viewer = make_napari_viewer_proxy()