diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml new file mode 100644 index 00000000..8a9b5bdc --- /dev/null +++ b/.github/workflows/build.yaml @@ -0,0 +1,159 @@ +name: Test and Deploy bioimageio.core + +on: + push: + branches: [ main ] + pull_request: + branches: [ "**" ] + +defaults: + run: + shell: micromamba-shell {0} + +jobs: + black: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: psf/black@stable + with: + options: "--check --verbose" + src: "." + jupyter: true + version: "24.3" + + test-spec-conda: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12'] + steps: + - uses: actions/checkout@v4 + - name: Install Conda environment with Micromamba + if: matrix.python-version != '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-wo-python.yaml + create-args: >- + python=${{ matrix.python-version }} + post-cleanup: 'all' + - name: Install py3.8 environment + if: matrix.python-version == '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-py38.yaml + post-cleanup: 'all' + - name: additional setup + run: pip install --no-deps -e . + - name: pytest-spec-conda + run: pytest --disable-pytest-warnings + + test-spec-main: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.8', '3.12'] + steps: + - uses: actions/checkout@v4 + - name: Install Conda environment with Micromamba + if: matrix.python-version != '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-wo-python.yaml + create-args: >- + python=${{ matrix.python-version }} + post-cleanup: 'all' + - name: Install py3.8 environment + if: matrix.python-version == '3.8' + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-py38.yaml + post-cleanup: 'all' + - name: additional setup + run: | + conda remove --yes --force bioimageio.spec || true # allow failure for cached env + pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io + pip install --no-deps -e . + - name: pytest-spec-main + run: pytest --disable-pytest-warnings + + test-tf: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ['3.9', '3.11'] + steps: + - uses: actions/checkout@v4 + - name: Install Conda environment with Micromamba + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-file: dev/env-tf.yaml + condarc: | + channel-priority: flexible + create-args: >- + python=${{ matrix.python-version }} + post-cleanup: 'all' + - name: additional setup + run: pip install --no-deps -e . + - name: pytest-spec-tf + run: pytest --disable-pytest-warnings + + conda-build: + runs-on: ubuntu-latest + needs: test-spec-conda + steps: + - name: checkout + uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Install Conda environment with Micromamba + uses: mamba-org/setup-micromamba@v1 + with: + cache-downloads: true + cache-environment: true + environment-name: build-env + condarc: | + channels: + - conda-forge + create-args: | + boa + - name: linux conda build + run: | + conda mambabuild -c conda-forge conda-recipe + + docs: + if: github.ref == 'refs/heads/main' + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: '3.12' + cache: 'pip' + - run: pip install -e .[dev] + - id: get_version + run: python -c 'import bioimageio.core;print(f"version={bioimageio.core.__version__}")' >> $GITHUB_OUTPUT + - name: Generate developer docs + run: | + pdoc \ + --logo https://bioimage.io/static/img/bioimage-io-logo.svg \ + --logo-link https://bioimage.io/ \ + --favicon https://bioimage.io/static/img/bioimage-io-icon-small.svg \ + --footer-text 'bioimageio.core ${{steps.get_version.outputs.version}}' \ + -o ./dist bioimageio.core + - run: cp README.md ./dist/README.md + - name: Deploy to gh-pages 🚀 + uses: JamesIves/github-pages-deploy-action@v4 + with: + branch: gh-pages + folder: dist diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml deleted file mode 100644 index 76a06b5d..00000000 --- a/.github/workflows/build.yml +++ /dev/null @@ -1,140 +0,0 @@ -name: Test and Deploy bioimageio.core - -on: - push: - branches: [ main ] - pull_request: - branches: [ "**" ] - -defaults: - run: - shell: bash -l {0} - -jobs: - black: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Check files using the black formatter - uses: rickstaa/action-black@v1 - id: action_black - with: - black_args: "." - - name: Annotate diff changes using reviewdog - if: steps.action_black.outputs.is_formatted == 'true' - uses: reviewdog/action-suggester@v1 - with: - tool_name: blackfmt - - test-spec-conda: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9] - steps: - - uses: actions/checkout@v3 - - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - cache-env: true - environment-file: dev/environment-torch.yaml - extra-specs: | - python=${{ matrix.python-version }} - - name: additional setup - run: pip install --no-deps -e . - - name: pytest-spec-conda - run: pytest --disable-pytest-warnings - - test-spec-main: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9] - steps: - - uses: actions/checkout@v3 - - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - cache-env: true - environment-file: dev/environment-torch.yaml - extra-specs: | - python=${{ matrix.python-version }} - - name: additional setup - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - pip install --no-deps -e . - - name: pytest-spec-main - run: pytest --disable-pytest-warnings - - test-spec-tf: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7, 3.8, 3.9] - steps: - - uses: actions/checkout@v3 - - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - cache-env: true - environment-file: dev/environment-tf.yaml - channel-priority: flexible - extra-specs: | - python=${{ matrix.python-version }} - - name: additional setup - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - pip install --no-deps -e . - - name: pytest-spec-tf - run: pytest --disable-pytest-warnings - - test-spec-tf-legacy: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [3.7] - steps: - - uses: actions/checkout@v3 - - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - cache-env: true - environment-file: dev/environment-tf-legacy.yaml - channel-priority: flexible - extra-specs: | - python=${{ matrix.python-version }} - - name: additional setup - run: | - conda remove --yes --force bioimageio.spec || true # allow failure for cached env - pip install --no-deps git+https://github.com/bioimage-io/spec-bioimage-io - pip install --no-deps -e . - - name: pytest-spec-tf-legacy - run: pytest --disable-pytest-warnings - - conda-build: - runs-on: ubuntu-latest - needs: test-spec-conda - steps: - - name: checkout - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@main - with: - cache-downloads: true - cache-env: true - environment-file: false - environment-name: build-env - channels: conda-forge - extra-specs: | - boa - - name: linux conda build - run: | - conda mambabuild -c conda-forge conda-recipe diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 60223279..73d9b263 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -35,7 +35,7 @@ jobs: - name: Check if there is a parent commit id: check-parent-commit run: | - echo "::set-output name=sha::$(git rev-parse --verify --quiet HEAD^)" + echo "sha=$(git rev-parse --verify --quiet HEAD^)" >> $GITHUB_OUTPUT - name: Detect new version id: check-version diff --git a/.gitignore b/.gitignore index c75e8e7a..a603dade 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ -build/ -dist/ .idea/ -*.egg-info/ -cache -**/tmp .tox/ +*.egg-info/ *.pyc +**/tmp +build/ +cache +dist/ +docs/ +typings/pooch/ diff --git a/.markdownlint.json b/.markdownlint.json new file mode 100644 index 00000000..e3494375 --- /dev/null +++ b/.markdownlint.json @@ -0,0 +1,8 @@ +{ + "default": true, + "MD013": { + "line_length": 120 + }, + "MD033": false, + "MD041": false +} \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f4256fb2..ef0eba58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,19 @@ repos: - repo: https://github.com/ambv/black - rev: 23.1.0 + rev: 24.2.0 hooks: - - id: black + - id: black-jupyter + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.2 + hooks: + - id: ruff + args: [--fix] + - repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: system + always_run: true + pass_filenames: true + files: ^.*\.py$ diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..64b5cecc --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,15 @@ +{ + "window.title": "bioimageio.core", + "python.analysis.extraPaths": [ + "../spec-bioimage-io", + ], + "python.testing.unittestArgs": [ + "-v", + "-s", + "./tests", + "-p", + "test_*.py" + ], + "python.testing.pytestEnabled": true, + "python.testing.unittestEnabled": false, +} diff --git a/MANIFEST.in b/MANIFEST.in index 031d8dc7..e1d35f13 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,3 @@ include bioimageio/core/VERSION +include README.md +include LICENSE diff --git a/README.md b/README.md index da9e1d81..dd76c085 100644 --- a/README.md +++ b/README.md @@ -4,92 +4,101 @@ Python specific core utilities for running models in the [BioImage Model Zoo](ht ## Installation -### Via Conda +### Via Mamba/Conda The `bioimageio.core` package can be installed from conda-forge via -``` -conda install -c conda-forge bioimageio.core +```console +mamba install -c conda-forge bioimageio.core ``` -if you don't install any additional deep learning libraries, you will only be able to use general convenience functionality, but not any functionality for model prediction. +If you do not install any additional deep learning libraries, you will only be able to use general convenience +functionality, but not any functionality for model prediction. To install additional deep learning libraries use: * Pytorch/Torchscript: - ```bash - # cpu installation (if you don't have an nvidia graphics card) - conda install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly + CPU installation (if you don't have an nvidia graphics card): + + ```console + mamba install -c pytorch -c conda-forge bioimageio.core pytorch torchvision cpuonly + ``` + + GPU installation (for cuda 11.6, please choose the appropriate cuda version for your system): - # gpu installation (for cuda 11.6, please choose the appropriate cuda version for your system) - conda install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.6 + ```console + mamba install -c pytorch -c nvidia -c conda-forge bioimageio.core pytorch torchvision pytorch-cuda=11.8 ``` Note that the pytorch installation instructions may change in the future. For the latest instructions please refer to [pytorch.org](https://pytorch.org/). * Tensorflow - ```bash - # currently only cpu version supported - conda install -c conda-forge bioimageio.core tensorflow + Currently only CPU version supported + + ```console + mamba install -c conda-forge bioimageio.core tensorflow ``` * ONNXRuntime - ```bash - # currently only cpu version supported - conda install -c conda-forge bioimageio.core onnxruntime + Currently only cpu version supported + + ```console + mamba install -c conda-forge bioimageio.core onnxruntime ``` ### Via pip -The package is also available via pip: +The package is also available via pip +(e.g. with recommended extras `onnx` and `pytorch`): -``` -pip install bioimageio.core +```console +pip install bioimageio.core[onnx,pytorch] ``` ### Set up Development Environment To set up a development conda environment run the following commands: -``` -conda env create -f dev/environment-base.yaml -conda activate bio-core-dev +```console +mamba env create -f dev/env.yaml +mamba activate core pip install -e . --no-deps ``` There are different environment files that only install tensorflow or pytorch as dependencies available. -## Command Line +## 💻 Command Line -`bioimageio.core` installs a command line interface for testing models and other functionality. You can list all the available commands via: +`bioimageio.core` installs a command line interface (CLI) for testing models and other functionality. +You can list all the available commands via: -``` +```console bioimageio ``` Check that a model adheres to the model spec: -``` +```console bioimageio validate ``` Test a model (including prediction for the test input): -``` +```console bioimageio test-model ``` Run prediction for an image stored on disc: -``` +```console bioimageio predict-image --inputs --outputs ``` Run prediction for multiple images stored on disc: -``` +```console bioimagei predict-images -m -i - o ``` diff --git a/bioimageio/core/VERSION b/bioimageio/core/VERSION index 167d3d30..424d6096 100644 --- a/bioimageio/core/VERSION +++ b/bioimageio/core/VERSION @@ -1,3 +1,3 @@ { - "version": "0.5.11" + "version": "0.6.0" } diff --git a/bioimageio/core/__init__.py b/bioimageio/core/__init__.py index b47ac27d..27794eb6 100644 --- a/bioimageio/core/__init__.py +++ b/bioimageio/core/__init__.py @@ -1,15 +1,36 @@ -import json -import pathlib +""" +.. include:: ../../README.md +""" -__version__ = json.loads((pathlib.Path(__file__).parent / "VERSION").read_text())["version"] +from bioimageio.spec import build_description as build_description +from bioimageio.spec import dump_description as dump_description +from bioimageio.spec import load_description as load_description +from bioimageio.spec import ( + load_description_and_validate_format_only as load_description_and_validate_format_only, +) +from bioimageio.spec import save_bioimageio_package as save_bioimageio_package +from bioimageio.spec import ( + save_bioimageio_package_as_folder as save_bioimageio_package_as_folder, +) +from bioimageio.spec import save_bioimageio_yaml_only as save_bioimageio_yaml_only +from bioimageio.spec import validate_format as validate_format -from .resource_io import ( - export_resource_package, - load_raw_resource_description, - load_resource_description, - save_raw_resource_description, - serialize_raw_resource_description, +from ._prediction_pipeline import PredictionPipeline as PredictionPipeline +from ._prediction_pipeline import ( + create_prediction_pipeline as create_prediction_pipeline, ) -from .prediction_pipeline import create_prediction_pipeline -from .prediction import predict_image, predict_images, predict_with_padding, predict_with_tiling -from .resource_tests import check_input_shape, check_output_shape, test_resource +from ._resource_tests import load_description_and_test as load_description_and_test +from ._resource_tests import test_description as test_description +from ._resource_tests import test_model as test_model +from ._settings import settings as settings +from .axis import Axis as Axis +from .axis import AxisId as AxisId +from .block_meta import BlockMeta as BlockMeta +from .common import MemberId as MemberId +from .sample import Sample as Sample +from .tensor import Tensor as Tensor +from .utils import VERSION + +__version__ = VERSION + +test_resource = test_description diff --git a/bioimageio/core/__main__.py b/bioimageio/core/__main__.py index a26f43d1..6c944725 100644 --- a/bioimageio/core/__main__.py +++ b/bioimageio/core/__main__.py @@ -1,314 +1,227 @@ -import enum -import json -import os import sys -import warnings -from glob import glob - from pathlib import Path -from pprint import pformat, pprint -from typing import List, Optional - -import typer - -from bioimageio.core import __version__, prediction, commands, resource_tests, load_raw_resource_description -from bioimageio.core.common import TestSummary -from bioimageio.core.prediction_pipeline import get_weight_formats -from bioimageio.spec.__main__ import app, help_version as help_version_spec -from bioimageio.spec.model.raw_nodes import WeightsFormat - -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - -try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from bioimageio.core.weight_converter import torch as torch_converter -except ImportError: - torch_converter = None - -try: - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - from bioimageio.core.weight_converter import keras as keras_converter -except ImportError: - keras_converter = None - - -# extend help/version string by core version -help_version_core = f"bioimageio.core {__version__}" -help_version = f"{help_version_spec}\n{help_version_core}" -# prevent rewrapping with \b\n: https://click.palletsprojects.com/en/7.x/documentation/#preventing-rewrapping -app.info.help = "\b\n" + help_version - - -@app.callback() -def callback(): - typer.echo(help_version) - - -@app.command() -def package( - rdf_source: str = typer.Argument(..., help="RDF source as relative file path or URI"), - path: Path = typer.Argument(Path() / "{src_name}-package.zip", help="Save package as"), - weights_priority_order: Optional[List[str]] = typer.Option( - None, - "--weights-priority-order", - "-wpo", - help="For model packages only. " - "If given only the first weights matching the given weight formats are included. " - "Defaults to include all weights present in source.", - show_default=False, - ), - verbose: bool = typer.Option(False, help="show traceback of exceptions"), -): - # typer bug: typer returns empty tuple instead of None if weights_order_priority is not given - weights_priority_order = weights_priority_order or None - - ret_code = commands.package( - rdf_source=rdf_source, path=path, weights_priority_order=weights_priority_order, verbose=verbose - ) - sys.exit(ret_code) - - -package.__doc__ = commands.package.__doc__ - - -# if we want to use something like "choice" for the weight formats, we need to use an enum, see: -# https://github.com/tiangolo/typer/issues/182 -WeightFormatEnum = enum.Enum("WeightFormatEnum", {wf: wf for wf in get_args(WeightsFormat)}) -# Enum with in values does not work with click.Choice: https://github.com/pallets/click/issues/784 -# so a simple Enum with auto int values is not an option: -# WeightFormatEnum = enum.Enum("WeightFormatEnum", get_args(WeightsFormat)) - - -def _log_test_summaries(summaries: List[TestSummary], msg: str): - # todo: improve logging of multiple test summaries - ret_code = 0 - for summary in summaries: - print(f"{summary['name']}: {summary['status']}") - if summary["status"] != "passed": - s = { - k: v - for k, v in summary.items() - if k not in ("name", "status", "bioimageio_spec_version", "bioimageio_core_version") - } - tb = s.pop("traceback") - if tb: - print("traceback:") - print("".join(tb)) - - def show_part(part, show): - if show: - line = f"{part}: " - print(line + pformat(show, width=min(80, 120 - len(line))).replace("\n", " " * len(line) + "\n")) - - for part in ["error", "warnings", "source_name"]: - show_part(part, s.pop(part, None)) - - for part in sorted(s.keys()): - show_part(part, s[part]) - - ret_code = 1 - - if ret_code: - result = "FAILED!" - icon = "❌" - else: - result = "passed." - icon = "✔️" - - print(msg.format(icon=icon, result=result)) - return ret_code - - -@app.command() -def test_model( - model_rdf: str = typer.Argument( - ..., help="Path or URL to the model resource description file (rdf.yaml) or zipped model." - ), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), - decimal: int = typer.Option(4, help="The test precision."), -): - # this is a weird typer bug: default devices are empty tuple although they should be None - devices = devices or None - - summaries = resource_tests.test_model( - model_rdf, - weight_format=None if weight_format is None else weight_format.value, - devices=devices, - decimal=decimal, - ) - print(f"\ntesting model {model_rdf}...") - ret_code = _log_test_summaries(summaries, f"\n{{icon}} Model {model_rdf} {{result}}") - sys.exit(ret_code) - - -test_model.__doc__ = resource_tests.test_model.__doc__ - - -@app.command() -def test_resource( - rdf: str = typer.Argument( - ..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package." - ), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="(for model only) The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="(for model only) Devices for running the model."), - decimal: int = typer.Option(4, help="(for model only) The test precision."), -): - # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: - devices = None - summaries = resource_tests.test_resource( - rdf, weight_format=None if weight_format is None else weight_format.value, devices=devices, decimal=decimal - ) - print(f"\ntesting {rdf}...") - ret_code = _log_test_summaries(summaries, f"{{icon}} Resource test for {rdf} has {{result}}") - sys.exit(ret_code) - - -test_resource.__doc__ = resource_tests.test_resource.__doc__ - - -@app.command() -def predict_image( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - inputs: List[Path] = typer.Option(..., help="Path(s) to the model input(s)."), - outputs: List[Path] = typer.Option(..., help="Path(s) for saveing the model output(s)."), - # NOTE: typer currently doesn't support union types, so we only support boolean here - # padding: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - # tiling: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), - tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), -): - - if isinstance(padding, str): - padding = json.loads(padding.replace("'", '"')) - assert isinstance(padding, dict) - if isinstance(tiling, str): - tiling = json.loads(tiling.replace("'", '"')) - assert isinstance(tiling, dict) - - # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: - devices = None - prediction.predict_image( - model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices - ) - - -predict_image.__doc__ = prediction.predict_image.__doc__ - - -@app.command() -def predict_images( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - input_pattern: str = typer.Argument(..., help="Glob pattern for the input images."), - output_folder: str = typer.Argument(..., help="Folder to save the outputs."), - output_extension: Optional[str] = typer.Argument(None, help="Optional output extension."), - # NOTE: typer currently doesn't support union types, so we only support boolean here - # padding: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - # tiling: Optional[Union[str, bool]] = typer.Argument( - # None, help="Padding to apply in each dimension passed as json encoded string." - # ), - padding: Optional[bool] = typer.Option(None, help="Whether to pad the image to a size suited for the model."), - tiling: Optional[bool] = typer.Option(None, help="Whether to run prediction in tiling mode."), - weight_format: Optional[WeightFormatEnum] = typer.Option(None, help="The weight format to use."), - devices: Optional[List[str]] = typer.Option(None, help="Devices for running the model."), -): - input_files = glob(input_pattern) - input_names = [os.path.split(infile)[1] for infile in input_files] - output_files = [os.path.join(output_folder, fname) for fname in input_names] - if output_extension is not None: - output_files = [f"{os.path.splitext(outfile)[0]}{output_extension}" for outfile in output_files] - - if isinstance(padding, str): - padding = json.loads(padding.replace("'", '"')) - assert isinstance(padding, dict) - if isinstance(tiling, str): - tiling = json.loads(tiling.replace("'", '"')) - assert isinstance(tiling, dict) - - # this is a weird typer bug: default devices are empty tuple although they should be None - if len(devices) == 0: - devices = None - prediction.predict_images( - model_rdf, - input_files, - output_files, - padding=padding, - tiling=tiling, - weight_format=None if weight_format is None else weight_format.value, - devices=devices, - verbose=True, - ) - - -predict_images.__doc__ = prediction.predict_images.__doc__ - - -if torch_converter is not None: - - @app.command() - def convert_torch_weights_to_onnx( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the onnx weights."), - opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."), - use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), - verbose: bool = typer.Option(True, help="Verbosity"), - ): - ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose) - sys.exit(ret_code) - - convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__ - - @app.command() - def convert_torch_weights_to_torchscript( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), - use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), - ): - ret_code = torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) - sys.exit(ret_code) +from typing import List, Optional, Union - convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ +import fire +from bioimageio.core import __version__, test_description +from bioimageio.spec import save_bioimageio_package +from bioimageio.spec.collection import CollectionDescr +from bioimageio.spec.dataset import DatasetDescr +from bioimageio.spec.model import ModelDescr +from bioimageio.spec.model.v0_5 import WeightsFormat +from bioimageio.spec.notebook import NotebookDescr -if keras_converter is not None: - @app.command() - def convert_keras_weights_to_tensorflow( - model_rdf: Path = typer.Argument( - ..., help="Path to the model resource description file (rdf.yaml) or zipped model." - ), - output_path: Path = typer.Argument(..., help="Where to save the tensorflow weights."), +class Bioimageio: + def package( + self, + source: str, + path: Path = Path("bioimageio-package.zip"), + weight_format: Optional[WeightsFormat] = None, ): - ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(model_rdf, output_path) - sys.exit(ret_code) - - convert_keras_weights_to_tensorflow.__doc__ = ( - keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ - ) + """Package a bioimageio resource as a zip file + + Args: + source: RDF source e.g. `bioimageio.yaml` or `http://example.com/rdf.yaml` + path: output path + weight-format: include only this single weight-format + """ + _ = save_bioimageio_package( + source, + output_path=path, + weights_priority_order=None if weight_format is None else (weight_format,), + ) + + def test( + self, + source: str, + weight_format: Optional[WeightsFormat] = None, + *, + devices: Optional[Union[str, List[str]]] = None, + decimal: int = 4, + ): + """test a bioimageio resource + + Args: + source: Path or URL to the bioimageio resource description file + (bioimageio.yaml or rdf.yaml) or to a zipped resource + weight_format: (model only) The weight format to use + devices: Device(s) to use for testing + decimal: Precision for numerical comparisons + """ + summary = test_description( + source, + weight_format=None if weight_format is None else weight_format, + devices=[devices] if isinstance(devices, str) else devices, + decimal=decimal, + ) + print(f"\ntesting model {source}...") + print(summary.format()) + sys.exit(0 if summary.status == "passed" else 1) + + +Bioimageio.__doc__ = f""" +work with resources shared on bioimage.io + +library versions: + bioimageio.core {__version__} + bioimageio.spec {__version__} + +spec format versions: + model RDF {ModelDescr.implemented_format_version} + dataset RDF {DatasetDescr.implemented_format_version} + notebook RDF {NotebookDescr.implemented_format_version} + collection RDF {CollectionDescr.implemented_format_version} + +""" + +# TODO: add predict commands +# @app.command() +# def predict_image( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# inputs: Annotated[List[Path], typer.Option(help="Path(s) to the model input(s).")], +# outputs: Annotated[List[Path], typer.Option(help="Path(s) for saveing the model output(s).")], +# # NOTE: typer currently doesn't support union types, so we only support boolean here +# # padding: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# # tiling: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# padding: Annotated[ +# Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") +# ] = None, +# tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, +# weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, +# devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, +# ): +# if isinstance(padding, str): +# padding = json.loads(padding.replace("'", '"')) +# assert isinstance(padding, dict) +# if isinstance(tiling, str): +# tiling = json.loads(tiling.replace("'", '"')) +# assert isinstance(tiling, dict) + +# # this is a weird typer bug: default devices are empty tuple although they should be None +# if devices is None or len(devices) == 0: +# devices = None + +# prediction.predict_image( +# model_rdf, inputs, outputs, padding, tiling, None if weight_format is None else weight_format.value, devices +# ) + + +# predict_image.__doc__ = prediction.predict_image.__doc__ + + +# @app.command() +# def predict_images( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# input_pattern: Annotated[str, typer.Argument(help="Glob pattern for the input images.")], +# output_folder: Annotated[str, typer.Argument(help="Folder to save the outputs.")], +# output_extension: Annotated[Optional[str], typer.Argument(help="Optional output extension.")] = None, +# # NOTE: typer currently doesn't support union types, so we only support boolean here +# # padding: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# # tiling: Optional[Union[str, bool]] = typer.Argument( +# # None, help="Padding to apply in each dimension passed as json encoded string." +# # ), +# padding: Annotated[ +# Optional[bool], typer.Option(help="Whether to pad the image to a size suited for the model.") +# ] = None, +# tiling: Annotated[Optional[bool], typer.Option(help="Whether to run prediction in tiling mode.")] = None, +# weight_format: Annotated[Optional[WeightsFormatEnum], typer.Option(help="The weight format to use.")] = None, +# devices: Annotated[Optional[List[str]], typer.Option(help="Devices for running the model.")] = None, +# ): +# input_files = glob(input_pattern) +# input_names = [os.path.split(infile)[1] for infile in input_files] +# output_files = [os.path.join(output_folder, fname) for fname in input_names] +# if output_extension is not None: +# output_files = [f"{os.path.splitext(outfile)[0]}{output_extension}" for outfile in output_files] + +# if isinstance(padding, str): +# padding = json.loads(padding.replace("'", '"')) +# assert isinstance(padding, dict) +# if isinstance(tiling, str): +# tiling = json.loads(tiling.replace("'", '"')) +# assert isinstance(tiling, dict) + +# # this is a weird typer bug: default devices are empty tuple although they should be None +# if len(devices) == 0: +# devices = None +# prediction.predict_images( +# model_rdf, +# input_files, +# output_files, +# padding=padding, +# tiling=tiling, +# weight_format=None if weight_format is None else weight_format.value, +# devices=devices, +# verbose=True, +# ) + + +# predict_images.__doc__ = prediction.predict_images.__doc__ + + +# if torch_converter is not None: + +# @app.command() +# def convert_torch_weights_to_onnx( +# model_rdf: Path = typer.Argument( +# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." +# ), +# output_path: Path = typer.Argument(..., help="Where to save the onnx weights."), +# opset_version: Optional[int] = typer.Argument(12, help="Onnx opset version."), +# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), +# verbose: bool = typer.Option(True, help="Verbosity"), +# ): +# ret_code = torch_converter.convert_weights_to_onnx(model_rdf, output_path, opset_version, use_tracing, verbose) +# sys.exit(ret_code) + +# convert_torch_weights_to_onnx.__doc__ = torch_converter.convert_weights_to_onnx.__doc__ + +# @app.command() +# def convert_torch_weights_to_torchscript( +# model_rdf: Path = typer.Argument( +# ..., help="Path to the model resource description file (rdf.yaml) or zipped model." +# ), +# output_path: Path = typer.Argument(..., help="Where to save the torchscript weights."), +# use_tracing: bool = typer.Option(True, help="Whether to use torch.jit tracing or scripting."), +# ): +# torch_converter.convert_weights_to_torchscript(model_rdf, output_path, use_tracing) +# sys.exit(0) + +# convert_torch_weights_to_torchscript.__doc__ = torch_converter.convert_weights_to_torchscript.__doc__ + + +# if keras_converter is not None: + +# @app.command() +# def convert_keras_weights_to_tensorflow( +# model_rdf: Annotated[ +# Path, typer.Argument(help="Path to the model resource description file (rdf.yaml) or zipped model.") +# ], +# output_path: Annotated[Path, typer.Argument(help="Where to save the tensorflow weights.")], +# ): +# rd = load_description(model_rdf) +# ret_code = keras_converter.convert_weights_to_tensorflow_saved_model_bundle(rd, output_path) +# sys.exit(ret_code) + +# convert_keras_weights_to_tensorflow.__doc__ = ( +# keras_converter.convert_weights_to_tensorflow_saved_model_bundle.__doc__ +# ) + + +def main(): + fire.Fire(Bioimageio, name="bioimageio") if __name__ == "__main__": - app() + main() diff --git a/bioimageio/core/_magic_tensor_ops.py b/bioimageio/core/_magic_tensor_ops.py new file mode 100644 index 00000000..c1526fef --- /dev/null +++ b/bioimageio/core/_magic_tensor_ops.py @@ -0,0 +1,235 @@ +# this file was modified from the generated +# https://github.com/pydata/xarray/blob/cf3655968b8b12cc0ecd28fb324e63fb94d5e7e2/xarray/core/_typed_ops.py +# TODO: should we generate this ourselves? +# TODO: test these magic methods +import operator +from typing import Any, Callable + +from typing_extensions import Self +from xarray.core import nputils, ops + + +class MagicTensorOpsMixin: + __slots__ = () + _Compatible = Any + + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + raise NotImplementedError + + def __add__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add) + + def __sub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub) + + def __mul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul) + + def __pow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow) + + def __truediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv) + + def __floordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv) + + def __mod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod) + + def __and__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_) + + def __xor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor) + + def __or__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_) + + def __lshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lshift) + + def __rshift__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.rshift) + + def __lt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.lt) + + def __le__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.le) + + def __gt__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.gt) + + def __ge__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.ge) + + def __eq__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_eq # pyright: ignore[reportUnknownArgumentType] + ) + + def __ne__(self, other: _Compatible) -> Self: # type: ignore[override] + return self._binary_op( + other, nputils.array_ne # pyright: ignore[reportUnknownArgumentType] + ) + + # When __eq__ is defined but __hash__ is not, then an object is unhashable, + # and it should be declared as follows: + __hash__: None # type:ignore[assignment] + + def __radd__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.add, reflexive=True) + + def __rsub__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.sub, reflexive=True) + + def __rmul__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mul, reflexive=True) + + def __rpow__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.pow, reflexive=True) + + def __rtruediv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.truediv, reflexive=True) + + def __rfloordiv__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.floordiv, reflexive=True) + + def __rmod__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.mod, reflexive=True) + + def __rand__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.and_, reflexive=True) + + def __rxor__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.xor, reflexive=True) + + def __ror__(self, other: _Compatible) -> Self: + return self._binary_op(other, operator.or_, reflexive=True) + + def _inplace_binary_op( + self, other: _Compatible, f: Callable[[Any, Any], Any] + ) -> Self: + raise NotImplementedError + + def __iadd__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iadd) + + def __isub__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.isub) + + def __imul__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imul) + + def __ipow__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ipow) + + def __itruediv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.itruediv) + + def __ifloordiv__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ifloordiv) + + def __imod__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.imod) + + def __iand__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.iand) + + def __ixor__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ixor) + + def __ior__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ior) + + def __ilshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.ilshift) + + def __irshift__(self, other: _Compatible) -> Self: + return self._inplace_binary_op(other, operator.irshift) + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + raise NotImplementedError + + def __neg__(self) -> Self: + return self._unary_op(operator.neg) + + def __pos__(self) -> Self: + return self._unary_op(operator.pos) + + def __abs__(self) -> Self: + return self._unary_op(operator.abs) + + def __invert__(self) -> Self: + return self._unary_op(operator.invert) + + def round(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.round_, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def argsort(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.argsort, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conj(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conj, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + def conjugate(self, *args: Any, **kwargs: Any) -> Self: + return self._unary_op( + ops.conjugate, *args, **kwargs # pyright: ignore[reportUnknownArgumentType] + ) + + __add__.__doc__ = operator.add.__doc__ + __sub__.__doc__ = operator.sub.__doc__ + __mul__.__doc__ = operator.mul.__doc__ + __pow__.__doc__ = operator.pow.__doc__ + __truediv__.__doc__ = operator.truediv.__doc__ + __floordiv__.__doc__ = operator.floordiv.__doc__ + __mod__.__doc__ = operator.mod.__doc__ + __and__.__doc__ = operator.and_.__doc__ + __xor__.__doc__ = operator.xor.__doc__ + __or__.__doc__ = operator.or_.__doc__ + __lshift__.__doc__ = operator.lshift.__doc__ + __rshift__.__doc__ = operator.rshift.__doc__ + __lt__.__doc__ = operator.lt.__doc__ + __le__.__doc__ = operator.le.__doc__ + __gt__.__doc__ = operator.gt.__doc__ + __ge__.__doc__ = operator.ge.__doc__ + __eq__.__doc__ = nputils.array_eq.__doc__ + __ne__.__doc__ = nputils.array_ne.__doc__ + __radd__.__doc__ = operator.add.__doc__ + __rsub__.__doc__ = operator.sub.__doc__ + __rmul__.__doc__ = operator.mul.__doc__ + __rpow__.__doc__ = operator.pow.__doc__ + __rtruediv__.__doc__ = operator.truediv.__doc__ + __rfloordiv__.__doc__ = operator.floordiv.__doc__ + __rmod__.__doc__ = operator.mod.__doc__ + __rand__.__doc__ = operator.and_.__doc__ + __rxor__.__doc__ = operator.xor.__doc__ + __ror__.__doc__ = operator.or_.__doc__ + __iadd__.__doc__ = operator.iadd.__doc__ + __isub__.__doc__ = operator.isub.__doc__ + __imul__.__doc__ = operator.imul.__doc__ + __ipow__.__doc__ = operator.ipow.__doc__ + __itruediv__.__doc__ = operator.itruediv.__doc__ + __ifloordiv__.__doc__ = operator.ifloordiv.__doc__ + __imod__.__doc__ = operator.imod.__doc__ + __iand__.__doc__ = operator.iand.__doc__ + __ixor__.__doc__ = operator.ixor.__doc__ + __ior__.__doc__ = operator.ior.__doc__ + __ilshift__.__doc__ = operator.ilshift.__doc__ + __irshift__.__doc__ = operator.irshift.__doc__ + __neg__.__doc__ = operator.neg.__doc__ + __pos__.__doc__ = operator.pos.__doc__ + __abs__.__doc__ = operator.abs.__doc__ + __invert__.__doc__ = operator.invert.__doc__ diff --git a/bioimageio/core/_op_base.py b/bioimageio/core/_op_base.py new file mode 100644 index 00000000..55c961bc --- /dev/null +++ b/bioimageio/core/_op_base.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Collection, Union + +from .sample import Sample, SampleBlock, SampleBlockWithOrigin +from .stat_measures import Measure + + +@dataclass +class Operator(ABC): + @abstractmethod + def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: ... + + @property + @abstractmethod + def required_measures(self) -> Collection[Measure]: ... + + +@dataclass +class BlockedOperator(Operator, ABC): + @abstractmethod + def __call__( + self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] + ) -> None: ... + + @property + @abstractmethod + def required_measures(self) -> Collection[Measure]: ... diff --git a/bioimageio/core/_prediction_pipeline.py b/bioimageio/core/_prediction_pipeline.py new file mode 100644 index 00000000..b9034d05 --- /dev/null +++ b/bioimageio/core/_prediction_pipeline.py @@ -0,0 +1,363 @@ +import warnings +from types import MappingProxyType +from typing import ( + Any, + Iterable, + List, + Mapping, + Optional, + Sequence, + Tuple, + TypeVar, + Union, +) + +from tqdm import tqdm + +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model.v0_5 import WeightsFormat + +from ._op_base import BlockedOperator +from .axis import AxisId, PerAxis +from .common import Halo, MemberId, PerMember, SampleId +from .digest_spec import ( + get_block_transform, + get_input_halo, + get_member_ids, +) +from .model_adapters import ModelAdapter, create_model_adapter +from .model_adapters import get_weight_formats as get_weight_formats +from .proc_ops import Processing +from .proc_setup import setup_pre_and_postprocessing +from .sample import Sample, SampleBlock, SampleBlockWithOrigin +from .stat_measures import DatasetMeasure, MeasureValue, Stat +from .tensor import Tensor + +Predict_IO = TypeVar( + "Predict_IO", + Sample, + Iterable[Sample], +) + + +class PredictionPipeline: + """ + Represents model computation including preprocessing and postprocessing + Note: Ideally use the PredictionPipeline as a context manager + """ + + def __init__( + self, + *, + name: str, + model_description: AnyModelDescr, + preprocessing: List[Processing], + postprocessing: List[Processing], + model_adapter: ModelAdapter, + default_ns: Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], + ] = 10, + default_batch_size: int = 1, + ) -> None: + super().__init__() + if model_description.run_mode: + warnings.warn( + f"Not yet implemented inference for run mode '{model_description.run_mode.name}'" + ) + + self.name = name + self._preprocessing = preprocessing + self._postprocessing = postprocessing + + self.model_description = model_description + if isinstance(model_description, v0_4.ModelDescr): + self._default_input_halo: PerMember[PerAxis[Halo]] = {} + self._block_transform = None + else: + default_output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) + for a in t.axes + if isinstance(a, v0_5.WithHalo) + } + for t in model_description.outputs + } + self._default_input_halo = get_input_halo( + model_description, default_output_halo + ) + self._block_transform = get_block_transform(model_description) + + self._default_ns = default_ns + self._default_batch_size = default_batch_size + + self._input_ids = get_member_ids(model_description.inputs) + self._output_ids = get_member_ids(model_description.outputs) + + self._adapter: ModelAdapter = model_adapter + + def __enter__(self): + self.load() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): # type: ignore + self.unload() + return False + + def predict_sample_block( + self, + sample_block: SampleBlockWithOrigin, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> SampleBlock: + if isinstance(self.model_description, v0_4.ModelDescr): + raise NotImplementedError( + f"predict_sample_block not implemented for model {self.model_description.format_version}" + ) + else: + assert self._block_transform is not None + + if not skip_preprocessing: + self.apply_preprocessing(sample_block) + + output_meta = sample_block.get_transformed_meta(self._block_transform) + output = output_meta.with_data( + { + tid: out + for tid, out in zip( + self._output_ids, + self._adapter.forward( + *(sample_block.members.get(t) for t in self._input_ids) + ), + ) + if out is not None + }, + stat=sample_block.stat, + ) + if not skip_postprocessing: + self.apply_postprocessing(output) + + return output + + def predict_sample_without_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ) -> Sample: + """predict a sample. + The sample's tensor shapes have to match the model's input tensor description. + If that is not the case, consider `predict_sample_with_blocking`""" + + if not skip_preprocessing: + self.apply_preprocessing(sample) + + output = Sample( + members={ + out_id: out + for out_id, out in zip( + self._output_ids, + self._adapter.forward( + *(sample.members.get(in_id) for in_id in self._input_ids) + ), + ) + if out is not None + }, + stat=sample.stat, + id=self.get_output_sample_id(sample.id), + ) + if not skip_postprocessing: + self.apply_postprocessing(output) + + return output + + def get_output_sample_id(self, input_sample_id: SampleId): + if input_sample_id is None: + return None + else: + return f"{input_sample_id}_" + ( + self.model_description.id or self.model_description.name + ) + + def predict_sample_with_blocking( + self, + sample: Sample, + skip_preprocessing: bool = False, + skip_postprocessing: bool = False, + ns: Optional[ + Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], + ] + ] = None, + batch_size: Optional[int] = None, + ) -> Sample: + """predict a sample by splitting it into blocks according to the model and the `ns` parameter""" + if not skip_preprocessing: + self.apply_preprocessing(sample) + + if isinstance(self.model_description, v0_4.ModelDescr): + raise NotImplementedError( + "predict with blocking not implemented for v0_4.ModelDescr {self.model_description.name}" + ) + + ns = ns or self._default_ns + if isinstance(ns, int): + ns = { + (ipt.id, a.id): ns + for ipt in self.model_description.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + input_block_shape = self.model_description.get_tensor_sizes( + ns, batch_size or self._default_batch_size + ).inputs + + n_blocks, input_blocks = sample.split_into_blocks( + input_block_shape, + halo=self._default_input_halo, + pad_mode="reflect", + ) + input_blocks = list(input_blocks) + predicted_blocks: List[SampleBlock] = [] + for b in tqdm( + input_blocks, + desc=f"predict sample {sample.id or ''} with {self.model_description.id or self.model_description.name}", + unit="block", + unit_divisor=1, + total=n_blocks, + ): + predicted_blocks.append( + self.predict_sample_block( + b, skip_preprocessing=True, skip_postprocessing=True + ) + ) + + predicted_sample = Sample.from_blocks(predicted_blocks) + if not skip_postprocessing: + self.apply_postprocessing(predicted_sample) + + return predicted_sample + + # def predict( + # self, + # inputs: Predict_IO, + # skip_preprocessing: bool = False, + # skip_postprocessing: bool = False, + # ) -> Predict_IO: + # """Run model prediction **including** pre/postprocessing.""" + + # if isinstance(inputs, Sample): + # return self.predict_sample_with_blocking( + # inputs, + # skip_preprocessing=skip_preprocessing, + # skip_postprocessing=skip_postprocessing, + # ) + # elif isinstance(inputs, collections.abc.Iterable): + # return ( + # self.predict( + # ipt, + # skip_preprocessing=skip_preprocessing, + # skip_postprocessing=skip_postprocessing, + # ) + # for ipt in inputs + # ) + # else: + # assert_never(inputs) + + def apply_preprocessing(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: + """apply preprocessing in-place, also updates sample stats""" + for op in self._preprocessing: + op(sample) + + def apply_postprocessing( + self, sample: Union[Sample, SampleBlock, SampleBlockWithOrigin] + ) -> None: + """apply postprocessing in-place, also updates samples stats""" + for op in self._postprocessing: + if isinstance(sample, (Sample, SampleBlockWithOrigin)): + op(sample) + elif not isinstance(op, BlockedOperator): + raise NotImplementedError( + "block wise update of output statistics not yet implemented" + ) + else: + op(sample) + + def load(self): + """ + optional step: load model onto devices before calling forward if not using it as context manager + """ + pass + + def unload(self): + """ + free any device memory in use + """ + self._adapter.unload() + + +def create_prediction_pipeline( + bioimageio_model: AnyModelDescr, + *, + devices: Optional[Sequence[str]] = None, + weight_format: Optional[WeightsFormat] = None, + weights_format: Optional[WeightsFormat] = None, + dataset_for_initial_statistics: Iterable[Union[Sample, Sequence[Tensor]]] = tuple(), + keep_updating_initial_dataset_statistics: bool = False, + fixed_dataset_statistics: Mapping[DatasetMeasure, MeasureValue] = MappingProxyType( + {} + ), + model_adapter: Optional[ModelAdapter] = None, + ns: Union[ + v0_5.ParameterizedSize.N, + Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N], + ] = 10, + **deprecated_kwargs: Any, +) -> PredictionPipeline: + """ + Creates prediction pipeline which includes: + * computation of input statistics + * preprocessing + * model prediction + * computation of output statistics + * postprocessing + """ + weights_format = weight_format or weights_format + del weight_format + if deprecated_kwargs: + warnings.warn( + f"deprecated create_prediction_pipeline kwargs: {set(deprecated_kwargs)}" + ) + + model_adapter = model_adapter or create_model_adapter( + model_description=bioimageio_model, + devices=devices, + weight_format_priority_order=weights_format and (weights_format,), + ) + + input_ids = get_member_ids(bioimageio_model.inputs) + + def dataset(): + common_stat: Stat = {} + for i, x in enumerate(dataset_for_initial_statistics): + if isinstance(x, Sample): + yield x + else: + yield Sample(members=dict(zip(input_ids, x)), stat=common_stat, id=i) + + preprocessing, postprocessing = setup_pre_and_postprocessing( + bioimageio_model, + dataset(), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_statistics, + fixed_dataset_stats=fixed_dataset_statistics, + ) + + return PredictionPipeline( + name=bioimageio_model.name, + model_description=bioimageio_model, + model_adapter=model_adapter, + preprocessing=preprocessing, + postprocessing=postprocessing, + default_ns=ns, + ) diff --git a/bioimageio/core/_resource_tests.py b/bioimageio/core/_resource_tests.py new file mode 100644 index 00000000..c84f0027 --- /dev/null +++ b/bioimageio/core/_resource_tests.py @@ -0,0 +1,420 @@ +import traceback +import warnings +from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union + +import numpy as np + +from bioimageio.core.sample import Sample +from bioimageio.spec import ( + InvalidDescr, + ResourceDescr, + build_description, + dump_description, + load_description, +) +from bioimageio.spec._internal.common_nodes import ResourceDescrBase +from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import WeightsFormat +from bioimageio.spec.summary import ( + ErrorEntry, + InstalledPackage, + ValidationDetail, + ValidationSummary, +) + +from ._prediction_pipeline import create_prediction_pipeline +from .axis import AxisId, BatchSize +from .digest_spec import get_test_inputs, get_test_outputs +from .utils import VERSION + + +def test_model( + source: Union[v0_5.ModelDescr, PermissiveFileSource], + weight_format: Optional[WeightsFormat] = None, + devices: Optional[List[str]] = None, + decimal: int = 4, +) -> ValidationSummary: + """Test model inference""" + return test_description( + source, + weight_format=weight_format, + devices=devices, + decimal=decimal, + expected_type="model", + ) + + +def test_description( + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], + *, + format_version: Union[Literal["discover", "latest"], str] = "discover", + weight_format: Optional[WeightsFormat] = None, + devices: Optional[List[str]] = None, + decimal: int = 4, + expected_type: Optional[str] = None, +) -> ValidationSummary: + """Test a bioimage.io resource dynamically, e.g. prediction of test tensors for models""" + rd = load_description_and_test( + source, + format_version=format_version, + weight_format=weight_format, + devices=devices, + decimal=decimal, + expected_type=expected_type, + ) + return rd.validation_summary + + +def load_description_and_test( + source: Union[ResourceDescr, PermissiveFileSource, BioimageioYamlContent], + *, + format_version: Union[Literal["discover", "latest"], str] = "discover", + weight_format: Optional[WeightsFormat] = None, + devices: Optional[List[str]] = None, + decimal: int = 4, + expected_type: Optional[str] = None, +) -> Union[ResourceDescr, InvalidDescr]: + """Test RDF dynamically, e.g. model inference of test inputs""" + if ( + isinstance(source, ResourceDescrBase) + and format_version != "discover" + and source.format_version != format_version + ): + warnings.warn( + f"deserializing source to ensure we validate and test using format {format_version}" + ) + source = dump_description(source) + + if isinstance(source, ResourceDescrBase): + rd = source + elif isinstance(source, dict): + rd = build_description(source, format_version=format_version) + else: + rd = load_description(source, format_version=format_version) + + rd.validation_summary.env.append( + InstalledPackage(name="bioimageio.core", version=VERSION) + ) + + if expected_type is not None: + _test_expected_resource_type(rd, expected_type) + + if isinstance(rd, (v0_4.ModelDescr, v0_5.ModelDescr)): + _test_model_inference(rd, weight_format, devices, decimal) + if not isinstance(rd, v0_4.ModelDescr): + _test_model_inference_parametrized(rd, weight_format, devices) + + # TODO: add execution of jupyter notebooks + # TODO: add more tests + + return rd + + +def _test_model_inference( + model: Union[v0_4.ModelDescr, v0_5.ModelDescr], + weight_format: Optional[WeightsFormat], + devices: Optional[List[str]], + decimal: int, +) -> None: + error: Optional[str] = None + tb: List[str] = [] + try: + inputs = get_test_inputs(model) + expected = get_test_outputs(model) + + with create_prediction_pipeline( + bioimageio_model=model, devices=devices, weight_format=weight_format + ) as prediction_pipeline: + results = prediction_pipeline.predict_sample_without_blocking(inputs) + + if len(results.members) != len(expected.members): + error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}" + + else: + for m, exp in expected.members.items(): + res = results.members.get(m) + if res is None: + error = "Output tensors for test case may not be None" + break + try: + np.testing.assert_array_almost_equal( + res.data, exp.data, decimal=decimal + ) + except AssertionError as e: + error = f"Output and expected output disagree:\n {e}" + break + except Exception as e: + error = str(e) + tb = traceback.format_tb(e.__traceback__) + + model.validation_summary.add_detail( + ValidationDetail( + name="Reproduce test outputs from test inputs", + status="passed" if error is None else "failed", + errors=( + [] + if error is None + else [ + ErrorEntry( + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), + msg=error, + type="bioimageio.core", + traceback=tb, + ) + ] + ), + ) + ) + + +def _test_model_inference_parametrized( + model: v0_5.ModelDescr, + weight_format: Optional[WeightsFormat], + devices: Optional[List[str]], + test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = { + (0, 2), + (1, 3), + (2, 1), + (3, 2), + }, +) -> None: + if not test_cases: + return + + if not any( + isinstance(a.size, v0_5.ParameterizedSize) + for ipt in model.inputs + for a in ipt.axes + ): + # no parameterized sizes => set n=0 + test_cases = {(0, b) for _n, b in test_cases} + + if not any(isinstance(a, v0_5.BatchAxis) for ipt in model.inputs for a in ipt.axes): + # no batch axis => set b=1 + test_cases = {(n, 1) for n, _b in test_cases} + + def generate_test_cases(): + tested: Set[Hashable] = set() + + def get_ns(n: int): + return { + (t.id, a.id): n + for t in model.inputs + for a in t.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + + for n, batch_size in sorted(test_cases): + input_target_sizes, expected_output_sizes = model.get_axis_sizes( + get_ns(n), batch_size=batch_size + ) + hashable_target_size = tuple( + (k, input_target_sizes[k]) for k in sorted(input_target_sizes) + ) + if hashable_target_size in tested: + continue + else: + tested.add(hashable_target_size) + + resized_test_inputs = Sample( + members={ + t.id: test_inputs.members[t.id].resize_to( + { + aid: s + for (tid, aid), s in input_target_sizes.items() + if tid == t.id + }, + ) + for t in model.inputs + }, + stat=test_inputs.stat, + id=test_inputs.id, + ) + expected_output_shapes = { + t.id: { + aid: s + for (tid, aid), s in expected_output_sizes.items() + if tid == t.id + } + for t in model.outputs + } + yield n, batch_size, resized_test_inputs, expected_output_shapes + + try: + test_inputs = get_test_inputs(model) + + with create_prediction_pipeline( + bioimageio_model=model, devices=devices, weight_format=weight_format + ) as prediction_pipeline: + for n, batch_size, inputs, exptected_output_shape in generate_test_cases(): + error: Optional[str] = None + result = prediction_pipeline.predict_sample_with_blocking(inputs) + if len(result.members) != len(exptected_output_shape): + error = ( + f"Expected {len(exptected_output_shape)} outputs," + + f" but got {len(result.members)}" + ) + + else: + for m, exp in exptected_output_shape.items(): + res = result.members.get(m) + if res is None: + error = "Output tensors may not be None for test case" + break + + diff: Dict[AxisId, int] = {} + for a, s in res.sizes.items(): + if isinstance((e_aid := exp[AxisId(a)]), int): + if s != e_aid: + diff[AxisId(a)] = s + elif ( + s < e_aid.min or e_aid.max is not None and s > e_aid.max + ): + diff[AxisId(a)] = s + if diff: + error = ( + f"(n={n}) Expected output shape {exp}," + + f" but got {res.sizes} (diff: {diff})" + ) + break + + model.validation_summary.add_detail( + ValidationDetail( + name="Run inference for inputs with batch_size:" + + f" {batch_size} and size parameter n: {n}", + status="passed" if error is None else "failed", + errors=( + [] + if error is None + else [ + ErrorEntry( + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), + msg=error, + type="bioimageio.core", + ) + ] + ), + ) + ) + except Exception as e: + error = str(e) + tb = traceback.format_tb(e.__traceback__) + model.validation_summary.add_detail( + ValidationDetail( + name="Run inference for parametrized inputs", + status="failed", + errors=[ + ErrorEntry( + loc=( + ("weights",) + if weight_format is None + else ("weights", weight_format) + ), + msg=error, + type="bioimageio.core", + traceback=tb, + ) + ], + ) + ) + + +def _test_expected_resource_type( + rd: Union[InvalidDescr, ResourceDescr], expected_type: str +): + has_expected_type = rd.type == expected_type + rd.validation_summary.details.append( + ValidationDetail( + name="Has expected resource type", + status="passed" if has_expected_type else "failed", + errors=( + [] + if has_expected_type + else [ + ErrorEntry( + loc=("type",), + type="type", + msg=f"expected type {expected_type}, found {rd.type}", + ) + ] + ), + ) + ) + + +# def debug_model( +# model_rdf: Union[RawResourceDescr, ResourceDescr, URI, Path, str], +# *, +# weight_format: Optional[WeightsFormat] = None, +# devices: Optional[List[str]] = None, +# ): +# """Run the model test and return dict with inputs, results, expected results and intermediates. + +# Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". +# """ +# inputs_raw: Optional = None +# inputs_processed: Optional = None +# outputs_raw: Optional = None +# outputs: Optional = None +# expected: Optional = None +# diff: Optional = None + +# model = load_description( +# model_rdf, weights_priority_order=None if weight_format is None else [weight_format] +# ) +# if not isinstance(model, Model): +# raise ValueError(f"Not a bioimageio.model: {model_rdf}") + +# prediction_pipeline = create_prediction_pipeline( +# bioimageio_model=model, devices=devices, weight_format=weight_format +# ) +# inputs = [ +# xr.DataArray(load_array(str(in_path)), dims=input_spec.axes) +# for in_path, input_spec in zip(model.test_inputs, model.inputs) +# ] +# input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} + +# # keep track of the non-processed inputs +# inputs_raw = [deepcopy(input) for input in inputs] + +# computed_measures = {} + +# prediction_pipeline.apply_preprocessing(input_dict, computed_measures) +# inputs_processed = list(input_dict.values()) +# outputs_raw = prediction_pipeline.predict(*inputs_processed) +# output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} +# prediction_pipeline.apply_postprocessing(output_dict, computed_measures) +# outputs = list(output_dict.values()) + +# if isinstance(outputs, (np.ndarray, xr.DataArray)): +# outputs = [outputs] + +# expected = [ +# xr.DataArray(load_array(str(out_path)), dims=output_spec.axes) +# for out_path, output_spec in zip(model.test_outputs, model.outputs) +# ] +# if len(outputs) != len(expected): +# error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" +# print(error) +# else: +# diff = [] +# for res, exp in zip(outputs, expected): +# diff.append(res - exp) + +# return { +# "inputs": inputs_raw, +# "inputs_processed": inputs_processed, +# "outputs_raw": outputs_raw, +# "outputs": outputs, +# "expected": expected, +# "diff": diff, +# } diff --git a/bioimageio/core/_settings.py b/bioimageio/core/_settings.py new file mode 100644 index 00000000..d09f3b8b --- /dev/null +++ b/bioimageio/core/_settings.py @@ -0,0 +1,20 @@ +from typing import Literal + +from dotenv import load_dotenv +from pydantic import Field +from typing_extensions import Annotated + +from bioimageio.spec._internal._settings import Settings as SpecSettings + +_ = load_dotenv() + + +class Settings(SpecSettings): + """environment variables""" + + keras_backend: Annotated[ + Literal["torch", "tensorflow", "jax"], Field(alias="KERAS_BACKEND") + ] = "torch" + + +settings = Settings() diff --git a/bioimageio/core/axis.py b/bioimageio/core/axis.py new file mode 100644 index 00000000..033b68d7 --- /dev/null +++ b/bioimageio/core/axis.py @@ -0,0 +1,112 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Mapping, Optional, TypeVar, Union + +from typing_extensions import assert_never + +from bioimageio.spec.model import v0_5 + + +def _get_axis_type(a: Literal["b", "t", "i", "c", "x", "y", "z"]): + if a == "b": + return "batch" + elif a == "t": + return "time" + elif a == "i": + return "index" + elif a == "c": + return "channel" + elif a in ("x", "y", "z"): + return "space" + else: + return "index" # return most unspecific axis + + +S = TypeVar("S", bound=str) + + +def _get_axis_id(a: Union[Literal["b", "t", "i", "c"], S]): + if a == "b": + return AxisId("batch") + elif a == "t": + return AxisId("time") + elif a == "i": + return AxisId("index") + elif a == "c": + return AxisId("channel") + else: + return AxisId(a) + + +AxisId = v0_5.AxisId + +T = TypeVar("T") +PerAxis = Mapping[AxisId, T] + +BatchSize = int + +AxisLetter = Literal["b", "i", "t", "c", "z", "y", "x"] +AxisLike = Union[AxisLetter, v0_5.AnyAxis, "Axis"] + + +@dataclass +class Axis: + id: AxisId + type: Literal["batch", "channel", "index", "space", "time"] + + @classmethod + def create(cls, axis: AxisLike) -> Axis: + if isinstance(axis, cls): + return axis + elif isinstance(axis, Axis): + return Axis(id=axis.id, type=axis.type) + elif isinstance(axis, str): + return Axis(id=_get_axis_id(axis), type=_get_axis_type(axis)) + elif isinstance(axis, v0_5.AxisBase): + return Axis(id=AxisId(axis.id), type=axis.type) + else: + assert_never(axis) + + +@dataclass +class AxisInfo(Axis): + maybe_singleton: bool + + @classmethod + def create(cls, axis: AxisLike, maybe_singleton: Optional[bool] = None) -> AxisInfo: + if isinstance(axis, AxisInfo): + return axis + + axis_base = super().create(axis) + if maybe_singleton is None: + if isinstance(axis, Axis): + maybe_singleton = False + elif isinstance(axis, str): + maybe_singleton = axis == "b" + else: + if axis.size is None: + maybe_singleton = True + elif isinstance(axis.size, int): + maybe_singleton = axis.size == 1 + elif isinstance(axis.size, v0_5.SizeReference): + maybe_singleton = ( + False # TODO: check if singleton is ok for a `SizeReference` + ) + elif isinstance( + axis.size, (v0_5.ParameterizedSize, v0_5.DataDependentSize) + ): + try: + maybe_size_one = axis.size.validate_size( + 1 + ) # TODO: refactor validate_size() to have boolean func here + except ValueError: + maybe_singleton = False + else: + maybe_singleton = maybe_size_one == 1 + else: + assert_never(axis.size) + + return AxisInfo( + id=axis_base.id, type=axis_base.type, maybe_singleton=maybe_singleton + ) diff --git a/bioimageio/core/block.py b/bioimageio/core/block.py new file mode 100644 index 00000000..3355f9b1 --- /dev/null +++ b/bioimageio/core/block.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass +from typing import ( + Any, + Generator, + Iterable, + Optional, + Tuple, + Union, +) + +from typing_extensions import Self + +from .axis import PerAxis +from .block_meta import BlockMeta, LinearAxisTransform, split_shape_into_blocks +from .common import ( + Halo, + HaloLike, + PadMode, + TotalNumberOfBlocks, +) +from .tensor import Tensor + + +@dataclass(frozen=True) +class Block(BlockMeta): + """A block/tile of a (larger) tensor""" + + data: Tensor + """the block's tensor, e.g. a (padded) slice of some larger, original tensor""" + + @property + def inner_data(self): + return self.data[self.local_slice] + + def __post_init__(self): + super().__post_init__() + for a, s in self.data.sizes.items(): + slice_ = self.inner_slice[a] + halo = self.halo.get(a, Halo(0, 0)) + assert s == halo.left + (slice_.stop - slice_.start) + halo.right, ( + s, + slice_, + halo, + ) + + @classmethod + def from_sample_member( + cls, + sample_member: Tensor, + block: BlockMeta, + *, + pad_mode: PadMode, + ) -> Self: + return cls( + data=sample_member[block.outer_slice].pad(block.padding, pad_mode), + sample_shape=sample_member.tagged_shape, + inner_slice=block.inner_slice, + halo=block.halo, + block_index=block.block_index, + blocks_in_sample=block.blocks_in_sample, + ) + + def get_transformed( + self, new_axes: PerAxis[Union[LinearAxisTransform, int]] + ) -> Self: + raise NotImplementedError + + +def split_tensor_into_blocks( + tensor: Tensor, + block_shape: PerAxis[int], + *, + halo: PerAxis[HaloLike], + stride: Optional[PerAxis[int]] = None, + pad_mode: PadMode, +) -> Tuple[TotalNumberOfBlocks, Generator[Block, Any, None]]: + """divide a sample tensor into tensor blocks.""" + n_blocks, block_gen = split_shape_into_blocks( + tensor.tagged_shape, block_shape=block_shape, halo=halo, stride=stride + ) + return n_blocks, _block_generator(tensor, block_gen, pad_mode=pad_mode) + + +def _block_generator(sample: Tensor, blocks: Iterable[BlockMeta], *, pad_mode: PadMode): + for block in blocks: + yield Block.from_sample_member(sample, block, pad_mode=pad_mode) diff --git a/bioimageio/core/block_meta.py b/bioimageio/core/block_meta.py new file mode 100644 index 00000000..0fe5b6c5 --- /dev/null +++ b/bioimageio/core/block_meta.py @@ -0,0 +1,387 @@ +import itertools +from dataclasses import dataclass +from functools import cached_property +from math import floor, prod +from typing import ( + Any, + Callable, + Collection, + Dict, + Generator, + Iterable, + List, + Optional, + Tuple, + Union, +) + +from loguru import logger +from typing_extensions import Self + +from .axis import AxisId, PerAxis +from .common import ( + BlockIndex, + Frozen, + Halo, + HaloLike, + MemberId, + PadWidth, + PerMember, + SliceInfo, + TotalNumberOfBlocks, +) + + +@dataclass +class LinearAxisTransform: + axis: AxisId + scale: float + offset: int + + def compute(self, s: int, round: Callable[[float], int] = floor) -> int: + return round(s * self.scale) + self.offset + + +@dataclass(frozen=True) +class BlockMeta: + """Block meta data of a sample member (a tensor in a sample) + + Figure for illustration: + The first 2d block (dashed) of a sample member (**bold**). + The inner slice (thin) is expanded by a halo in both dimensions on both sides. + The outer slice reaches from the sample member origin (0, 0) to the right halo point. + + ```terminal + ┌ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ┐ + ╷ halo(left) ╷ + ╷ ╷ + ╷ (0, 0)┏━━━━━━━━━━━━━━━━━┯━━━━━━━━━┯━━━➔ + ╷ ┃ │ ╷ sample member + ╷ ┃ inner │ ╷ + ╷ ┃ (and outer) │ outer ╷ + ╷ ┃ slice │ slice ╷ + ╷ ┃ │ ╷ + ╷ ┣─────────────────┘ ╷ + ╷ ┃ outer slice ╷ + ╷ ┃ halo(right) ╷ + └ ─ ─ ─ ─┃─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─ ─┘ + ⬇ + ``` + + note: + - Inner and outer slices are specified in sample member coordinates. + - The outer_slice of a block at the sample edge may overlap by more than the + halo with the neighboring block (the inner slices will not overlap though). + + """ + + sample_shape: PerAxis[int] + """the axis sizes of the whole (unblocked) sample""" + + inner_slice: PerAxis[SliceInfo] + """inner region (without halo) wrt the sample""" + + halo: PerAxis[Halo] + """halo enlarging the inner region to the block's sizes""" + + block_index: BlockIndex + """the i-th block of the sample""" + + blocks_in_sample: TotalNumberOfBlocks + """total number of blocks in the sample""" + + @cached_property + def shape(self) -> PerAxis[int]: + """axis lengths of the block""" + return Frozen( + { + a: s.stop - s.start + (sum(self.halo[a]) if a in self.halo else 0) + for a, s in self.inner_slice.items() + } + ) + + @cached_property + def padding(self) -> PerAxis[PadWidth]: + """padding to realize the halo at the sample edge + where we cannot simply enlarge the inner slice""" + return Frozen( + { + a: PadWidth( + ( + self.halo[a].left + - (self.inner_slice[a].start - self.outer_slice[a].start) + if a in self.halo + else 0 + ), + ( + self.halo[a].right + - (self.outer_slice[a].stop - self.inner_slice[a].stop) + if a in self.halo + else 0 + ), + ) + for a in self.inner_slice + } + ) + + @cached_property + def outer_slice(self) -> PerAxis[SliceInfo]: + """slice of the outer block (without padding) wrt the sample""" + return Frozen( + { + a: SliceInfo( + max( + 0, + min( + self.inner_slice[a].start + - (self.halo[a].left if a in self.halo else 0), + self.sample_shape[a] + - self.inner_shape[a] + - (self.halo[a].left if a in self.halo else 0), + ), + ), + min( + self.sample_shape[a], + self.inner_slice[a].stop + + (self.halo[a].right if a in self.halo else 0), + ), + ) + for a in self.inner_slice + } + ) + + @cached_property + def inner_shape(self) -> PerAxis[int]: + """axis lengths of the inner region (without halo)""" + return Frozen({a: s.stop - s.start for a, s in self.inner_slice.items()}) + + @cached_property + def local_slice(self) -> PerAxis[SliceInfo]: + """inner slice wrt the block, **not** the sample""" + return Frozen( + { + a: SliceInfo( + self.padding[a].left, + self.padding[a].left + self.inner_shape[a], + ) + for a in self.inner_slice + } + ) + + @property + def dims(self) -> Collection[AxisId]: + return set(self.inner_shape) + + @property + def tagged_shape(self) -> PerAxis[int]: + """alias for shape""" + return self.shape + + @property + def inner_slice_wo_overlap(self): + """subslice of the inner slice, such that all `inner_slice_wo_overlap` can be + stiched together trivially to form the original sample. + + This can also be used to calculate statistics + without overrepresenting block edge regions.""" + # TODO: update inner_slice_wo_overlap when adding block overlap + return self.inner_slice + + def __post_init__(self): + # freeze mutable inputs + if not isinstance(self.sample_shape, Frozen): + object.__setattr__(self, "sample_shape", Frozen(self.sample_shape)) + + if not isinstance(self.inner_slice, Frozen): + object.__setattr__(self, "inner_slice", Frozen(self.inner_slice)) + + if not isinstance(self.halo, Frozen): + object.__setattr__(self, "halo", Frozen(self.halo)) + + assert all( + a in self.sample_shape for a in self.inner_slice + ), "block has axes not present in sample" + + assert all( + a in self.inner_slice for a in self.halo + ), "halo has axes not present in block" + + if any(s > self.sample_shape[a] for a, s in self.shape.items()): + logger.warning( + "block {} larger than sample {}", self.shape, self.sample_shape + ) + + def get_transformed( + self, new_axes: PerAxis[Union[LinearAxisTransform, int]] + ) -> Self: + return self.__class__( + sample_shape={ + a: ( + trf + if isinstance(trf, int) + else trf.compute(self.sample_shape[trf.axis]) + ) + for a, trf in new_axes.items() + }, + inner_slice={ + a: ( + SliceInfo(0, trf) + if isinstance(trf, int) + else SliceInfo( + trf.compute(self.inner_slice[trf.axis].start), + trf.compute(self.inner_slice[trf.axis].stop), + ) + ) + for a, trf in new_axes.items() + }, + halo={ + a: ( + Halo(0, 0) + if isinstance(trf, int) + else Halo(self.halo[trf.axis].left, self.halo[trf.axis].right) + ) + for a, trf in new_axes.items() + }, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + + +def split_shape_into_blocks( + shape: PerAxis[int], + block_shape: PerAxis[int], + halo: PerAxis[HaloLike], + stride: Optional[PerAxis[int]] = None, +) -> Tuple[TotalNumberOfBlocks, Generator[BlockMeta, Any, None]]: + assert all(a in shape for a in block_shape), ( + tuple(shape), + set(block_shape), + ) + if any(shape[a] < block_shape[a] for a in block_shape): + raise ValueError(f"shape {shape} is smaller than block shape {block_shape}") + + assert all(a in shape for a in halo), (tuple(shape), set(halo)) + + # fill in default halo (0) and block axis length (from tensor shape) + halo = {a: Halo.create(halo.get(a, 0)) for a in shape} + block_shape = {a: block_shape.get(a, s) for a, s in shape.items()} + if stride is None: + stride = {} + + inner_1d_slices: Dict[AxisId, List[SliceInfo]] = {} + for a, s in shape.items(): + inner_size = block_shape[a] - sum(halo[a]) + stride_1d = stride.get(a, inner_size) + inner_1d_slices[a] = [ + SliceInfo(min(p, s - inner_size), min(p + inner_size, s)) + for p in range(0, s, stride_1d) + ] + + n_blocks = prod(map(len, inner_1d_slices.values())) + + return n_blocks, _block_meta_generator( + shape, + blocks_in_sample=n_blocks, + inner_1d_slices=inner_1d_slices, + halo=halo, + ) + + +def _block_meta_generator( + sample_shape: PerAxis[int], + *, + blocks_in_sample: int, + inner_1d_slices: Dict[AxisId, List[SliceInfo]], + halo: PerAxis[HaloLike], +): + assert all(a in sample_shape for a in halo) + + halo = {a: Halo.create(halo.get(a, 0)) for a in inner_1d_slices} + for i, nd_tile in enumerate(itertools.product(*inner_1d_slices.values())): + inner_slice: PerAxis[SliceInfo] = dict(zip(inner_1d_slices, nd_tile)) + + yield BlockMeta( + sample_shape=sample_shape, + inner_slice=inner_slice, + halo=halo, + block_index=i, + blocks_in_sample=blocks_in_sample, + ) + + +def split_multiple_shapes_into_blocks( + shapes: PerMember[PerAxis[int]], + block_shapes: PerMember[PerAxis[int]], + *, + halo: PerMember[PerAxis[HaloLike]], + strides: Optional[PerMember[PerAxis[int]]] = None, + broadcast: bool = False, +) -> Tuple[TotalNumberOfBlocks, Iterable[PerMember[BlockMeta]]]: + assert not ( + missing := [t for t in block_shapes if t not in shapes] + ), f"block shape specified for unknown tensors: {missing}" + if not block_shapes: + block_shapes = shapes + + assert broadcast or not ( + missing := [t for t in shapes if t not in block_shapes] + ), f"no block shape specified for {missing} (set `broadcast` to True if these tensors should be repeated for each block)" + assert not ( + missing := [t for t in halo if t not in block_shapes] + ), f"`halo` specified for tensors without block shape: {missing}" + + if strides is None: + strides = {} + + assert not ( + missing := [t for t in strides if t not in block_shapes] + ), f"`stride` specified for tensors without block shape: {missing}" + + blocks: Dict[MemberId, Iterable[BlockMeta]] = {} + n_blocks: Dict[MemberId, TotalNumberOfBlocks] = {} + for t in block_shapes: + n_blocks[t], blocks[t] = split_shape_into_blocks( + shape=shapes[t], + block_shape=block_shapes[t], + halo=halo.get(t, {}), + stride=strides.get(t), + ) + assert n_blocks[t] > 0 + + assert len(blocks) > 0, blocks + assert len(n_blocks) > 0, n_blocks + unique_n_blocks = set(n_blocks.values()) + n = max(unique_n_blocks) + if len(unique_n_blocks) == 2 and 1 in unique_n_blocks: + if not broadcast: + raise ValueError( + f"Mismatch for total number of blocks due to unsplit (single block) tensors: {n_blocks}." + + " Set `broadcast` to True if you want to repeat unsplit (single block) tensors." + ) + + blocks = { + t: _repeat_single_block(block_gen, n) if n_blocks[t] == 1 else block_gen + for t, block_gen in blocks.items() + } + elif len(unique_n_blocks) != 1: + raise ValueError(f"Mismatch for total number of blocks: {n_blocks}") + + return n, _aligned_blocks_generator(n, blocks) + + +def _aligned_blocks_generator( + n: TotalNumberOfBlocks, blocks: Dict[MemberId, Iterable[BlockMeta]] +): + iterators = {t: iter(gen) for t, gen in blocks.items()} + for _ in range(n): + yield {t: next(it) for t, it in iterators.items()} + + +def _repeat_single_block(block_generator: Iterable[BlockMeta], n: TotalNumberOfBlocks): + round_two = False + for block in block_generator: + assert not round_two + for _ in range(n): + yield block + + round_two = True diff --git a/bioimageio/core/build_spec/__init__.py b/bioimageio/core/build_spec/__init__.py deleted file mode 100644 index c11615df..00000000 --- a/bioimageio/core/build_spec/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .add_weights import add_weights -from .build_model import build_model diff --git a/bioimageio/core/build_spec/add_weights.py b/bioimageio/core/build_spec/add_weights.py deleted file mode 100644 index 0e4e7949..00000000 --- a/bioimageio/core/build_spec/add_weights.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -from pathlib import Path -from shutil import copyfile -from typing import Dict, Optional, Union, List - -from bioimageio.core import export_resource_package, load_raw_resource_description -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription -from .build_model import _get_weights - - -def add_weights( - model: Union[RawResourceDescription, os.PathLike, str], - weight_uri: Union[str, Path], - output_path: Union[str, Path], - *, - weight_type: Optional[str] = None, - architecture: Optional[str] = None, - model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None, - tensorflow_version: Optional[str] = None, - opset_version: Optional[str] = None, - pytorch_version: Optional[str] = None, - attachments: Optional[Dict[str, Union[str, List[str]]]] = None, -): - """Add weight entry to bioimage.io model. - - Args: - model: the resource description of the model to which the weight format is added - weight_uri: the weight file to be added - output_path: where to serialize the new model with additional weight format - weight_type: the format of the weights to be added - architecture: the file with the source code for the model architecture and the corresponding class. - Only required for models with pytorch_state_dict weight format. - model_kwargs: the keyword arguments for the model class. - Only required for models with pytorch_state_dict weight format. - tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights. - opset_version: the opset version for this model. Only for onnx weights. - pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights. - attachments: extra weight specific attachments. - """ - model = load_raw_resource_description(model) - if not isinstance(model.root_path, Path): - # ensure model is available locally - model = load_raw_resource_description(export_resource_package(model)) - - assert isinstance(model.root_path, Path), model.root_path - - # copy the weight path to the input model's root, otherwise it will - # not be found when packaging the new model - weight_out = os.path.join(model.root_path, Path(weight_uri).name) - if Path(weight_out).absolute() != Path(weight_uri).absolute(): - copyfile(weight_uri, weight_out) - - new_weights, tmp_arch = _get_weights( - weight_out, - weight_type, - root=Path("."), - architecture=architecture, - model_kwargs=model_kwargs, - tensorflow_version=tensorflow_version, - opset_version=opset_version, - pytorch_version=pytorch_version, - attachments=attachments, - ) - model.weights.update(new_weights) - - try: - model_package = export_resource_package(model, output_path=output_path) - model = load_raw_resource_description(model_package) - except Exception as e: - raise e - finally: - # clean up tmp files - if Path(weight_out).absolute() != Path(weight_uri).absolute(): - os.remove(weight_out) - if tmp_arch is not None: - os.remove(tmp_arch) - # for some reason the weights are also copied to the cwd. - # not sure why this happens, but it needs to be cleaned up, unless these are the input weigths - weights_cwd = Path(os.path.split(weight_uri)[1]) - if weights_cwd.exists() and weights_cwd.absolute() != Path(weight_uri).absolute(): - os.remove(weights_cwd) - return model diff --git a/bioimageio/core/build_spec/build_model.py b/bioimageio/core/build_spec/build_model.py deleted file mode 100644 index 59a72994..00000000 --- a/bioimageio/core/build_spec/build_model.py +++ /dev/null @@ -1,931 +0,0 @@ -import datetime -import hashlib -import os -from pathlib import Path -from typing import Any, Dict, List, Optional, Union -from warnings import warn - -import imageio -import numpy as np -import requests -import tifffile - -import bioimageio.spec as spec -import bioimageio.spec.model as model_spec -from bioimageio.core import export_resource_package, load_raw_resource_description -from bioimageio.core.resource_io.nodes import URI -from bioimageio.spec.shared.raw_nodes import ImportableModule, ImportableSourceFile -from bioimageio.spec.shared import resolve_local_source, resolve_source - -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - - -# -# utility functions to build the spec from python -# - - -def _get_hash(path): - with open(path, "rb") as f: - data = f.read() - return hashlib.sha256(data).hexdigest() - - -def _infer_weight_type(path): - ext = os.path.splitext(path)[-1] - if ext in (".pt", ".pth", ".torch"): - return "pytorch_state_dict" - elif ext == ".onnx": - return "onnx" - elif ext in (".hdf", ".hdf5", ".h5"): - return "keras_hdf5" - elif ext == ".zip": - return "tensorflow_saved_model_bundle" - elif ext == ".json": - return "tensorflow_js" - else: - raise ValueError(f"Could not infer weight type from extension {ext} for weight file {path}") - - -def _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root): - assert architecture is not None - tmp_archtecture = None - weight_kwargs = {"kwargs": model_kwargs} if model_kwargs else {} - if ":" in architecture: - # note: path itself might include : for absolute paths in windows - *arch_file_parts, callable_name = architecture.replace("::", ":").split(":") - arch_file = _ensure_local(":".join(arch_file_parts), root) - arch = ImportableSourceFile(callable_name, arch_file) - arch_hash = _get_hash(root / arch.source_file) - weight_kwargs["architecture_sha256"] = arch_hash - else: - arch = spec.shared.fields.ImportableSource().deserialize(architecture) - assert isinstance(arch, ImportableModule) - - weight_kwargs["architecture"] = arch - return weight_kwargs, tmp_archtecture - - -def _get_attachments(attachments, root): - assert isinstance(attachments, dict) - if "files" in attachments: - afiles = attachments["files"] - if isinstance(afiles, str): - afiles = [afiles] - - if isinstance(afiles, list): - afiles = _ensure_local_or_url(afiles, root) - else: - raise TypeError(attachments) - - attachments["files"] = afiles - return attachments - - -def _get_weights( - original_weight_source, - weight_type, - root, - architecture=None, - model_kwargs=None, - tensorflow_version=None, - opset_version=None, - pytorch_version=None, - dependencies=None, - attachments=None, -): - weight_path = resolve_source(original_weight_source, root) - if weight_type is None: - weight_type = _infer_weight_type(weight_path) - weight_hash = _get_hash(weight_path) - - weight_types = model_spec.raw_nodes.WeightsFormat - weight_source = _ensure_local_or_url(original_weight_source, root) - - weight_kwargs = {"source": weight_source, "sha256": weight_hash} - if attachments is not None: - weight_kwargs["attachments"] = _get_attachments(attachments, root) - if dependencies is not None: - weight_kwargs["dependencies"] = _get_dependencies(dependencies, root) - - tmp_archtecture = None - if weight_type == "pytorch_state_dict": - # pytorch-state-dict -> we need an architecture definition - pytorch_weight_kwargs, tmp_file = _get_pytorch_state_dict_weight_kwargs(architecture, model_kwargs, root) - weight_kwargs.update(**pytorch_weight_kwargs) - if pytorch_version is not None: - weight_kwargs["pytorch_version"] = pytorch_version - elif dependencies is None: - warn( - "You are building a pytorch model but have neither passed dependencies nor the pytorch_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.PytorchStateDictWeightsEntry(**weight_kwargs) - - elif weight_type == "onnx": - if opset_version is not None: - weight_kwargs["opset_version"] = opset_version - elif dependencies is None: - warn( - "You are building an onnx model but have neither passed dependencies nor the opset_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.OnnxWeightsEntry(**weight_kwargs) - - elif weight_type == "torchscript": - if pytorch_version is not None: - weight_kwargs["pytorch_version"] = pytorch_version - elif dependencies is None: - warn( - "You are building a pytorch model but have neither passed dependencies nor the pytorch_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TorchscriptWeightsEntry(**weight_kwargs) - - elif weight_type == "keras_hdf5": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a keras model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.KerasHdf5WeightsEntry(**weight_kwargs) - - elif weight_type == "tensorflow_saved_model_bundle": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TensorflowSavedModelBundleWeightsEntry(**weight_kwargs) - - elif weight_type == "tensorflow_js": - if tensorflow_version is not None: - weight_kwargs["tensorflow_version"] = tensorflow_version - elif dependencies is None: - warn( - "You are building a tensorflow model but have neither passed dependencies nor the tensorflow_version." - "It may not be possible to create an environmnet where your model can be used." - ) - weights = model_spec.raw_nodes.TensorflowJsWeightsEntry(**weight_kwargs) - - elif weight_type in weight_types: - raise ValueError(f"Weight type {weight_type} is not supported yet in 'build_spec'") - else: - raise ValueError(f"Invalid weight type {weight_type}, expect one of {weight_types}") - - return {weight_type: weights}, tmp_archtecture - - -def _get_data_range(data_range, dtype): - if data_range is None: - if np.issubdtype(np.dtype(dtype), np.integer): - min_, max_ = np.iinfo(dtype).min, np.iinfo(dtype).max - # for floating point numbers we assume valid range from -inf to inf - elif np.issubdtype(np.dtype(dtype), np.floating): - min_, max_ = -np.inf, np.inf - elif np.issubdtype(np.dtype(dtype), np.bool): - min_, max_ = 0, 1 - else: - raise RuntimeError(f"Cannot derived data range for dtype {dtype}") - data_range = (min_, max_) - assert isinstance(data_range, (tuple, list)), type(data_range) - assert len(data_range) == 2 - return data_range - - -def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessing): - test_in = np.load(path) - shape = test_in.shape - if step is None: - assert min_shape is None - shape_description = shape - else: - shape_description = {"min": shape if min_shape is None else min_shape, "step": step} - - data_range = _get_data_range(data_range, test_in.dtype) - kwargs = {} - if preprocessing is not None: - kwargs["preprocessing"] = preprocessing - - inputs = model_spec.raw_nodes.InputTensor( - name="input" if name is None else name, - data_type=str(test_in.dtype), - axes=axes, - shape=shape_description, - data_range=data_range, - **kwargs, - ) - return inputs - - -def _get_output_tensor(path, name, reference_tensor, scale, offset, axes, data_range, postprocessing, halo): - test_out = np.load(path) - shape = test_out.shape - if reference_tensor is None: - assert scale is None - assert offset is None - shape_description = shape - else: - assert scale is not None - assert offset is not None - shape_description = {"reference_tensor": reference_tensor, "scale": scale, "offset": offset} - - data_range = _get_data_range(data_range, test_out.dtype) - kwargs = {} - if postprocessing is not None: - kwargs["postprocessing"] = postprocessing - if halo is not None: - kwargs["halo"] = halo - - outputs = model_spec.raw_nodes.OutputTensor( - name="output" if name is None else name, - data_type=str(test_out.dtype), - axes=axes, - data_range=data_range, - shape=shape_description, - **kwargs, - ) - return outputs - - -def _build_cite(cite: List[Dict[str, str]]): - citation_list = [] - for entry in cite: - if "doi" in entry: - spec_entry = spec.rdf.raw_nodes.CiteEntry(text=entry["text"], doi=entry["doi"]) - elif "url" in entry: - spec_entry = spec.rdf.raw_nodes.CiteEntry(text=entry["text"], url=entry["url"]) - else: - raise ValueError(f"Expect one of doi or url in citation enrty {entry}") - citation_list.append(spec_entry) - return citation_list - - -def _get_dependencies(dependencies, root): - if isinstance(dependencies, Path) or ":" not in dependencies: - manager = "conda" - path = dependencies - else: - manager, path = dependencies.split(":") - - return model_spec.raw_nodes.Dependencies(manager=manager, file=_ensure_local(path, root)) - - -def _get_deepimagej_macro(name, kwargs, export_folder): - # macros available in deepimagej - macro_names = ("binarize", "scale_linear", "scale_range", "zero_mean_unit_variance") - if name == "scale_linear": - macro = "scale_linear.ijm" - replace = {"gain": kwargs["gain"], "offset": kwargs["offset"]} - - elif name == "scale_range": - macro = "per_sample_scale_range.ijm" - replace = {"min_precentile": kwargs["min_percentile"], "max_percentile": kwargs["max_percentile"]} - - elif name == "zero_mean_unit_variance": - mode = kwargs["mode"] - if mode == "fixed": - macro = "fixed_zero_mean_unit_variance.ijm" - replace = {"paramMean": kwargs["mean"], "paramStd": kwargs["std"]} - else: - macro = "zero_mean_unit_variance.ijm" - replace = {} - - elif name == "binarize": - macro = "binarize.ijm" - replace = {"optimalThreshold": kwargs["threshold"]} - - else: - raise ValueError(f"Macro {name} is not available, must be one of {macro_names}.") - - url = f"https://raw.githubusercontent.com/deepimagej/imagej-macros/master/bioimage.io/{macro}" - - path = os.path.join(export_folder, macro) - # use https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/resource_io/utils.py#L267 - # instead if the implementation is update s.t. an output path is accepted - with requests.get(url, stream=True) as r: - text = r.text - if text.startswith("4"): - raise RuntimeError(f"An error occured when downloading {url}: {r.text}") - with open(path, "w") as f: - f.write(r.text) - - # replace the kwargs in the macro file - if replace: - lines = [] - with open(path) as f: - for line in f: - kwarg = [kwarg for kwarg in replace if line.startswith(kwarg)] - if kwarg: - assert len(kwarg) == 1 - kwarg = kwarg[0] - # each kwarg should only be replaced ones - val = replace.pop(kwarg) - lines.append(f"{kwarg} = {val};\n") - else: - lines.append(line) - - with open(path, "w") as f: - for line in lines: - f.write(line) - - return {"spec": "ij.IJ::runMacroFile", "kwargs": macro} - - -def _get_deepimagej_config( - export_folder, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing -): - assert len(test_inputs) == len(test_outputs) == 1, "deepimagej config only valid for single input/output" - - if any(preproc is not None for preproc in preprocessing): - assert len(preprocessing) == 1 - preprocess_ij = [ - _get_deepimagej_macro(preproc["name"], preproc["kwargs"], export_folder) for preproc in preprocessing[0] - ] - attachments = [preproc["kwargs"] for preproc in preprocess_ij] - else: - preprocess_ij = [{"spec": None}] - attachments = [] - - if any(postproc is not None for postproc in postprocessing): - assert len(postprocessing) == 1 - postprocess_ij = [ - _get_deepimagej_macro(postproc["name"], postproc["kwargs"], export_folder) for postproc in postprocessing[0] - ] - attachments.extend([postproc["kwargs"] for postproc in postprocess_ij]) - else: - postprocess_ij = [{"spec": None}] - - def get_size(fname, axes): - shape = np.load(export_folder / fname).shape - assert len(shape) == len(axes) - shape = [sh for sh, ax in zip(shape, axes) if ax != "b"] - axes = [ax for ax in axes if ax != "b"] - # the shape for deepij is always given as xyzc - if len(shape) == 3: - axes_ij = "xyc" - else: - axes_ij = "xyzc" - assert set(axes) == set(axes_ij) - axis_permutation = [axes_ij.index(ax) for ax in axes] - shape = [shape[permut] for permut in axis_permutation] - if len(shape) == 3: - shape = shape[:2] + [1] + shape[-1:] - assert len(shape) == 4 - return " x ".join(map(str, shape)) - - # deepimagej always expexts a pixel size for the z axis - pixel_sizes_ = [pix_size if "z" in pix_size else dict(z=1.0, **pix_size) for pix_size in pixel_sizes] - - test_info = { - "inputs": [ - {"name": in_path, "size": get_size(in_path, axes), "pixel_size": pix_size} - for in_path, axes, pix_size in zip(test_inputs, input_axes, pixel_sizes_) - ], - "outputs": [ - {"name": out_path, "type": "image", "size": get_size(out_path, axes)} - for out_path, axes in zip(test_outputs, output_axes) - ], - "memory_peak": None, - "runtime": None, - } - - config = { - "prediction": {"preprocess": preprocess_ij, "postprocess": postprocess_ij}, - "test_information": test_info, - # other stuff deepimagej needs - "pyramidal_model": False, - "allow_tiling": True, - "model_keys": None, - } - return {"deepimagej": config}, [Path(a) for a in attachments] - - -def _write_sample_data(input_paths, output_paths, input_axes, output_axes, pixel_sizes, export_folder: Path): - def write_im(path, im, axes, pixel_size=None): - assert len(axes) == im.ndim, f"{len(axes), {im.ndim}}" - assert im.ndim in (4, 5), f"{im.ndim}" - - # convert the image to expects (Z)CYX axis order - if im.ndim == 4: - assert set(axes) == {"b", "x", "y", "c"}, f"{axes}" - resolution_axes_ij = "cyxb" - else: - assert set(axes) == {"b", "x", "y", "z", "c"}, f"{axes}" - resolution_axes_ij = "bzcyx" - - def addMissingAxes(im_axes): - needed_axes = ["b", "c", "x", "y", "z", "s"] - for ax in needed_axes: - if ax not in im_axes: - im_axes += ax - return im_axes - - axes_ij = "bzcyxs" - # Expand the image to ImageJ dimensions - im = np.expand_dims(im, axis=tuple(range(len(axes), len(axes_ij)))) - - axis_permutation = tuple(addMissingAxes(axes).index(ax) for ax in axes_ij) - im = im.transpose(axis_permutation) - - if pixel_size is None: - resolution = None - else: - spatial_axes = list(set(resolution_axes_ij) - set("bc")) - resolution = tuple(1.0 / pixel_size[ax] for ax in resolution_axes_ij if ax in spatial_axes) - # does not work for double - if np.dtype(im.dtype) == np.dtype("float64"): - im = im.astype("float32") - tifffile.imwrite(path, im, imagej=True, resolution=resolution) - - sample_in_paths = [] - for i, (in_path, axes) in enumerate(zip(input_paths, input_axes)): - inp = np.load(export_folder / in_path) - sample_in_path = export_folder / f"sample_input_{i}.tif" - pixel_size = None if pixel_sizes is None else pixel_sizes[i] - write_im(sample_in_path, inp, axes, pixel_size) - sample_in_paths.append(sample_in_path) - - sample_out_paths = [] - for i, (out_path, axes) in enumerate(zip(output_paths, output_axes)): - outp = np.load(export_folder / out_path) - sample_out_path = export_folder / f"sample_output_{i}.tif" - write_im(sample_out_path, outp, axes) - sample_out_paths.append(sample_out_path) - - return [Path(p.name) for p in sample_in_paths], [Path(p.name) for p in sample_out_paths] - - -# create better cover images for 3d data and non-image outputs -def _generate_covers(in_path, out_path, input_axes, output_axes, root): - def normalize(data, axis, eps=1e-7): - data = data.astype("float32") - data -= data.min(axis=axis, keepdims=True) - data /= data.max(axis=axis, keepdims=True) + eps - return data - - def to_image(data, data_axes): - assert data.ndim in (4, 5) - - # transpose the data to "bczyx" / "bcyx" order - axes = "bczyx" if data.ndim == 5 else "bcyx" - assert set(data_axes) == set(axes) - if axes != data_axes: - ax_permutation = tuple(data_axes.index(ax) for ax in axes) - data = data.transpose(ax_permutation) - - # select single image with channels from the data - if data.ndim == 5: - z0 = data.shape[2] // 2 - data = data[0, :, z0] - else: - data = data[0, :] - - # normalize the data and map to 8 bit - data = normalize(data, axis=(1, 2)) - data = (data * 255).astype("uint8") - return data - - cover_path = os.path.join(root, "cover.png") - input_, output = np.load(in_path), np.load(out_path) - - input_ = to_image(input_, input_axes) - # this is not image data so we only save the input image - if output.ndim < 4: - imageio.imwrite(cover_path, input_.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - output = to_image(output, output_axes) - - chan_in = input_.shape[0] - # make sure the input is rgb - if chan_in == 1: # single channel -> repeat it 3 times - input_ = np.repeat(input_, 3, axis=0) - elif chan_in != 3: # != 3 channels -> take first channe and repeat it 3 times - input_ = np.repeat(input_[0:1], 3, axis=0) - - im_shape = input_.shape[1:] - # we just save the input image if the shapes don't agree - if im_shape != output.shape[1:]: - imageio.imwrite(cover_path, input_.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - - def diagonal_split(im0, im1): - assert im0.shape[0] == im1.shape[0] == 3 - n, m = im_shape - out = np.ones((3, n, m), dtype="uint8") - for c in range(3): - outc = np.tril(im0[c]) - mask = outc == 0 - outc[mask] = np.triu(im1[c])[mask] - out[c] = outc - return out - - def grid_im(im0, im1): - ims_per_row = 3 - n_chan = im1.shape[0] - n_images = n_chan + 1 - n_rows = int(np.ceil(float(n_images) / ims_per_row)) - - n, m = im_shape - x, y = ims_per_row * n, n_rows * m - out = np.zeros((3, y, x), dtype=im0.dtype) - images = [im0] + [np.repeat(im1[i : i + 1], 3, axis=0) for i in range(n_chan)] - - i, j = 0, 0 - for im in images: - x0, x1 = i * n, (i + 1) * n - y0, y1 = j * m, (j + 1) * m - out[:, y0:y1, x0:x1] = im - - i += 1 - if i == ims_per_row: - i = 0 - j += 1 - - return out - - chan_out = output.shape[0] - if chan_out == 1: # single prediction channel: create diagonal split - im = diagonal_split(input_, np.repeat(output, 3, axis=0)) - elif chan_out == 3: # three prediction channel: create diagonal split with rgb - im = diagonal_split(input_, output) - else: # otherwise create grid image - im = grid_im(input_, output) - - # to channel last - imageio.imwrite(cover_path, im.transpose((1, 2, 0))) - return [_ensure_local(cover_path, root)] - - -def _ensure_local(source: Union[Path, URI, str, list], root: Path) -> Union[Path, URI, list]: - """ensure source is local relative path in root""" - if isinstance(source, list): - return [_ensure_local(s, root) for s in source] - - local_source = resolve_source(source, root) - local_source = resolve_source(local_source, root, root / local_source.name) - return local_source.relative_to(root) - - -def _ensure_local_or_url(source: Union[Path, URI, str, list], root: Path) -> Union[Path, URI, list]: - """ensure source is remote URI or local relative path in root""" - if isinstance(source, list): - return [_ensure_local_or_url(s, root) for s in source] - - local_source = resolve_local_source(source, root) - if not isinstance(local_source, URI): - local_source = resolve_local_source(local_source, root, root / local_source.name) - return local_source.relative_to(root) - - -def build_model( - # model or tensor specific and required - weight_uri: str, - test_inputs: List[Union[str, Path]], - test_outputs: List[Union[str, Path]], - input_axes: List[str], - output_axes: List[str], - # general required - name: str, - description: str, - authors: List[Dict[str, str]], - tags: List[Union[str, Path]], - documentation: Union[str, Path], - cite: List[Dict[str, str]], - output_path: Union[str, Path], - # model specific optional - architecture: Optional[str] = None, - model_kwargs: Optional[Dict[str, Union[int, float, str]]] = None, - weight_type: Optional[str] = None, - sample_inputs: Optional[List[str]] = None, - sample_outputs: Optional[List[str]] = None, - # tensor specific - input_names: Optional[List[str]] = None, - input_step: Optional[List[List[int]]] = None, - input_min_shape: Optional[List[List[int]]] = None, - input_data_range: Optional[List[List[Union[int, str]]]] = None, - output_names: Optional[List[str]] = None, - output_reference: Optional[List[str]] = None, - output_scale: Optional[List[List[int]]] = None, - output_offset: Optional[List[List[int]]] = None, - output_data_range: Optional[List[List[Union[int, str]]]] = None, - halo: Optional[List[List[int]]] = None, - preprocessing: Optional[List[List[Dict[str, Dict[str, Union[int, float, str]]]]]] = None, - postprocessing: Optional[List[List[Dict[str, Dict[str, Union[int, float, str]]]]]] = None, - pixel_sizes: Optional[List[Dict[str, float]]] = None, - # general optional - maintainers: Optional[List[Dict[str, str]]] = None, - license: Optional[str] = None, - covers: Optional[List[str]] = None, - git_repo: Optional[str] = None, - attachments: Optional[Dict[str, Union[str, List[str]]]] = None, - packaged_by: Optional[List[str]] = None, - run_mode: Optional[str] = None, - parent: Optional[Dict[str, str]] = None, - config: Optional[Dict[str, Any]] = None, - dependencies: Optional[Union[Path, str]] = None, - links: Optional[List[str]] = None, - training_data: Optional[Dict[str, str]] = None, - root: Optional[Union[Path, str]] = None, - add_deepimagej_config: bool = False, - tensorflow_version: Optional[str] = None, - opset_version: Optional[int] = None, - pytorch_version: Optional[str] = None, - weight_attachments: Optional[Dict[str, Union[str, List[str]]]] = None, -): - """Create a zipped bioimage.io model. - - Example usage: - ``` - from pathlib import Path - import bioimageio.spec as spec - import bioimageio.core.build_spec as build_spec - model_spec = build_spec.build_model( - weight_uri="test_weights.pt", - test_inputs=["./test_inputs"], - test_outputs=["./test_outputs"], - input_axes=["bcyx"], - output_axes=["bcyx"], - name="my-model", - description="My very fancy model.", - authors=[{"name": "John Doe", "affiliation": "My Institute"}], - tags=["segmentation", "light sheet data"], - license="CC-BY-4.0", - documentation="./documentation.md", - cite=[{"text": "Ronneberger et al. U-Net", "doi": "10.1007/978-3-319-24574-4_28"}], - output_path="my-model.zip" - ) - ``` - - Args: - weight_uri: the url or relative local file path to the weight file for this model. - test_inputs: list of test input files stored in numpy format. - test_outputs: list of test outputs corresponding to test_inputs, stored in numpy format. - input_axes: axis names of the input tensors. - output_axes: axiss names of the output tensors. - name: name of this model. - description: short description of this model. - authors: the authors of this model. - tags: list of tags for this model. - documentation: relative file path to markdown documentation for this model. - cite: references for this model. - output_path: where to save the zipped model package. - architecture: the file with the source code for the model architecture and the corresponding class. - Only required for models with pytorch_state_dict weight format. - model_kwargs: the keyword arguments for the model class. - Only required for models with pytorch_state_dict weight format. - weight_type: the type of the weights. - sample_inputs: list of sample inputs to demonstrate the model performance. - sample_outputs: list of sample outputs corresponding to sample_inputs. - input_names: names of the input tensors. - input_step: minimal valid increase of the input tensor shape. - input_min_shape: minimal input tensor shape. - input_data_range: valid data range for the input tensor. - output_names: names of the output tensors. - output_reference: name of the input reference tensor used to cimpute the output tensor shape. - output_scale: multiplicative factor to compute the output tensor shape. - output_offset: additive term to compute the output tensor shape. - output_data_range: valid data range for the output tensor. - halo: halo to be cropped from the output tensor. - preprocessing: list of preprocessing operations for the input. - postprocessing: list of postprocessing operations for the output. - pixel_sizes: the pixel sizes for the input tensors, only for spatial axes. - This information is currently only used by deepimagej, but will be added to the spec soon. - license: the license for this model. By default CC-BY-4.0 will be set as license. - covers: list of file paths for cover images. - By default a cover will be generated from the input and output data. - git_repo: reference git repository for this model. - attachments: list of additional files to package with the model. - packaged_by: list of authors that have packaged this model. - run_mode: custom run mode for this model. - parent: id of the parent model from which this model is derived and sha256 of the corresponding rdf file. - config: custom configuration for this model. - dependencies: relative path to file with dependencies for this model. - training_data: the training data for this model, either id for a bioimageio dataset or a dataset spec. - root: optional root path for relative paths. This can be helpful when building a spec from another model spec. - add_deepimagej_config: add the deepimagej config to the model. - tensorflow_version: the tensorflow version for this model. Only for tensorflow or keras weights. - opset_version: the opset version for this model. Only for onnx weights. - pytorch_version: the pytorch version for this model. Only for pytoch_state_dict or torchscript weights. - weight_attachments: extra weight specific attachments. - """ - assert architecture is None or isinstance(architecture, str) - if root is None: - root = "." - root = Path(root) - - if attachments is not None: - attachments = _get_attachments(attachments, root) - - # - # generate the model specific fields - # - - assert len(test_inputs) - assert len(test_outputs) - test_inputs = _ensure_local_or_url(test_inputs, root) - test_outputs = _ensure_local_or_url(test_outputs, root) - - n_inputs = len(test_inputs) - if input_names is None: - input_names = [f"input{i}" for i in range(n_inputs)] - else: - assert len(input_names) == len(test_inputs) - - input_step = n_inputs * [None] if input_step is None else input_step - input_min_shape = n_inputs * [None] if input_min_shape is None else input_min_shape - input_data_range = n_inputs * [None] if input_data_range is None else input_data_range - preprocessing = n_inputs * [None] if preprocessing is None else preprocessing - - inputs = [ - _get_input_tensor(root / test_in, name, step, min_shape, data_range, axes, preproc) - for test_in, name, step, min_shape, axes, data_range, preproc in zip( - test_inputs, input_names, input_step, input_min_shape, input_axes, input_data_range, preprocessing - ) - ] - - n_outputs = len(test_outputs) - if output_names is None: - output_names = [f"output{i}" for i in range(n_outputs)] - else: - assert len(output_names) == len(test_outputs) - - output_reference = n_outputs * [None] if output_reference is None else output_reference - output_scale = n_outputs * [None] if output_scale is None else output_scale - output_offset = n_outputs * [None] if output_offset is None else output_offset - output_data_range = n_outputs * [None] if output_data_range is None else output_data_range - postprocessing = n_outputs * [None] if postprocessing is None else postprocessing - halo = n_outputs * [None] if halo is None else halo - - outputs = [ - _get_output_tensor(root / test_out, name, reference, scale, offset, axes, data_range, postproc, hal) - for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in zip( - test_outputs, - output_names, - output_reference, - output_scale, - output_offset, - output_axes, - output_data_range, - postprocessing, - halo, - ) - ] - - # validate the pixel sizes (currently only used by deepimagej) - spatial_axes = [[ax for ax in inp.axes if ax in "xyz"] for inp in inputs] - if pixel_sizes is None: - pixel_sizes = [{ax: 1.0 for ax in axes} for axes in spatial_axes] - else: - assert len(pixel_sizes) == n_inputs - for pix_size, axes in zip(pixel_sizes, spatial_axes): - assert isinstance(pix_size, dict) - assert set(pix_size.keys()) == set(axes) - - # - # generate general fields - # - format_version = get_args(model_spec.raw_nodes.FormatVersion)[-1] - timestamp = datetime.datetime.now() - - authors = [model_spec.raw_nodes.Author(**a) for a in authors] - cite = _build_cite(cite) - documentation = _ensure_local(documentation, root) - if covers is None: - covers = _generate_covers(root / test_inputs[0], root / test_outputs[0], input_axes[0], output_axes[0], root) - else: - covers = _ensure_local(covers, root) - if license is None: - license = "CC-BY-4.0" - - # parse the weights - weights, tmp_archtecture = _get_weights( - weight_uri, - weight_type, - root, - architecture, - model_kwargs, - tensorflow_version=tensorflow_version, - opset_version=opset_version, - pytorch_version=pytorch_version, - dependencies=dependencies, - attachments=weight_attachments, - ) - - # validate the sample inputs and outputs (if given) - if sample_inputs is not None: - assert sample_outputs is not None - assert len(sample_inputs) == n_inputs - assert len(sample_outputs) == n_outputs - - # add the deepimagej config if specified - if add_deepimagej_config: - if sample_inputs is None: - sample_inputs, sample_outputs = _write_sample_data( - test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, root - ) - # deepimagej expect tifs as sample data - assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_inputs) - assert all(os.path.splitext(path)[1] in (".tif", ".tiff") for path in sample_outputs) - - ij_config, ij_attachments = _get_deepimagej_config( - root, test_inputs, test_outputs, input_axes, output_axes, pixel_sizes, preprocessing, postprocessing - ) - - if config is None: - config = ij_config - else: - config.update(ij_config) - - if ij_attachments is not None: - if attachments is None: - attachments = {"files": ij_attachments} - elif "files" not in attachments: - attachments["files"] = ij_attachments - else: - attachments["files"] = list(set(attachments["files"]) | set(ij_attachments)) - - if links is None: - links = ["deepimagej/deepimagej"] - else: - links.append("deepimagej/deepimagej") - - # make sure links are unique - if links is not None: - links = list(set(links)) - - # make sure sample inputs / outputs are relative paths - if sample_inputs is not None: - sample_inputs = _ensure_local_or_url(sample_inputs, root) - - if sample_outputs is not None: - sample_outputs = _ensure_local_or_url(sample_outputs, root) - - # optional kwargs, don't pass them if none - optional_kwargs = { - "config": config, - "git_repo": git_repo, - "packaged_by": packaged_by, - "run_mode": run_mode, - "sample_inputs": sample_inputs, - "sample_outputs": sample_outputs, - "links": links, - } - kwargs = {k: v for k, v in optional_kwargs.items() if v is not None} - - if attachments is not None: - kwargs["attachments"] = spec.rdf.raw_nodes.Attachments(**attachments) - - if maintainers is not None: - kwargs["maintainers"] = [model_spec.raw_nodes.Maintainer(**m) for m in maintainers] - - if parent is not None: - kwargs["parent"] = parent - - if training_data is not None: - if "id" in training_data: - msg = f"If training data is specified via 'id' no other keys are allowed, got {training_data}" - assert len(training_data) == 1, msg - kwargs["training_data"] = training_data - else: - if "type" not in training_data: - training_data["type"] = "dataset" - if "format_version" not in training_data: - training_data["format_version"] = spec.dataset.format_version - - try: - model = model_spec.raw_nodes.Model( - authors=authors, - cite=cite, - covers=covers, - description=description, - documentation=documentation, - format_version=format_version, - inputs=inputs, - license=license, - name=name, - outputs=outputs, - root_path=root, - tags=tags, - test_inputs=test_inputs, - test_outputs=test_outputs, - timestamp=timestamp, - weights=weights, - **kwargs, - ) - model_package = export_resource_package(model, output_path=output_path) - except Exception as e: - raise e - finally: - if tmp_archtecture is not None: - os.remove(tmp_archtecture) - - model = load_raw_resource_description(model_package) - return model diff --git a/bioimageio/core/commands.py b/bioimageio/core/commands.py deleted file mode 100644 index 640493a6..00000000 --- a/bioimageio/core/commands.py +++ /dev/null @@ -1,51 +0,0 @@ -import shutil -import traceback -from pathlib import Path -from typing import List, Optional, Union - -from bioimageio.core import export_resource_package -from bioimageio.core.resource_io.utils import resolve_source -from bioimageio.spec.commands import validate -from bioimageio.spec.shared.raw_nodes import URI - - -def package( - rdf_source: Union[Path, str, URI, dict], - path: Path = Path() / "{src_name}-package.zip", - weights_priority_order: Optional[List[str]] = None, - verbose: bool = False, -) -> int: - """Package a BioImage.IO resource described by a BioImage.IO Resource Description File (RDF).""" - code = validate(rdf_source, update_format=True, update_format_inner=True) - source_name = rdf_source.get("name") if isinstance(rdf_source, dict) else rdf_source - if code["error"]: - print(f"Cannot package invalid BioImage.IO RDF {source_name}") - return 1 - - try: - tmp_package_path = export_resource_package(rdf_source, weights_priority_order=weights_priority_order) - except Exception as e: - print(f"Failed to package {source_name} due to: {e}") - if verbose: - traceback.print_exc() - return 1 - - try: - rdf_local_source = resolve_source(rdf_source) - except Exception as e: - print(f"Failed to resolve RDF source {rdf_source}: {e}") - if verbose: - traceback.print_exc() - return 1 - - try: - path = path.with_name(path.name.format(src_name=rdf_local_source.stem)) - shutil.move(tmp_package_path, path) - except Exception as e: - print(f"Failed to move package from {tmp_package_path} to {path} due to: {e}") - if verbose: - traceback.print_exc() - return 1 - - print(f"exported bioimageio package from {source_name} to {path}") - return 0 diff --git a/bioimageio/core/common.py b/bioimageio/core/common.py index 2a660d0c..78a85886 100644 --- a/bioimageio/core/common.py +++ b/bioimageio/core/common.py @@ -1,5 +1,127 @@ -from bioimageio.spec.shared.common import ValidationSummary +from __future__ import annotations +from types import MappingProxyType +from typing import ( + Hashable, + Literal, + Mapping, + NamedTuple, + Tuple, + TypeVar, + Union, +) -class TestSummary(ValidationSummary): - bioimageio_core_version: str +from typing_extensions import Self, assert_never + +from bioimageio.spec.model import v0_5 + +DTypeStr = Literal[ + "bool", + "float32", + "float64", + "int8", + "int16", + "int32", + "int64", + "uint8", + "uint16", + "uint32", + "uint64", +] + + +_LeftRight_T = TypeVar("_LeftRight_T", bound="_LeftRight") +_LeftRightLike = Union[int, Tuple[int, int], _LeftRight_T] + + +class _LeftRight(NamedTuple): + left: int + right: int + + @classmethod + def create(cls, like: _LeftRightLike[Self]) -> Self: + if isinstance(like, cls): + return like + elif isinstance(like, tuple): + return cls(*like) + elif isinstance(like, int): + return cls(like, like) + else: + assert_never(like) + + +_Where = Literal["left", "right", "left_and_right"] + + +class CropWidth(_LeftRight): + pass + + +CropWidthLike = _LeftRightLike[CropWidth] +CropWhere = _Where + + +class Halo(_LeftRight): + pass + + +HaloLike = _LeftRightLike[Halo] + + +class OverlapWidth(_LeftRight): + pass + + +class PadWidth(_LeftRight): + pass + + +PadWidthLike = _LeftRightLike[PadWidth] +PadMode = Literal["edge", "reflect", "symmetric"] +PadWhere = _Where + + +class SliceInfo(NamedTuple): + start: int + stop: int + + +SampleId = Hashable +MemberId = v0_5.TensorId +T = TypeVar("T") +PerMember = Mapping[MemberId, T] + +BlockIndex = int +TotalNumberOfBlocks = int + + +K = TypeVar("K", bound=Hashable) +V = TypeVar("V") + +Frozen = MappingProxyType +# class Frozen(Mapping[K, V]): # adapted from xarray.core.utils.Frozen +# """Wrapper around an object implementing the mapping interface to make it +# immutable.""" + +# __slots__ = ("mapping",) + +# def __init__(self, mapping: Mapping[K, V]): +# super().__init__() +# self.mapping = deepcopy( +# mapping +# ) # added deepcopy (compared to xarray.core.utils.Frozen) + +# def __getitem__(self, key: K) -> V: +# return self.mapping[key] + +# def __iter__(self) -> Iterator[K]: +# return iter(self.mapping) + +# def __len__(self) -> int: +# return len(self.mapping) + +# def __contains__(self, key: object) -> bool: +# return key in self.mapping + +# def __repr__(self) -> str: +# return f"{type(self).__name__}({self.mapping!r})" diff --git a/bioimageio/core/dataset.py b/bioimageio/core/dataset.py new file mode 100644 index 00000000..59361b2d --- /dev/null +++ b/bioimageio/core/dataset.py @@ -0,0 +1,5 @@ +from typing import Iterable + +from bioimageio.core.sample import Sample + +Dataset = Iterable[Sample] diff --git a/bioimageio/core/digest_spec.py b/bioimageio/core/digest_spec.py new file mode 100644 index 00000000..b3a693f8 --- /dev/null +++ b/bioimageio/core/digest_spec.py @@ -0,0 +1,379 @@ +from __future__ import annotations + +import importlib.util +from functools import singledispatch +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Tuple, + Union, +) + +from numpy.typing import NDArray +from typing_extensions import Unpack, assert_never + +from bioimageio.spec._internal.io_utils import HashKwargs, download +from bioimageio.spec.common import FileSource +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromFileDescr, + ArchitectureFromLibraryDescr, + ParameterizedSize, +) +from bioimageio.spec.utils import load_array + +from .axis import AxisId, AxisInfo, PerAxis +from .block_meta import split_multiple_shapes_into_blocks +from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks +from .sample import ( + LinearSampleAxisTransform, + Sample, + SampleBlockMeta, + sample_block_meta_generator, +) +from .stat_measures import Stat +from .tensor import Tensor + + +@singledispatch +def import_callable(node: type, /) -> Callable[..., Any]: + """import a callable (e.g. a torch.nn.Module) from a spec node describing it""" + raise TypeError(type(node)) + + +@import_callable.register +def _(node: CallableFromDepencency) -> Callable[..., Any]: + module = importlib.import_module(node.module_name) + c = getattr(module, str(node.callable_name)) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def _(node: ArchitectureFromLibraryDescr) -> Callable[..., Any]: + module = importlib.import_module(node.import_from) + c = getattr(module, str(node.callable)) + if not callable(c): + raise ValueError(f"{node} (imported: {c}) is not callable") + + return c + + +@import_callable.register +def _(node: CallableFromFile, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source_file, str(node.callable_name), **kwargs) + + +@import_callable.register +def _(node: ArchitectureFromFileDescr, **kwargs: Unpack[HashKwargs]): + return _import_from_file_impl(node.source, str(node.callable), sha256=node.sha256) + + +def _import_from_file_impl( + source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] +): + local_file = download(source, **kwargs) + module_name = local_file.path.stem + importlib_spec = importlib.util.spec_from_file_location( + module_name, local_file.path + ) + if importlib_spec is None: + raise ImportError(f"Failed to import {module_name} from {source}.") + + dep = importlib.util.module_from_spec(importlib_spec) + importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? + return getattr(dep, callable_name) + + +def get_axes_infos( + io_descr: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] +) -> List[AxisInfo]: + """get a unified, simplified axis representation from spec axes""" + return [ + ( + AxisInfo.create("i") + if isinstance(a, str) and a not in ("b", "i", "t", "c", "z", "y", "x") + else AxisInfo.create(a) + ) + for a in io_descr.axes + ] + + +def get_member_id( + tensor_description: Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] +) -> MemberId: + """get the normalized tensor ID, usable as a sample member ID""" + + if isinstance(tensor_description, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): + return MemberId(tensor_description.name) + elif isinstance( + tensor_description, (v0_5.InputTensorDescr, v0_5.OutputTensorDescr) + ): + return tensor_description.id + else: + assert_never(tensor_description) + + +def get_member_ids( + tensor_descriptions: Sequence[ + Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, + ] + ] +) -> List[MemberId]: + """get normalized tensor IDs to be used as sample member IDs""" + return [get_member_id(descr) for descr in tensor_descriptions] + + +def get_test_inputs(model: AnyModelDescr) -> Sample: + """returns a model's test input sample""" + member_ids = get_member_ids(model.inputs) + if isinstance(model, v0_4.ModelDescr): + arrays = [load_array(tt) for tt in model.test_inputs] + else: + arrays = [load_array(d.test_tensor) for d in model.inputs] + + axes = [get_axes_infos(t) for t in model.inputs] + return Sample( + members={ + m: Tensor.from_numpy(arr, dims=ax) + for m, arr, ax in zip(member_ids, arrays, axes) + }, + stat={}, + id="test-input", + ) + + +def get_test_outputs(model: AnyModelDescr) -> Sample: + """returns a model's test output sample""" + member_ids = get_member_ids(model.outputs) + + if isinstance(model, v0_4.ModelDescr): + arrays = [load_array(tt) for tt in model.test_outputs] + else: + arrays = [load_array(d.test_tensor) for d in model.outputs] + + axes = [get_axes_infos(t) for t in model.outputs] + + return Sample( + members={ + m: Tensor.from_numpy(arr, dims=ax) + for m, arr, ax in zip(member_ids, arrays, axes) + }, + stat={}, + id="test-output", + ) + + +class IO_SampleBlockMeta(NamedTuple): + input: SampleBlockMeta + output: SampleBlockMeta + + +def get_input_halo(model: v0_5.ModelDescr, output_halo: PerMember[PerAxis[Halo]]): + """returns which halo input tensors need to be divided into blocks with such that + `output_halo` can be cropped from their outputs without intorducing gaps.""" + input_halo: Dict[MemberId, Dict[AxisId, Halo]] = {} + outputs = {t.id: t for t in model.outputs} + all_tensors = {**{t.id: t for t in model.inputs}, **outputs} + + for t, th in output_halo.items(): + axes = {a.id: a for a in outputs[t].axes} + + for a, ah in th.items(): + s = axes[a].size + if not isinstance(s, v0_5.SizeReference): + raise ValueError( + f"Unable to map output halo for {t}.{a} to an input axis" + ) + + axis = axes[a] + ref_axis = {a.id: a for a in all_tensors[s.tensor_id].axes}[s.axis_id] + + total_output_halo = sum(ah) + total_input_halo = total_output_halo * axis.scale / ref_axis.scale + assert ( + total_input_halo == int(total_input_halo) and total_input_halo % 2 == 0 + ) + input_halo.setdefault(s.tensor_id, {})[a] = Halo( + int(total_input_halo // 2), int(total_input_halo // 2) + ) + + return input_halo + + +def get_block_transform(model: v0_5.ModelDescr): + """returns how a model's output tensor shapes relate to its input shapes""" + ret: Dict[MemberId, Dict[AxisId, Union[LinearSampleAxisTransform, int]]] = {} + batch_axis_trf = None + for ipt in model.inputs: + for a in ipt.axes: + if a.type == "batch": + batch_axis_trf = LinearSampleAxisTransform( + axis=a.id, scale=1, offset=0, member=ipt.id + ) + break + if batch_axis_trf is not None: + break + axis_scales = { + t.id: {a.id: a.scale for a in t.axes} + for t in chain(model.inputs, model.outputs) + } + for out in model.outputs: + new_axes: Dict[AxisId, Union[LinearSampleAxisTransform, int]] = {} + for a in out.axes: + if a.size is None: + assert a.type == "batch" + if batch_axis_trf is None: + raise ValueError( + "no batch axis found in any input tensor, but output tensor" + + f" '{out.id}' has one." + ) + s = batch_axis_trf + elif isinstance(a.size, int): + s = a.size + elif isinstance(a.size, v0_5.DataDependentSize): + s = -1 + elif isinstance(a.size, v0_5.SizeReference): + s = LinearSampleAxisTransform( + axis=a.size.axis_id, + scale=axis_scales[a.size.tensor_id][a.size.axis_id] / a.scale, + offset=a.size.offset, + member=a.size.tensor_id, + ) + else: + assert_never(a.size) + + new_axes[a.id] = s + + ret[out.id] = new_axes + + return ret + + +def get_io_sample_block_metas( + model: v0_5.ModelDescr, + input_sample_shape: PerMember[PerAxis[int]], + ns: Mapping[Tuple[MemberId, AxisId], ParameterizedSize.N], + batch_size: int = 1, +) -> Tuple[TotalNumberOfBlocks, Iterable[IO_SampleBlockMeta]]: + """returns an iterable yielding meta data for corresponding input and output samples""" + if not isinstance(model, v0_5.ModelDescr): + raise TypeError(f"get_block_meta() not implemented for {type(model)}") + + block_axis_sizes = model.get_axis_sizes(ns=ns, batch_size=batch_size) + input_block_shape = { + t: {aa: s for (tt, aa), s in block_axis_sizes.inputs.items() if tt == t} + for t in {tt for tt, _ in block_axis_sizes.inputs} + } + output_block_shape = { + t: { + aa: s + for (tt, aa), s in block_axis_sizes.outputs.items() + if tt == t and not isinstance(s, tuple) + } + for t in {tt for tt, _ in block_axis_sizes.outputs} + } + output_halo = { + t.id: { + a.id: Halo(a.halo, a.halo) for a in t.axes if isinstance(a, v0_5.WithHalo) + } + for t in model.outputs + } + input_halo = get_input_halo(model, output_halo) + + # TODO: fix output_sample_shape_data_dep + # (below only valid if input_sample_shape is a valid model input, + # which is not a valid assumption) + output_sample_shape_data_dep = model.get_output_tensor_sizes(input_sample_shape) + + output_sample_shape = { + t: { + a: -1 if isinstance(s, tuple) else s + for a, s in output_sample_shape_data_dep[t].items() + } + for t in output_sample_shape_data_dep + } + n_input_blocks, input_blocks = split_multiple_shapes_into_blocks( + input_sample_shape, input_block_shape, halo=input_halo + ) + n_output_blocks, output_blocks = split_multiple_shapes_into_blocks( + output_sample_shape, output_block_shape, halo=output_halo + ) + assert n_input_blocks == n_output_blocks + return n_input_blocks, ( + IO_SampleBlockMeta(ipt, out) + for ipt, out in zip( + sample_block_meta_generator( + input_blocks, sample_shape=input_sample_shape, sample_id=None + ), + sample_block_meta_generator( + output_blocks, + sample_shape=output_sample_shape, + sample_id=None, + ), + ) + ) + + +def create_sample_for_model( + model: AnyModelDescr, + *, + stat: Optional[Stat] = None, + sample_id: SampleId = None, + inputs: Optional[PerMember[NDArray[Any]]] = None, # TODO: make non-optional + **kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs` +) -> Sample: + """Create a sample from a single set of input(s) for a specific bioimage.io model + + Args: + model: a bioimage.io model description + stat: dictionary with sample and dataset statistics (may be updated in-place!) + inputs: the input(s) constituting a single sample. + """ + inputs = {MemberId(k): v for k, v in {**kwargs, **(inputs or {})}.items()} + + model_inputs = {get_member_id(d): d for d in model.inputs} + if unknown := {k for k in inputs if k not in model_inputs}: + raise ValueError(f"Got unexpected inputs: {unknown}") + + if missing := { + k + for k, v in model_inputs.items() + if k not in inputs and not (isinstance(v, v0_5.InputTensorDescr) and v.optional) + }: + raise ValueError(f"Missing non-optional model inputs: {missing}") + + return Sample( + members={ + m: Tensor.from_numpy(inputs[m], dims=get_axes_infos(ipt)) + for m, ipt in model_inputs.items() + if m in inputs + }, + stat={} if stat is None else stat, + id=sample_id, + ) diff --git a/bioimageio/core/image_helper.py b/bioimageio/core/image_helper.py deleted file mode 100644 index 0468b61f..00000000 --- a/bioimageio/core/image_helper.py +++ /dev/null @@ -1,180 +0,0 @@ -import os -from copy import deepcopy -from typing import Dict, List, Optional, Sequence, Tuple, Union - -import imageio -import numpy as np -from xarray import DataArray -from bioimageio.core.resource_io.nodes import InputTensor, OutputTensor - - -# -# helper functions to transform input images / output tensors to the required axes -# - - -def transform_input_image(image: np.ndarray, tensor_axes: str, image_axes: Optional[str] = None): - """Transform input image into output tensor with desired axes. - - Args: - image: the input image - tensor_axes: the desired tensor axes - input_axes: the axes of the input image (optional) - """ - # if the image axes are not given deduce them from the required axes and image shape - if image_axes is None: - has_z_axis = "z" in tensor_axes - ndim = image.ndim - if ndim == 2: - image_axes = "yx" - elif ndim == 3: - image_axes = "zyx" if has_z_axis else "cyx" - elif ndim == 4: - image_axes = "czyx" - elif ndim == 5: - image_axes = "bczyx" - else: - raise ValueError(f"Invalid number of image dimensions: {ndim}") - tensor = DataArray(image, dims=tuple(image_axes)) - # expand the missing image axes - missing_axes = tuple(set(tensor_axes) - set(image_axes)) - tensor = tensor.expand_dims(dim=missing_axes) - # transpose to the correct axis order - tensor = tensor.transpose(*tuple(tensor_axes)) - # return numpy array - return tensor.values - - -def _drop_axis_default(axis_name, axis_len): - # spatial axes: drop at middle coordnate - # other axes (channel or batch): drop at 0 coordinate - return axis_len // 2 if axis_name in "zyx" else 0 - - -def transform_output_tensor(tensor: np.ndarray, tensor_axes: str, output_axes: str, drop_function=_drop_axis_default): - """Transform output tensor into image with desired axes. - - Args: - tensor: the output tensor - tensor_axes: bioimageio model spec - output_axes: the desired output axes - drop_function: function that determines how to drop unwanted axes - """ - if len(tensor_axes) != tensor.ndim: - raise ValueError(f"Number of axes {len(tensor_axes)} and dimension of tensor {tensor.ndim} don't match") - shape = {ax_name: sh for ax_name, sh in zip(tensor_axes, tensor.shape)} - output = DataArray(tensor, dims=tuple(tensor_axes)) - # drop unwanted axes - drop_axis_names = tuple(set(tensor_axes) - set(output_axes)) - drop_axes = {ax_name: drop_function(ax_name, shape[ax_name]) for ax_name in drop_axis_names} - output = output[drop_axes] - # transpose to the desired axis order - output = output.transpose(*tuple(output_axes)) - # return numpy array - return output.values - - -def to_channel_last(image): - chan_id = image.dims.index("c") - if chan_id != image.ndim - 1: - target_axes = tuple(ax for ax in image.dims if ax != "c") + ("c",) - image = image.transpose(*target_axes) - return image - - -# -# helper functions for loading and saving images -# - - -def load_image(in_path, axes: Sequence[str]) -> DataArray: - ext = os.path.splitext(in_path)[1] - if ext == ".npy": - im = np.load(in_path) - else: - is_volume = "z" in axes - im = imageio.volread(in_path) if is_volume else imageio.imread(in_path) - im = transform_input_image(im, axes) - return DataArray(im, dims=axes) - - -def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[DataArray]: - return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)] - - -def save_image(out_path, image): - ext = os.path.splitext(out_path)[1] - if ext == ".npy": - np.save(out_path, image) - else: - is_volume = "z" in image.dims - - # squeeze batch or channel axes if they are singletons - squeeze = {ax: 0 if (ax in "bc" and sh == 1) else slice(None) for ax, sh in zip(image.dims, image.shape)} - image = image[squeeze] - - if "b" in image.dims: - raise RuntimeError(f"Cannot save prediction with batchsize > 1 as {ext}-file") - if "c" in image.dims: # image formats need channel last - image = to_channel_last(image) - - save_function = imageio.volsave if is_volume else imageio.imsave - # most image formats only support channel dimensions of 1, 3 or 4; - # if not we need to save the channels separately - ndim = 3 if is_volume else 2 - save_as_single_image = image.ndim == ndim or (image.shape[-1] in (3, 4)) - - if save_as_single_image: - save_function(out_path, image) - else: - out_prefix, ext = os.path.splitext(out_path) - for c in range(image.shape[-1]): - chan_out_path = f"{out_prefix}-c{c}{ext}" - save_function(chan_out_path, image[..., c]) - - -# -# helper function for padding -# - - -def pad(image, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, Dict[str, slice]]: - assert image.ndim == len(axes), f"{image.ndim}, {len(axes)}" - - padding_ = deepcopy(padding) - mode = padding_.pop("mode", "dynamic") - assert mode in ("dynamic", "fixed") - - is_volume = "z" in axes - if is_volume: - assert len(padding_) == 3 - else: - assert len(padding_) == 2 - - if isinstance(pad_right, bool): - pad_right = len(axes) * [pad_right] - - pad_width = [] - crop = {} - for ax, dlen, pr in zip(axes, image.shape, pad_right): - - if ax in "zyx": - pad_to = padding_[ax] - - if mode == "dynamic": - r = dlen % pad_to - pwidth = 0 if r == 0 else (pad_to - r) - else: - if pad_to < dlen: - msg = f"Padding for axis {ax} failed; pad shape {pad_to} is smaller than the image shape {dlen}." - raise RuntimeError(msg) - pwidth = pad_to - dlen - - pad_width.append([0, pwidth] if pr else [pwidth, 0]) - crop[ax] = slice(0, dlen) if pr else slice(pwidth, None) - else: - pad_width.append([0, 0]) - crop[ax] = slice(None) - - image = np.pad(image, pad_width, mode="symmetric") - return image, crop diff --git a/bioimageio/core/io.py b/bioimageio/core/io.py new file mode 100644 index 00000000..f053077e --- /dev/null +++ b/bioimageio/core/io.py @@ -0,0 +1,80 @@ +from pathlib import Path +from typing import Any, Dict, Optional, Sequence + +import imageio +from loguru import logger +from numpy.typing import NDArray + +from bioimageio.spec.model import AnyModelDescr +from bioimageio.spec.utils import load_array + +from .axis import Axis, AxisLike +from .common import MemberId, PerMember, SampleId +from .digest_spec import get_axes_infos, get_member_id +from .sample import Sample +from .stat_measures import Stat +from .tensor import Tensor + + +def load_image(path: Path, is_volume: bool) -> NDArray[Any]: + """load a single image as numpy array""" + ext = path.suffix + if ext == ".npy": + return load_array(path) + else: + return imageio.volread(path) if is_volume else imageio.imread(path) + + +def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor: + array = load_image( + path, + is_volume=( + axes is None or sum(Axis.create(a).type != "channel" for a in axes) > 2 + ), + ) + + return Tensor.from_numpy(array, dims=axes) + + +def load_sample_for_model( + *, + model: AnyModelDescr, + paths: PerMember[Path], + axes: Optional[PerMember[Sequence[AxisLike]]] = None, + stat: Optional[Stat] = None, + sample_id: Optional[SampleId] = None, +): + """load a single sample from `paths` that can be processed by `model`""" + + if axes is None: + axes = {} + + # make sure members are keyed by MemberId, not string + paths = {MemberId(k): v for k, v in paths.items()} + axes = {MemberId(k): v for k, v in axes.items()} + + model_inputs = {get_member_id(d): d for d in model.inputs} + + if unknown := {k for k in paths if k not in model_inputs}: + raise ValueError(f"Got unexpected paths for {unknown}") + + if unknown := {k for k in axes if k not in model_inputs}: + raise ValueError(f"Got unexpected axes hints for: {unknown}") + + members: Dict[MemberId, Tensor] = {} + for m, p in paths.items(): + if m not in axes: + axes[m] = get_axes_infos(model_inputs[m]) + logger.warning( + "loading paths with {}'s default input axes {} for input '{}'", + axes[m], + model.id or model.name, + m, + ) + members[m] = load_tensor(p, axes[m]) + + return Sample( + members=members, + stat={} if stat is None else stat, + id=sample_id or tuple(sorted(paths.values())), + ) diff --git a/bioimageio/core/model_adapters/__init__.py b/bioimageio/core/model_adapters/__init__.py new file mode 100644 index 00000000..85387b6a --- /dev/null +++ b/bioimageio/core/model_adapters/__init__.py @@ -0,0 +1,3 @@ +from ._model_adapter import ModelAdapter as ModelAdapter +from ._model_adapter import create_model_adapter as create_model_adapter +from ._model_adapter import get_weight_formats as get_weight_formats diff --git a/bioimageio/core/model_adapters/_keras_model_adapter.py b/bioimageio/core/model_adapters/_keras_model_adapter.py new file mode 100644 index 00000000..c5d74132 --- /dev/null +++ b/bioimageio/core/model_adapters/_keras_model_adapter.py @@ -0,0 +1,103 @@ +import os +from typing import Any, List, Optional, Sequence, Union + +from loguru import logger +from numpy.typing import NDArray + +from bioimageio.spec._internal.io_utils import download +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Version + +from .._settings import settings +from ..digest_spec import get_axes_infos +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + +os.environ["KERAS_BACKEND"] = settings.keras_backend + +# by default, we use the keras integrated with tensorflow +try: + import tensorflow as tf # pyright: ignore[reportMissingImports] + from tensorflow import ( # pyright: ignore[reportMissingImports] + keras, # pyright: ignore[reportUnknownVariableType] + ) + + tf_version = Version(tf.__version__) # pyright: ignore[reportUnknownArgumentType] +except Exception: + try: + import keras # pyright: ignore[reportMissingImports] + except Exception: + keras = None + + tf_version = None + + +class KerasModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ) -> None: + if keras is None: + raise ImportError("keras") + + super().__init__() + if model_description.weights.keras_hdf5 is None: + raise ValueError("model has not keras_hdf5 weights specified") + model_tf_version = model_description.weights.keras_hdf5.tensorflow_version + + if tf_version is None or model_tf_version is None: + logger.warning("Could not check tensorflow versions.") + elif model_tf_version > tf_version: + logger.warning( + "The model specifies a newer tensorflow version than installed: {} > {}.", + model_tf_version, + tf_version, + ) + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): + logger.warning( + "Model tensorflow version {} does not match {}.", + model_tf_version, + tf_version, + ) + + # TODO keras device management + if devices is not None: + logger.warning( + "Device management is not implemented for keras yet, ignoring the devices {}", + devices, + ) + + weight_path = download(model_description.weights.keras_hdf5.source).path + + self._network = keras.models.load_model(weight_path) + self._output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + _result: Union[Sequence[NDArray[Any]], NDArray[Any]] + _result = self._network.predict( # pyright: ignore[reportUnknownVariableType] + *[None if t is None else t.data.data for t in input_tensors] + ) + if isinstance(_result, (tuple, list)): + result: Sequence[NDArray[Any]] = _result + else: + result = [_result] # type: ignore + + assert len(result) == len(self._output_axes) + ret: List[Optional[Tensor]] = [] + ret.extend( + [Tensor(r, dims=axes) for r, axes, in zip(result, self._output_axes)] + ) + return ret + + def unload(self) -> None: + logger.warning( + "Device management is not implemented for keras yet, cannot unload model" + ) diff --git a/bioimageio/core/model_adapters/_model_adapter.py b/bioimageio/core/model_adapters/_model_adapter.py new file mode 100644 index 00000000..1d3c2b95 --- /dev/null +++ b/bioimageio/core/model_adapters/_model_adapter.py @@ -0,0 +1,163 @@ +import warnings +from abc import ABC, abstractmethod +from typing import List, Optional, Sequence, Tuple, Union, final + +from bioimageio.spec.model import v0_4, v0_5 + +from ..tensor import Tensor + +WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat] + +# Known weight formats in order of priority +# First match wins +DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = ( + "pytorch_state_dict", + "tensorflow_saved_model_bundle", + "torchscript", + "onnx", + "keras_hdf5", +) + + +class ModelAdapter(ABC): + """ + Represents model *without* any preprocessing or postprocessing. + + ``` + from bioimageio.core import load_description + + model = load_description(...) + + # option 1: + adapter = ModelAdapter.create(model) + adapter.forward(...) + adapter.unload() + + # option 2: + with ModelAdapter.create(model) as adapter: + adapter.forward(...) + ``` + """ + + @final + @classmethod + def create( + cls, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + *, + devices: Optional[Sequence[str]] = None, + weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None, + ): + """ + Creates model adapter based on the passed spec + Note: All specific adapters should happen inside this function to prevent different framework + initializations interfering with each other + """ + if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError( + f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}" + ) + + weights = model_description.weights + errors: List[str] = [] + weight_format_priority_order = ( + DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER + if weight_format_priority_order is None + else weight_format_priority_order + ) + for wf in weight_format_priority_order: + if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None: + try: + from ._pytorch_model_adapter import PytorchModelAdapter + + return PytorchModelAdapter( + outputs=model_description.outputs, + weights=weights.pytorch_state_dict, + devices=devices, + ) + except Exception as e: + errors.append(f"{wf}: {e}") + elif ( + wf == "tensorflow_saved_model_bundle" + and weights.tensorflow_saved_model_bundle is not None + ): + try: + from ._tensorflow_model_adapter import TensorflowModelAdapter + + return TensorflowModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(f"{wf}: {e}") + elif wf == "onnx" and weights.onnx is not None: + try: + from ._onnx_model_adapter import ONNXModelAdapter + + return ONNXModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(f"{wf}: {e}") + elif wf == "torchscript" and weights.torchscript is not None: + try: + from ._torchscript_model_adapter import TorchscriptModelAdapter + + return TorchscriptModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(f"{wf}: {e}") + elif wf == "keras_hdf5" and weights.keras_hdf5 is not None: + # keras can either be installed as a separate package or used as part of tensorflow + # we try to first import the keras model adapter using the separate package and, + # if it is not available, try to load the one using tf + try: + from ._keras_model_adapter import ( + KerasModelAdapter, + keras, # type: ignore + ) + + if keras is None: + from ._tensorflow_model_adapter import KerasModelAdapter + + return KerasModelAdapter( + model_description=model_description, devices=devices + ) + except Exception as e: + errors.append(f"{wf}: {e}") + + assert errors + error_list = "\n - ".join(errors) + raise ValueError( + "None of the weight format specific model adapters could be created for" + + f" '{model_description.id or model_description.name}'" + + f" in this environment. Errors are:\n\n{error_list}.\n\n" + ) + + @final + def load(self, *, devices: Optional[Sequence[str]] = None) -> None: + warnings.warn("Deprecated. ModelAdapter is loaded on initialization") + + @abstractmethod + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + """ + Run forward pass of model to get model predictions + """ + # TODO: handle tensor.transpose in here and make _forward_impl the abstract impl + + @abstractmethod + def unload(self): + """ + Unload model from any devices, freeing their memory. + The moder adapter should be considered unusable afterwards. + """ + + +def get_weight_formats() -> List[str]: + """ + Return list of supported weight types + """ + return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER) + + +create_model_adapter = ModelAdapter.create diff --git a/bioimageio/core/model_adapters/_onnx_model_adapter.py b/bioimageio/core/model_adapters/_onnx_model_adapter.py new file mode 100644 index 00000000..e7bdfc05 --- /dev/null +++ b/bioimageio/core/model_adapters/_onnx_model_adapter.py @@ -0,0 +1,68 @@ +import warnings +from typing import Any, List, Optional, Sequence, Union + +from numpy.typing import NDArray + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + +try: + import onnxruntime as rt +except Exception: + rt = None + + +class ONNXModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if rt is None: + raise ImportError("onnxruntime") + + super().__init__() + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + if model_description.weights.onnx is None: + raise ValueError("No ONNX weights specified for {model_description.name}") + + self._session = rt.InferenceSession( + str(download(model_description.weights.onnx.source).path) + ) + onnx_inputs = self._session.get_inputs() # type: ignore + self._input_names: List[str] = [ipt.name for ipt in onnx_inputs] # type: ignore + + if devices is not None: + warnings.warn( + f"Device management is not implemented for onnx yet, ignoring the devices {devices}" + ) + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + assert len(input_tensors) == len(self._input_names) + input_arrays = [None if ipt is None else ipt.data.data for ipt in input_tensors] + result: Union[Sequence[Optional[NDArray[Any]]], Optional[NDArray[Any]]] + result = self._session.run( # pyright: ignore[reportUnknownVariableType] + None, dict(zip(self._input_names, input_arrays)) + ) + if isinstance(result, (list, tuple)): + result_seq: Sequence[Optional[NDArray[Any]]] = result + else: + result_seq = [result] # type: ignore + + return [ + None if r is None else Tensor(r, dims=axes) + for r, axes in zip(result_seq, self._internal_output_axes) + ] + + def unload(self) -> None: + warnings.warn( + "Device management is not implemented for onnx yet, cannot unload model" + ) diff --git a/bioimageio/core/model_adapters/_pytorch_model_adapter.py b/bioimageio/core/model_adapters/_pytorch_model_adapter.py new file mode 100644 index 00000000..b647aeff --- /dev/null +++ b/bioimageio/core/model_adapters/_pytorch_model_adapter.py @@ -0,0 +1,149 @@ +import gc +import warnings +from typing import Any, List, Optional, Sequence, Tuple, Union + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..axis import AxisId +from ..digest_spec import get_axes_infos, import_callable +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + +try: + import torch +except Exception: + torch = None + + +class PytorchModelAdapter(ModelAdapter): + def __init__( + self, + *, + outputs: Union[ + Sequence[v0_4.OutputTensorDescr], Sequence[v0_5.OutputTensorDescr] + ], + weights: Union[ + v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr + ], + devices: Optional[Sequence[str]] = None, + ): + if torch is None: + raise ImportError("torch") + super().__init__() + self.output_dims = [tuple(a.id for a in get_axes_infos(out)) for out in outputs] + self._network = self.get_network(weights) + self._devices = self.get_devices(devices) + self._network = self._network.to(self._devices[0]) + + self._primary_device = self._devices[0] + state: Any = torch.load( + download(weights.source).path, + map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType] + ) + self._network.load_state_dict(state) + + self._network = self._network.eval() + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + if torch is None: + raise ImportError("torch") + with torch.no_grad(): + tensors = [ + None if ipt is None else torch.from_numpy(ipt.data.data) + for ipt in input_tensors + ] + tensors = [ + ( + None + if t is None + else t.to( + self._primary_device # pyright: ignore[reportUnknownArgumentType] + ) + ) + for t in tensors + ] + result: Union[Tuple[Any, ...], List[Any], Any] + result = self._network( # pyright: ignore[reportUnknownVariableType] + *tensors + ) + if not isinstance(result, (tuple, list)): + result = [result] + + result = [ + ( + None + if r is None + else r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r + ) + for r in result # pyright: ignore[reportUnknownVariableType] + ] + if len(result) > len(self.output_dims): + raise ValueError( + f"Expected at most {len(self.output_dims)} outputs, but got {len(result)}" + ) + + return [ + None if r is None else Tensor(r, dims=out) + for r, out in zip(result, self.output_dims) + ] + + def unload(self) -> None: + del self._network + _ = gc.collect() # deallocate memory + assert torch is not None + torch.cuda.empty_cache() # release reserved memory + + @staticmethod + def get_network( # pyright: ignore[reportUnknownParameterType] + weight_spec: Union[ + v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr + ] + ) -> "torch.nn.Module": # pyright: ignore[reportInvalidTypeForm] + if torch is None: + raise ImportError("torch") + arch = import_callable( + weight_spec.architecture, + sha256=( + weight_spec.architecture_sha256 + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) + else weight_spec.sha256 + ), + ) + model_kwargs = ( + weight_spec.kwargs + if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr) + else weight_spec.architecture.kwargs + ) + network = arch(**model_kwargs) + if not isinstance(network, torch.nn.Module): + raise ValueError( + f"calling {weight_spec.architecture.callable} did not return a torch.nn.Module" + ) + + return network + + @staticmethod + def get_devices( # pyright: ignore[reportUnknownParameterType] + devices: Optional[Sequence[str]] = None, + ) -> List["torch.device"]: # pyright: ignore[reportInvalidTypeForm] + if torch is None: + raise ImportError("torch") + if not devices: + torch_devices = [ + ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) + ] + else: + torch_devices = [torch.device(d) for d in devices] + + if len(torch_devices) > 1: + warnings.warn( + f"Multiple devices for single pytorch model not yet implemented; ignoring {torch_devices[1:]}" + ) + torch_devices = torch_devices[:1] + + return torch_devices diff --git a/bioimageio/core/model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py new file mode 100644 index 00000000..8a44ba5c --- /dev/null +++ b/bioimageio/core/model_adapters/_tensorflow_model_adapter.py @@ -0,0 +1,270 @@ +import warnings +import zipfile +from typing import List, Literal, Optional, Sequence, Union + +import numpy as np + +from bioimageio.spec.common import FileSource +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + +try: + import tensorflow as tf # pyright: ignore[reportMissingImports] +except Exception: + tf = None + + +class TensorflowModelAdapterBase(ModelAdapter): + weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] + + def __init__( + self, + *, + devices: Optional[Sequence[str]] = None, + weights: Union[ + v0_4.KerasHdf5WeightsDescr, + v0_4.TensorflowSavedModelBundleWeightsDescr, + v0_5.KerasHdf5WeightsDescr, + v0_5.TensorflowSavedModelBundleWeightsDescr, + ], + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + ): + if tf is None: + raise ImportError("tensorflow") + + super().__init__() + self.model_description = model_description + tf_version = v0_5.Version( + tf.__version__ # pyright: ignore[reportUnknownArgumentType] + ) + model_tf_version = weights.tensorflow_version + if model_tf_version is None: + warnings.warn( + "The model does not specify the tensorflow version." + + f"Cannot check if it is compatible with intalled tensorflow {tf_version}." + ) + elif model_tf_version > tf_version: + warnings.warn( + f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}." + ) + elif (model_tf_version.major, model_tf_version.minor) != ( + tf_version.major, + tf_version.minor, + ): + warnings.warn( + "The tensorflow version specified by the model does not match the installed: " + + f"{model_tf_version} != {tf_version}." + ) + + self.use_keras_api = ( + tf_version.major > 1 + or self.weight_format == KerasModelAdapter.weight_format + ) + + # TODO tf device management + if devices is not None: + warnings.warn( + f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}" + ) + + weight_file = self.require_unzipped(weights.source) + self._network = self._get_network(weight_file) + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + + def require_unzipped(self, weight_file: FileSource): + loacl_weights_file = download(weight_file).path + if zipfile.is_zipfile(loacl_weights_file): + out_path = loacl_weights_file.with_suffix(".unzipped") + with zipfile.ZipFile(loacl_weights_file, "r") as f: + f.extractall(out_path) + + return out_path + else: + return loacl_weights_file + + def _get_network( # pyright: ignore[reportUnknownParameterType] + self, weight_file: FileSource + ): + weight_file = self.require_unzipped(weight_file) + assert tf is not None + if self.use_keras_api: + return tf.keras.models.load_model( + weight_file, compile=False + ) # pyright: ignore[reportUnknownVariableType] + else: + # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model + return str(weight_file) + + # TODO currently we relaod the model every time. it would be better to keep the graph and session + # alive in between of forward passes (but then the sessions need to be properly opened / closed) + def _forward_tf( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): + assert tf is not None + input_keys = [ + ipt.name if isinstance(ipt, v0_4.InputTensorDescr) else ipt.id + for ipt in self.model_description.inputs + ] + output_keys = [ + out.name if isinstance(out, v0_4.OutputTensorDescr) else out.id + for out in self.model_description.outputs + ] + # TODO read from spec + tag = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.tag_constants.SERVING + ) + signature_key = ( # pyright: ignore[reportUnknownVariableType] + tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + ) + + graph = tf.Graph() # pyright: ignore[reportUnknownVariableType] + with graph.as_default(): + with tf.Session( + graph=graph + ) as sess: # pyright: ignore[reportUnknownVariableType] + # load the model and the signature + graph_def = tf.saved_model.loader.load( # pyright: ignore[reportUnknownVariableType] + sess, [tag], self._network + ) + signature = ( # pyright: ignore[reportUnknownVariableType] + graph_def.signature_def + ) + + # get the tensors into the graph + in_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].inputs[key].name for key in input_keys + ] + out_names = [ # pyright: ignore[reportUnknownVariableType] + signature[signature_key].outputs[key].name for key in output_keys + ] + in_tensors = [ # pyright: ignore[reportUnknownVariableType] + graph.get_tensor_by_name(name) + for name in in_names # pyright: ignore[reportUnknownVariableType] + ] + out_tensors = [ # pyright: ignore[reportUnknownVariableType] + graph.get_tensor_by_name(name) + for name in out_names # pyright: ignore[reportUnknownVariableType] + ] + + # run prediction + res = sess.run( # pyright: ignore[reportUnknownVariableType] + dict( + zip( + out_names, # pyright: ignore[reportUnknownArgumentType] + out_tensors, # pyright: ignore[reportUnknownArgumentType] + ) + ), + dict( + zip( + in_tensors, # pyright: ignore[reportUnknownArgumentType] + input_tensors, + ) + ), + ) + # from dict to list of tensors + res = [ # pyright: ignore[reportUnknownVariableType] + res[out] + for out in out_names # pyright: ignore[reportUnknownVariableType] + ] + + return res # pyright: ignore[reportUnknownVariableType] + + def _forward_keras( # pyright: ignore[reportUnknownParameterType] + self, *input_tensors: Optional[Tensor] + ): + assert self.use_keras_api + assert not isinstance(self._network, str) + assert tf is not None + tf_tensor = [ # pyright: ignore[reportUnknownVariableType] + None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors + ] + + try: + result = ( # pyright: ignore[reportUnknownVariableType] + self._network.forward(*tf_tensor) + ) + except AttributeError: + result = ( # pyright: ignore[reportUnknownVariableType] + self._network.predict(*tf_tensor) + ) + + if not isinstance(result, (tuple, list)): + result = [result] # pyright: ignore[reportUnknownVariableType] + + return [ # pyright: ignore[reportUnknownVariableType] + ( + None + if r is None + else r if isinstance(r, np.ndarray) else tf.make_ndarray(r) + ) + for r in result # pyright: ignore[reportUnknownVariableType] + ] + + def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]: + data = [None if ipt is None else ipt.data for ipt in input_tensors] + if self.use_keras_api: + result = self._forward_keras( # pyright: ignore[reportUnknownVariableType] + *data + ) + else: + result = self._forward_tf( # pyright: ignore[reportUnknownVariableType] + *data + ) + + return [ + None if r is None else Tensor(r, dims=axes) + for r, axes in zip( # pyright: ignore[reportUnknownVariableType] + result, # pyright: ignore[reportUnknownArgumentType] + self._internal_output_axes, + ) + ] + + def unload(self) -> None: + warnings.warn( + "Device management is not implemented for keras yet, cannot unload model" + ) + + +class TensorflowModelAdapter(TensorflowModelAdapterBase): + weight_format = "tensorflow_saved_model_bundle" + + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if model_description.weights.tensorflow_saved_model_bundle is None: + raise ValueError("missing tensorflow_saved_model_bundle weights") + + super().__init__( + devices=devices, + weights=model_description.weights.tensorflow_saved_model_bundle, + model_description=model_description, + ) + + +class KerasModelAdapter(TensorflowModelAdapterBase): + weight_format = "keras_hdf5" + + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if model_description.weights.keras_hdf5 is None: + raise ValueError("missing keras_hdf5 weights") + + super().__init__( + model_description=model_description, + devices=devices, + weights=model_description.weights.keras_hdf5, + ) diff --git a/bioimageio/core/model_adapters/_torchscript_model_adapter.py b/bioimageio/core/model_adapters/_torchscript_model_adapter.py new file mode 100644 index 00000000..0d28f019 --- /dev/null +++ b/bioimageio/core/model_adapters/_torchscript_model_adapter.py @@ -0,0 +1,92 @@ +import gc +import warnings +from typing import Any, List, Optional, Sequence, Tuple, Union + +import numpy as np +from numpy.typing import NDArray + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +from ..digest_spec import get_axes_infos +from ..tensor import Tensor +from ._model_adapter import ModelAdapter + +try: + import torch +except Exception: + torch = None + + +class TorchscriptModelAdapter(ModelAdapter): + def __init__( + self, + *, + model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr], + devices: Optional[Sequence[str]] = None, + ): + if torch is None: + raise ImportError("torch") + + super().__init__() + if model_description.weights.torchscript is None: + raise ValueError( + f"No torchscript weights found for model {model_description.name}" + ) + + weight_path = download(model_description.weights.torchscript.source).path + if devices is None: + self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] + else: + self.devices = [torch.device(d) for d in devices] + + if len(self.devices) > 1: + warnings.warn( + "Multiple devices for single torchscript model not yet implemented" + ) + + self._model = torch.jit.load(weight_path) + self._model.to(self.devices[0]) + self._internal_output_axes = [ + tuple(a.id for a in get_axes_infos(out)) + for out in model_description.outputs + ] + + def forward(self, *batch: Optional[Tensor]) -> List[Optional[Tensor]]: + assert torch is not None + with torch.no_grad(): + torch_tensor = [ + None if b is None else torch.from_numpy(b.data.data).to(self.devices[0]) + for b in batch + ] + _result: Union[ # pyright: ignore[reportUnknownVariableType] + Tuple[Optional[NDArray[Any]], ...], + List[Optional[NDArray[Any]]], + Optional[NDArray[Any]], + ] = self._model.forward(*torch_tensor) + if isinstance(_result, (tuple, list)): + result: Sequence[Optional[NDArray[Any]]] = _result + else: + result = [_result] + + result = [ + ( + None + if r is None + else r.cpu().numpy() if not isinstance(r, np.ndarray) else r + ) + for r in result + ] + + assert len(result) == len(self._internal_output_axes) + return [ + None if r is None else Tensor(r, dims=axes) + for r, axes in zip(result, self._internal_output_axes) + ] + + def unload(self) -> None: + assert torch is not None + self._devices = None + del self._model + _ = gc.collect() # deallocate memory + torch.cuda.empty_cache() # release reserved memory diff --git a/bioimageio/core/prediction.py b/bioimageio/core/prediction.py index cb30668e..3d10d31d 100644 --- a/bioimageio/core/prediction.py +++ b/bioimageio/core/prediction.py @@ -1,514 +1,7 @@ -import collections -import os -from fractions import Fraction -from itertools import product -from pathlib import Path -from typing import Dict, Iterator, List, NamedTuple, Optional, OrderedDict, Sequence, Tuple, Union +"""convenience functions for prediction coming soon. +For now, please use `create_prediction_pipeline` to get a `PredictionPipeline` +and then `PredictionPipeline.predict(sample)` +e..g load samples with core.io.load_sample_for_model() +""" -import numpy as np -import xarray as xr - -from bioimageio.core import image_helper, load_resource_description -from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline -from bioimageio.core.resource_io.nodes import ImplicitOutputShape, Model, ResourceDescription -from bioimageio.spec.shared import raw_nodes -from bioimageio.spec.shared.common import tqdm -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription - - -def _apply_crop(data, crop): - crop = tuple(crop[ax] for ax in data.dims) - return data[crop] - - -class TileDef(NamedTuple): - outer: Dict[str, slice] - inner: Dict[str, slice] - local: Dict[str, slice] - - -def get_tiling( - shape: Sequence[int], - tile_shape: Dict[str, int], - halo: Dict[str, int], - input_axes: Sequence[str], - scaling: Dict[str, float], -) -> Iterator[TileDef]: - # outer_tile is the "input" tile, inner_tile is the "output" tile with the halo removed - # tile_shape is the shape of the outer_tile - assert len(shape) == len(input_axes) - scaling = {ax: Fraction(sc).limit_denominator() for ax, sc in scaling.items()} - - shape_ = [sh for sh, ax in zip(shape, input_axes) if ax in "xyz"] - spatial_axes = [ax for ax in input_axes if ax in "xyz"] - inner_tile_shape_ = [tile_shape[ax] - 2 * halo[ax] for ax in spatial_axes] - scaling_ = [scaling[ax] for ax in spatial_axes] - assert all([sh % fr.denominator == 0 for sh, fr in zip(shape_, scaling_)]) - assert all([ish % fr.denominator == 0 for ish, fr in zip(inner_tile_shape_, scaling_)]) - halo_ = [halo[ax] for ax in spatial_axes] - assert len(shape_) == len(inner_tile_shape_) == len(spatial_axes) == len(halo_) - - ranges = [range(sh // tsh if sh % tsh == 0 else sh // tsh + 1) for sh, tsh in zip(shape_, inner_tile_shape_)] - start_points = product(*ranges) - - for start_point in start_points: - positions = [sp * tsh for sp, tsh in zip(start_point, inner_tile_shape_)] - - inner_tile = { - ax: slice(int(pos * fr), int(min(pos + tsh, sh) * fr)) - for ax, pos, tsh, sh, fr in zip(spatial_axes, positions, inner_tile_shape_, shape_, scaling_) - } - inner_tile["b"] = slice(None) - inner_tile["c"] = slice(None) - - outer_tile = { - ax: slice(max(pos - ha, 0), min(pos + tsh + ha, sh)) - for ax, pos, tsh, sh, ha in zip(spatial_axes, positions, inner_tile_shape_, shape_, halo_) - } - outer_tile["b"] = slice(None) - outer_tile["c"] = slice(None) - - local_tile = { - ax: slice( - inner_tile[ax].start - int(outer_tile[ax].start * scaling[ax]), - -(int(outer_tile[ax].stop * scaling[ax]) - inner_tile[ax].stop) - if int(outer_tile[ax].stop * scaling[ax]) != inner_tile[ax].stop - else None, - ) - for ax in spatial_axes - } - local_tile["b"] = slice(None) - local_tile["c"] = slice(None) - - yield TileDef(outer_tile, inner_tile, local_tile) - - -def _predict_with_tiling_impl( - prediction_pipeline: PredictionPipeline, - inputs: Sequence[xr.DataArray], - outputs: Sequence[xr.DataArray], - tile_shapes: Sequence[Dict[str, int]], - halos: Sequence[Dict[str, int]], - scales: Sequence[Dict[str, Tuple[int, int]]], - verbose: bool = False, -): - if len(inputs) > 1: - raise NotImplementedError("Tiling with multiple inputs not implemented yet") - - if len(outputs) > 1: - raise NotImplementedError("Tiling with multiple outputs not implemented yet") - - assert len(tile_shapes) == len(outputs) - assert len(halos) == len(outputs) - - input_ = inputs[0] - output = outputs[0] - tile_shape = tile_shapes[0] - halo = halos[0] - scaling = scales[0] - - tiles = get_tiling(shape=input_.shape, tile_shape=tile_shape, halo=halo, input_axes=input_.dims, scaling=scaling) - - assert all(isinstance(ax, str) for ax in input_.dims) - input_axes: Tuple[str, ...] = input_.dims # noqa - - def load_tile(tile): - inp = input_[tile] - # whether to pad on the right or left of the dim for the spatial dims - # + placeholders for batch and axis dimension, where we don't pad - pad_right = [tile[ax].start == 0 if ax in "xyz" else None for ax in input_axes] - return inp, pad_right - - if verbose: - shape = {ax: sh for ax, sh in zip(prediction_pipeline.input_specs[0].axes, input_.shape)} - n_tiles = int(np.prod([np.ceil(float(shape[ax]) / (tsh - 2 * halo[ax])) for ax, tsh in tile_shape.items()])) - tiles = tqdm(tiles, total=n_tiles, desc="prediction with tiling") - - # we need to use padded prediction for the individual tiles in case the - # border tiles don't match the requested tile shape - padding = {ax: tile_shape[ax] for ax in input_axes if ax in "xyz"} - padding["mode"] = "fixed" - for outer_tile, inner_tile, local_tile in tiles: - inp, pad_right = load_tile(outer_tile) - out = predict_with_padding(prediction_pipeline, inp, padding, pad_right) - assert len(out) == 1 - out = out[0] - output[inner_tile] = out[local_tile] - - -# -# prediction functions -# - - -def predict( - prediction_pipeline: PredictionPipeline, - inputs: Union[ - xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray], np.ndarray, List[np.ndarray], Tuple[np.ndarray] - ], -) -> List[xr.DataArray]: - """Run prediction for a single set of input(s) with a bioimage.io model - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data or numpy nd array. - """ - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - - assert len(inputs) == len(prediction_pipeline.input_specs) - tagged_data = [ - ipt if isinstance(ipt, xr.DataArray) else xr.DataArray(ipt, dims=ipt_spec.axes) - for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_specs) - ] - return prediction_pipeline.forward(*tagged_data) - - -def _parse_padding(padding, input_specs): - if padding is None: # no padding - return padding - if len(input_specs) > 1: - raise NotImplementedError("Padding for multiple inputs not yet implemented") - - input_spec = input_specs[0] - pad_keys = tuple(input_spec.axes) + ("mode",) - - def check_padding(padding): - assert all(k in pad_keys for k in padding.keys()) - - if isinstance(padding, dict): # pre-defined padding - check_padding(padding) - elif isinstance(padding, bool): # determine padding from spec - if padding: - axes = input_spec.axes - shape = input_spec.shape - if isinstance(shape, list): # fixed padding - padding = {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"} - padding["mode"] = "fixed" - else: # dynamic padding - step = shape.step - padding = {ax: st for ax, st in zip(axes, step) if ax in "xyz"} - padding["mode"] = "dynamic" - check_padding(padding) - else: # no padding - padding = None - else: - raise ValueError(f"Invalid argument for padding: {padding}") - return padding - - -def predict_with_padding( - prediction_pipeline: PredictionPipeline, - inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], - padding: Union[bool, Dict[str, int]] = True, - pad_right: bool = True, -) -> List[xr.DataArray]: - """Run prediction with padding for a single set of input(s) with a bioimage.io model. - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data. - padding: the padding settings. Pass True to derive from the model spec. - pad_right: whether to applying padding to the right or left of the input. - """ - if not padding: - raise ValueError - assert len(inputs) == len(prediction_pipeline.input_specs) - - output_spec = prediction_pipeline.output_specs[0] - if hasattr(output_spec.shape, "scale"): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in "xyz") or any( - off != 0 for ax, off in offset.items() if ax in "xyz" - ) - else: - network_resizes = False - - padding = _parse_padding(padding, prediction_pipeline.input_specs) - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - if not isinstance(padding, (tuple, list)): - padding = [padding] - assert len(padding) == len(prediction_pipeline.input_specs) - inputs, crops = zip( - *[ - image_helper.pad(inp, spec.axes, p, pad_right=pad_right) - for inp, spec, p in zip(inputs, prediction_pipeline.input_specs, padding) - ] - ) - result = predict(prediction_pipeline, inputs) - if network_resizes: - crops = tuple( - { - ax: slice( - crp.start if crp.start is None else int(crp.start * scale[ax] + 2 * offset[ax]), - crp.stop if crp.stop is None else int(crp.stop * scale[ax] + 2 * offset[ax]), - ) - if ax in "xyz" - else crp - for ax, crp in crop.items() - } - for crop in crops - ) - return [_apply_crop(res, crop) for res, crop in zip(result, crops)] - - -# simple heuristic to determine suitable shape from min and step -def _determine_shape(min_shape, step, axes): - is3d = "z" in axes - min_len = 64 if is3d else 256 - shape = [] - for ax, min_ax, step_ax in zip(axes, min_shape, step): - if ax in "zyx" and step_ax > 0: - len_ax = min_ax - while len_ax < min_len: - len_ax += step_ax - shape.append(len_ax) - else: - shape.append(min_ax) - return shape - - -def _parse_tiling(tiling, input_specs, output_specs): - if tiling is None: # no tiling - return tiling - if len(input_specs) > 1: - raise NotImplementedError("Tiling for multiple inputs not yet implemented") - if len(output_specs) > 1: - raise NotImplementedError("Tiling for multiple outputs not yet implemented") - - input_spec = input_specs[0] - output_spec = output_specs[0] - if isinstance(output_spec.shape, list): - assert isinstance(input_spec.shape, list) and input_spec.shape == output_spec.shape, ( - "When predicting with tiling, output_shape and input_shape must either be specified " - "explictly and must be identical, or output_shape must be" - "implicitly defined by input_shape, otherwise relationship between " - "input and output shapes per tile cannot be known." - ) - axes = input_spec.axes - - def check_tiling(tiling): - assert "halo" in tiling and "tile" in tiling - spatial_axes = [ax for ax in axes if ax in "xyz"] - halo = tiling["halo"] - tile = tiling["tile"] - scale = tiling.get("scale", dict()) - assert all(halo.get(ax, 0) >= 0 for ax in spatial_axes) - assert all(tile.get(ax, 0) > 0 for ax in spatial_axes) - assert all(scale.get(ax, 1) > 0 for ax in spatial_axes) - - if isinstance(tiling, dict) or (isinstance(tiling, bool) and tiling): - # NOTE we assume here that shape in input and output are the same - # for different input and output shapes, we should actually tile in the - # output space and then request the corresponding input tiles - # so we would need to apply the output scale and offset to the - # input shape to compute the tile size and halo here - shape = input_spec.shape - if not isinstance(shape, list): - shape = _determine_shape(shape.min, shape.step, axes) - assert isinstance(shape, list) - assert len(shape) == len(axes) - - scale = None - output_shape = output_spec.shape - scale = [1.0] * len(output_spec.shape) if isinstance(output_shape, list) else output_shape.scale - assert len(scale) == len(axes) - - halo = output_spec.halo - if not isinstance(halo, list): - halo = [0] * len(axes) - assert len(halo) == len(axes) - - default_tiling = { - "halo": {ax: ha for ax, ha in zip(axes, halo) if ax in "xyz"}, - "tile": {ax: sh for ax, sh in zip(axes, shape) if ax in "xyz"}, - "scale": {ax: sc for ax, sc in zip(axes, scale) if ax in "xyz"}, - } - - # override metadata defaults with provided dict - if isinstance(tiling, dict): - for key in ["halo", "tile", "scale"]: - default_tiling[key].update(tiling.get(key, dict())) - tiling = default_tiling - check_tiling(tiling) - - elif isinstance(tiling, bool) and not tiling: - raise NotImplementedError("Should be unreachable") - - else: - raise ValueError(f"Invalid argument for tiling: {tiling}") - - return tiling - - -def predict_with_tiling( - prediction_pipeline: PredictionPipeline, - inputs: Union[xr.DataArray, List[xr.DataArray], Tuple[xr.DataArray]], - tiling: Union[bool, Dict[str, Dict[str, int]]] = True, - verbose: bool = False, -) -> List[xr.DataArray]: - """Run prediction with tiling for a single set of input(s) with a bioimage.io model. - - Args: - prediction_pipeline: the prediction pipeline for the input model. - inputs: the input(s) for this model represented as xarray data. - tiling: the tiling settings. Pass True to derive from the model spec. - verbose: whether to print the prediction progress. - """ - if not tiling: - raise ValueError("cannot call predict_with_tiling with tiling=False") - assert len(inputs) == len(prediction_pipeline.input_specs) - - tiling = _parse_tiling(tiling, prediction_pipeline.input_specs, prediction_pipeline.output_specs) - if not isinstance(inputs, (list, tuple)): - inputs = [inputs] - named_inputs: OrderedDict[str, xr.DataArray] = collections.OrderedDict( - **{ - ipt_spec.name: xr.DataArray(ipt_data, dims=tuple(ipt_spec.axes)) - for ipt_data, ipt_spec in zip(inputs, prediction_pipeline.input_specs) - } - ) - - outputs = [] - for output_spec in prediction_pipeline.output_specs: - if isinstance(output_spec.shape, ImplicitOutputShape): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - - ref_input = named_inputs[output_spec.shape.reference_tensor] - ref_input_shape = dict(zip(ref_input.dims, ref_input.shape)) - output_shape = tuple(int(scale[ax] * ref_input_shape[ax] + 2 * offset[ax]) for ax in output_spec.axes) - else: - if len(inputs) > 1: - raise NotImplementedError - input_spec = prediction_pipeline.input_specs[0] - if input_spec.axes != output_spec.axes: - raise NotImplementedError("Tiling with a different output shape is not yet supported") - out_axes = output_spec.axes - fixed_shape = tuple(output_spec.shape) - if not all(fixed_shape[out_axes.index(ax)] == tile_shape for ax, tile_shape in tiling["tile"].items()): - raise NotImplementedError("Tiling with a different output shape is not yet supported") - - output_shape = list(inputs[0].shape) - chan_id = out_axes.index("c") - if fixed_shape[chan_id] != output_shape[chan_id]: - output_shape[chan_id] = fixed_shape[chan_id] - output_shape = tuple(output_shape) - - outputs.append(xr.DataArray(np.zeros(output_shape, dtype=output_spec.data_type), dims=tuple(output_spec.axes))) - - _predict_with_tiling_impl( - prediction_pipeline, - list(named_inputs.values()), - outputs, - tile_shapes=[tiling["tile"]], # todo: update tiling for multiple inputs/outputs - halos=[tiling["halo"]], - scales=[tiling["scale"]], - verbose=verbose, - ) - - return outputs - - -def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling): - if padding and tiling: - raise ValueError("Only one of padding or tiling is supported") - - input_data = image_helper.load_tensors(inputs, prediction_pipeline.input_specs) - if padding is not None: - result = predict_with_padding(prediction_pipeline, input_data, padding) - elif tiling is not None: - result = predict_with_tiling(prediction_pipeline, input_data, tiling) - else: - result = predict(prediction_pipeline, input_data) - - assert isinstance(result, list) - assert len(result) == len(outputs) - for res, out in zip(result, outputs): - image_helper.save_image(out, res) - - -def predict_image( - model_rdf: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - inputs: Union[Tuple[Path, ...], List[Path], Path], - outputs: Union[Tuple[Path, ...], List[Path], Path], - padding: Optional[Union[bool, Dict[str, int]]] = None, - tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, - weight_format: Optional[str] = None, - devices: Optional[List[str]] = None, - verbose: bool = False, -): - """Run prediction for a single set of input image(s) with a bioimage.io model. - - Args: - model_rdf: the bioimageio model. - inputs: the filepaths for the input images. - outputs: the filepaths for saving the input images. - padding: the padding settings for prediction. By default no padding is used. - tiling: the tiling settings for prediction. By default no tiling is used. - weight_format: the weight format to use for predictions. - devices: the devices to use for prediction. - verbose: run prediction in verbose mode. - """ - if not isinstance(inputs, (tuple, list)): - inputs = [inputs] - - if not isinstance(outputs, (tuple, list)): - outputs = [outputs] - - model = load_resource_description(model_rdf) - assert isinstance(model, Model) - if len(model.inputs) != len(inputs): - raise ValueError - if len(model.outputs) != len(outputs): - raise ValueError - - with create_prediction_pipeline( - bioimageio_model=model, weight_format=weight_format, devices=devices - ) as prediction_pipeline: - _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling) - - -def predict_images( - model_rdf: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - inputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], - outputs: Sequence[Union[Tuple[Path, ...], List[Path], Path]], - padding: Optional[Union[bool, Dict[str, int]]] = None, - tiling: Optional[Union[bool, Dict[str, Dict[str, int]]]] = None, - weight_format: Optional[str] = None, - devices: Optional[List[str]] = None, - verbose: bool = False, -): - """Predict multiple input images with a bioimage.io model. - - Args: - model_rdf: the bioimageio model. - inputs: the filepaths for the input images. - outputs: the filepaths for saving the input images. - padding: the padding settings for prediction. By default no padding is used. - tiling: the tiling settings for prediction. By default no tiling is used. - weight_format: the weight format to use for predictions. - devices: the devices to use for prediction. - verbose: run prediction in verbose mode. - """ - - model = load_resource_description(model_rdf) - assert isinstance(model, Model) - - with create_prediction_pipeline( - bioimageio_model=model, weight_format=weight_format, devices=devices - ) as prediction_pipeline: - prog = zip(inputs, outputs) - if verbose: - prog = tqdm(prog, total=len(inputs)) - - for inp, outp in prog: - if not isinstance(inp, (tuple, list)): - inp = [inp] - - if not isinstance(outp, (tuple, list)): - outp = [outp] - - _predict_sample(prediction_pipeline, inp, outp, padding, tiling) +# TODO: add convenience functions for predictions diff --git a/bioimageio/core/prediction_pipeline/__init__.py b/bioimageio/core/prediction_pipeline/__init__.py deleted file mode 100644 index d982ea1b..00000000 --- a/bioimageio/core/prediction_pipeline/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from ._model_adapters import get_weight_formats -from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline diff --git a/bioimageio/core/prediction_pipeline/_combined_processing.py b/bioimageio/core/prediction_pipeline/_combined_processing.py deleted file mode 100644 index bbd3e354..00000000 --- a/bioimageio/core/prediction_pipeline/_combined_processing.py +++ /dev/null @@ -1,103 +0,0 @@ -import dataclasses -from typing import Any, Dict, List, Optional, Sequence, Union - -from bioimageio.core.resource_io import nodes -from ._processing import AssertDtype, EnsureDtype, KNOWN_PROCESSING, Processing, TensorName -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - - -@dataclasses.dataclass -class ProcessingInfoStep: - name: str - kwargs: Dict[str, Any] - - -@dataclasses.dataclass -class ProcessingInfo: - steps: List[ProcessingInfoStep] - assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match - ensure_dtype_before: Optional[str] = None # cast data type if needed - assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match - ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match - - -class CombinedProcessing: - def __init__(self, combine_tensors: Dict[TensorName, ProcessingInfo]): - self._procs = [] - known = dict(KNOWN_PROCESSING["pre"]) - known.update(KNOWN_PROCESSING["post"]) - - # ensure all tensors have correct data type before any processing - for tensor_name, info in combine_tensors.items(): - if info.assert_dtype_before is not None: - self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_before)) - - if info.ensure_dtype_before is not None: - self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_before)) - - for tensor_name, info in combine_tensors.items(): - for step in info.steps: - self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs)) - - if info.assert_dtype_after is not None: - self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_after)) - - # ensure tensor has correct data type right after its processing - if info.ensure_dtype_after is not None: - self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_after)) - - self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs) - self.tensor_names = list(combine_tensors) - - @classmethod - def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.OutputTensor]]): - combine_tensors = {} - for ts in tensor_specs: - # There is a difference between pre-and postprocessing: - # After preprocessing we ensure float32, because the output is consumed by the model. - # After postprocessing the dtype that is specified in the model spec needs to be ensured. - assert ts.name not in combine_tensors - if isinstance(ts, nodes.InputTensor): - # todo: assert nodes.InputTensor.dtype with assert_dtype_before? - # todo: in the long run we do not want to limit model inputs to float32... - combine_tensors[ts.name] = ProcessingInfo( - [ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.preprocessing or []], - ensure_dtype_after="float32", - ) - elif isinstance(ts, nodes.OutputTensor): - combine_tensors[ts.name] = ProcessingInfo( - [ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.postprocessing or []], - ensure_dtype_after=ts.data_type, - ) - else: - raise NotImplementedError(type(ts)) - - inst = cls(combine_tensors) - for ts in tensor_specs: - if isinstance(ts, nodes.OutputTensor) and ts.name in inst.required_measures[PER_DATASET]: - raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented") - - return inst - - def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - for proc in self._procs: - proc.set_computed_measures(computed_measures) - sample[proc.tensor_name] = proc.apply(sample[proc.tensor_name]) - - @staticmethod - def _collect_required_measures(proc: Sequence[Processing]) -> RequiredMeasures: - ret: RequiredMeasures = {PER_SAMPLE: {}, PER_DATASET: {}} - for p in proc: - for mode, ms_per_mode in p.get_required_measures().items(): - for tn, ms_per_tn in ms_per_mode.items(): - if tn not in ret[mode]: - ret[mode][tn] = set() - - ret[mode][tn].update(ms_per_tn) - - return ret diff --git a/bioimageio/core/prediction_pipeline/_measure_groups.py b/bioimageio/core/prediction_pipeline/_measure_groups.py deleted file mode 100644 index 59e0ce74..00000000 --- a/bioimageio/core/prediction_pipeline/_measure_groups.py +++ /dev/null @@ -1,343 +0,0 @@ -from __future__ import annotations - -import collections -import warnings -from collections import defaultdict -from itertools import product -from typing import DefaultDict, Dict, Hashable, Iterator, List, Mapping, Optional, Sequence, Set, Tuple, Type, Union - -import numpy -import xarray as xr - -from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std, Var -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample, TensorName - -try: - from typing import Literal, TypedDict -except ImportError: - from typing_extensions import Literal, TypedDict # type: ignore - -try: - import crick -except ImportError: - crick = None - -MeasureValue = xr.DataArray - - -class SampleMeasureGroup: - """group of measures for more efficient computation of multiple measures per sample""" - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - raise NotImplementedError - - -class DatasetMeasureGroup: - """group of measures for more efficient computation of multiple measures per dataset""" - - def reset(self) -> None: - """reset any accumulated intermediates""" - raise NotImplementedError - - def update_with_sample(self, sample: Sample) -> None: - """update intermediate representation with a data sample""" - raise NotImplementedError - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - """compute statistics from intermediate representation""" - raise NotImplementedError - - -MeasureGroups = TypedDict( - "MeasureGroups", {PER_SAMPLE: Sequence[SampleMeasureGroup], PER_DATASET: Sequence[DatasetMeasureGroup]} -) - - -class DatasetMean(DatasetMeasureGroup): - n: int - mean: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[int]]): - self.axes: Optional[Tuple[str]] = axes - self.tensor_name = tensor_name - self.reset() - - def reset(self): - self.n = 0 - self.mean = None - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name].astype(numpy.float64, copy=False) - mean_b = tensor.mean(dim=self.axes) - assert mean_b.dtype == numpy.float64 - n_b = numpy.prod(tensor.shape) / numpy.prod(mean_b.shape) # reduced voxel count - if self.n == 0: - assert self.mean is None - self.n = n_b - self.mean = mean_b - else: - n_a = self.n - mean_a = self.mean - self.n = n = n_a + n_b - self.mean = (n_a * mean_a + n_b * mean_b) / n - assert self.mean.dtype == numpy.float64 - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.n == 0: - return {} - else: - return {self.tensor_name: {Mean(axes=self.axes): self.mean}} - - -class MeanVarStd(SampleMeasureGroup, DatasetMeasureGroup): - n: int - mean: Optional[xr.DataArray] - m2: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[int]]): - self.axes: Optional[Tuple[str]] = axes - self.tensor_name = tensor_name - self.reset() - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - tensor = sample[self.tensor_name] - mean = tensor.mean(dim=self.axes) - c = tensor - mean - n = tensor.size if self.axes is None else numpy.prod([tensor.sizes[d] for d in self.axes]) - var = xr.dot(c, c, dims=self.axes) / n - std = numpy.sqrt(var) - return {self.tensor_name: {Mean(axes=self.axes): mean, Var(axes=self.axes): var, Std(axes=self.axes): std}} - - def reset(self): - self.n = 0 - self.mean = None - self.m2 = None - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name].astype(numpy.float64, copy=False) - mean_b = tensor.mean(dim=self.axes) - assert mean_b.dtype == numpy.float64 - n_b = numpy.prod(tensor.shape) / numpy.prod(mean_b.shape) # reduced voxel count - m2_b = ((tensor - mean_b) ** 2).sum(dim=self.axes) - assert m2_b.dtype == numpy.float64 - if self.n == 0: - assert self.mean is None - assert self.m2 is None - self.n = n_b - self.mean = mean_b - self.m2 = m2_b - else: - n_a = self.n - mean_a = self.mean - m2_a = self.m2 - self.n = n = n_a + n_b - self.mean = (n_a * mean_a + n_b * mean_b) / n - assert self.mean.dtype == numpy.float64 - d = mean_b - mean_a - self.m2 = m2_a + m2_b + d ** 2 * n_a * n_b / n - assert self.m2.dtype == numpy.float64 - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.n == 0: - return {} - else: - var = self.m2 / self.n - return { - self.tensor_name: { - Mean(axes=self.axes): self.mean, - Var(axes=self.axes): var, - Std(axes=self.axes): numpy.sqrt(var), - } - } - - -class SamplePercentiles(SampleMeasureGroup): - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - tensor = sample[self.tensor_name] - ps = tensor.quantile(self.qs, dim=self.axes) - return {self.tensor_name: {Percentile(n=n, axes=self.axes): p for n, p in zip(self.ns, ps)}} - - -class MeanPercentiles(DatasetMeasureGroup): - n: int - estimates: Optional[xr.DataArray] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - self.reset() - - def reset(self): - self.n = 0 - self.estimates = None - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name] - sample_estimates = tensor.quantile(self.qs, dim=self.axes).astype(numpy.float64, copy=False) - - n = numpy.prod(tensor.shape) / numpy.prod(sample_estimates.shape[1:]) # reduced voxel count - - if self.n == 0: - self.estimates = sample_estimates - else: - self.estimates = (self.n * self.estimates + n * sample_estimates) / (self.n + n) - assert self.estimates.dtype == numpy.float64 - - self.n += n - - def finalize(self) -> Dict[TensorName, Dict[Percentile, MeasureValue]]: - if self.n == 0: - return {} - else: - warnings.warn(f"Computed dataset percentiles naively by averaging percentiles of samples.") - return {self.tensor_name: {Percentile(n=n, axes=self.axes): e for n, e in zip(self.ns, self.estimates)}} - - -class CrickPercentiles(DatasetMeasureGroup): - digest: Optional[List["crick.TDigest"]] - dims: Optional[Tuple[Hashable, ...]] - indices: Optional[Iterator[Tuple[int, ...]]] - shape: Optional[Tuple[int, ...]] - - def __init__(self, tensor_name: TensorName, axes: Optional[Tuple[str]], ns: Sequence[float]): - assert all(0 <= n <= 100 for n in ns) - assert axes is None or "_percentiles" not in axes - warnings.warn(f"Computing dataset percentiles with experimental 'crick' library.") - self.ns = ns - self.qs = [n / 100 for n in ns] - self.axes = axes - self.tensor_name = tensor_name - self.reset() - - def reset(self): - self.digest = None - self.dims = None - self.indices = None - self.shape = None - - def _initialize(self, tensor_sizes: Mapping[Hashable, int]): - out_sizes = collections.OrderedDict(_percentiles=len(self.ns)) - if self.axes is not None: - for d, s in tensor_sizes.items(): - if d not in self.axes: - out_sizes[d] = s - - self.dims, self.shape = zip(*out_sizes.items()) - self.digest = [crick.TDigest() for _ in range(int(numpy.prod(self.shape[1:])))] - self.indices = product(*map(range, self.shape[1:])) - - def update_with_sample(self, sample: Sample): - tensor = sample[self.tensor_name] - assert "_percentiles" not in tensor.dims - if self.digest is None: - self._initialize(tensor.sizes) - assert self.digest is not None - - for i, idx in enumerate(self.indices): - self.digest[i].update(tensor.isel(dict(zip(self.dims[1:], idx)))) - - def finalize(self) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - if self.digest is None: - return {} - else: - vs = numpy.asarray([[d.quantile(q) for d in self.digest] for q in self.qs]).reshape(self.shape) - return { - self.tensor_name: { - Percentile(n=n, axes=self.axes): xr.DataArray(v, dims=self.dims[1:]) for n, v in zip(self.ns, vs) - } - } - - -if crick is None: - DatasetPercentileGroup: Union[Type[MeanPercentiles], Type[CrickPercentiles]] = MeanPercentiles -else: - DatasetPercentileGroup = CrickPercentiles - - -class SingleMeasureAsGroup(SampleMeasureGroup): - """wrapper for measures to match interface of SampleMeasureGroup""" - - def __init__(self, tensor_name: TensorName, measure: Measure): - self.tensor_name = tensor_name - self.measure = measure - - def compute(self, sample: Sample) -> Dict[TensorName, Dict[Measure, MeasureValue]]: - return {self.tensor_name: {self.measure: self.measure.compute(sample[self.tensor_name])}} - - -def get_measure_groups(measures: RequiredMeasures) -> MeasureGroups: - """find a list of MeasureGroups to compute measures efficiently""" - - measure_groups = {PER_SAMPLE: [], PER_DATASET: []} - means: Set[Tuple[TensorName, Mean]] = set() - mean_var_std_groups: Set[Tuple[TensorName, Optional[Tuple[str, ...]]]] = set() - percentile_groups: DefaultDict[Tuple[TensorName, Optional[Tuple[str, ...]]], List[float]] = defaultdict(list) - for mode, ms_per_mode in measures.items(): - for tn, ms_per_tn in ms_per_mode.items(): - for m in ms_per_tn: - if isinstance(m, Mean): - means.add((tn, m)) - elif isinstance(m, (Var, Std)): - mean_var_std_groups.add((tn, m.axes)) - elif isinstance(m, Percentile): - percentile_groups[(tn, m.axes)].append(m.n) - elif mode == PER_SAMPLE: - measure_groups[mode].append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) - else: - raise NotImplementedError(f"Computing statistics for {m} {mode} not yet implemented") - - # add all mean measures that are not included in a mean/var/std group - for (tn, m) in means: - if (tn, m.axes) not in mean_var_std_groups: - # compute only mean - if mode == PER_SAMPLE: - measure_groups[mode].append(SingleMeasureAsGroup(tensor_name=tn, measure=m)) - elif mode == PER_DATASET: - measure_groups[mode].append(DatasetMean(tensor_name=tn, axes=m.axes)) - else: - raise NotImplementedError(mode) - - for (tn, axes) in mean_var_std_groups: - measure_groups[mode].append(MeanVarStd(tensor_name=tn, axes=axes)) - - for (tn, axes), ns in percentile_groups.items(): - if mode == PER_SAMPLE: - measure_groups[mode].append(SamplePercentiles(tensor_name=tn, axes=axes, ns=ns)) - elif mode == PER_DATASET: - measure_groups[mode].append(DatasetPercentileGroup(tensor_name=tn, axes=axes, ns=ns)) - else: - raise NotImplementedError(mode) - - return measure_groups - - -def compute_measures( - measures: RequiredMeasures, *, sample: Optional[Sample] = None, dataset: Iterator[Sample] = tuple() -) -> ComputedMeasures: - ms_groups = get_measure_groups(measures) - ret = {PER_SAMPLE: {}, PER_DATASET: {}} - if sample is not None: - for mg in ms_groups[PER_SAMPLE]: - assert isinstance(mg, SampleMeasureGroup) - ret[PER_SAMPLE].update(mg.compute(sample)) - - for sample in dataset: - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureGroup) - mg.update_with_sample(sample) - - for mg in ms_groups[PER_DATASET]: - assert isinstance(mg, DatasetMeasureGroup) - ret[PER_DATASET].update(mg.finalize()) - - return ret diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py b/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py deleted file mode 100644 index 5d2745d6..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from ._model_adapter import ModelAdapter, create_model_adapter, get_weight_formats - -__all__ = ["ModelAdapter", "create_model_adapter", "get_weight_formats"] diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py deleted file mode 100644 index a9ee132b..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py +++ /dev/null @@ -1,53 +0,0 @@ -import warnings -from typing import List, Optional, Sequence -from marshmallow import missing - -# by default, we use the keras integrated with tensorflow -try: - from tensorflow import keras - import tensorflow as tf - - TF_VERSION = tf.__version__ -except Exception: - import keras - - TF_VERSION = None -import xarray as xr - -from ._model_adapter import ModelAdapter - - -class KerasModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[Sequence[str]] = None) -> None: - model_tf_version = self.bioimageio_model.weights["keras_hdf5"].tensorflow_version - if model_tf_version is missing: - model_tf_version = None - else: - model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor)) - - if TF_VERSION is None or model_tf_version is None: - warnings.warn("Could not check tensorflow versions. The prediction results may be wrong.") - elif tuple(model_tf_version[:2]) != tuple(map(int, TF_VERSION.split(".")))[:2]: - warnings.warn( - f"Model tensorflow version {model_tf_version} does not match {TF_VERSION}." - "The prediction results may be wrong" - ) - - # TODO keras device management - if devices is not None: - warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}") - - weight_file = self.bioimageio_model.weights["keras_hdf5"].source - self._model = keras.models.load_model(weight_file) - self._output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - def _unload(self) -> None: - warnings.warn("Device management is not implemented for keras yet, cannot unload model") - - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - result = self._model.predict(*input_tensors) - if not isinstance(result, (tuple, list)): - result = [result] - - assert len(result) == len(self._output_axes) - return [xr.DataArray(r, dims=axes) for r, axes, in zip(result, self._output_axes)] diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py deleted file mode 100644 index 8ab3fa88..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py +++ /dev/null @@ -1,178 +0,0 @@ -import abc -from typing import List, Optional, Sequence, Type, Union - -import xarray as xr - -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io import nodes - -#: Known weight formats in order of priority -#: First match wins -from bioimageio.spec.model import raw_nodes - -_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "torchscript", "onnx", "keras_hdf5"] - - -class ModelAdapter(abc.ABC): - """ - Represents model *without* any preprocessing and postprocessing - """ - - def __init__( - self, *, bioimageio_model: Union[nodes.Model, raw_nodes.Model], devices: Optional[Sequence[str]] = None - ): - self.bioimageio_model = self._prepare_model(bioimageio_model) - self.default_devices = devices - self.loaded = False - - @staticmethod - def _prepare_model(bioimageio_model): - """the (raw) model node is prepared (here: loaded as non-raw model node) for the model adapter to be ready - for operation. - Note: To write a model adapter that uses the raw model node one can overwrite this method. - """ - if isinstance(bioimageio_model, nodes.Model): - return bioimageio_model - else: - return load_resource_description(bioimageio_model) - - def __enter__(self): - """load on entering context""" - assert not self.loaded - self.load() # using default_devices - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - """unload on exiting context""" - assert self.loaded - self.unload() - return False - - def load(self, *, devices: Optional[Sequence[str]] = None) -> None: - """ - Note: Use ModelAdapter as context to not worry about calling unload()! - Load model onto devices. If devices is None, self.default_devices are chosen - (which may be None as well, in which case a framework dependent default is chosen) - """ - self._load(devices=devices or self.default_devices) - self.loaded = True - - @abc.abstractmethod - def _load(self, *, devices: Optional[Sequence[str]] = None) -> None: - """ - Load model onto devices. If devices is None a framework dependent default is chosen - """ - ... - - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """ - Load model if unloaded/outside context; then run forward pass of model to get model predictions - """ - if not self.loaded: - self.load() - - assert self.loaded - return self._forward(*input_tensors) - - @abc.abstractmethod - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """ - Run forward pass of model to get model predictions - Note: model is responsible converting it's data representation to - xarray.DataArray - """ - ... - - def unload(self): - """ - Unload model from any devices, freeing their memory. - Note: Use ModelAdapter as context to not worry about calling unload()! - """ - # implementation of non-state-machine logic in _unload() - assert self.loaded - self._unload() - self.loaded = False - - @abc.abstractmethod - def _unload(self) -> None: - """ - Unload model from any devices, freeing their memory. - """ - ... - - -def get_weight_formats() -> List[str]: - """ - Return list of supported weight types - """ - return _WEIGHT_FORMATS.copy() - - -def create_model_adapter( - *, - bioimageio_model: Union[nodes.Model, raw_nodes.Model], - devices=Optional[Sequence[str]], - weight_format: Optional[str] = None, -) -> ModelAdapter: - """ - Creates model adapter based on the passed spec - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - weights = bioimageio_model.weights - weight_formats = get_weight_formats() - - if weight_format is not None: - if weight_format not in weight_formats: - raise ValueError(f"Weight format {weight_format} is not in supported formats {weight_formats}") - weight_formats = [weight_format] - - for weight in weight_formats: - if weight in weights: - adapter_cls = _get_model_adapter(weight) - return adapter_cls(bioimageio_model=bioimageio_model, devices=devices) - - raise RuntimeError( - f"weight format {weight_format} not among formats listed in model: {list(bioimageio_model.weights.keys())}" - ) - - -def _get_model_adapter(weight_format: str) -> Type[ModelAdapter]: - """ - Return adapter class based on the weight format - Note: All specific adapters should happen inside this function to prevent different framework - initializations interfering with each other - """ - if weight_format == "pytorch_state_dict": - from ._pytorch_model_adapter import PytorchModelAdapter - - return PytorchModelAdapter - - elif weight_format == "tensorflow_saved_model_bundle": - from ._tensorflow_model_adapter import TensorflowModelAdapter - - return TensorflowModelAdapter - - elif weight_format == "onnx": - from ._onnx_model_adapter import ONNXModelAdapter - - return ONNXModelAdapter - - elif weight_format == "torchscript": - from ._torchscript_model_adapter import TorchscriptModelAdapter - - return TorchscriptModelAdapter - - elif weight_format == "keras_hdf5": - # keras can either be installed as a separate package or used as part of tensorflow - # we try to first import the keras model adapter using the separate package and, - # if it is not available, try to load the one using tf - try: - from ._keras_model_adapter import KerasModelAdapter - except ImportError: - from ._tensorflow_model_adapter import KerasModelAdapter - - return KerasModelAdapter - - else: - raise ValueError(f"Weight format {weight_format} is not supported.") diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py deleted file mode 100644 index 5495c8bd..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import warnings -from typing import List, Optional - -import onnxruntime as rt -import xarray as xr - -from ._model_adapter import ModelAdapter - -logger = logging.getLogger(__name__) - - -class ONNXModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - self._session = rt.InferenceSession(str(self.bioimageio_model.weights["onnx"].source)) - onnx_inputs = self._session.get_inputs() - self._input_names = [ipt.name for ipt in onnx_inputs] - - if devices is not None: - warnings.warn(f"Device management is not implemented for onnx yet, ignoring the devices {devices}") - - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - assert len(input_tensors) == len(self._input_names) - input_arrays = [ipt.data for ipt in input_tensors] - result = self._session.run(None, dict(zip(self._input_names, input_arrays))) - if not isinstance(result, (list, tuple)): - result = [] - - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - - def _unload(self) -> None: - warnings.warn("Device management is not implemented for onnx yet, cannot unload model") diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py deleted file mode 100644 index f47aa1d7..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py +++ /dev/null @@ -1,62 +0,0 @@ -import gc -import warnings -from typing import List, Optional - -import torch -import xarray as xr -from marshmallow import missing - -from bioimageio.core.resource_io import nodes -from ._model_adapter import ModelAdapter - - -class PytorchModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): - self._model = self.get_nn_instance(self.bioimageio_model) - - if devices is None: - self._devices = ["cuda" if torch.cuda.is_available() else "cpu"] - else: - self._devices = [torch.device(d) for d in devices] - - if len(self._devices) > 1: - warnings.warn("Multiple devices for single pytorch model not yet implemented") - - self._model.to(self._devices[0]) - - assert isinstance(self._model, torch.nn.Module) - weights = self.bioimageio_model.weights.get("pytorch_state_dict") - if weights is not None and weights.source: - state = torch.load(weights.source, map_location=self._devices[0]) - self._model.load_state_dict(state) - - self._model.eval() - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - with torch.no_grad(): - tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors] - tensors = [t.to(self._devices[0]) for t in tensors] - result = self._model(*tensors) - if not isinstance(result, (tuple, list)): - result = [result] - - result = [r.detach().cpu().numpy() if isinstance(r, torch.Tensor) else r for r in result] - - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - - def _unload(self) -> None: - self._devices = None - del self._model - gc.collect() # deallocate memory - torch.cuda.empty_cache() # release reserved memory - - @staticmethod - def get_nn_instance(model_node: nodes.Model, **kwargs): - weight_spec = model_node.weights.get("pytorch_state_dict") - assert weight_spec is not None - assert isinstance(weight_spec.architecture, nodes.ImportedSource) - model_kwargs = weight_spec.kwargs - joined_kwargs = {} if model_kwargs is missing else dict(model_kwargs) - joined_kwargs.update(kwargs) - return weight_spec.architecture(**joined_kwargs) diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py deleted file mode 100644 index 57f8de41..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py +++ /dev/null @@ -1,131 +0,0 @@ -import warnings -import zipfile -from typing import List, Optional - -import numpy as np -import tensorflow as tf -import xarray as xr -from marshmallow import missing - -from ._model_adapter import ModelAdapter - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - - -class TensorflowModelAdapterBase(ModelAdapter): - weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"] - - def require_unzipped(self, weight_file): - if zipfile.is_zipfile(weight_file): - out_path = weight_file.with_suffix("") - with zipfile.ZipFile(weight_file, "r") as f: - f.extractall(out_path) - return out_path - return weight_file - - def _load_model(self, weight_file): - weight_file = self.require_unzipped(weight_file) - if self.use_keras_api: - return tf.keras.models.load_model(weight_file, compile=False) - else: - # NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model - return str(weight_file) - - def _load(self, *, devices: Optional[List[str]] = None): - model_tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version - if model_tf_version is missing: - model_tf_version = None - else: - model_tf_version = (int(model_tf_version.major), int(model_tf_version.minor)) - - tf_version = tf.__version__ - tf_major_and_minor = tuple(map(int, tf_version.split(".")))[:2] - if model_tf_version is None: - warnings.warn( - "The model did not contain metadata about the tensorflow version used for training." - f"Cannot check if it is compatible with tf {tf_version}. The prediction result may be wrong." - ) - elif tuple(model_tf_version[:2]) != tf_major_and_minor: - warnings.warn( - f"Model tensorflow version {model_tf_version} does not match {tf_version}." - "The prediction results may be wrong" - ) - - tf_major_ver = tf_major_and_minor[0] - assert tf_major_ver in (1, 2) - self.use_keras_api = tf_major_ver > 1 or self.weight_format == KerasModelAdapter.weight_format - - # TODO tf device management - if devices is not None: - warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}") - - weight_file = self.require_unzipped(self.bioimageio_model.weights[self.weight_format].source) - self._model = self._load_model(weight_file) - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - # TODO currently we relaod the model every time. it would be better to keep the graph and session - # alive in between of forward passes (but then the sessions need to be properly opened / closed) - def _forward_tf(self, *input_tensors): - input_keys = [ipt.name for ipt in self.bioimageio_model.inputs] - output_keys = [out.name for out in self.bioimageio_model.outputs] - - # TODO read from spec - tag = tf.saved_model.tag_constants.SERVING - signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY - - graph = tf.Graph() - with graph.as_default(): - with tf.Session(graph=graph) as sess: - - # load the model and the signature - graph_def = tf.saved_model.loader.load(sess, [tag], self._model) - signature = graph_def.signature_def - - # get the tensors into the graph - in_names = [signature[signature_key].inputs[key].name for key in input_keys] - out_names = [signature[signature_key].outputs[key].name for key in output_keys] - in_tensors = [graph.get_tensor_by_name(name) for name in in_names] - out_tensors = [graph.get_tensor_by_name(name) for name in out_names] - - # run prediction - res = sess.run(dict(zip(out_names, out_tensors)), dict(zip(in_tensors, input_tensors))) - # from dict to list of tensors - res = [res[out] for out in out_names] - - return res - - def _forward_keras(self, *input_tensors): - tf_tensor = [tf.convert_to_tensor(ipt) for ipt in input_tensors] - - try: - result = self._model.forward(*tf_tensor) - except AttributeError: - result = self._model.predict(*tf_tensor) - - if not isinstance(result, (tuple, list)): - result = [result] - - return [r if isinstance(r, np.ndarray) else tf.make_ndarray(r) for r in result] - - def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - data = [ipt.data for ipt in input_tensors] - if self.use_keras_api: - result = self._forward_keras(*data) - else: - result = self._forward_tf(*data) - - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - - def _unload(self) -> None: - warnings.warn("Device management is not implemented for keras yet, cannot unload model") - - -class TensorflowModelAdapter(TensorflowModelAdapterBase): - weight_format = "tensorflow_saved_model_bundle" - - -class KerasModelAdapter(TensorflowModelAdapterBase): - weight_format = "keras_hdf5" diff --git a/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py b/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py deleted file mode 100644 index 3b339722..00000000 --- a/bioimageio/core/prediction_pipeline/_model_adapters/_torchscript_model_adapter.py +++ /dev/null @@ -1,43 +0,0 @@ -import gc -import warnings -from typing import List, Optional - -import numpy as np -import torch -import xarray as xr - -from ._model_adapter import ModelAdapter - - -class TorchscriptModelAdapter(ModelAdapter): - def _load(self, *, devices: Optional[List[str]] = None): - weight_path = str(self.bioimageio_model.weights["torchscript"].source.resolve()) - if devices is None: - self.devices = ["cuda" if torch.cuda.is_available() else "cpu"] - else: - self.devices = [torch.device(d) for d in devices] - - if len(self.devices) > 1: - warnings.warn("Multiple devices for single torchscript model not yet implemented") - - self._model = torch.jit.load(weight_path) - self._model.to(self.devices[0]) - self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs] - - def _forward(self, *batch: xr.DataArray) -> List[xr.DataArray]: - with torch.no_grad(): - torch_tensor = [torch.from_numpy(b.data).to(self.devices[0]) for b in batch] - result = self._model.forward(*torch_tensor) - if not isinstance(result, (tuple, list)): - result = [result] - - result = [r.cpu().numpy() if not isinstance(r, np.ndarray) else r for r in result] - - assert len(result) == len(self._internal_output_axes) - return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)] - - def _unload(self) -> None: - self._devices = None - del self._model - gc.collect() # deallocate memory - torch.cuda.empty_cache() # release reserved memory diff --git a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py b/bioimageio/core/prediction_pipeline/_prediction_pipeline.py deleted file mode 100644 index dc98a373..00000000 --- a/bioimageio/core/prediction_pipeline/_prediction_pipeline.py +++ /dev/null @@ -1,244 +0,0 @@ -import abc -import warnings -from dataclasses import dataclass -from typing import Iterable, List, Optional, Sequence, Tuple, Union - -import xarray as xr -from marshmallow import missing - -from bioimageio.core.resource_io import nodes -from bioimageio.core.resource_io.utils import resolve_raw_node -from bioimageio.spec.model import raw_nodes -from ._combined_processing import CombinedProcessing -from ._model_adapters import ModelAdapter, create_model_adapter -from ._stat_state import StatsState -from ._utils import ComputedMeasures, Sample, TensorName - - -@dataclass -class NamedImplicitOutputShape: - reference_input: TensorName = missing - scale: List[Tuple[str, float]] = missing - offset: List[Tuple[str, int]] = missing - - def __len__(self): - return len(self.scale) - - -class PredictionPipeline(abc.ABC): - """ - Represents model computation including preprocessing and postprocessing - Note: Ideally use the PredictionPipeline as a context manager - """ - - @abc.abstractmethod - def __enter__(self): - ... - - @abc.abstractmethod - def __exit__(self, exc_type, exc_val, exc_tb): - ... - - @abc.abstractmethod - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """ - Compute predictions - """ - ... - - @property - @abc.abstractmethod - def name(self) -> str: - """ - Name of the pipeline - """ - ... - - @property - @abc.abstractmethod - def input_specs(self) -> List[nodes.InputTensor]: - """ - specs of inputs - """ - ... - - @property - @abc.abstractmethod - def output_specs(self) -> List[nodes.OutputTensor]: - """ - specs of outputs - """ - ... - - @abc.abstractmethod - def load(self) -> None: - """ - optional step: load model onto devices before calling forward if not using it as context manager - """ - ... - - @abc.abstractmethod - def unload(self) -> None: - """ - free any device memory in use - """ - ... - - -class _PredictionPipelineImpl(PredictionPipeline): - def __init__( - self, - *, - name: str, - bioimageio_model: Union[nodes.Model, raw_nodes.Model], - preprocessing: CombinedProcessing, - postprocessing: CombinedProcessing, - ipt_stats: StatsState, - out_stats: StatsState, - model: ModelAdapter, - ) -> None: - if bioimageio_model.run_mode: - warnings.warn(f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'") - - self._name = name - if isinstance(bioimageio_model, nodes.Model): - self._input_specs = bioimageio_model.inputs - self._output_specs = bioimageio_model.outputs - else: - assert isinstance(bioimageio_model, raw_nodes.Model) - self._input_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs] - self._output_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs] - - self._preprocessing = preprocessing - self._postprocessing = postprocessing - self._ipt_stats = ipt_stats - self._out_stats = out_stats - self._model: ModelAdapter = model - - def __call__(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - return self.forward(*input_tensors) - - def __enter__(self): - self.load() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.unload() - return False - - @property - def name(self): - return self._name - - @property - def input_specs(self): - return self._input_specs - - @property - def output_specs(self): - return self._output_specs - - def predict(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """Predict input_tensor with the model without applying pre/postprocessing.""" - return self._model.forward(*input_tensors) - - def apply_preprocessing(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - """apply preprocessing in-place, also updates given computed_measures""" - self._ipt_stats.update_with_sample(sample) - for mode, stats in self._ipt_stats.compute_measures().items(): - if mode not in computed_measures: - computed_measures[mode] = {} - computed_measures[mode].update(stats) - - self._preprocessing.apply(sample, computed_measures) - - def apply_postprocessing(self, sample: Sample, computed_measures: ComputedMeasures) -> None: - """apply postprocessing in-place, also updates given computed_measures""" - self._out_stats.update_with_sample(sample) - for mode, stats in self._out_stats.compute_measures().items(): - if mode not in computed_measures: - computed_measures[mode] = {} - computed_measures[mode].update(stats) - - self._postprocessing.apply(sample, computed_measures) - - def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]: - """Apply preprocessing, run prediction and apply postprocessing. - Note: The preprocessing might change input_tensors in-pace. - """ - input_sample = dict(zip([ipt.name for ipt in self.input_specs], input_tensors)) - computed_measures = {} - self.apply_preprocessing(input_sample, computed_measures) - - prediction_tensors = self.predict(*list(input_sample.values())) - prediction = dict(zip([out.name for out in self.output_specs], prediction_tensors)) - self.apply_postprocessing(prediction, computed_measures) - - return [prediction[tn] for tn in [out.name for out in self.output_specs]] - - def load(self): - self._model.load() - - def unload(self): - self._model.unload() - - -def create_prediction_pipeline( - bioimageio_model: Union[nodes.Model, raw_nodes.Model], - *, - devices: Optional[Sequence[str]] = None, - weight_format: Optional[str] = None, - dataset_for_initial_statistics: Iterable[Sequence[xr.DataArray]] = tuple(), - update_dataset_stats_after_n_samples: Optional[int] = None, - update_dataset_stats_for_n_samples: int = float("inf"), - model_adapter: Optional[ModelAdapter] = None, -) -> PredictionPipeline: - """ - Creates prediction pipeline which includes: - * computation of input statistics - * preprocessing - * model prediction - * computation of output statistics - * postprocessing - """ - model_adapter: ModelAdapter = model_adapter or create_model_adapter( - bioimageio_model=bioimageio_model, devices=devices, weight_format=weight_format - ) - if isinstance(bioimageio_model, nodes.Model): - ipts = bioimageio_model.inputs - outs = bioimageio_model.outputs - - else: - assert isinstance(bioimageio_model, raw_nodes.Model) - ipts = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs] - outs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs] - - preprocessing = CombinedProcessing.from_tensor_specs(ipts) - - def sample_dataset(): - for tensors in dataset_for_initial_statistics: - yield dict(zip([ipt.name for ipt in bioimageio_model.inputs], tensors)) - - ipt_stats = StatsState( - preprocessing.required_measures, - dataset=sample_dataset(), - update_dataset_stats_after_n_samples=update_dataset_stats_after_n_samples, - update_dataset_stats_for_n_samples=update_dataset_stats_for_n_samples, - ) - postprocessing = CombinedProcessing.from_tensor_specs(outs) - out_stats = StatsState( - postprocessing.required_measures, - dataset=tuple(), - update_dataset_stats_after_n_samples=0, - update_dataset_stats_for_n_samples=ipt_stats.sample_count + update_dataset_stats_for_n_samples, - ) - - return _PredictionPipelineImpl( - name=bioimageio_model.name, - bioimageio_model=bioimageio_model, - model=model_adapter, - preprocessing=preprocessing, - postprocessing=postprocessing, - ipt_stats=ipt_stats, - out_stats=out_stats, - ) diff --git a/bioimageio/core/prediction_pipeline/_processing.py b/bioimageio/core/prediction_pipeline/_processing.py deleted file mode 100644 index 6fbea8c6..00000000 --- a/bioimageio/core/prediction_pipeline/_processing.py +++ /dev/null @@ -1,299 +0,0 @@ -"""Here pre- and postprocessing operations are implemented according to their definitions in bioimageio.spec: -see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md -and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md -""" -import numbers -from dataclasses import InitVar, dataclass, field, fields -from typing import List, Mapping, Optional, Sequence, Tuple, Type, Union - -import numpy -import numpy as np -import xarray as xr - -from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std -from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName -from ._utils import ComputedMeasures, DatasetMode, FIXED, Mode, PER_DATASET, PER_SAMPLE, RequiredMeasures, SampleMode - -try: - from typing import Literal, get_args, TypedDict -except ImportError: - from typing_extensions import Literal, get_args, TypedDict # type: ignore - - -def _get_fixed( - fixed: Union[float, Sequence[float]], tensor: xr.DataArray, axes: Optional[Sequence[str]] -) -> Union[float, xr.DataArray]: - if axes is None: - return fixed - - fixed_shape = tuple(s for d, s in tensor.sizes.items() if d not in axes) - fixed_dims = tuple(d for d in tensor.dims if d not in axes) - fixed = np.array(fixed).reshape(fixed_shape) - return xr.DataArray(fixed, dims=fixed_dims) - - -TensorName = str - -MISSING = "MISSING" - - -@dataclass -class Processing: - """base class for all Pre- and Postprocessing transformations.""" - - tensor_name: str - # todo: in python>=3.10 we should use dataclasses.KW_ONLY instead of MISSING (see child classes) to make inheritance work properly - computed_measures: ComputedMeasures = field(default_factory=dict) - mode: Mode = FIXED - - def get_required_measures(self) -> RequiredMeasures: - return {} - - def set_computed_measures(self, computed: ComputedMeasures): - # check if computed contains all required measures - for mode, req_per_mode in self.get_required_measures().items(): - for tn, req_per_tn in req_per_mode.items(): - comp_measures = computed.get(mode, {}).get(tn, {}) - for req_measure in req_per_tn: - if req_measure not in comp_measures: - raise ValueError(f"Missing required {req_measure} for {tn} {mode}.") - - self.computed_measures = computed - - def get_computed_measure(self, tensor_name: TensorName, measure: Measure, *, mode: Optional[Mode] = None): - """helper to unpack self.computed_measures""" - ret = self.computed_measures.get(mode or self.mode, {}).get(tensor_name, {}).get(measure) - if ret is None: - raise RuntimeError(f"Missing computed {measure} for {tensor_name} {mode}.") - - return ret - - def __call__(self, tensor: xr.DataArray) -> xr.DataArray: - return self.apply(tensor) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - """apply processing""" - raise NotImplementedError - - def __post_init__(self): - # validate common kwargs by their annotations - for f in fields(self): - # check MISSING - if getattr(self, f.name) is MISSING: - raise TypeError(f"missing required argument {f.name}") - - if f.name == "mode": - # mode is always annotated as literals (or literals of literals) - valid_modes = get_args(f.type) - for inner in get_args(f.type): - valid_modes += get_args(inner) - - if self.mode not in valid_modes: - raise NotImplementedError(f"Unsupported mode {self.mode} for {self.__class__.__name__}") - - -# -# Pre- and Postprocessing implementations -# - - -@dataclass -class AssertDtype(Processing): - """Helper Processing to assert dtype.""" - - dtype: Union[str, Sequence[str]] = MISSING - assert_with: Tuple[Type[numpy.dtype], ...] = field(init=False) - - def __post_init__(self): - if isinstance(self.dtype, str): - dtype = [self.dtype] - else: - dtype = self.dtype - - self.assert_with = tuple(type(numpy.dtype(dt)) for dt in dtype) - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - assert isinstance(tensor.dtype, self.assert_with) - return tensor - - -@dataclass -class Binarize(Processing): - """'output = tensor > threshold'.""" - - threshold: float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value. - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor > self.threshold - - -@dataclass -class Clip(Processing): - """Limit tensor values to [min, max].""" - - min: float = MISSING - max: float = MISSING - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.clip(min=self.min, max=self.max) - - -@dataclass -class EnsureDtype(Processing): - """Helper Processing to cast dtype if needed.""" - - dtype: str = MISSING - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.astype(self.dtype) - - -@dataclass -class ScaleLinear(Processing): - """Scale the tensor with a fixed multiplicative and additive factor.""" - - gain: Union[float, Sequence[float]] = MISSING - offset: Union[float, Sequence[float]] = MISSING - axes: Optional[Sequence[str]] = None - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - scale_axes = tuple(ax for ax in tensor.dims if (ax not in self.axes and ax != "b")) - if scale_axes: - gain = xr.DataArray(np.atleast_1d(self.gain), dims=scale_axes) - offset = xr.DataArray(np.atleast_1d(self.offset), dims=scale_axes) - else: - gain = self.gain - offset = self.offset - - return tensor * gain + offset - - def __post_init__(self): - super().__post_init__() - if self.axes is None: - assert isinstance(self.gain, (int, float)) - assert isinstance(self.offset, (int, float)) - - -@dataclass -class ScaleMeanVariance(Processing): - """Scale the tensor s.t. its mean and variance match a reference tensor.""" - - mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE - reference_tensor: TensorName = MISSING - axes: Optional[Sequence[str]] = None - eps: float = 1e-6 - - def get_required_measures(self) -> RequiredMeasures: - axes = None if self.axes is None else tuple(self.axes) - return { - self.mode: { - self.tensor_name: {Mean(axes=axes), Std(axes=axes)}, - self.reference_tensor: {Mean(axes=axes), Std(axes=axes)}, - } - } - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axes = None if self.axes is None else tuple(self.axes) - assert self.mode in (PER_SAMPLE, PER_DATASET) - mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode) - ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode) - ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode) - - return (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean - - -@dataclass -class ScaleRange(Processing): - """Scale with percentiles.""" - - mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE - axes: Optional[Sequence[str]] = None - min_percentile: float = 0.0 - max_percentile: float = 100.0 - eps: float = 1e-6 - reference_tensor: Optional[TensorName] = None - - def get_required_measures(self) -> RequiredMeasures: - axes = None if self.axes is None else tuple(self.axes) - measures = {Percentile(self.min_percentile, axes=axes), Percentile(self.max_percentile, axes=axes)} - return {self.mode: {self.reference_tensor or self.tensor_name: measures}} - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - ref_name = self.reference_tensor or self.tensor_name - axes = None if self.axes is None else tuple(self.axes) - v_lower = self.get_computed_measure(ref_name, Percentile(self.min_percentile, axes=axes)) - v_upper = self.get_computed_measure(ref_name, Percentile(self.max_percentile, axes=axes)) - - return (tensor - v_lower) / (v_upper - v_lower + self.eps) - - def __post_init__(self): - super().__post_init__() - self.axes = None if self.axes is None else tuple(self.axes) # make sure axes is Tuple[str] or None - - -@dataclass -class Sigmoid(Processing): - """1 / (1 + e^(-tensor)).""" - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - return 1.0 / (1.0 + np.exp(-tensor)) - - -@dataclass -class ZeroMeanUnitVariance(Processing): - """normalize to zero mean, unit variance.""" - - mode: Mode = PER_SAMPLE - mean: Optional[Union[float, Sequence[float]]] = None - std: Optional[Union[float, Sequence[float]]] = None - axes: Optional[Sequence[str]] = None - eps: float = 1.0e-6 - - def get_required_measures(self) -> RequiredMeasures: - if self.mode == FIXED: - return {} - else: - axes = None if self.axes is None else tuple(self.axes) - return {self.mode: {self.tensor_name: {Mean(axes=axes), Std(axes=axes)}}} - - def apply(self, tensor: xr.DataArray) -> xr.DataArray: - axes = None if self.axes is None else tuple(self.axes) - if self.mode == FIXED: - assert self.mean is not None and self.std is not None - mean = _get_fixed(self.mean, tensor, axes) - std = _get_fixed(self.std, tensor, axes) - elif self.mode in (PER_SAMPLE, PER_DATASET): - assert self.mean is None and self.std is None - mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode) - std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode) - else: - raise ValueError(self.mode) - - return (tensor - mean) / (std + self.eps) - - -_KnownProcessing = TypedDict( - "_KnownProcessing", - dict(pre=Mapping[PreprocessingName, Type[Processing]], post=Mapping[PostprocessingName, Type[Processing]]), -) - -KNOWN_PROCESSING: _KnownProcessing = dict( - pre={ - "binarize": Binarize, - "clip": Clip, - "scale_linear": ScaleLinear, - "scale_range": ScaleRange, - "sigmoid": Sigmoid, - "zero_mean_unit_variance": ZeroMeanUnitVariance, - }, - post={ - "binarize": Binarize, - "clip": Clip, - "scale_linear": ScaleLinear, - "scale_mean_variance": ScaleMeanVariance, - "scale_range": ScaleRange, - "sigmoid": Sigmoid, - "zero_mean_unit_variance": ZeroMeanUnitVariance, - }, -) diff --git a/bioimageio/core/prediction_pipeline/_stat_state.py b/bioimageio/core/prediction_pipeline/_stat_state.py deleted file mode 100644 index 6de4d68d..00000000 --- a/bioimageio/core/prediction_pipeline/_stat_state.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Dict, Iterable, Optional - -from bioimageio.core.statistical_measures import Measure -from bioimageio.spec.shared.common import tqdm -from ._measure_groups import MeasureGroups, MeasureValue, get_measure_groups -from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample, TensorName - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - - -class StatsState: - """class to compute, hold and update dataset and sample statistics""" - - sample_count: int - last_sample: Optional[Sample] - measure_groups: MeasureGroups - _n_start: int - _n_stop: int - _final_dataset_stats: Optional[Dict[TensorName, Dict[Measure, MeasureValue]]] - - def __init__( - self, - required_measures: RequiredMeasures, - *, - dataset: Iterable[Sample] = tuple(), - update_dataset_stats_after_n_samples: Optional[int] = None, - update_dataset_stats_for_n_samples: int = float("inf"), - ): - """iterates over dataset to compute dataset statistics (if required). The resulting dataset statistics are further updated with each new sample. A sample in this context may be a mini-batch. - - Args: - required_measures: measures to be computed - dataset: (partial) dataset to initialize dataset statistics with - update_dataset_stats_after_n_samples: Update dataset statistics for new samples S_i if i > n. - (default: len(dataset)) - This parameter allows to avoid weighting the first n processed - samples to count twice if they make up the given 'dataset'. - update_dataset_stats_for_n_samples: stop updating dataset statistics with new samples S_i if - i > for_n_samples (+ update_dataset_stats_after_n_samples) - """ - self.required_measures = required_measures - self.update_dataset_stats_after_n_samples = update_dataset_stats_after_n_samples - self.update_dataset_stats_for_n_samples = update_dataset_stats_for_n_samples - self.reset(dataset) - - def reset(self, dataset: Iterable[Sample]): - self.sample_count = 0 - self.last_sample = None - self._final_dataset_stats = None - self.measure_groups = get_measure_groups(self.required_measures) - - len_dataset = 0 - if self.measure_groups[PER_DATASET]: - for sample in tqdm(dataset, "computing dataset statistics"): - len_dataset += 1 - self._update_dataset_measure_groups(sample) - - if self.update_dataset_stats_after_n_samples is None: - self._n_start = len_dataset - else: - self._n_start = self.update_dataset_stats_after_n_samples - - self._n_stop = self._n_start + self.update_dataset_stats_for_n_samples - - def update_with_sample(self, sample: Sample): - self.last_sample = sample - self.sample_count += 1 - if self._n_start < self.sample_count <= self._n_stop: - self._update_dataset_measure_groups(sample) - - def _update_dataset_measure_groups(self, sample: Sample): - for mg in self.measure_groups[PER_DATASET]: - mg.update_with_sample(sample) - - def compute_measures(self) -> ComputedMeasures: - ret = {PER_SAMPLE: {}, PER_DATASET: {}} - if self.last_sample is not None: - for mg in self.measure_groups[PER_SAMPLE]: - ret[PER_SAMPLE].update(mg.compute(self.last_sample)) - - if self._final_dataset_stats is None: - dataset_stats = {} - for mg in self.measure_groups[PER_DATASET]: - dataset_stats.update(mg.finalize()) - - if self.sample_count > self._n_stop: - # stop recomputing final dataset statistics - self._final_dataset_stats = dataset_stats - else: - dataset_stats = self._final_dataset_stats - - ret[PER_DATASET] = dataset_stats - return ret diff --git a/bioimageio/core/prediction_pipeline/_utils.py b/bioimageio/core/prediction_pipeline/_utils.py deleted file mode 100644 index 8b39753d..00000000 --- a/bioimageio/core/prediction_pipeline/_utils.py +++ /dev/null @@ -1,25 +0,0 @@ -from typing import Dict, Set - -import xarray as xr - -from bioimageio.core.statistical_measures import Measure, MeasureValue - -try: - from typing import Literal -except ImportError: - from typing_extensions import Literal # type: ignore - -TensorName = str -FixedMode = Literal["fixed"] -SampleMode = Literal["per_sample"] -DatasetMode = Literal["per_dataset"] -Mode = Literal[FixedMode, SampleMode, DatasetMode] - -FIXED: FixedMode = "fixed" -PER_SAMPLE: SampleMode = "per_sample" -PER_DATASET: DatasetMode = "per_dataset" -MODES: Set[Mode] = {FIXED, PER_SAMPLE, PER_DATASET} - -Sample = Dict[TensorName, xr.DataArray] -RequiredMeasures = Dict[Literal[SampleMode, DatasetMode], Dict[TensorName, Set[Measure]]] -ComputedMeasures = Dict[Literal[SampleMode, DatasetMode], Dict[TensorName, Dict[Measure, MeasureValue]]] diff --git a/bioimageio/core/proc_ops.py b/bioimageio/core/proc_ops.py new file mode 100644 index 00000000..b05d7c8b --- /dev/null +++ b/bioimageio/core/proc_ops.py @@ -0,0 +1,688 @@ +import collections.abc +from abc import ABC, abstractmethod +from dataclasses import InitVar, dataclass, field +from typing import ( + Collection, + Literal, + Mapping, + Optional, + Sequence, + Set, + Tuple, + Union, +) + +import numpy as np +import xarray as xr +from typing_extensions import Self, assert_never + +from bioimageio.core.block import Block +from bioimageio.core.sample import Sample, SampleBlock, SampleBlockWithOrigin +from bioimageio.spec.model import v0_4, v0_5 + +from ._op_base import BlockedOperator, Operator +from .axis import AxisId, PerAxis +from .common import DTypeStr, MemberId +from .stat_calculators import StatsCalculator +from .stat_measures import ( + DatasetMean, + DatasetMeasure, + DatasetPercentile, + DatasetStd, + MeanMeasure, + Measure, + MeasureValue, + SampleMean, + SampleQuantile, + SampleStd, + Stat, + StdMeasure, +) +from .tensor import Tensor + + +def convert_axis_ids( + axes: Union[Sequence[AxisId], v0_4.AxesInCZYX], + mode: Literal["per_sample", "per_dataset"], +) -> Tuple[AxisId, ...]: + if not isinstance(axes, str): + return tuple(axes) + + axis_map = dict(b=AxisId("batch"), c=AxisId("channel"), i=AxisId("index")) + if mode == "per_sample": + ret = [] + elif mode == "per_dataset": + ret = [AxisId("batch")] + else: + assert_never(mode) + + ret.extend([axis_map.get(a, AxisId(a)) for a in axes]) + return tuple(ret) + + +@dataclass +class _SimpleOperator(BlockedOperator, ABC): + input: MemberId + output: MemberId + + @property + def required_measures(self) -> Collection[Measure]: + return set() + + @abstractmethod + def get_output_shape(self, input_shape: PerAxis[int]) -> PerAxis[int]: ... + + def __call__(self, sample: Union[Sample, SampleBlock]) -> None: + input_tensor = sample.members[self.input] + output_tensor = self._apply(input_tensor, sample.stat) + + if self.output in sample.members: + assert ( + sample.members[self.output].tagged_shape == output_tensor.tagged_shape + ) + + if isinstance(sample, Sample): + sample.members[self.output] = output_tensor + elif isinstance(sample, SampleBlock): + b = sample.blocks[self.input] + sample.blocks[self.output] = Block( + sample_shape=self.get_output_shape(sample.shape[self.input]), + data=output_tensor, + inner_slice=b.inner_slice, + halo=b.halo, + block_index=b.block_index, + blocks_in_sample=b.blocks_in_sample, + ) + else: + assert_never(sample) + + @abstractmethod + def _apply(self, input: Tensor, stat: Stat) -> Tensor: ... + + +@dataclass +class AddKnownDatasetStats(BlockedOperator): + dataset_stats: Mapping[DatasetMeasure, MeasureValue] + + @property + def required_measures(self) -> Set[Measure]: + return set() + + def __call__(self, sample: Union[Sample, SampleBlock]) -> None: + sample.stat.update(self.dataset_stats.items()) + + +# @dataclass +# class UpdateStats(Operator): +# """Calculates sample and/or dataset measures""" + +# measures: Union[Sequence[Measure], Set[Measure], Mapping[Measure, MeasureValue]] +# """sample and dataset `measuers` to be calculated by this operator. Initial/fixed +# dataset measure values may be given, see `keep_updating_dataset_stats` for details. +# """ +# keep_updating_dataset_stats: Optional[bool] = None +# """indicates if operator calls should keep updating dataset statistics or not + +# default (None): if `measures` is a `Mapping` (i.e. initial measure values are +# given) no further updates to dataset statistics is conducted, otherwise (w.o. +# initial measure values) dataset statistics are updated by each processed sample. +# """ +# _keep_updating_dataset_stats: bool = field(init=False) +# _stats_calculator: StatsCalculator = field(init=False) + +# @property +# def required_measures(self) -> Set[Measure]: +# return set() + +# def __post_init__(self): +# self._stats_calculator = StatsCalculator(self.measures) +# if self.keep_updating_dataset_stats is None: +# self._keep_updating_dataset_stats = not isinstance(self.measures, collections.abc.Mapping) +# else: +# self._keep_updating_dataset_stats = self.keep_updating_dataset_stats + +# def __call__(self, sample_block: SampleBlockWithOrigin> None: +# if self._keep_updating_dataset_stats: +# sample.stat.update(self._stats_calculator.update_and_get_all(sample)) +# else: +# sample.stat.update(self._stats_calculator.skip_update_and_get_all(sample)) + + +@dataclass +class UpdateStats(Operator): + """Calculates sample and/or dataset measures""" + + stats_calculator: StatsCalculator + """`StatsCalculator` to be used by this operator.""" + keep_updating_initial_dataset_stats: bool = False + """indicates if operator calls should keep updating initial dataset statistics or not; + if the `stats_calculator` was not provided with any initial dataset statistics, + these are always updated with every new sample. + """ + _keep_updating_dataset_stats: bool = field(init=False) + + @property + def required_measures(self) -> Set[Measure]: + return set() + + def __post_init__(self): + self._keep_updating_dataset_stats = ( + self.keep_updating_initial_dataset_stats + or not self.stats_calculator.has_dataset_measures + ) + + def __call__(self, sample: Union[Sample, SampleBlockWithOrigin]) -> None: + if isinstance(sample, SampleBlockWithOrigin): + # update stats with whole sample on first block + if sample.block_index != 0: + return + + origin = sample.origin + else: + origin = sample + + if self._keep_updating_dataset_stats: + sample.stat.update(self.stats_calculator.update_and_get_all(origin)) + else: + sample.stat.update(self.stats_calculator.skip_update_and_get_all(origin)) + + +@dataclass +class Binarize(_SimpleOperator): + """'output = tensor > threshold'.""" + + threshold: Union[float, Sequence[float]] + axis: Optional[AxisId] = None + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input > self.threshold + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, descr: Union[v0_4.BinarizeDescr, v0_5.BinarizeDescr], member_id: MemberId + ) -> Self: + if isinstance(descr.kwargs, (v0_4.BinarizeKwargs, v0_5.BinarizeKwargs)): + return cls( + input=member_id, output=member_id, threshold=descr.kwargs.threshold + ) + elif isinstance(descr.kwargs, v0_5.BinarizeAlongAxisKwargs): + return cls( + input=member_id, + output=member_id, + threshold=descr.kwargs.threshold, + axis=descr.kwargs.axis, + ) + else: + assert_never(descr.kwargs) + + +@dataclass +class Clip(_SimpleOperator): + min: Optional[float] = None + """minimum value for clipping""" + max: Optional[float] = None + """maximum value for clipping""" + + def __post_init__(self): + assert self.min is not None or self.max is not None, "missing min or max value" + assert ( + self.min is None or self.max is None or self.min < self.max + ), f"expected min < max, but {self.min} !< {self.max}" + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input.clip(self.min, self.max) + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, descr: Union[v0_4.ClipDescr, v0_5.ClipDescr], member_id: MemberId + ) -> Self: + return cls( + input=member_id, + output=member_id, + min=descr.kwargs.min, + max=descr.kwargs.max, + ) + + +@dataclass +class EnsureDtype(_SimpleOperator): + dtype: DTypeStr + + @classmethod + def from_proc_descr(cls, descr: v0_5.EnsureDtypeDescr, member_id: MemberId): + return cls(input=member_id, output=member_id, dtype=descr.kwargs.dtype) + + def get_descr(self): + return v0_5.EnsureDtypeDescr(kwargs=v0_5.EnsureDtypeKwargs(dtype=self.dtype)) + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input.astype(self.dtype) + + +@dataclass +class ScaleLinear(_SimpleOperator): + gain: Union[float, xr.DataArray] = 1.0 + """multiplicative factor""" + + offset: Union[float, xr.DataArray] = 0.0 + """additive term""" + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return input * self.gain + self.offset + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, + descr: Union[v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr], + member_id: MemberId, + ) -> Self: + kwargs = descr.kwargs + if isinstance(kwargs, v0_5.ScaleLinearAlongAxisKwargs): + axis = kwargs.axis + elif isinstance(kwargs, (v0_4.ScaleLinearKwargs, v0_5.ScaleLinearKwargs)): + axis = None + else: + assert_never(kwargs) + + if axis: + gain = xr.DataArray(np.atleast_1d(kwargs.gain), dims=axis) + offset = xr.DataArray(np.atleast_1d(kwargs.offset), dims=axis) + else: + assert ( + isinstance(kwargs.gain, (float, int)) or len(kwargs.gain) == 1 + ), kwargs.gain + gain = ( + kwargs.gain if isinstance(kwargs.gain, (float, int)) else kwargs.gain[0] + ) + assert isinstance(kwargs.offset, (float, int)) or len(kwargs.offset) == 1 + offset = ( + kwargs.offset + if isinstance(kwargs.offset, (float, int)) + else kwargs.offset[0] + ) + + return cls(input=member_id, output=member_id, gain=gain, offset=offset) + + +@dataclass +class ScaleMeanVariance(_SimpleOperator): + axes: Optional[Sequence[AxisId]] = None + reference_tensor: Optional[MemberId] = None + eps: float = 1e-6 + mean: Union[SampleMean, DatasetMean] = field(init=False) + std: Union[SampleStd, DatasetStd] = field(init=False) + ref_mean: Union[SampleMean, DatasetMean] = field(init=False) + ref_std: Union[SampleStd, DatasetStd] = field(init=False) + + @property + def required_measures(self): + return {self.mean, self.std, self.ref_mean, self.ref_std} + + def __post_init__(self): + axes = None if self.axes is None else tuple(self.axes) + ref_tensor = self.reference_tensor or self.input + if axes is None or AxisId("batch") not in axes: + Mean = SampleMean + Std = SampleStd + else: + Mean = DatasetMean + Std = DatasetStd + + self.mean = Mean(member_id=self.input, axes=axes) + self.std = Std(member_id=self.input, axes=axes) + self.ref_mean = Mean(member_id=ref_tensor, axes=axes) + self.ref_std = Std(member_id=ref_tensor, axes=axes) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + mean = stat[self.mean] + std = stat[self.std] + self.eps + ref_mean = stat[self.ref_mean] + ref_std = stat[self.ref_std] + self.eps + return (input - mean) / std * ref_std + ref_mean + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, + descr: Union[v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr], + member_id: MemberId, + ) -> Self: + kwargs = descr.kwargs + axes = _get_axes(descr.kwargs) + + return cls( + input=member_id, + output=member_id, + reference_tensor=MemberId(str(kwargs.reference_tensor)), + axes=axes, + eps=kwargs.eps, + ) + + +def _get_axes( + kwargs: Union[ + v0_4.ZeroMeanUnitVarianceKwargs, + v0_5.ZeroMeanUnitVarianceKwargs, + v0_4.ScaleRangeKwargs, + v0_5.ScaleRangeKwargs, + v0_4.ScaleMeanVarianceKwargs, + v0_5.ScaleMeanVarianceKwargs, + ] +) -> Union[Tuple[AxisId, ...], None]: + if kwargs.axes is None: + axes = None + elif isinstance(kwargs.axes, str): + axes = convert_axis_ids(kwargs.axes, kwargs["mode"]) + elif isinstance(kwargs.axes, collections.abc.Sequence): + axes = tuple(kwargs.axes) + else: + assert_never(kwargs.axes) + + return axes + + +@dataclass +class ScaleRange(_SimpleOperator): + lower_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None + upper_percentile: InitVar[Optional[Union[SampleQuantile, DatasetPercentile]]] = None + lower: Union[SampleQuantile, DatasetPercentile] = field(init=False) + upper: Union[SampleQuantile, DatasetPercentile] = field(init=False) + + eps: float = 1e-6 + + def __post_init__( + self, + lower_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], + upper_percentile: Optional[Union[SampleQuantile, DatasetPercentile]], + ): + if lower_percentile is None: + tid = self.input if upper_percentile is None else upper_percentile.member_id + self.lower = DatasetPercentile(q=0.0, member_id=tid) + else: + self.lower = lower_percentile + + if upper_percentile is None: + self.upper = DatasetPercentile(q=1.0, member_id=self.lower.member_id) + else: + self.upper = upper_percentile + + assert self.lower.member_id == self.upper.member_id + assert self.lower.q < self.upper.q + assert self.lower.axes == self.upper.axes + + @property + def required_measures(self): + return {self.lower, self.upper} + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, + descr: Union[v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr], + member_id: MemberId, + ): + kwargs = descr.kwargs + ref_tensor = ( + member_id + if kwargs.reference_tensor is None + else MemberId(str(kwargs.reference_tensor)) + ) + axes = _get_axes(descr.kwargs) + if axes is None or AxisId("batch") in axes: + Percentile = DatasetPercentile + else: + Percentile = SampleQuantile + + return cls( + input=member_id, + output=member_id, + lower_percentile=Percentile( + q=kwargs.min_percentile / 100, axes=axes, member_id=ref_tensor + ), + upper_percentile=Percentile( + q=kwargs.max_percentile / 100, axes=axes, member_id=ref_tensor + ), + ) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + lower = stat[self.lower] + upper = stat[self.upper] + return (input - lower) / (upper - lower + self.eps) + + def get_descr(self): + assert self.lower.axes == self.upper.axes + assert self.lower.member_id == self.upper.member_id + + return v0_5.ScaleRangeDescr( + kwargs=v0_5.ScaleRangeKwargs( + axes=self.lower.axes, + min_percentile=self.lower.q * 100, + max_percentile=self.upper.q * 100, + eps=self.eps, + reference_tensor=self.lower.member_id, + ) + ) + + +@dataclass +class Sigmoid(_SimpleOperator): + """1 / (1 + e^(-input)).""" + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return Tensor(1.0 / (1.0 + np.exp(-input)), dims=input.dims) + + @property + def required_measures(self) -> Collection[Measure]: + return {} + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, descr: Union[v0_4.SigmoidDescr, v0_5.SigmoidDescr], member_id: MemberId + ) -> Self: + assert isinstance(descr, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)) + return cls(input=member_id, output=member_id) + + def get_descr(self): + return v0_5.SigmoidDescr() + + +@dataclass +class ZeroMeanUnitVariance(_SimpleOperator): + """normalize to zero mean, unit variance.""" + + mean: MeanMeasure + std: StdMeasure + + eps: float = 1e-6 + + def __post_init__(self): + assert self.mean.axes == self.std.axes + + @property + def required_measures(self) -> Set[Union[MeanMeasure, StdMeasure]]: + return {self.mean, self.std} + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, + descr: Union[v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr], + member_id: MemberId, + ): + axes = _get_axes(descr.kwargs) + + if axes is None or AxisId("batch") in axes: + Mean = DatasetMean + Std = DatasetStd + else: + Mean = SampleMean + Std = SampleStd + + return cls( + input=member_id, + output=member_id, + mean=Mean(axes=axes, member_id=member_id), + std=Std(axes=axes, member_id=member_id), + ) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + mean = stat[self.mean] + std = stat[self.std] + return (input - mean) / (std + self.eps) + + def get_descr(self): + return v0_5.ZeroMeanUnitVarianceDescr( + kwargs=v0_5.ZeroMeanUnitVarianceKwargs(axes=self.mean.axes, eps=self.eps) + ) + + +@dataclass +class FixedZeroMeanUnitVariance(_SimpleOperator): + """normalize to zero mean, unit variance with precomputed values.""" + + mean: Union[float, xr.DataArray] + std: Union[float, xr.DataArray] + + eps: float = 1e-6 + + def __post_init__(self): + assert ( + isinstance(self.mean, (int, float)) + or isinstance(self.std, (int, float)) + or self.mean.dims == self.std.dims + ) + + def get_output_shape( + self, input_shape: Mapping[AxisId, int] + ) -> Mapping[AxisId, int]: + return input_shape + + @classmethod + def from_proc_descr( + cls, + descr: v0_5.FixedZeroMeanUnitVarianceDescr, + member_id: MemberId, + ) -> Self: + if isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceKwargs): + dims = None + elif isinstance(descr.kwargs, v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs): + dims = (descr.kwargs.axis,) + else: + assert_never(descr.kwargs) + + return cls( + input=member_id, + output=member_id, + mean=xr.DataArray(descr.kwargs.mean, dims=dims), + std=xr.DataArray(descr.kwargs.std, dims=dims), + ) + + def get_descr(self): + if isinstance(self.mean, (int, float)): + assert isinstance(self.std, (int, float)) + kwargs = v0_5.FixedZeroMeanUnitVarianceKwargs(mean=self.mean, std=self.std) + else: + assert isinstance(self.std, xr.DataArray) + assert len(self.mean.dims) == 1 + kwargs = v0_5.FixedZeroMeanUnitVarianceAlongAxisKwargs( + axis=AxisId(str(self.mean.dims[0])), + mean=list(self.mean), + std=list(self.std), + ) + + return v0_5.FixedZeroMeanUnitVarianceDescr(kwargs=kwargs) + + def _apply(self, input: Tensor, stat: Stat) -> Tensor: + return (input - self.mean) / (self.std + self.eps) + + +ProcDescr = Union[ + v0_4.PreprocessingDescr, + v0_4.PostprocessingDescr, + v0_5.PreprocessingDescr, + v0_5.PostprocessingDescr, +] + +Processing = Union[ + AddKnownDatasetStats, + Binarize, + Clip, + EnsureDtype, + FixedZeroMeanUnitVariance, + ScaleLinear, + ScaleMeanVariance, + ScaleRange, + Sigmoid, + UpdateStats, + ZeroMeanUnitVariance, +] + + +def get_proc_class(proc_spec: ProcDescr): + if isinstance(proc_spec, (v0_4.BinarizeDescr, v0_5.BinarizeDescr)): + return Binarize + elif isinstance(proc_spec, (v0_4.ClipDescr, v0_5.ClipDescr)): + return Clip + elif isinstance(proc_spec, v0_5.EnsureDtypeDescr): + return EnsureDtype + elif isinstance(proc_spec, v0_5.FixedZeroMeanUnitVarianceDescr): + return FixedZeroMeanUnitVariance + elif isinstance(proc_spec, (v0_4.ScaleLinearDescr, v0_5.ScaleLinearDescr)): + return ScaleLinear + elif isinstance( + proc_spec, (v0_4.ScaleMeanVarianceDescr, v0_5.ScaleMeanVarianceDescr) + ): + return ScaleMeanVariance + elif isinstance(proc_spec, (v0_4.ScaleRangeDescr, v0_5.ScaleRangeDescr)): + return ScaleRange + elif isinstance(proc_spec, (v0_4.SigmoidDescr, v0_5.SigmoidDescr)): + return Sigmoid + elif ( + isinstance(proc_spec, v0_4.ZeroMeanUnitVarianceDescr) + and proc_spec.kwargs.mode == "fixed" + ): + return FixedZeroMeanUnitVariance + elif isinstance( + proc_spec, + (v0_4.ZeroMeanUnitVarianceDescr, v0_5.ZeroMeanUnitVarianceDescr), + ): + return ZeroMeanUnitVariance + else: + assert_never(proc_spec) diff --git a/bioimageio/core/proc_setup.py b/bioimageio/core/proc_setup.py new file mode 100644 index 00000000..947ea0c2 --- /dev/null +++ b/bioimageio/core/proc_setup.py @@ -0,0 +1,151 @@ +from typing import ( + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Sequence, + Set, + Union, +) + +from typing_extensions import assert_never + +from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 +from bioimageio.spec.model.v0_5 import TensorId + +from .proc_ops import AddKnownDatasetStats, Processing, UpdateStats, get_proc_class +from .sample import Sample +from .stat_calculators import StatsCalculator +from .stat_measures import DatasetMeasure, Measure, MeasureValue + +TensorDescr = Union[ + v0_4.InputTensorDescr, + v0_4.OutputTensorDescr, + v0_5.InputTensorDescr, + v0_5.OutputTensorDescr, +] + + +class PreAndPostprocessing(NamedTuple): + pre: List[Processing] + post: List[Processing] + + +class _SetupProcessing(NamedTuple): + pre: List[Processing] + post: List[Processing] + pre_measures: Set[Measure] + post_measures: Set[Measure] + + +def setup_pre_and_postprocessing( + model: AnyModelDescr, + dataset_for_initial_statistics: Iterable[Sample], + keep_updating_initial_dataset_stats: bool = False, + fixed_dataset_stats: Optional[Mapping[DatasetMeasure, MeasureValue]] = None, +) -> PreAndPostprocessing: + """ + Get pre- and postprocessing operators for a `model` description. + userd in `bioimageio.core.create_prediction_pipeline""" + prep, post, prep_meas, post_meas = _prepare_setup_pre_and_postprocessing(model) + + missing_dataset_stats = { + m + for m in prep_meas | post_meas + if fixed_dataset_stats is None or m not in fixed_dataset_stats + } + initial_stats_calc = StatsCalculator(missing_dataset_stats) + for sample in dataset_for_initial_statistics: + initial_stats_calc.update(sample) + + initial_stats = initial_stats_calc.finalize() + prep.insert( + 0, + UpdateStats( + StatsCalculator(prep_meas, initial_stats), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + ), + ) + if post_meas: + post.insert( + 0, + UpdateStats( + StatsCalculator(post_meas, initial_stats), + keep_updating_initial_dataset_stats=keep_updating_initial_dataset_stats, + ), + ) + + if fixed_dataset_stats: + prep.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) + post.insert(0, AddKnownDatasetStats(fixed_dataset_stats)) + + return PreAndPostprocessing(prep, post) + + +def _prepare_setup_pre_and_postprocessing(model: AnyModelDescr) -> _SetupProcessing: + pre_measures: Set[Measure] = set() + post_measures: Set[Measure] = set() + + if isinstance(model, v0_4.ModelDescr): + input_ids = {TensorId(str(d.name)) for d in model.inputs} + output_ids = {TensorId(str(d.name)) for d in model.outputs} + else: + input_ids = {d.id for d in model.inputs} + output_ids = {d.id for d in model.outputs} + + def prepare_procs(tensor_descrs: Sequence[TensorDescr]): + procs: List[Processing] = [] + for t_descr in tensor_descrs: + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_5.InputTensorDescr)): + proc_descrs: List[ + Union[ + v0_4.PreprocessingDescr, + v0_5.PreprocessingDescr, + v0_4.PostprocessingDescr, + v0_5.PostprocessingDescr, + ] + ] = list(t_descr.preprocessing) + elif isinstance( + t_descr, + (v0_4.OutputTensorDescr, v0_5.OutputTensorDescr), + ): + proc_descrs = list(t_descr.postprocessing) + else: + assert_never(t_descr) + + if isinstance(t_descr, (v0_4.InputTensorDescr, v0_4.OutputTensorDescr)): + ensure_dtype = v0_5.EnsureDtypeDescr( + kwargs=v0_5.EnsureDtypeKwargs(dtype=t_descr.data_type) + ) + if isinstance(t_descr, v0_4.InputTensorDescr) and proc_descrs: + proc_descrs.insert(0, ensure_dtype) + + proc_descrs.append(ensure_dtype) + + for proc_d in proc_descrs: + proc_class = get_proc_class(proc_d) + member_id = ( + TensorId(str(t_descr.name)) + if isinstance(t_descr, v0_4.TensorDescrBase) + else t_descr.id + ) + req = proc_class.from_proc_descr( + proc_d, member_id # pyright: ignore[reportArgumentType] + ) + for m in req.required_measures: + if m.member_id in input_ids: + pre_measures.add(m) + elif m.member_id in output_ids: + post_measures.add(m) + else: + raise ValueError("When to raise ") + procs.append(req) + return procs + + return _SetupProcessing( + pre=prepare_procs(model.inputs), + post=prepare_procs(model.outputs), + pre_measures=pre_measures, + post_measures=post_measures, + ) diff --git a/bioimageio/core/py.typed b/bioimageio/core/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/bioimageio/core/resource_io/__init__.py b/bioimageio/core/resource_io/__init__.py deleted file mode 100644 index bdfb805f..00000000 --- a/bioimageio/core/resource_io/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -import bioimageio.spec -from .io_ import ( - export_resource_package, - load_resource_description, - save_raw_resource_description, - serialize_raw_resource_description, -) - -load_raw_resource_description = bioimageio.spec.load_raw_resource_description diff --git a/bioimageio/core/resource_io/io_.py b/bioimageio/core/resource_io/io_.py deleted file mode 100644 index 3bb98ead..00000000 --- a/bioimageio/core/resource_io/io_.py +++ /dev/null @@ -1,185 +0,0 @@ -import os -import pathlib -from copy import deepcopy -from tempfile import TemporaryDirectory -from typing import Dict, Optional, Sequence, Union -from zipfile import ZIP_DEFLATED, ZipFile - -from marshmallow import missing - -from bioimageio import spec -from bioimageio.core.resource_io.nodes import ResourceDescription -from bioimageio.spec import load_raw_resource_description -from bioimageio.spec.shared import raw_nodes -from bioimageio.spec.shared.common import ( - BIOIMAGEIO_CACHE_PATH, - BIOIMAGEIO_USE_CACHE, - get_class_name_from_type, - no_cache_tmp_list, -) -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription -from . import nodes -from .utils import resolve_raw_node, resolve_source - -serialize_raw_resource_description = spec.io_.serialize_raw_resource_description -save_raw_resource_description = spec.io_.save_raw_resource_description - - -def load_resource_description( - source: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - *, - weights_priority_order: Optional[Sequence[str]] = None, # model only -) -> ResourceDescription: - """load a BioImage.IO resource description file (RDF). - This includes some transformations for convenience, e.g. importing `source`. - Use `load_raw_resource_description` to obtain a raw representation instead. - - Args: - source: resource description file (RDF) or raw BioImage.IO resource - weights_priority_order: If given only the first weights format present in the model resource is included - Returns: - BioImage.IO resource - """ - source = deepcopy(source) - if isinstance(source, ResourceDescription): - return source - - raw_rd = load_raw_resource_description(source, update_to_format="latest") - - if raw_rd.type == "model" and weights_priority_order is not None: - for wf in weights_priority_order: - if wf in raw_rd.weights: - raw_rd.weights = {wf: raw_rd.weights[wf]} - break - else: - raise ValueError(f"Not found any of the specified weights formats {weights_priority_order}") - - rd: ResourceDescription = resolve_raw_node(raw_rd=raw_rd, nodes_module=nodes) - assert isinstance(rd, getattr(nodes, get_class_name_from_type(raw_rd.type))) - - return rd - - -def get_local_resource_package_content( - source: RawResourceDescription, - weights_priority_order: Optional[Sequence[Union[str]]], - update_to_format: Optional[str] = None, -) -> Dict[str, Union[pathlib.Path, str]]: - """ - - Args: - source: raw resource description - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - update_to_format: update resource to specific major.minor format version; ignoring patch version. - - Returns: - Package content of local file paths or text content keyed by file names. - - """ - raw_rd = load_raw_resource_description(source, update_to_format=update_to_format) - package_content = spec.get_resource_package_content(raw_rd, weights_priority_order=weights_priority_order) - - local_package_content = {} - for k, v in package_content.items(): - if isinstance(v, raw_nodes.URI): - v = resolve_source(v, raw_rd.root_path) - elif isinstance(v, pathlib.Path): - v = raw_rd.root_path / v - - local_package_content[k] = v - - return local_package_content - - -def export_resource_package( - source: Union[RawResourceDescription, os.PathLike, str, dict, raw_nodes.URI], - *, - compression: int = ZIP_DEFLATED, - compression_level: int = 1, - output_path: Optional[os.PathLike] = None, - update_to_format: Optional[str] = None, - weights_priority_order: Optional[Sequence[Union[str]]] = None, -) -> pathlib.Path: - """Package a BioImage.IO resource as a zip file. - - Args: - source: raw resource description, path, URI or raw data as dict - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - output_path: file path to write package to - update_to_format: update resource to specific "major.minor" or "latest" format version; ignoring patch version. - weights_priority_order: If given only the first weights format present in the model is included. - If none of the prioritized weights formats is found all are included. - - Returns: - path to zipped BioImage.IO package in BIOIMAGEIO_CACHE_PATH or 'output_path' - """ - raw_rd = load_raw_resource_description(source, update_to_format=update_to_format) - package_content = get_local_resource_package_content( - raw_rd, weights_priority_order, update_to_format=update_to_format - ) - if output_path is None: - package_path = _get_tmp_package_path(raw_rd, weights_priority_order) - else: - package_path = output_path - - make_zip(package_path, package_content, compression=compression, compression_level=compression_level) - return package_path - - -def _get_package_base_name(raw_rd: RawResourceDescription, weights_priority_order: Optional[Sequence[str]]) -> str: - package_file_name = raw_rd.name - if raw_rd.version is not missing: - package_file_name += f"_{raw_rd.version}" - - package_file_name = package_file_name.replace(" ", "_").replace(".", "_") - - return package_file_name - - -def _get_tmp_package_path(raw_rd: RawResourceDescription, weights_priority_order: Optional[Sequence[str]]): - if BIOIMAGEIO_USE_CACHE: - package_file_name = _get_package_base_name(raw_rd, weights_priority_order) - cache_folder = BIOIMAGEIO_CACHE_PATH / "packages" - cache_folder.mkdir(exist_ok=True, parents=True) - - package_path = (cache_folder / package_file_name).with_suffix(".zip") - max_cached_packages_with_same_name = 100 - for p in range(max_cached_packages_with_same_name): - if package_path.exists(): - package_path = (cache_folder / f"{package_file_name}p{p}").with_suffix(".zip") - else: - break - else: - raise FileExistsError( - f"Already caching {max_cached_packages_with_same_name} versions of {cache_folder / package_file_name}!" - ) - else: - tmp_dir = TemporaryDirectory() - no_cache_tmp_list.append(tmp_dir) - package_path = pathlib.Path(tmp_dir.name) / "file" - - return package_path - - -def make_zip( - path: os.PathLike, content: Dict[str, Union[str, pathlib.Path]], *, compression: int, compression_level: int -) -> None: - """Write a zip archive. - - Args: - path: output path to write to. - content: dict with archive names and local file paths or strings for text files. - compression: The numeric constant of compression method. - compression_level: Compression level to use when writing files to the archive. - See https://docs.python.org/3/library/zipfile.html#zipfile.ZipFile - - """ - with ZipFile(path, "w", compression=compression, compresslevel=compression_level) as myzip: - for arc_name, file_or_str_content in content.items(): - if isinstance(file_or_str_content, str): - myzip.writestr(arc_name, file_or_str_content) - else: - myzip.write(file_or_str_content, arcname=arc_name) diff --git a/bioimageio/core/resource_io/nodes.py b/bioimageio/core/resource_io/nodes.py deleted file mode 100644 index 47e2035f..00000000 --- a/bioimageio/core/resource_io/nodes.py +++ /dev/null @@ -1,187 +0,0 @@ -import pathlib -from dataclasses import dataclass -from pathlib import Path -from typing import Any, Callable, Dict, List, Tuple, Union - -from marshmallow import missing -from marshmallow.utils import _Missing - -from bioimageio.spec.model import raw_nodes as model_raw_nodes -from bioimageio.spec.rdf import raw_nodes as rdf_raw_nodes -from bioimageio.spec.collection import raw_nodes as collection_raw_nodes -from bioimageio.spec.shared import raw_nodes - - -@dataclass -class Node(raw_nodes.RawNode): - pass - - -@dataclass -class ResourceDescription(Node, raw_nodes.ResourceDescription): - pass - - -@dataclass -class URI(Node, raw_nodes.URI): - pass - - -@dataclass -class ParametrizedInputShape(Node, raw_nodes.ParametrizedInputShape): - pass - - -@dataclass -class ImplicitOutputShape(Node, raw_nodes.ImplicitOutputShape): - pass - - -@dataclass -class Dependencies(Node, raw_nodes.Dependencies): - file: pathlib.Path = missing - - -@dataclass -class CiteEntry(Node, rdf_raw_nodes.CiteEntry): - pass - - -@dataclass -class Author(Node, model_raw_nodes.Author): - pass - - -@dataclass -class Maintainer(Node, model_raw_nodes.Maintainer): - pass - - -@dataclass -class Badge(Node, rdf_raw_nodes.Badge): - pass - - -@dataclass -class RDF(rdf_raw_nodes.RDF, ResourceDescription): - badges: Union[_Missing, List[Badge]] = missing - covers: Union[_Missing, List[Path]] = missing - - -@dataclass -class CollectionEntry(Node, collection_raw_nodes.CollectionEntry): - source: URI = missing - - -@dataclass -class LinkedDataset(Node, model_raw_nodes.LinkedDataset): - pass - - -@dataclass -class ModelParent(Node, model_raw_nodes.ModelParent): - pass - - -@dataclass -class Collection(collection_raw_nodes.Collection, RDF): - pass - - -@dataclass -class RunMode(Node, model_raw_nodes.RunMode): - pass - - -@dataclass -class Preprocessing(Node, model_raw_nodes.Preprocessing): - pass - - -@dataclass -class Postprocessing(Node, model_raw_nodes.Postprocessing): - pass - - -@dataclass -class InputTensor(Node, model_raw_nodes.InputTensor): - axes: Tuple[str, ...] = missing - - def __post_init__(self): - super().__post_init__() - # raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray). - self.axes = tuple(self.axes) - - -@dataclass -class OutputTensor(Node, model_raw_nodes.OutputTensor): - axes: Tuple[str, ...] = missing - - def __post_init__(self): - super().__post_init__() - # raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray). - self.axes = tuple(self.axes) - - -@dataclass -class ImportedSource(Node): - factory: Callable - - def __call__(self, *args, **kwargs): - return self.factory(*args, **kwargs) - - -@dataclass -class KerasHdf5WeightsEntry(Node, model_raw_nodes.KerasHdf5WeightsEntry): - source: Path = missing - - -@dataclass -class OnnxWeightsEntry(Node, model_raw_nodes.OnnxWeightsEntry): - source: Path = missing - - -@dataclass -class PytorchStateDictWeightsEntry(Node, model_raw_nodes.PytorchStateDictWeightsEntry): - source: Path = missing - architecture: Union[_Missing, ImportedSource] = missing - - -@dataclass -class TorchscriptWeightsEntry(Node, model_raw_nodes.TorchscriptWeightsEntry): - source: Path = missing - - -@dataclass -class TensorflowJsWeightsEntry(Node, model_raw_nodes.TensorflowJsWeightsEntry): - source: Path = missing - - -@dataclass -class TensorflowSavedModelBundleWeightsEntry(Node, model_raw_nodes.TensorflowSavedModelBundleWeightsEntry): - source: Path = missing - - -@dataclass -class Attachments(Node, rdf_raw_nodes.Attachments): - files: Union[_Missing, List[Path]] = missing - unknown: Union[_Missing, Dict[str, Any]] = missing - - -WeightsEntry = Union[ - KerasHdf5WeightsEntry, - OnnxWeightsEntry, - PytorchStateDictWeightsEntry, - TensorflowJsWeightsEntry, - TensorflowSavedModelBundleWeightsEntry, - TorchscriptWeightsEntry, -] - - -@dataclass -class Model(model_raw_nodes.Model, RDF): - authors: List[Author] = missing - maintainers: Union[_Missing, List[Maintainer]] = missing - test_inputs: List[Path] = missing - test_outputs: List[Path] = missing - weights: Dict[model_raw_nodes.WeightsFormat, WeightsEntry] = missing diff --git a/bioimageio/core/resource_io/utils.py b/bioimageio/core/resource_io/utils.py deleted file mode 100644 index 575f049b..00000000 --- a/bioimageio/core/resource_io/utils.py +++ /dev/null @@ -1,124 +0,0 @@ -import dataclasses -import importlib.util -import os -import pathlib -import sys -import typing -from types import ModuleType - -from bioimageio.spec.shared import raw_nodes, resolve_source, source_available -from bioimageio.spec.shared.node_transformer import ( - GenericRawNode, - GenericResolvedNode, - NodeTransformer, - NodeVisitor, - UriNodeTransformer, -) -from . import nodes - -GenericNode = typing.Union[GenericRawNode, GenericResolvedNode] - - -def iter_fields(node: GenericNode): - for field in dataclasses.fields(node): - yield field.name, getattr(node, field.name) - - -class SourceNodeChecker(NodeVisitor): - """raises FileNotFoundError for unavailable URIs and paths""" - - def __init__(self, *, root_path: os.PathLike): - self.root_path = root_path if isinstance(root_path, raw_nodes.URI) else pathlib.Path(root_path).resolve() - - def _visit_source(self, leaf: typing.Union[pathlib.Path, raw_nodes.URI]): - if not source_available(leaf, self.root_path): - raise FileNotFoundError(leaf) - - def visit_URI(self, node: raw_nodes.URI): - self._visit_source(node) - - def visit_PosixPath(self, leaf: pathlib.PosixPath): - self._visit_source(leaf) - - def visit_WindowsPath(self, leaf: pathlib.WindowsPath): - self._visit_source(leaf) - - def generic_visit(self, node): - """Called if no explicit visitor function exists for a node.""" - - if isinstance(node, raw_nodes.RawNode): - for field, value in iter_fields(node): - if field != "root_path": # do not visit root_path, as it might be an incomplete (non-available) URL - self.visit(value) - else: - super().generic_visit(node) - - -class SourceNodeTransformer(NodeTransformer): - """ - Imports all source callables - note: Requires previous transformation by UriNodeTransformer - """ - - class TemporaryInsertionIntoPythonPath: - def __init__(self, path: str): - self.path = path - - def __enter__(self): - sys.path.insert(0, self.path) - - def __exit__(self, exc_type, exc_value, traceback): - sys.path.remove(self.path) - - def transform_LocalImportableModule(self, node: raw_nodes.LocalImportableModule) -> nodes.ImportedSource: - with self.TemporaryInsertionIntoPythonPath(str(node.root_path)): - module = importlib.import_module(node.module_name) - - return nodes.ImportedSource(factory=getattr(module, node.callable_name)) - - @staticmethod - def transform_ResolvedImportableSourceFile(node: raw_nodes.ResolvedImportableSourceFile) -> nodes.ImportedSource: - module_path = resolve_source(node.source_file) - module_name = f"module_from_source.{module_path.stem}" - importlib_spec = importlib.util.spec_from_file_location(module_name, module_path) - assert importlib_spec is not None - dep = importlib.util.module_from_spec(importlib_spec) - importlib_spec.loader.exec_module(dep) # type: ignore # todo: possible to use "loader.load_module"? - return nodes.ImportedSource(factory=getattr(dep, node.callable_name)) - - -class RawNodeTypeTransformer(NodeTransformer): - def __init__(self, nodes_module: ModuleType): - super().__init__() - self.nodes = nodes_module - - def generic_transformer(self, node: GenericRawNode) -> GenericResolvedNode: - if isinstance(node, raw_nodes.RawNode): - resolved_data = { - field.name: self.transform(getattr(node, field.name)) for field in dataclasses.fields(node) - } - resolved_node_type: typing.Type[GenericResolvedNode] = getattr(self.nodes, node.__class__.__name__) - return resolved_node_type(**resolved_data) # type: ignore - else: - return super().generic_transformer(node) - - -def all_sources_available( - node: typing.Union[GenericNode, list, tuple, dict], root_path: os.PathLike = pathlib.Path() -) -> bool: - try: - SourceNodeChecker(root_path=root_path).visit(node) - except FileNotFoundError: - return False - else: - return True - - -def resolve_raw_node( - raw_rd: GenericRawNode, nodes_module: typing.Any, uri_only_if_in_package: bool = True -) -> GenericResolvedNode: - """resolve all uris and paths (that are included when packaging)""" - rd = UriNodeTransformer(root_path=raw_rd.root_path, uri_only_if_in_package=uri_only_if_in_package).transform(raw_rd) - rd = SourceNodeTransformer().transform(rd) - rd = RawNodeTypeTransformer(nodes_module).transform(rd) - return rd diff --git a/bioimageio/core/resource_tests.py b/bioimageio/core/resource_tests.py deleted file mode 100644 index d807bed0..00000000 --- a/bioimageio/core/resource_tests.py +++ /dev/null @@ -1,333 +0,0 @@ -import os -import re -import traceback -import warnings -from copy import deepcopy -from pathlib import Path -from typing import List, Optional, Tuple, Union - -import numpy -import numpy as np -import xarray as xr -from marshmallow import ValidationError - -from bioimageio.core import __version__ as bioimageio_core_version, load_resource_description -from bioimageio.core.common import TestSummary -from bioimageio.core.prediction import predict -from bioimageio.core.prediction_pipeline import create_prediction_pipeline -from bioimageio.core.resource_io.nodes import ( - ImplicitOutputShape, - Model, - ParametrizedInputShape, - ResourceDescription, - URI, -) -from bioimageio.core.resource_io.utils import SourceNodeChecker -from bioimageio.spec import __version__ as bioimageio_spec_version -from bioimageio.spec.model.raw_nodes import WeightsFormat -from bioimageio.spec.shared import resolve_source -from bioimageio.spec.shared.common import ValidationWarning -from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription - - -def test_model( - model_rdf: Union[URI, Path, str], - weight_format: Optional[WeightsFormat] = None, - devices: Optional[List[str]] = None, - decimal: int = 4, -) -> List[TestSummary]: - """Test whether the test output(s) of a model can be reproduced.""" - return test_resource( - model_rdf, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model" - ) - - -def check_input_shape(shape: Tuple[int, ...], shape_spec) -> bool: - if isinstance(shape_spec, list): - if shape != tuple(shape_spec): - return False - elif isinstance(shape_spec, ParametrizedInputShape): - assert len(shape_spec.min) == len(shape_spec.step) - if len(shape) != len(shape_spec.min): - return False - min_shape = shape_spec.min - step = shape_spec.step - # check if the shape is valid for all dimension by seeing if it can be reached with an integer number of steps - # NOTE we allow that the valid shape is reached using a different number of steps for each axis here - # this is usually valid because dimensions are independent in neural networks - is_valid = [(sh - minsh) % st == 0 if st > 0 else sh == minsh for sh, st, minsh in zip(shape, step, min_shape)] - return all(is_valid) - else: - raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") - - return True - - -def check_output_shape(shape: Tuple[int, ...], shape_spec, input_shapes) -> bool: - if isinstance(shape_spec, list): - return shape == tuple(shape_spec) - elif isinstance(shape_spec, ImplicitOutputShape): - ref_tensor = shape_spec.reference_tensor - if ref_tensor not in input_shapes: - raise ValidationError(f"The reference tensor name {ref_tensor} is not in {input_shapes}") - ipt_shape = numpy.array(input_shapes[ref_tensor]) - scale = numpy.array([0.0 if sc is None else sc for sc in shape_spec.scale]) - offset = numpy.array(shape_spec.offset) - exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset - - return shape == tuple(exp_shape) - else: - raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}") - - -def _test_resource_urls(rd: ResourceDescription) -> TestSummary: - assert isinstance(rd, ResourceDescription) - with warnings.catch_warnings(record=True) as all_warnings: - try: - SourceNodeChecker(root_path=rd.root_path).visit(rd) - except FileNotFoundError as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - return dict( - name="All URLs and paths available", - status="passed" if error is None else "failed", - error=error, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - nested_errors=None, - source_name=rd.id or rd.id or rd.name if hasattr(rd, "id") else rd.name, - warnings={"SourceNodeChecker": [str(w.message) for w in all_warnings]} if all_warnings else {}, - ) - - -def _test_model_documentation(rd: ResourceDescription) -> TestSummary: - assert isinstance(rd, Model) - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - doc_path: Path = resolve_source(rd.documentation, root_path=rd.root_path) - doc = doc_path.read_text() - wrn = "" - if not re.match("#.*[vV]alidation", doc): - wrn = "No '# Validation' (sub)section found." - - return dict( - name="Test documentation completeness.", - status="passed", - error=None, - traceback=None, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - source_name=rd.id or rd.name if hasattr(rd, "id") else rd.name, - warnings={"documentation": wrn} if wrn else {}, - ) - - -def _test_model_inference(model: Model, weight_format: str, devices: Optional[List[str]], decimal: int) -> TestSummary: - error: Optional[str] = None - tb: Optional = None - with warnings.catch_warnings(record=True) as all_warnings: - try: - inputs = [np.load(str(in_path)) for in_path in model.test_inputs] - expected = [np.load(str(out_path)) for out_path in model.test_outputs] - - assert len(inputs) == len(model.inputs) # should be checked by validation - input_shapes = {} - for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)): - if not check_input_shape(tuple(ipt.shape), ipt_spec.shape): - raise ValidationError( - f"Shape {tuple(ipt.shape)} of test input {idx} '{ipt_spec.name}' does not match " - f"input shape description: {ipt_spec.shape}." - ) - input_shapes[ipt_spec.name] = ipt.shape - - assert len(expected) == len(model.outputs) # should be checked by validation - for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)): - if not check_output_shape(tuple(out.shape), out_spec.shape, input_shapes): - error = (error or "") + ( - f"Shape {tuple(out.shape)} of test output {idx} '{out_spec.name}' does not match " - f"output shape description: {out_spec.shape}." - ) - - with create_prediction_pipeline( - bioimageio_model=model, devices=devices, weight_format=weight_format - ) as prediction_pipeline: - results = predict(prediction_pipeline, inputs) - - if len(results) != len(expected): - error = (error or "") + ( - f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}" - ) - else: - for res, exp in zip(results, expected): - try: - np.testing.assert_array_almost_equal(res, exp, decimal=decimal) - except AssertionError as e: - error = (error or "") + f"Output and expected output disagree:\n {e}" - except Exception as e: - error = str(e) - tb = traceback.format_tb(e.__traceback__) - - return dict( - name=f"reproduce test outputs from test inputs (bioimageio.core {bioimageio_core_version})", - status="passed" if error is None else "failed", - error=error, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - warnings=ValidationWarning.get_warning_summary(all_warnings), - source_name=model.id or model.name, - ) - - -def _test_load_resource( - rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], - weight_format: Optional[WeightsFormat] = None, -) -> Tuple[Optional[ResourceDescription], TestSummary]: - if isinstance(rdf, (URI, os.PathLike)): - source_name = str(rdf) - elif isinstance(rdf, str): - source_name = rdf[:120] - else: - source_name = rdf.id if hasattr(rdf, "id") else rdf.name - - main_test_warnings = [] - try: - with warnings.catch_warnings(record=True) as all_warnings: - rd: Optional[ResourceDescription] = load_resource_description( - rdf, weights_priority_order=None if weight_format is None else [weight_format] - ) - - main_test_warnings += list(all_warnings) - except Exception as e: - rd = None - error: Optional[str] = str(e) - tb: Optional = traceback.format_tb(e.__traceback__) - else: - error = None - tb = None - - load_summary = TestSummary( - name="load resource description", - status="passed" if error is None else "failed", - error=error, - nested_errors=None, - traceback=tb, - bioimageio_spec_version=bioimageio_spec_version, - bioimageio_core_version=bioimageio_core_version, - warnings={}, - source_name=source_name, - ) - - return rd, load_summary - - -def _test_expected_resource_type(rd: ResourceDescription, expected_type: str) -> TestSummary: - has_expected_type = rd.type == expected_type - return dict( - name="has expected resource type", - status="passed" if has_expected_type else "failed", - error=None if has_expected_type else f"expected type {expected_type}, found {rd.type}", - traceback=None, - source_name=rd.id or rd.name if hasattr(rd, "id") else rd.name, - ) - - -def test_resource( - rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], - *, - weight_format: Optional[WeightsFormat] = None, - devices: Optional[List[str]] = None, - decimal: int = 4, - expected_type: Optional[str] = None, -) -> List[TestSummary]: - """Test RDF dynamically - - Returns: summary dict with keys: name, status, error, traceback, bioimageio_spec_version, bioimageio_core_version - """ - rd, load_test = _test_load_resource(rdf, weight_format) - tests: List[TestSummary] = [load_test] - if rd is not None: - if expected_type is not None: - tests.append(_test_expected_resource_type(rd, expected_type)) - - tests.append(_test_resource_urls(rd)) - - if isinstance(rd, Model): - tests.append(_test_model_documentation(rd)) - tests.append(_test_model_inference(rd, weight_format, devices, decimal)) - - return tests - - -def debug_model( - model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str], - *, - weight_format: Optional[WeightsFormat] = None, - devices: Optional[List[str]] = None, -): - """Run the model test and return dict with inputs, results, expected results and intermediates. - - Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff". - """ - inputs_raw: Optional = None - inputs_processed: Optional = None - outputs_raw: Optional = None - outputs: Optional = None - expected: Optional = None - diff: Optional = None - - model = load_resource_description( - model_rdf, weights_priority_order=None if weight_format is None else [weight_format] - ) - if not isinstance(model, Model): - raise ValueError(f"Not a bioimageio.model: {model_rdf}") - - prediction_pipeline = create_prediction_pipeline( - bioimageio_model=model, devices=devices, weight_format=weight_format - ) - inputs = [ - xr.DataArray(np.load(str(in_path)), dims=input_spec.axes) - for in_path, input_spec in zip(model.test_inputs, model.inputs) - ] - input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)} - - # keep track of the non-processed inputs - inputs_raw = [deepcopy(input) for input in inputs] - - computed_measures = {} - - prediction_pipeline.apply_preprocessing(input_dict, computed_measures) - inputs_processed = list(input_dict.values()) - outputs_raw = prediction_pipeline.predict(*inputs_processed) - output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)} - prediction_pipeline.apply_postprocessing(output_dict, computed_measures) - outputs = list(output_dict.values()) - - if isinstance(outputs, (np.ndarray, xr.DataArray)): - outputs = [outputs] - - expected = [ - xr.DataArray(np.load(str(out_path)), dims=output_spec.axes) - for out_path, output_spec in zip(model.test_outputs, model.outputs) - ] - if len(outputs) != len(expected): - error = f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}" - print(error) - else: - diff = [] - for res, exp in zip(outputs, expected): - diff.append(res - exp) - - return { - "inputs": inputs_raw, - "inputs_processed": inputs_processed, - "outputs_raw": outputs_raw, - "outputs": outputs, - "expected": expected, - "diff": diff, - } diff --git a/bioimageio/core/sample.py b/bioimageio/core/sample.py new file mode 100644 index 00000000..8cec90fa --- /dev/null +++ b/bioimageio/core/sample.py @@ -0,0 +1,352 @@ +from __future__ import annotations + +from dataclasses import dataclass +from math import ceil, floor +from typing import ( + Callable, + Dict, + Generic, + Iterable, + Optional, + Tuple, + TypeVar, + Union, +) + +import numpy as np +from typing_extensions import Self + +from bioimageio.core.block import Block + +from .axis import AxisId, PerAxis +from .block_meta import ( + BlockMeta, + LinearAxisTransform, + split_multiple_shapes_into_blocks, +) +from .common import ( + BlockIndex, + Halo, + HaloLike, + MemberId, + PadMode, + PerMember, + SampleId, + SliceInfo, + TotalNumberOfBlocks, +) +from .stat_measures import Stat +from .tensor import Tensor + +# TODO: allow for lazy samples to read/write to disk + + +@dataclass +class Sample: + """A dataset sample""" + + members: Dict[MemberId, Tensor] + """the sample's tensors""" + + stat: Stat + """sample and dataset statistics""" + + id: SampleId + """identifier within the sample's dataset""" + + @property + def shape(self) -> PerMember[PerAxis[int]]: + return {tid: t.sizes for tid, t in self.members.items()} + + def split_into_blocks( + self, + block_shapes: PerMember[PerAxis[int]], + halo: PerMember[PerAxis[HaloLike]], + pad_mode: PadMode, + broadcast: bool = False, + ) -> Tuple[TotalNumberOfBlocks, Iterable[SampleBlockWithOrigin]]: + assert not ( + missing := [m for m in block_shapes if m not in self.members] + ), f"`block_shapes` specified for unknown members: {missing}" + assert not ( + missing := [m for m in halo if m not in block_shapes] + ), f"`halo` specified for members without `block_shape`: {missing}" + + n_blocks, blocks = split_multiple_shapes_into_blocks( + shapes=self.shape, + block_shapes=block_shapes, + halo=halo, + broadcast=broadcast, + ) + return n_blocks, sample_block_generator(blocks, origin=self, pad_mode=pad_mode) + + def as_single_block(self, halo: Optional[PerMember[PerAxis[Halo]]] = None): + if halo is None: + halo = {} + return SampleBlockWithOrigin( + sample_shape=self.shape, + sample_id=self.id, + blocks={ + m: Block( + sample_shape=self.shape[m], + data=data, + inner_slice={ + a: SliceInfo(0, s) for a, s in data.tagged_shape.items() + }, + halo=halo.get(m, {}), + block_index=0, + blocks_in_sample=1, + ) + for m, data in self.members.items() + }, + stat=self.stat, + origin=self, + block_index=0, + blocks_in_sample=1, + ) + + @classmethod + def from_blocks( + cls, + sample_blocks: Iterable[SampleBlock], + *, + fill_value: float = float("nan"), + ) -> Self: + members: PerMember[Tensor] = {} + stat: Stat = {} + sample_id = None + for sample_block in sample_blocks: + assert sample_id is None or sample_id == sample_block.sample_id + sample_id = sample_block.sample_id + stat = sample_block.stat + for m, block in sample_block.blocks.items(): + if m not in members: + if -1 in block.sample_shape.values(): + raise NotImplementedError( + "merging blocks with data dependent axis not yet implemented" + ) + + members[m] = Tensor( + np.full( + tuple(block.sample_shape[a] for a in block.data.dims), + fill_value, + dtype=block.data.dtype, + ), + dims=block.data.dims, + ) + + members[m][block.inner_slice] = block.inner_data + + return cls(members=members, stat=stat, id=sample_id) + + +BlockT = TypeVar("BlockT", Block, BlockMeta) + + +@dataclass +class SampleBlockBase(Generic[BlockT]): + """base class for `SampleBlockMeta` and `SampleBlock`""" + + sample_shape: PerMember[PerAxis[int]] + """the sample shape this block represents a part of""" + + sample_id: SampleId + """identifier for the sample within its dataset""" + + blocks: Dict[MemberId, BlockT] + """Individual tensor blocks comprising this sample block""" + + block_index: BlockIndex + """the n-th block of the sample""" + + blocks_in_sample: TotalNumberOfBlocks + """total number of blocks in the sample""" + + @property + def shape(self) -> PerMember[PerAxis[int]]: + return {mid: b.shape for mid, b in self.blocks.items()} + + @property + def inner_shape(self) -> PerMember[PerAxis[int]]: + return {mid: b.inner_shape for mid, b in self.blocks.items()} + + +@dataclass +class LinearSampleAxisTransform(LinearAxisTransform): + member: MemberId + + +@dataclass +class SampleBlockMeta(SampleBlockBase[BlockMeta]): + """Meta data of a dataset sample block""" + + def get_transformed( + self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] + ) -> Self: + sample_shape = { + m: { + a: ( + trf + if isinstance(trf, int) + else trf.compute(self.sample_shape[trf.member][trf.axis]) + ) + for a, trf in new_axes[m].items() + } + for m in new_axes + } + + def get_member_halo(m: MemberId, round: Callable[[float], int]): + return { + a: ( + Halo(0, 0) + if isinstance(trf, int) + or trf.axis not in self.blocks[trf.member].halo + else Halo( + round(self.blocks[trf.member].halo[trf.axis].left * trf.scale), + round(self.blocks[trf.member].halo[trf.axis].right * trf.scale), + ) + ) + for a, trf in new_axes[m].items() + } + + halo: Dict[MemberId, Dict[AxisId, Halo]] = {} + for m in new_axes: + halo[m] = get_member_halo(m, floor) + assert halo[m] == get_member_halo( + m, ceil + ), f"failed to unambiguously scale halo {halo[m]} with {new_axes[m]}" + + inner_slice = { + m: { + a: ( + SliceInfo(0, trf) + if isinstance(trf, int) + else SliceInfo( + trf.compute( + self.blocks[trf.member].inner_slice[trf.axis].start + ), + trf.compute(self.blocks[trf.member].inner_slice[trf.axis].stop), + ) + ) + for a, trf in new_axes[m].items() + } + for m in new_axes + } + return self.__class__( + blocks={ + m: BlockMeta( + sample_shape=sample_shape[m], + inner_slice=inner_slice[m], + halo=halo[m], + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + for m in new_axes + }, + sample_shape=sample_shape, + sample_id=self.sample_id, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + + def with_data(self, data: PerMember[Tensor], *, stat: Stat) -> SampleBlock: + return SampleBlock( + sample_shape=self.sample_shape, + sample_id=self.sample_id, + blocks={ + m: Block( + sample_shape=self.sample_shape[m], + inner_slice=b.inner_slice, + halo=b.halo, + block_index=b.block_index, + blocks_in_sample=b.blocks_in_sample, + data=data[m], + ) + for m, b in self.blocks.items() + }, + stat=stat, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ) + + +@dataclass +class SampleBlock(SampleBlockBase[Block]): + """A block of a dataset sample""" + + stat: Stat + """computed statistics""" + + @property + def members(self) -> PerMember[Tensor]: + """the sample block's tensors""" + return {m: b.data for m, b in self.blocks.items()} + + def get_transformed_meta( + self, new_axes: PerMember[PerAxis[Union[LinearSampleAxisTransform, int]]] + ) -> SampleBlockMeta: + return SampleBlockMeta( + sample_id=self.sample_id, + blocks=dict(self.blocks), + sample_shape=self.sample_shape, + block_index=self.block_index, + blocks_in_sample=self.blocks_in_sample, + ).get_transformed(new_axes) + + +@dataclass +class SampleBlockWithOrigin(SampleBlock): + origin: Sample + """the sample this sample black was taken from""" + + +class _ConsolidatedMemberBlocks: + def __init__(self, blocks: PerMember[BlockMeta]): + super().__init__() + block_indices = {b.block_index for b in blocks.values()} + assert len(block_indices) == 1 + self.block_index = block_indices.pop() + blocks_in_samples = {b.blocks_in_sample for b in blocks.values()} + assert len(blocks_in_samples) == 1 + self.blocks_in_sample = blocks_in_samples.pop() + + +def sample_block_meta_generator( + blocks: Iterable[PerMember[BlockMeta]], + *, + sample_shape: PerMember[PerAxis[int]], + sample_id: SampleId, +): + for member_blocks in blocks: + cons = _ConsolidatedMemberBlocks(member_blocks) + yield SampleBlockMeta( + blocks=dict(member_blocks), + sample_shape=sample_shape, + sample_id=sample_id, + block_index=cons.block_index, + blocks_in_sample=cons.blocks_in_sample, + ) + + +def sample_block_generator( + blocks: Iterable[PerMember[BlockMeta]], + *, + origin: Sample, + pad_mode: PadMode, +): + for member_blocks in blocks: + cons = _ConsolidatedMemberBlocks(member_blocks) + yield SampleBlockWithOrigin( + blocks={ + m: Block.from_sample_member( + origin.members[m], block=member_blocks[m], pad_mode=pad_mode + ) + for m in origin.members + }, + sample_shape=origin.shape, + origin=origin, + stat=origin.stat, + sample_id=origin.id, + block_index=cons.block_index, + blocks_in_sample=cons.blocks_in_sample, + ) diff --git a/bioimageio/core/stat_calculators.py b/bioimageio/core/stat_calculators.py new file mode 100644 index 00000000..afd0ce24 --- /dev/null +++ b/bioimageio/core/stat_calculators.py @@ -0,0 +1,611 @@ +from __future__ import annotations + +import collections.abc +import warnings +from itertools import product +from typing import ( + Any, + Collection, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + OrderedDict, + Sequence, + Set, + Tuple, + Type, + Union, +) + +import numpy as np +import xarray as xr +from numpy.typing import NDArray +from typing_extensions import assert_never + +from .axis import AxisId, PerAxis +from .common import MemberId +from .sample import Sample +from .stat_measures import ( + DatasetMean, + DatasetMeasure, + DatasetMeasureBase, + DatasetPercentile, + DatasetStd, + DatasetVar, + Measure, + MeasureValue, + SampleMean, + SampleMeasure, + SampleQuantile, + SampleStd, + SampleVar, +) +from .tensor import Tensor + +try: + import crick + +except Exception: + crick = None + + class TDigest: + def update(self, obj: Any): + pass + + def quantile(self, q: Any) -> Any: + pass + +else: + TDigest = crick.TDigest # type: ignore + + +class MeanCalculator: + """to calculate sample and dataset mean for in-memory samples""" + + def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): + super().__init__() + self._n: int = 0 + self._mean: Optional[Tensor] = None + self._axes = None if axes is None else tuple(axes) + self._member_id = member_id + self._sample_mean = SampleMean(member_id=self._member_id, axes=self._axes) + self._dataset_mean = DatasetMean(member_id=self._member_id, axes=self._axes) + + def compute(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + return {self._sample_mean: self._compute_impl(sample)} + + def _compute_impl(self, sample: Sample) -> Tensor: + tensor = sample.members[self._member_id].astype("float64", copy=False) + return tensor.mean(dim=self._axes) + + def update(self, sample: Sample) -> None: + mean = self._compute_impl(sample) + self._update_impl(sample.members[self._member_id], mean) + + def compute_and_update(self, sample: Sample) -> Dict[SampleMean, MeasureValue]: + mean = self._compute_impl(sample) + self._update_impl(sample.members[self._member_id], mean) + return {self._sample_mean: mean} + + def _update_impl(self, tensor: Tensor, tensor_mean: Tensor): + assert tensor_mean.dtype == "float64" + # reduced voxel count + n_b = int(tensor.size / tensor_mean.size) + + if self._mean is None: + assert self._n == 0 + self._n = n_b + self._mean = tensor_mean + else: + assert self._n != 0 + n_a = self._n + mean_old = self._mean + self._n = n_a + n_b + self._mean = (n_a * mean_old + n_b * tensor_mean) / self._n + assert self._mean.dtype == "float64" + + def finalize(self) -> Dict[DatasetMean, MeasureValue]: + if self._mean is None: + return {} + else: + return {self._dataset_mean: self._mean} + + +class MeanVarStdCalculator: + """to calculate sample and dataset mean, variance or standard deviation""" + + def __init__(self, member_id: MemberId, axes: Optional[Sequence[AxisId]]): + super().__init__() + self._axes = None if axes is None else tuple(axes) + self._member_id = member_id + self._n: int = 0 + self._mean: Optional[Tensor] = None + self._m2: Optional[Tensor] = None + + def compute( + self, sample: Sample + ) -> Dict[Union[SampleMean, SampleVar, SampleStd], MeasureValue]: + tensor = sample.members[self._member_id] + mean = tensor.mean(dim=self._axes) + c = (tensor - mean).data + if self._axes is None: + n = tensor.size + else: + n = int(np.prod([tensor.sizes[d] for d in self._axes])) + + var = xr.dot(c, c, dims=self._axes) / n + assert isinstance(var, xr.DataArray) + std = np.sqrt(var) + assert isinstance(std, xr.DataArray) + return { + SampleMean(axes=self._axes, member_id=self._member_id): mean, + SampleVar(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( + var + ), + SampleStd(axes=self._axes, member_id=self._member_id): Tensor.from_xarray( + std + ), + } + + def update(self, sample: Sample): + tensor = sample.members[self._member_id].astype("float64", copy=False) + mean_b = tensor.mean(dim=self._axes) + assert mean_b.dtype == "float64" + # reduced voxel count + n_b = int(tensor.size / mean_b.size) + m2_b = ((tensor - mean_b) ** 2).sum(dim=self._axes) + assert m2_b.dtype == "float64" + if self._mean is None: + assert self._m2 is None + self._n = n_b + self._mean = mean_b + self._m2 = m2_b + else: + n_a = self._n + mean_a = self._mean + m2_a = self._m2 + self._n = n = n_a + n_b + self._mean = (n_a * mean_a + n_b * mean_b) / n + assert self._mean.dtype == "float64" + d = mean_b - mean_a + self._m2 = m2_a + m2_b + d**2 * n_a * n_b / n + assert self._m2.dtype == "float64" + + def finalize( + self, + ) -> Dict[Union[DatasetMean, DatasetVar, DatasetStd], MeasureValue]: + if self._mean is None: + return {} + else: + assert self._m2 is not None + var = self._m2 / self._n + sqrt = np.sqrt(var) + if isinstance(sqrt, (int, float)): + # var and mean are scalar tensors, let's keep it consistent + sqrt = Tensor.from_xarray(xr.DataArray(sqrt)) + + assert isinstance(sqrt, Tensor), type(sqrt) + return { + DatasetMean(member_id=self._member_id, axes=self._axes): self._mean, + DatasetVar(member_id=self._member_id, axes=self._axes): var, + DatasetStd(member_id=self._member_id, axes=self._axes): sqrt, + } + + +class SamplePercentilesCalculator: + """to calculate sample percentiles""" + + def __init__( + self, + member_id: MemberId, + axes: Optional[Sequence[AxisId]], + qs: Collection[float], + ): + super().__init__() + assert all(0.0 <= q <= 1.0 for q in qs) + self._qs = sorted(set(qs)) + self._axes = None if axes is None else tuple(axes) + self._member_id = member_id + + def compute(self, sample: Sample) -> Dict[SampleQuantile, MeasureValue]: + tensor = sample.members[self._member_id] + ps = tensor.quantile(self._qs, dim=self._axes) + return { + SampleQuantile(q=q, axes=self._axes, member_id=self._member_id): p + for q, p in zip(self._qs, ps) + } + + +class MeanPercentilesCalculator: + """to calculate dataset percentiles heuristically by averaging across samples + **note**: the returned dataset percentiles are an estiamte and **not mathematically correct** + """ + + def __init__( + self, + member_id: MemberId, + axes: Optional[Sequence[AxisId]], + qs: Collection[float], + ): + super().__init__() + assert all(0.0 <= q <= 1.0 for q in qs) + self._qs = sorted(set(qs)) + self._axes = None if axes is None else tuple(axes) + self._member_id = member_id + self._n: int = 0 + self._estimates: Optional[Tensor] = None + + def update(self, sample: Sample): + tensor = sample.members[self._member_id] + sample_estimates = tensor.quantile(self._qs, dim=self._axes).astype( + "float64", copy=False + ) + + # reduced voxel count + n = int(tensor.size / np.prod(sample_estimates.shape_tuple[1:])) + + if self._estimates is None: + assert self._n == 0 + self._estimates = sample_estimates + else: + self._estimates = (self._n * self._estimates + n * sample_estimates) / ( + self._n + n + ) + assert self._estimates.dtype == "float64" + + self._n += n + + def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: + if self._estimates is None: + return {} + else: + warnings.warn( + "Computed dataset percentiles naively by averaging percentiles of samples." + ) + return { + DatasetPercentile(q=q, axes=self._axes, member_id=self._member_id): e + for q, e in zip(self._qs, self._estimates) + } + + +class CrickPercentilesCalculator: + """to calculate dataset percentiles with the experimental [crick libray](https://github.com/dask/crick)""" + + def __init__( + self, + member_id: MemberId, + axes: Optional[Sequence[AxisId]], + qs: Collection[float], + ): + warnings.warn( + "Computing dataset percentiles with experimental 'crick' library." + ) + super().__init__() + assert all(0.0 <= q <= 1.0 for q in qs) + assert axes is None or "_percentiles" not in axes + self._qs = sorted(set(qs)) + self._axes = None if axes is None else tuple(axes) + self._member_id = member_id + self._digest: Optional[List[TDigest]] = None + self._dims: Optional[Tuple[AxisId, ...]] = None + self._indices: Optional[Iterator[Tuple[int, ...]]] = None + self._shape: Optional[Tuple[int, ...]] = None + + def _initialize(self, tensor_sizes: PerAxis[int]): + assert crick is not None + out_sizes: OrderedDict[AxisId, int] = collections.OrderedDict( + _percentiles=len(self._qs) + ) + if self._axes is not None: + for d, s in tensor_sizes.items(): + if d not in self._axes: + out_sizes[d] = s + + self._dims, self._shape = zip(*out_sizes.items()) + d = int(np.prod(self._shape[1:])) # type: ignore + self._digest = [TDigest() for _ in range(d)] + self._indices = product(*map(range, self._shape[1:])) + + def update(self, part: Sample): + tensor = ( + part.members[self._member_id] + if isinstance(part, Sample) + else part.members[self._member_id].data + ) + assert "_percentiles" not in tensor.dims + if self._digest is None: + self._initialize(tensor.tagged_shape) + + assert self._digest is not None + assert self._indices is not None + assert self._dims is not None + for i, idx in enumerate(self._indices): + self._digest[i].update(tensor[dict(zip(self._dims[1:], idx))]) + + def finalize(self) -> Dict[DatasetPercentile, MeasureValue]: + if self._digest is None: + return {} + else: + assert self._dims is not None + assert self._shape is not None + + vs: NDArray[Any] = np.asarray( + [[d.quantile(q) for d in self._digest] for q in self._qs] + ).reshape(self._shape) + return { + DatasetPercentile( + q=q, axes=self._axes, member_id=self._member_id + ): Tensor(v, dims=self._dims[1:]) + for q, v in zip(self._qs, vs) + } + + +if crick is None: + DatasetPercentilesCalculator: Type[ + Union[MeanPercentilesCalculator, CrickPercentilesCalculator] + ] = MeanPercentilesCalculator +else: + DatasetPercentilesCalculator = CrickPercentilesCalculator + + +class NaiveSampleMeasureCalculator: + """wrapper for measures to match interface of other sample measure calculators""" + + def __init__(self, member_id: MemberId, measure: SampleMeasure): + super().__init__() + self.tensor_name = member_id + self.measure = measure + + def compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + return {self.measure: self.measure.compute(sample)} + + +SampleMeasureCalculator = Union[ + MeanCalculator, + MeanVarStdCalculator, + SamplePercentilesCalculator, + NaiveSampleMeasureCalculator, +] +DatasetMeasureCalculator = Union[ + MeanCalculator, MeanVarStdCalculator, DatasetPercentilesCalculator +] + + +class StatsCalculator: + """Estimates dataset statistics and computes sample statistics efficiently""" + + def __init__( + self, + measures: Collection[Measure], + initial_dataset_measures: Optional[ + Mapping[DatasetMeasure, MeasureValue] + ] = None, + ): + super().__init__() + self.sample_count = 0 + self.sample_calculators, self.dataset_calculators = get_measure_calculators( + measures + ) + if initial_dataset_measures is None: + self._current_dataset_measures: Optional[ + Dict[DatasetMeasure, MeasureValue] + ] = None + else: + missing_dataset_meas = { + m + for m in measures + if isinstance(m, DatasetMeasureBase) + and m not in initial_dataset_measures + } + if missing_dataset_meas: + warnings.warn( + f"ignoring `initial_dataset_measure` as it is missing {missing_dataset_meas}" + ) + self._current_dataset_measures = None + else: + self._current_dataset_measures = dict(initial_dataset_measures) + + @property + def has_dataset_measures(self): + return self._current_dataset_measures is not None + + def update( + self, + sample: Union[Sample, Iterable[Sample]], + ) -> None: + _ = self._update(sample) + + def finalize(self) -> Dict[DatasetMeasure, MeasureValue]: + """returns aggregated dataset statistics""" + if self._current_dataset_measures is None: + self._current_dataset_measures = {} + for calc in self.dataset_calculators: + values = calc.finalize() + self._current_dataset_measures.update(values.items()) + + return self._current_dataset_measures + + def update_and_get_all( + self, + sample: Union[Sample, Iterable[Sample]], + ) -> Dict[Measure, MeasureValue]: + """Returns sample as well as updated dataset statistics""" + last_sample = self._update(sample) + if last_sample is None: + raise ValueError("`sample` was not a `Sample`, nor did it yield any.") + + return {**self._compute(last_sample), **self.finalize()} + + def skip_update_and_get_all(self, sample: Sample) -> Dict[Measure, MeasureValue]: + """Returns sample as well as previously computed dataset statistics""" + return {**self._compute(sample), **self.finalize()} + + def _compute(self, sample: Sample) -> Dict[SampleMeasure, MeasureValue]: + ret: Dict[SampleMeasure, MeasureValue] = {} + for calc in self.sample_calculators: + values = calc.compute(sample) + ret.update(values.items()) + + return ret + + def _update(self, sample: Union[Sample, Iterable[Sample]]) -> Optional[Sample]: + self.sample_count += 1 + samples = [sample] if isinstance(sample, Sample) else sample + last_sample = None + for el in samples: + last_sample = el + for calc in self.dataset_calculators: + calc.update(el) + + self._current_dataset_measures = None + return last_sample + + +def get_measure_calculators( + required_measures: Iterable[Measure], +) -> Tuple[List[SampleMeasureCalculator], List[DatasetMeasureCalculator]]: + """determines which calculators are needed to compute the required measures efficiently""" + + sample_calculators: List[SampleMeasureCalculator] = [] + dataset_calculators: List[DatasetMeasureCalculator] = [] + + # split required measures into groups + required_sample_means: Set[SampleMean] = set() + required_dataset_means: Set[DatasetMean] = set() + required_sample_mean_var_std: Set[Union[SampleMean, SampleVar, SampleStd]] = set() + required_dataset_mean_var_std: Set[Union[DatasetMean, DatasetVar, DatasetStd]] = ( + set() + ) + required_sample_percentiles: Dict[ + Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] + ] = {} + required_dataset_percentiles: Dict[ + Tuple[MemberId, Optional[Tuple[AxisId, ...]]], Set[float] + ] = {} + + for rm in required_measures: + if isinstance(rm, SampleMean): + required_sample_means.add(rm) + elif isinstance(rm, DatasetMean): + required_dataset_means.add(rm) + elif isinstance(rm, (SampleVar, SampleStd)): + required_sample_mean_var_std.update( + { + msv(axes=rm.axes, member_id=rm.member_id) + for msv in (SampleMean, SampleStd, SampleVar) + } + ) + assert rm in required_sample_mean_var_std + elif isinstance(rm, (DatasetVar, DatasetStd)): + required_dataset_mean_var_std.update( + { + msv(axes=rm.axes, member_id=rm.member_id) + for msv in (DatasetMean, DatasetStd, DatasetVar) + } + ) + assert rm in required_dataset_mean_var_std + elif isinstance(rm, SampleQuantile): + required_sample_percentiles.setdefault((rm.member_id, rm.axes), set()).add( + rm.q + ) + elif isinstance(rm, DatasetPercentile): + required_dataset_percentiles.setdefault((rm.member_id, rm.axes), set()).add( + rm.q + ) + else: + assert_never(rm) + + for rm in required_sample_means: + if rm in required_sample_mean_var_std: + # computed togehter with var and std + continue + + sample_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) + + for rm in required_sample_mean_var_std: + sample_calculators.append( + MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) + ) + + for rm in required_dataset_means: + if rm in required_dataset_mean_var_std: + # computed togehter with var and std + continue + + dataset_calculators.append(MeanCalculator(member_id=rm.member_id, axes=rm.axes)) + + for rm in required_dataset_mean_var_std: + dataset_calculators.append( + MeanVarStdCalculator(member_id=rm.member_id, axes=rm.axes) + ) + + for (tid, axes), qs in required_sample_percentiles.items(): + sample_calculators.append( + SamplePercentilesCalculator(member_id=tid, axes=axes, qs=qs) + ) + + for (tid, axes), qs in required_dataset_percentiles.items(): + dataset_calculators.append( + DatasetPercentilesCalculator(member_id=tid, axes=axes, qs=qs) + ) + + return sample_calculators, dataset_calculators + + +def compute_dataset_measures( + measures: Iterable[DatasetMeasure], dataset: Iterable[Sample] +) -> Dict[DatasetMeasure, MeasureValue]: + """compute all dataset `measures` for the given `dataset`""" + sample_calculators, calculators = get_measure_calculators(measures) + assert not sample_calculators + + ret: Dict[DatasetMeasure, MeasureValue] = {} + + for sample in dataset: + for calc in calculators: + calc.update(sample) + + for calc in calculators: + ret.update(calc.finalize().items()) + + return ret + + +def compute_sample_measures( + measures: Iterable[SampleMeasure], sample: Sample +) -> Dict[SampleMeasure, MeasureValue]: + """compute all sample `measures` for the given `sample`""" + calculators, dataset_calculators = get_measure_calculators(measures) + assert not dataset_calculators + ret: Dict[SampleMeasure, MeasureValue] = {} + + for calc in calculators: + ret.update(calc.compute(sample).items()) + + return ret + + +def compute_measures( + measures: Iterable[Measure], dataset: Iterable[Sample] +) -> Dict[Measure, MeasureValue]: + """compute all `measures` for the given `dataset` + sample measures are computed for the last sample in `dataset`""" + sample_calculators, dataset_calculators = get_measure_calculators(measures) + ret: Dict[Measure, MeasureValue] = {} + sample = None + for sample in dataset: + for calc in dataset_calculators: + calc.update(sample) + if sample is None: + raise ValueError("empty dataset") + + for calc in dataset_calculators: + ret.update(calc.finalize().items()) + + for calc in sample_calculators: + ret.update(calc.compute(sample).items()) + + return ret diff --git a/bioimageio/core/stat_measures.py b/bioimageio/core/stat_measures.py new file mode 100644 index 00000000..e581916f --- /dev/null +++ b/bioimageio/core/stat_measures.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Dict, Optional, Protocol, Tuple, TypeVar, Union + +from .axis import AxisId +from .common import MemberId, PerMember +from .tensor import Tensor + +MeasureValue = Union[float, Tensor] + + +# using Sample Protocol really only to avoid circular imports +class SampleLike(Protocol): + @property + def members(self) -> PerMember[Tensor]: ... + + +@dataclass(frozen=True) +class MeasureBase: + member_id: MemberId + + +@dataclass(frozen=True) +class SampleMeasureBase(MeasureBase, ABC): + @abstractmethod + def compute(self, sample: SampleLike) -> MeasureValue: + """compute the measure""" + ... + + +@dataclass(frozen=True) +class DatasetMeasureBase(MeasureBase, ABC): + pass + + +@dataclass(frozen=True) +class _Mean: + axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" + + +@dataclass(frozen=True) +class SampleMean(_Mean, SampleMeasureBase): + """The mean value of a single tensor""" + + def compute(self, sample: SampleLike) -> MeasureValue: + tensor = sample.members[self.member_id] + return tensor.mean(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes + + +@dataclass(frozen=True) +class DatasetMean(_Mean, DatasetMeasureBase): + """The mean value across multiple samples""" + + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes + + +@dataclass(frozen=True) +class _Std: + axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" + + +@dataclass(frozen=True) +class SampleStd(_Std, SampleMeasureBase): + """The standard deviation of a single tensor""" + + def compute(self, sample: SampleLike) -> MeasureValue: + tensor = sample.members[self.member_id] + return tensor.std(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes + + +@dataclass(frozen=True) +class DatasetStd(_Std, DatasetMeasureBase): + """The standard deviation across multiple samples""" + + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes + + +@dataclass(frozen=True) +class _Var: + axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" + + +@dataclass(frozen=True) +class SampleVar(_Var, SampleMeasureBase): + """The variance of a single tensor""" + + def compute(self, sample: SampleLike) -> MeasureValue: + tensor = sample.members[self.member_id] + return tensor.var(dim=self.axes) + + def __post_init__(self): + assert self.axes is None or AxisId("batch") not in self.axes + + +@dataclass(frozen=True) +class DatasetVar(_Var, DatasetMeasureBase): + """The variance across multiple samples""" + + def __post_init__(self): + assert self.axes is None or AxisId("batch") in self.axes + + +@dataclass(frozen=True) +class _Quantile: + q: float + axes: Optional[Tuple[AxisId, ...]] = None + """`axes` to reduce""" + + def __post_init__(self): + assert self.q >= 0.0 + assert self.q <= 1.0 + + +@dataclass(frozen=True) +class SampleQuantile(_Quantile, SampleMeasureBase): + """The `n`th percentile of a single tensor""" + + def compute(self, sample: SampleLike) -> MeasureValue: + tensor = sample.members[self.member_id] + return tensor.quantile(self.q, dim=self.axes) + + def __post_init__(self): + super().__post_init__() + assert self.axes is None or AxisId("batch") not in self.axes + + +@dataclass(frozen=True) +class DatasetPercentile(_Quantile, DatasetMeasureBase): + """The `n`th percentile across multiple samples""" + + def __post_init__(self): + super().__post_init__() + assert self.axes is None or AxisId("batch") in self.axes + + +SampleMeasure = Union[SampleMean, SampleStd, SampleVar, SampleQuantile] +DatasetMeasure = Union[DatasetMean, DatasetStd, DatasetVar, DatasetPercentile] +Measure = Union[SampleMeasure, DatasetMeasure] +Stat = Dict[Measure, MeasureValue] + +MeanMeasure = Union[SampleMean, DatasetMean] +StdMeasure = Union[SampleStd, DatasetStd] +VarMeasure = Union[SampleVar, DatasetVar] +PercentileMeasure = Union[SampleQuantile, DatasetPercentile] +MeanMeasureT = TypeVar("MeanMeasureT", bound=MeanMeasure) +StdMeasureT = TypeVar("StdMeasureT", bound=StdMeasure) +VarMeasureT = TypeVar("VarMeasureT", bound=VarMeasure) +PercentileMeasureT = TypeVar("PercentileMeasureT", bound=PercentileMeasure) diff --git a/bioimageio/core/statistical_measures.py b/bioimageio/core/statistical_measures.py deleted file mode 100644 index 0a3df99b..00000000 --- a/bioimageio/core/statistical_measures.py +++ /dev/null @@ -1,52 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from typing import Optional, Tuple - -import xarray as xr - -MeasureValue = xr.DataArray - - -@dataclass(frozen=True) -class Measure: - def compute(self, tensor: xr.DataArray) -> MeasureValue: - """compute the measure (and also associated other Measures)""" - raise NotImplementedError(self.__class__.__name__) - - -@dataclass(frozen=True) -class Mean(Measure): - axes: Optional[Tuple[str, ...]] = None - - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.mean(dim=self.axes) - - -@dataclass(frozen=True) -class Std(Measure): - axes: Optional[Tuple[str, ...]] = None - - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.std(dim=self.axes) - - -@dataclass(frozen=True) -class Var(Measure): - axes: Optional[Tuple[str, ...]] = None - - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.var(dim=self.axes) - - -@dataclass(frozen=True) -class Percentile(Measure): - n: float - axes: Optional[Tuple[str, ...]] = None - - def __post_init__(self): - assert self.n >= 0 - assert self.n <= 100 - - def compute(self, tensor: xr.DataArray) -> xr.DataArray: - return tensor.quantile(self.n / 100.0, dim=self.axes) diff --git a/bioimageio/core/tensor.py b/bioimageio/core/tensor.py new file mode 100644 index 00000000..c93bd31a --- /dev/null +++ b/bioimageio/core/tensor.py @@ -0,0 +1,492 @@ +from __future__ import annotations + +import collections.abc +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Iterator, + Mapping, + Optional, + Sequence, + Tuple, + Union, + cast, + get_args, +) + +import numpy as np +import xarray as xr +from loguru import logger +from numpy.typing import DTypeLike, NDArray +from typing_extensions import Self, assert_never + +from bioimageio.spec.model import v0_5 + +from ._magic_tensor_ops import MagicTensorOpsMixin +from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis +from .common import ( + CropWhere, + DTypeStr, + PadMode, + PadWhere, + PadWidth, + PadWidthLike, + SliceInfo, +) + +if TYPE_CHECKING: + from numpy.typing import ArrayLike, NDArray + + +_ScalarOrArray = Union["ArrayLike", np.generic, "NDArray[Any]"] # TODO: add "DaskArray" + + +# TODO: complete docstrings +class Tensor(MagicTensorOpsMixin): + """A wrapper around an xr.DataArray for better integration with bioimageio.spec + and improved type annotations.""" + + _Compatible = Union["Tensor", xr.DataArray, _ScalarOrArray] + + def __init__( + self, + array: NDArray[Any], + dims: Sequence[AxisId], + ) -> None: + super().__init__() + if any(not isinstance(d, AxisId) for d in dims): + raise TypeError( + f"Expected sequence of `AxisId`, but got {list(map(type, dims))}" + ) + + self._data = xr.DataArray(array, dims=dims) + + def __array__(self, dtype: DTypeLike = None): + return np.asarray(self._data, dtype=dtype) + + def __getitem__( + self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]] + ) -> Self: + if isinstance(key, SliceInfo): + key = slice(*key) + elif isinstance(key, collections.abc.Mapping): + key = { + a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s) + for a, s in key.items() + } + return self.__class__.from_xarray(self._data[key]) + + def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None: + key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()} + self._data[key] = value._data + + def __len__(self) -> int: + return len(self.data) + + def _iter(self: Any) -> Iterator[Any]: + for n in range(len(self)): + yield self[n] + + def __iter__(self: Any) -> Iterator[Any]: + if self.ndim == 0: + raise TypeError("iteration over a 0-d array") + return self._iter() + + def _binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + reflexive: bool = False, + ) -> Self: + data = self._data._binary_op( # pyright: ignore[reportPrivateUsage] + (other._data if isinstance(other, Tensor) else other), + f, + reflexive, + ) + return self.__class__.from_xarray(data) + + def _inplace_binary_op( + self, + other: _Compatible, + f: Callable[[Any, Any], Any], + ) -> Self: + _ = self._data._inplace_binary_op( # pyright: ignore[reportPrivateUsage] + ( + other_d + if (other_d := getattr(other, "data")) is not None + and isinstance( + other_d, + xr.DataArray, + ) + else other + ), + f, + ) + return self + + def _unary_op(self, f: Callable[[Any], Any], *args: Any, **kwargs: Any) -> Self: + data = self._data._unary_op( # pyright: ignore[reportPrivateUsage] + f, *args, **kwargs + ) + return self.__class__.from_xarray(data) + + @classmethod + def from_xarray(cls, data_array: xr.DataArray) -> Self: + """create a `Tensor` from an xarray data array + + note for internal use: this factory method is round-trip save + for any `Tensor`'s `data` property (an xarray.DataArray). + """ + return cls( + array=data_array.data, dims=tuple(AxisId(d) for d in data_array.dims) + ) + + @classmethod + def from_numpy( + cls, + array: NDArray[Any], + *, + dims: Optional[Union[AxisLike, Sequence[AxisLike]]], + ) -> Tensor: + """create a `Tensor` from a numpy array + + Args: + array: the nd numpy array + axes: A description of the array's axes, + if None axes are guessed (which might fail and raise a ValueError.) + + Raises: + ValueError: if `axes` is None and axes guessing fails. + """ + + if dims is None: + return cls._interprete_array_wo_known_axes(array) + elif isinstance(dims, (str, Axis, v0_5.AxisBase)): + dims = [dims] + + axis_infos = [AxisInfo.create(a) for a in dims] + original_shape = tuple(array.shape) + if len(array.shape) > len(dims): + # remove singletons + for i, s in enumerate(array.shape): + if s == 1: + array = np.take(array, 0, axis=i) + if len(array.shape) == len(dims): + break + + # add singletons if nececsary + for a in axis_infos: + + if len(array.shape) >= len(dims): + break + + if a.maybe_singleton: + array = array[None] + + if len(array.shape) != len(dims): + raise ValueError( + f"Array shape {original_shape} does not map to axes {dims}" + ) + + return Tensor(array, dims=tuple(a.id for a in axis_infos)) + + @property + def data(self): + return self._data + + @property + def dims(self): # TODO: rename to `axes`? + """Tuple of dimension names associated with this tensor.""" + return cast(Tuple[AxisId, ...], self._data.dims) + + @property + def tagged_shape(self): + """(alias for `sizes`) Ordered, immutable mapping from axis ids to lengths.""" + return self.sizes + + @property + def shape_tuple(self): + """Tuple of tensor axes lengths""" + return self._data.shape + + @property + def size(self): + """Number of elements in the tensor. + + Equal to math.prod(tensor.shape), i.e., the product of the tensors’ dimensions. + """ + return self._data.size + + def sum(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + """Reduce this Tensor's data by applying sum along some dimension(s).""" + return self.__class__.from_xarray(self._data.sum(dim=dim)) + + @property + def ndim(self): + """Number of tensor dimensions.""" + return self._data.ndim + + @property + def dtype(self) -> DTypeStr: + dt = str(self.data.dtype) # pyright: ignore[reportUnknownArgumentType] + assert dt in get_args(DTypeStr) + return dt # pyright: ignore[reportReturnType] + + @property + def sizes(self): + """Ordered, immutable mapping from axis ids to axis lengths.""" + return cast(Mapping[AxisId, int], self.data.sizes) + + def astype(self, dtype: DTypeStr, *, copy: bool = False): + """Return tensor cast to `dtype` + + note: if dtype is already satisfied copy if `copy`""" + return self.__class__.from_xarray(self._data.astype(dtype, copy=copy)) + + def clip(self, min: Optional[float] = None, max: Optional[float] = None): + """Return a tensor whose values are limited to [min, max]. + At least one of max or min must be given.""" + return self.__class__.from_xarray(self._data.clip(min, max)) + + def crop_to( + self, + sizes: PerAxis[int], + crop_where: Union[ + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", + ) -> Self: + """crop to match `sizes`""" + if isinstance(crop_where, str): + crop_axis_where: PerAxis[CropWhere] = {a: crop_where for a in self.dims} + else: + crop_axis_where = crop_where + + slices: Dict[AxisId, SliceInfo] = {} + + for a, s_is in self.sizes.items(): + if a not in sizes or sizes[a] == s_is: + pass + elif sizes[a] > s_is: + logger.warning( + "Cannot crop axis {} of size {} to larger size {}", + a, + s_is, + sizes[a], + ) + elif a not in crop_axis_where: + raise ValueError( + f"Don't know where to crop axis {a}, `crop_where`={crop_where}" + ) + else: + crop_this_axis_where = crop_axis_where[a] + if crop_this_axis_where == "left": + slices[a] = SliceInfo(s_is - sizes[a], s_is) + elif crop_this_axis_where == "right": + slices[a] = SliceInfo(0, sizes[a]) + elif crop_this_axis_where == "left_and_right": + slices[a] = SliceInfo( + start := (s_is - sizes[a]) // 2, sizes[a] + start + ) + else: + assert_never(crop_this_axis_where) + + return self[slices] + + def expand_dims(self, dims: Union[Sequence[AxisId], PerAxis[int]]) -> Self: + return self.__class__.from_xarray(self._data.expand_dims(dims=dims)) + + def mean(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.mean(dim=dim)) + + def std(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.std(dim=dim)) + + def var(self, dim: Optional[Union[AxisId, Sequence[AxisId]]] = None) -> Self: + return self.__class__.from_xarray(self._data.var(dim=dim)) + + def pad( + self, + pad_width: PerAxis[PadWidthLike], + mode: PadMode = "symmetric", + ) -> Self: + pad_width = {a: PadWidth.create(p) for a, p in pad_width.items()} + return self.__class__.from_xarray( + self._data.pad(pad_width=pad_width, mode=mode) + ) + + def pad_to( + self, + sizes: PerAxis[int], + pad_where: Union[PadWhere, PerAxis[PadWhere]] = "left_and_right", + mode: PadMode = "symmetric", + ) -> Self: + """pad `tensor` to match `sizes`""" + if isinstance(pad_where, str): + pad_axis_where: PerAxis[PadWhere] = {a: pad_where for a in self.dims} + else: + pad_axis_where = pad_where + + pad_width: Dict[AxisId, PadWidth] = {} + for a, s_is in self.sizes.items(): + if a not in sizes or sizes[a] == s_is: + pad_width[a] = PadWidth(0, 0) + elif s_is > sizes[a]: + pad_width[a] = PadWidth(0, 0) + logger.warning( + "Cannot pad axis {} of size {} to smaller size {}", + a, + s_is, + sizes[a], + ) + elif a not in pad_axis_where: + raise ValueError( + f"Don't know where to pad axis {a}, `pad_where`={pad_where}" + ) + else: + pad_this_axis_where = pad_axis_where[a] + d = sizes[a] - s_is + if pad_this_axis_where == "left": + pad_width[a] = PadWidth(d, 0) + elif pad_this_axis_where == "right": + pad_width[a] = PadWidth(0, d) + elif pad_this_axis_where == "left_and_right": + pad_width[a] = PadWidth(left := d // 2, d - left) + else: + assert_never(pad_this_axis_where) + + return self.pad(pad_width, mode) + + def quantile( + self, + q: Union[float, Sequence[float]], + dim: Optional[Union[AxisId, Sequence[AxisId]]] = None, + ) -> Self: + assert ( + isinstance(q, (float, int)) + and q >= 0.0 + or not isinstance(q, (float, int)) + and all(qq >= 0.0 for qq in q) + ) + assert ( + isinstance(q, (float, int)) + and q <= 1.0 + or not isinstance(q, (float, int)) + and all(qq <= 1.0 for qq in q) + ) + assert dim is None or ( + (quantile_dim := AxisId("quantile")) != dim and quantile_dim not in set(dim) + ) + return self.__class__.from_xarray(self._data.quantile(q, dim=dim)) + + def resize_to( + self, + sizes: PerAxis[int], + *, + pad_where: Union[ + PadWhere, + PerAxis[PadWhere], + ] = "left_and_right", + crop_where: Union[ + CropWhere, + PerAxis[CropWhere], + ] = "left_and_right", + pad_mode: PadMode = "symmetric", + ): + """return cropped/padded tensor with `sizes`""" + crop_to_sizes: Dict[AxisId, int] = {} + pad_to_sizes: Dict[AxisId, int] = {} + new_axes = dict(sizes) + for a, s_is in self.sizes.items(): + a = AxisId(str(a)) + _ = new_axes.pop(a, None) + if a not in sizes or sizes[a] == s_is: + pass + elif s_is > sizes[a]: + crop_to_sizes[a] = sizes[a] + else: + pad_to_sizes[a] = sizes[a] + + tensor = self + if crop_to_sizes: + tensor = tensor.crop_to(crop_to_sizes, crop_where=crop_where) + + if pad_to_sizes: + tensor = tensor.pad_to(pad_to_sizes, pad_where=pad_where, mode=pad_mode) + + if new_axes: + tensor = tensor.expand_dims(new_axes) + + return tensor + + def transpose( + self, + axes: Sequence[AxisId], + ) -> Self: + """return a transposed tensor + + Args: + axes: the desired tensor axes + """ + # expand missing tensor axes + missing_axes = tuple(a for a in axes if a not in self.dims) + array = self._data + if missing_axes: + array = array.expand_dims(missing_axes) + + # transpose to the correct axis order + return self.__class__.from_xarray(array.transpose(*axes)) + + @classmethod + def _interprete_array_wo_known_axes(cls, array: NDArray[Any]): + ndim = array.ndim + if ndim == 2: + current_axes = ( + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[1]), + ) + elif ndim == 3 and any(s <= 3 for s in array.shape): + current_axes = ( + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + ) + elif ndim == 3: + current_axes = ( + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[0]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[2]), + ) + elif ndim == 4: + current_axes = ( + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[0]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[1]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[3]), + ) + elif ndim == 5: + current_axes = ( + v0_5.BatchAxis(), + v0_5.ChannelAxis( + channel_names=[ + v0_5.Identifier(f"channel{i}") for i in range(array.shape[1]) + ] + ), + v0_5.SpaceInputAxis(id=AxisId("z"), size=array.shape[2]), + v0_5.SpaceInputAxis(id=AxisId("y"), size=array.shape[3]), + v0_5.SpaceInputAxis(id=AxisId("x"), size=array.shape[4]), + ) + else: + raise ValueError(f"Could not guess an axis mapping for {array.shape}") + + return cls(array, dims=tuple(a.id for a in current_axes)) diff --git a/bioimageio/core/utils/__init__.py b/bioimageio/core/utils/__init__.py new file mode 100644 index 00000000..84e94d38 --- /dev/null +++ b/bioimageio/core/utils/__init__.py @@ -0,0 +1,17 @@ +import json +import sys +from pathlib import Path + +if sys.version_info < (3, 9): + + def files(package_name: str): + assert package_name == "bioimageio.core" + return Path(__file__).parent.parent + +else: + from importlib.resources import files as files + + +with files("bioimageio.core").joinpath("VERSION").open("r", encoding="utf-8") as f: + VERSION = json.load(f)["version"] + assert isinstance(VERSION, str) diff --git a/bioimageio/core/utils.py b/bioimageio/core/utils/testing.py similarity index 70% rename from bioimageio/core/utils.py rename to bioimageio/core/utils/testing.py index 770f7d21..acd65d95 100644 --- a/bioimageio/core/utils.py +++ b/bioimageio/core/utils/testing.py @@ -1,23 +1,28 @@ -from functools import wraps -from typing import Type - - -def skip_on(exception: Type[Exception], reason: str): - """adapted from https://stackoverflow.com/a/63522579""" - import pytest - - # Func below is the real decorator and will receive the test function as param - def decorator_func(f): - @wraps(f) - def wrapper(*args, **kwargs): - try: - # Try to run the test - return f(*args, **kwargs) - except exception: - # If exception of given type happens - # just swallow it and raise pytest.Skip with given reason - pytest.skip(reason) - - return wrapper - - return decorator_func +# TODO: move to tests/ +from functools import wraps +from typing import Any, Protocol, Type + + +class test_func(Protocol): + def __call__(*args: Any, **kwargs: Any): ... + + +def skip_on(exception: Type[Exception], reason: str): + """adapted from https://stackoverflow.com/a/63522579""" + import pytest + + # Func below is the real decorator and will receive the test function as param + def decorator_func(f: test_func): + @wraps(f) + def wrapper(*args: Any, **kwargs: Any): + try: + # Try to run the test + return f(*args, **kwargs) + except exception: + # If exception of given type happens + # just swallow it and raise pytest.Skip with given reason + pytest.skip(reason) + + return wrapper + + return decorator_func diff --git a/bioimageio/core/weight_converter/__init__.py b/bioimageio/core/weight_converter/__init__.py index e69de29b..5f1674c9 100644 --- a/bioimageio/core/weight_converter/__init__.py +++ b/bioimageio/core/weight_converter/__init__.py @@ -0,0 +1 @@ +"""coming soon""" diff --git a/bioimageio/core/weight_converter/keras/__init__.py b/bioimageio/core/weight_converter/keras/__init__.py index 471713e2..195b42b8 100644 --- a/bioimageio/core/weight_converter/keras/__init__.py +++ b/bioimageio/core/weight_converter/keras/__init__.py @@ -1 +1 @@ -from .tensorflow import convert_weights_to_tensorflow_saved_model_bundle +# TODO: update keras weight converters diff --git a/bioimageio/core/weight_converter/keras/_tensorflow.py b/bioimageio/core/weight_converter/keras/_tensorflow.py new file mode 100644 index 00000000..c901f458 --- /dev/null +++ b/bioimageio/core/weight_converter/keras/_tensorflow.py @@ -0,0 +1,151 @@ +# type: ignore # TODO: type +import os +import shutil +from pathlib import Path +from typing import no_type_check +from zipfile import ZipFile + +try: + import tensorflow.saved_model +except Exception: + tensorflow = None + +from bioimageio.spec._internal.io_utils import download +from bioimageio.spec.model.v0_5 import ModelDescr + + +def _zip_model_bundle(model_bundle_folder: Path): + zipped_model_bundle = model_bundle_folder.with_suffix(".zip") + + with ZipFile(zipped_model_bundle, "w") as zip_obj: + for root, _, files in os.walk(model_bundle_folder): + for filename in files: + src = os.path.join(root, filename) + zip_obj.write(src, os.path.relpath(src, model_bundle_folder)) + + try: + shutil.rmtree(model_bundle_folder) + except Exception: + print("TensorFlow bundled model was not removed after compression") + + return zipped_model_bundle + + +# adapted from +# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 +def _convert_tf1( + keras_weight_path: Path, + output_path: Path, + input_name: str, + output_name: str, + zip_weights: bool, +): + try: + # try to build the tf model with the keras import from tensorflow + from bioimageio.core.weight_converter.keras._tensorflow import ( + keras, # type: ignore + ) + + except Exception: + # if the above fails try to export with the standalone keras + import keras + + @no_type_check + def build_tf_model(): + keras_model = keras.models.load_model(keras_weight_path) + assert tensorflow is not None + builder = tensorflow.saved_model.builder.SavedModelBuilder(output_path) + signature = tensorflow.saved_model.signature_def_utils.predict_signature_def( + inputs={input_name: keras_model.input}, + outputs={output_name: keras_model.output}, + ) + + signature_def_map = { + tensorflow.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature + } + + builder.add_meta_graph_and_variables( + keras.backend.get_session(), + [tensorflow.saved_model.tag_constants.SERVING], + signature_def_map=signature_def_map, + ) + builder.save() + + build_tf_model() + + if zip_weights: + output_path = _zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return 0 + + +def _convert_tf2(keras_weight_path: Path, output_path: Path, zip_weights: bool): + try: + # try to build the tf model with the keras import from tensorflow + from bioimageio.core.weight_converter.keras._tensorflow import keras + except Exception: + # if the above fails try to export with the standalone keras + import keras + + model = keras.models.load_model(keras_weight_path) + keras.models.save_model(model, output_path) + + if zip_weights: + output_path = _zip_model_bundle(output_path) + print("TensorFlow model exported to", output_path) + + return 0 + + +def convert_weights_to_tensorflow_saved_model_bundle( + model: ModelDescr, output_path: Path +): + """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. + + Adapted from + https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py + + Args: + model: The bioimageio model description + output_path: where to save the tensorflow weights. This path must not exist yet. + """ + assert tensorflow is not None + tf_major_ver = int(tensorflow.__version__.split(".")[0]) + + if output_path.suffix == ".zip": + output_path = output_path.with_suffix("") + zip_weights = True + else: + zip_weights = False + + if output_path.exists(): + raise ValueError(f"The ouptut directory at {output_path} must not exist.") + + if model.weights.keras_hdf5 is None: + raise ValueError("Missing Keras Hdf5 weights to convert from.") + + weight_spec = model.weights.keras_hdf5 + weight_path = download(weight_spec.source).path + + if weight_spec.tensorflow_version: + model_tf_major_ver = int(weight_spec.tensorflow_version.major) + if model_tf_major_ver != tf_major_ver: + raise RuntimeError( + f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}" + ) + + if tf_major_ver == 1: + if len(model.inputs) != 1 or len(model.outputs) != 1: + raise NotImplementedError( + "Weight conversion for models with multiple inputs or outputs is not yet implemented." + ) + return _convert_tf1( + weight_path, + output_path, + model.inputs[0].id, + model.outputs[0].id, + zip_weights, + ) + else: + return _convert_tf2(weight_path, output_path, zip_weights) diff --git a/bioimageio/core/weight_converter/keras/tensorflow.py b/bioimageio/core/weight_converter/keras/tensorflow.py deleted file mode 100644 index 0656ac9e..00000000 --- a/bioimageio/core/weight_converter/keras/tensorflow.py +++ /dev/null @@ -1,130 +0,0 @@ -import os -import shutil -from pathlib import Path -from typing import Union -from zipfile import ZipFile - -import bioimageio.spec as spec -from bioimageio.core import load_resource_description - -import tensorflow -from tensorflow import saved_model - - -def _zip_weights(output_path): - zipped_model = f"{output_path}.zip" - # zip the weights - file_paths = [] - for folder_names, subfolder, filenames in os.walk(os.path.join(output_path)): - for filename in filenames: - # create complete filepath of file in directory - file_paths.append(os.path.join(folder_names, filename)) - - with ZipFile(zipped_model, "w") as zip_obj: - for f in file_paths: - # Add file to zip - zip_obj.write(f, os.path.relpath(f, output_path)) - - try: - shutil.rmtree(output_path) - except Exception: - print("TensorFlow bundled model was not removed after compression") - - return zipped_model - - -# adapted from -# https://github.com/deepimagej/pydeepimagej/blob/master/pydeepimagej/yaml/create_config.py#L236 -def _convert_tf1(keras_weight_path, output_path, input_name, output_name, zip_weights): - def build_tf_model(): - keras_model = keras.models.load_model(keras_weight_path) - - builder = saved_model.builder.SavedModelBuilder(output_path) - signature = saved_model.signature_def_utils.predict_signature_def( - inputs={input_name: keras_model.input}, outputs={output_name: keras_model.output} - ) - - signature_def_map = {saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature} - - builder.add_meta_graph_and_variables( - keras.backend.get_session(), [saved_model.tag_constants.SERVING], signature_def_map=signature_def_map - ) - builder.save() - - try: - # try to build the tf model with the keras import from tensorflow - from tensorflow import keras - build_tf_model() - except Exception: - # if the above fails try to export with the standalone keras - import keras - - build_tf_model() - - if zip_weights: - output_path = _zip_weights(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def _convert_tf2(keras_weight_path, output_path, zip_weights): - try: - # try to build the tf model with the keras import from tensorflow - from tensorflow import keras - except Exception: - # if the above fails try to export with the standalone keras - import keras - - model = keras.models.load_model(keras_weight_path) - keras.models.save_model(model, output_path) - - if zip_weights: - output_path = _zip_weights(output_path) - print("TensorFlow model exported to", output_path) - - return 0 - - -def convert_weights_to_tensorflow_saved_model_bundle( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path] -): - """Convert model weights from format 'keras_hdf5' to 'tensorflow_saved_model_bundle'. - - Adapted from - https://github.com/deepimagej/pydeepimagej/blob/5aaf0e71f9b04df591d5ca596f0af633a7e024f5/pydeepimagej/yaml/create_config.py - - Args: - model_spec: location of the resource for the input bioimageio model - output_path: where to save the tensorflow weights. This path must not exist yet. - """ - tf_major_ver = int(tensorflow.__version__.split(".")[0]) - - path_ = Path(output_path) - if path_.suffix == ".zip": - path_ = Path(os.path.splitext(path_)[0]) - zip_weights = True - else: - zip_weights = False - - if path_.exists(): - raise ValueError(f"The ouptut directory at {path_} must not exist.") - - model = load_resource_description(model_spec) - assert "keras_hdf5" in model.weights - weight_spec = model.weights["keras_hdf5"] - weight_path = str(weight_spec.source) - - if weight_spec.tensorflow_version: - model_tf_major_ver = int(weight_spec.tensorflow_version.major) - if model_tf_major_ver != tf_major_ver: - raise RuntimeError(f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}") - - if tf_major_ver == 1: - if len(model.inputs) != 1 or len(model.outputs) != 1: - raise NotImplementedError( - "Weight conversion for models with multiple inputs or outputs is not yet implemented." - ) - return _convert_tf1(weight_path, str(path_), model.inputs[0].name, model.outputs[0].name, zip_weights) - else: - return _convert_tf2(weight_path, str(path_), zip_weights) diff --git a/bioimageio/core/weight_converter/torch/__init__.py b/bioimageio/core/weight_converter/torch/__init__.py index 27b20c99..1b1ba526 100644 --- a/bioimageio/core/weight_converter/torch/__init__.py +++ b/bioimageio/core/weight_converter/torch/__init__.py @@ -1,2 +1 @@ -from .onnx import convert_weights_to_onnx -from .torchscript import convert_weights_to_torchscript +# TODO: torch weight converters diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py new file mode 100644 index 00000000..3935e1d1 --- /dev/null +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -0,0 +1,108 @@ +# type: ignore # TODO: type +import warnings +from pathlib import Path +from typing import Any, List, Sequence, cast + +import numpy as np +from numpy.testing import assert_array_almost_equal + +from bioimageio.spec import load_description +from bioimageio.spec.common import InvalidDescr +from bioimageio.spec.model import v0_4, v0_5 + +from ...digest_spec import get_member_id, get_test_inputs +from ...weight_converter.torch._utils import load_torch_model + +try: + import torch +except ImportError: + torch = None + + +def add_onnx_weights( + model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", + *, + output_path: Path, + use_tracing: bool = True, + test_decimal: int = 4, + verbose: bool = False, + opset_version: "int | None" = None, +): + """Convert model weights from format 'pytorch_state_dict' to 'onnx'. + + Args: + source_model: model without onnx weights + opset_version: onnx opset version + use_tracing: whether to use tracing or scripting to export the onnx format + test_decimal: precision for testing whether the results agree + """ + if isinstance(model_spec, (str, Path)): + loaded_spec = load_description(Path(model_spec)) + if isinstance(loaded_spec, InvalidDescr): + raise ValueError(f"Bad resource description: {loaded_spec}") + if not isinstance(loaded_spec, (v0_4.ModelDescr, v0_5.ModelDescr)): + raise TypeError( + f"Path {model_spec} is a {loaded_spec.__class__.__name__}, expected a v0_4.ModelDescr or v0_5.ModelDescr" + ) + model_spec = loaded_spec + + state_dict_weights_descr = model_spec.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + assert torch is not None + with torch.no_grad(): + + sample = get_test_inputs(model_spec) + input_data = [sample[get_member_id(ipt)].data.data for ipt in model_spec.inputs] + input_tensors = [torch.from_numpy(ipt) for ipt in input_data] + model = load_torch_model(state_dict_weights_descr) + + expected_tensors = model(*input_tensors) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] + + if use_tracing: + torch.onnx.export( + model, + tuple(input_tensors) if len(input_tensors) > 1 else input_tensors[0], + str(output_path), + verbose=verbose, + opset_version=opset_version, + ) + else: + raise NotImplementedError + + try: + import onnxruntime as rt # pyright: ignore [reportMissingTypeStubs] + except ImportError: + msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." + warnings.warn(msg) + return + + # check the onnx model + sess = rt.InferenceSession(str(output_path)) + onnx_input_node_args = cast( + List[Any], sess.get_inputs() + ) # fixme: remove cast, try using rt.NodeArg instead of Any + onnx_inputs = { + input_name.name: inp + for input_name, inp in zip(onnx_input_node_args, input_data) + } + outputs = cast( + Sequence[np.ndarray[Any, Any]], sess.run(None, onnx_inputs) + ) # FIXME: remove cast + + try: + for exp, out in zip(expected_outputs, outputs): + assert_array_almost_equal(exp, out, decimal=test_decimal) + return 0 + except AssertionError as e: + msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" + warnings.warn(msg) + return 1 diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py new file mode 100644 index 00000000..5ca16069 --- /dev/null +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -0,0 +1,146 @@ +# type: ignore # TODO: type +from pathlib import Path +from typing import List, Sequence, Union + +import numpy as np +from numpy.testing import assert_array_almost_equal +from typing_extensions import Any, assert_never + +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.model.v0_5 import Version + +from ._utils import load_torch_model + +try: + import torch +except ImportError: + torch = None + + +# FIXME: remove Any +def _check_predictions( + model: Any, + scripted_model: Any, + model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", + input_data: Sequence["torch.Tensor"], +): + assert torch is not None + + def _check(input_: Sequence[torch.Tensor]) -> None: + expected_tensors = model(*input_) + if isinstance(expected_tensors, torch.Tensor): + expected_tensors = [expected_tensors] + expected_outputs: List[np.ndarray[Any, Any]] = [ + out.numpy() for out in expected_tensors + ] + + output_tensors = scripted_model(*input_) + if isinstance(output_tensors, torch.Tensor): + output_tensors = [output_tensors] + outputs: List[np.ndarray[Any, Any]] = [out.numpy() for out in output_tensors] + + try: + for exp, out in zip(expected_outputs, outputs): + assert_array_almost_equal(exp, out, decimal=4) + except AssertionError as e: + raise ValueError( + f"Results before and after weights conversion do not agree:\n {str(e)}" + ) + + _check(input_data) + + if len(model_spec.inputs) > 1: + return # FIXME: why don't we check multiple inputs? + + input_descr = model_spec.inputs[0] + if isinstance(input_descr, v0_4.InputTensorDescr): + if not isinstance(input_descr.shape, v0_4.ParameterizedInputShape): + return + min_shape = input_descr.shape.min + step = input_descr.shape.step + else: + min_shape: List[int] = [] + step: List[int] = [] + for axis in input_descr.axes: + if isinstance(axis.size, v0_5.ParameterizedSize): + min_shape.append(axis.size.min) + step.append(axis.size.step) + elif isinstance(axis.size, int): + min_shape.append(axis.size) + step.append(0) + elif axis.size is None: + raise NotImplementedError( + f"Can't verify inputs that don't specify their shape fully: {axis}" + ) + elif isinstance(axis.size, v0_5.SizeReference): + raise NotImplementedError(f"Can't handle axes like '{axis}' yet") + else: + assert_never(axis.size) + + half_step = [st // 2 for st in step] + max_steps = 4 + + # check that input and output agree for decreasing input sizes + for step_factor in range(1, max_steps + 1): + slice_ = tuple( + slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) + for st in half_step + ) + this_input = [inp[slice_] for inp in input_data] + this_shape = this_input[0].shape + if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): + raise ValueError( + f"Mismatched shapes: {this_shape}. Expected at least {min_shape}" + ) + _check(this_input) + + +def convert_weights_to_torchscript( + model_descr: Union[v0_4.ModelDescr, v0_5.ModelDescr], + output_path: Path, + use_tracing: bool = True, +) -> v0_5.TorchscriptWeightsDescr: + """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. + + Args: + model_descr: location of the resource for the input bioimageio model + output_path: where to save the torchscript weights + use_tracing: whether to use tracing or scripting to export the torchscript format + """ + + state_dict_weights_descr = model_descr.weights.pytorch_state_dict + if state_dict_weights_descr is None: + raise ValueError( + "The provided model does not have weights in the pytorch state dict format" + ) + + input_data = model_descr.get_input_test_arrays() + + with torch.no_grad(): + input_data = [torch.from_numpy(inp.astype("float32")) for inp in input_data] + + model = load_torch_model(state_dict_weights_descr) + + # FIXME: remove Any + if use_tracing: + scripted_model: Any = torch.jit.trace(model, input_data) + else: + scripted_model: Any = torch.jit.script(model) + + _check_predictions( + model=model, + scripted_model=scripted_model, + model_spec=model_descr, + input_data=input_data, + ) + + # save the torchscript model + scripted_model.save( + str(output_path) + ) # does not support Path, so need to cast to str + + return v0_5.TorchscriptWeightsDescr( + source=output_path, + pytorch_version=Version(torch.__version__), + parent="pytorch_state_dict", + ) diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py new file mode 100644 index 00000000..01df0747 --- /dev/null +++ b/bioimageio/core/weight_converter/torch/_utils.py @@ -0,0 +1,24 @@ +from typing import Union + +from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter +from bioimageio.spec.model import v0_4, v0_5 +from bioimageio.spec.utils import download + +try: + import torch +except ImportError: + torch = None + + +# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too +# and for each weight format +def load_torch_model( # pyright: ignore[reportUnknownParameterType] + node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], +): + assert torch is not None + model = ( # pyright: ignore[reportUnknownVariableType] + PytorchModelAdapter.get_network(node) + ) + state = torch.load(download(node.source).path, map_location="cpu") + model.load_state_dict(state) # FIXME: check incompatible keys? + return model.eval() # pyright: ignore[reportUnknownVariableType] diff --git a/bioimageio/core/weight_converter/torch/onnx.py b/bioimageio/core/weight_converter/torch/onnx.py deleted file mode 100644 index 6f9ac1d2..00000000 --- a/bioimageio/core/weight_converter/torch/onnx.py +++ /dev/null @@ -1,82 +0,0 @@ -import warnings -from pathlib import Path -from typing import Union, Optional - -import numpy as np -import torch -from numpy.testing import assert_array_almost_equal - -import bioimageio.spec as spec -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io import nodes -from .utils import load_model - -try: - import onnxruntime as rt -except ImportError: - rt = None - - -def convert_weights_to_onnx( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], - output_path: Union[str, Path], - opset_version: Optional[int] = 12, - use_tracing: bool = True, - verbose: bool = True, - test_decimal: int = 4 -): - """Convert model weights from format 'pytorch_state_dict' to 'onnx'. - - Args: - model_spec: location of the resource for the input bioimageio model - output_path: where to save the onnx weights - opset_version: onnx opset version - use_tracing: whether to use tracing or scripting to export the onnx format - verbose: be verbose during the onnx export - test_decimal: precision for testing whether the results agree - """ - if isinstance(model_spec, (str, Path)): - model_spec = load_resource_description(Path(model_spec)) - - assert isinstance(model_spec, nodes.Model) - with torch.no_grad(): - # load input and expected output data - input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs] - input_tensors = [torch.from_numpy(inp) for inp in input_data] - - # instantiate and generate the expected output - model = load_model(model_spec) - expected_outputs = model(*input_tensors) - if isinstance(expected_outputs, torch.Tensor): - expected_outputs = [expected_outputs] - expected_outputs = [out.numpy() for out in expected_outputs] - - if use_tracing: - torch.onnx.export( - model, - input_tensors if len(input_tensors) > 1 else input_tensors[0], - output_path, - verbose=verbose, - opset_version=opset_version, - ) - else: - raise NotImplementedError - - if rt is None: - msg = "The onnx weights were exported, but onnx rt is not available and weights cannot be checked." - warnings.warn(msg) - return 1 - - # check the onnx model - sess = rt.InferenceSession(str(output_path)) # does not support Path, so need to cast to str - onnx_inputs = {input_name.name: inp for input_name, inp in zip(sess.get_inputs(), input_data)} - outputs = sess.run(None, onnx_inputs) - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=test_decimal) - return 0 - except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 diff --git a/bioimageio/core/weight_converter/torch/torchscript.py b/bioimageio/core/weight_converter/torch/torchscript.py deleted file mode 100644 index 7da79bfe..00000000 --- a/bioimageio/core/weight_converter/torch/torchscript.py +++ /dev/null @@ -1,107 +0,0 @@ -import warnings - -from pathlib import Path -from typing import Union - -import numpy as np -import torch -from numpy.testing import assert_array_almost_equal - -import bioimageio.spec as spec -from bioimageio.core import load_resource_description -from .utils import load_model - - -def _check_predictions(model, scripted_model, model_spec, input_data): - assert isinstance(input_data, list) - - def _check(input_): - # get the expected output to validate the torchscript weights - expected_outputs = model(*input_) - if isinstance(expected_outputs, (torch.Tensor)): - expected_outputs = [expected_outputs] - expected_outputs = [out.numpy() for out in expected_outputs] - - outputs = scripted_model(*input_) - if isinstance(outputs, (torch.Tensor)): - outputs = [outputs] - outputs = [out.numpy() for out in outputs] - - try: - for exp, out in zip(expected_outputs, outputs): - assert_array_almost_equal(exp, out, decimal=4) - return 0 - except AssertionError as e: - msg = f"The onnx weights were exported, but results before and after conversion do not agree:\n {str(e)}" - warnings.warn(msg) - return 1 - - ret = _check(input_data) - n_inputs = len(model_spec.inputs) - # check has not passed or we have more tahn one input? then return immediately - if ret == 1 or n_inputs > 1: - return ret - - # do we have fixed input size or variable? - # if variable, we need to check multiple sizes! - shape_spec = model_spec.inputs[0].shape - try: # we have a variable shape - min_shape = shape_spec.min - step = shape_spec.step - except AttributeError: # we have fixed shape - return ret - - half_step = [st // 2 for st in step] - max_steps = 4 - step_factor = 1 - - # check that input and output agree for decreasing input sizes - while True: - - slice_ = tuple(slice(None) if st == 0 else slice(step_factor * st, -step_factor * st) for st in half_step) - this_input = [inp[slice_] for inp in input_data] - this_shape = this_input[0].shape - if any(tsh < msh for tsh, msh in zip(this_shape, min_shape)): - return ret - - ret = _check(this_input) - if ret == 1: - return ret - step_factor += 1 - if step_factor > max_steps: - return ret - - -def convert_weights_to_torchscript( - model_spec: Union[str, Path, spec.model.raw_nodes.Model], output_path: Union[str, Path], use_tracing: bool = True -): - """Convert model weights from format 'pytorch_state_dict' to 'torchscript'. - - Args: - model_spec: location of the resource for the input bioimageio model - output_path: where to save the torchscript weights - use_tracing: whether to use tracing or scripting to export the torchscript format - """ - if isinstance(model_spec, (str, Path)): - model_spec = load_resource_description(Path(model_spec)) - - with torch.no_grad(): - # load input and expected output data - input_data = [np.load(inp).astype("float32") for inp in model_spec.test_inputs] - input_data = [torch.from_numpy(inp) for inp in input_data] - - # instantiate model and get reference output - model = load_model(model_spec) - - # make scripted model - if use_tracing: - scripted_model = torch.jit.trace(model, input_data) - else: - scripted_model = torch.jit.script(model) - - # check the scripted model - ret = _check_predictions(model, scripted_model, model_spec, input_data) - - # save the torchscript model - scripted_model.save(str(output_path)) # does not support Path, so need to cast to str - return ret diff --git a/bioimageio/core/weight_converter/torch/utils.py b/bioimageio/core/weight_converter/torch/utils.py deleted file mode 100644 index 9c122ad5..00000000 --- a/bioimageio/core/weight_converter/torch/utils.py +++ /dev/null @@ -1,12 +0,0 @@ -import torch -from bioimageio.core.prediction_pipeline._model_adapters._pytorch_model_adapter import PytorchModelAdapter - - -# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too -# and for each weight format -def load_model(node): - model = PytorchModelAdapter.get_nn_instance(node) - state = torch.load(node.weights["pytorch_state_dict"].source, map_location="cpu") - model.load_state_dict(state) - model.eval() - return model diff --git a/conda-recipe/meta.yaml b/conda-recipe/meta.yaml index 482aaf33..7654f5ab 100644 --- a/conda-recipe/meta.yaml +++ b/conda-recipe/meta.yaml @@ -18,12 +18,10 @@ build: requirements: host: - - python >=3.7,<3.10 + - python >=3.8,<3.13 - pip run: - - python >=3.7,<3.10 - - tqdm - - typer + - python >=3.8,<3.13 {% for dep in setup_py_data['install_requires'] %} - {{ dep.lower() }} {% endfor %} @@ -47,12 +45,15 @@ requirements: test: imports: - bioimageio.core - - bioimageio.core.build_spec source_files: - tests requires: - {% for dep in setup_py_data['extras_require']['test'] %} + {% for dep in setup_py_data['extras_require']['dev'] %} + {% if dep.startswith('torch>=') %} # pip: torch -> conda: pytorch + - py{{ dep.lower() }} + {% else %} - {{ dep.lower() }} + {% endif %} {% endfor %} commands: - pytest @@ -64,6 +65,5 @@ about: license_family: MIT license_file: LICENSE summary: 'Tools for running BioimageIO compliant neural networks in Python.' - doc_url: https://github.com/bioimage-io/core-bioimage-io-python dev_url: https://github.com/bioimage-io/core-bioimage-io-python diff --git a/dev/env-py38.yaml b/dev/env-py38.yaml new file mode 100644 index 00000000..760d2f97 --- /dev/null +++ b/dev/env-py38.yaml @@ -0,0 +1,39 @@ +# manipulated copy of env.yaml +name: core38 +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.2.post1 + - black + - crick # uncommented + - filelock + - fire + - imageio>=2.5 + - jupyter + - jupyter-black + # - keras>=3.0 # removed + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + - python=3.8 # changed + - pytorch>=2.1 + - rich + - ruff + - ruyaml + - torchvision + - tqdm + - typing-extensions + - xarray + - pip: + - -e .. diff --git a/dev/env-tf.yaml b/dev/env-tf.yaml new file mode 100644 index 00000000..d51c5ad3 --- /dev/null +++ b/dev/env-tf.yaml @@ -0,0 +1,40 @@ +# modified copy of env.yaml +name: core-tf # changed +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.2.post1 + - black + # - crick # currently requires python<=3.9 + - filelock + - fire + - imageio>=2.5 + - jupyter + - jupyter-black + - keras>=2.15 # changed + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + # - python=3.9 # removed + # - pytorch>=2.1 # removed + - rich + # - ruff # removed + - ruyaml + - tensorflow>=2.15 # added + # - torchvision # removed + - tqdm + - typing-extensions + - xarray + - pip: + - -e .. diff --git a/dev/env-wo-python.yaml b/dev/env-wo-python.yaml new file mode 100644 index 00000000..fedd86d3 --- /dev/null +++ b/dev/env-wo-python.yaml @@ -0,0 +1,39 @@ +# modified copy of env.yaml +name: core +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.2.post1 + - black + # - crick # currently requires python<=3.9 + - filelock + - fire + - imageio>=2.5 + - jupyter + - jupyter-black + - keras>=3.0 + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + # - python=3.9 # removed + - pytorch>=2.1 + - rich + # - ruff # removed + - ruyaml + - torchvision + - tqdm + - typing-extensions + - xarray + - pip: + - -e .. diff --git a/dev/env.yaml b/dev/env.yaml new file mode 100644 index 00000000..4ac24ecf --- /dev/null +++ b/dev/env.yaml @@ -0,0 +1,38 @@ +name: core +channels: + - conda-forge + - defaults +dependencies: + - bioimageio.spec>=0.5.2.post1 + - black + # - crick # currently requires python<=3.9 + - filelock + - fire + - imageio>=2.5 + - jupyter + - jupyter-black + - keras>=3.0 + - loguru + - numpy + - onnxruntime + - packaging>=17.0 + - pip + - pre-commit + - psutil + - pydantic + - pydantic-settings + - pyright + - pytest + - pytest-xdist + - python-dotenv + - python=3.9 + - pytorch>=2.1 + - rich + - ruff + - ruyaml + - torchvision + - tqdm + - typing-extensions + - xarray + - pip: + - -e .. diff --git a/dev/environment-base.yaml b/dev/environment-base.yaml deleted file mode 100644 index 88a336ee..00000000 --- a/dev/environment-base.yaml +++ /dev/null @@ -1,22 +0,0 @@ -name: bio-core-dev -channels: - - conda-forge - - defaults -dependencies: - - black - - bioimageio.spec - - conda-build - - h5py >=2.10,<2.11 - - mypy - - pip - - pre-commit - - pytest - - python >=3.7,<3.8 # this environment is only available for python 3.7 - - xarray - - pytorch - - onnx - - onnxruntime - - tensorflow >=1.12,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - - pip: - - keras==1.2.2 diff --git a/dev/environment-tf-legacy.yaml b/dev/environment-tf-legacy.yaml deleted file mode 100644 index 976ea3d6..00000000 --- a/dev/environment-tf-legacy.yaml +++ /dev/null @@ -1,17 +0,0 @@ -name: bio-core-tf-legacy -channels: - - conda-forge - - defaults -dependencies: - - black - - bioimageio.spec - - conda-build - - h5py >=2.10,<2.11 - - mypy - - pip - - pytest - - python >=3.7,<3.8 # this environment is only available for python 3.7 - - xarray - - tensorflow >1.14,<2.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - - keras diff --git a/dev/environment-tf.yaml b/dev/environment-tf.yaml deleted file mode 100644 index 4ecd57d8..00000000 --- a/dev/environment-tf.yaml +++ /dev/null @@ -1,15 +0,0 @@ -name: bio-core-tf -channels: - - conda-forge - - defaults -dependencies: - - black - - bioimageio.spec - - conda-build - - mypy - - pip - - pytest - - python - - xarray - - tensorflow >=2.9,<3.0 - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 diff --git a/dev/environment-torch.yaml b/dev/environment-torch.yaml deleted file mode 100644 index 98a944cd..00000000 --- a/dev/environment-torch.yaml +++ /dev/null @@ -1,19 +0,0 @@ -name: bio-core-torch -channels: - - conda-forge - - defaults -dependencies: - - black - - bioimageio.spec >=0.4.4 - - conda-build - - h5py - - mypy - - pip - - pytest - - python >=3.7 - - xarray - - pytorch - - onnx - - onnxruntime - - tifffile <=2022.4.8 # pin fixes Syntax error; see https://github.com/bioimage-io/core-bioimage-io-python/pull/259 - diff --git a/example/dataset_creation.ipynb b/example/dataset_creation.ipynb new file mode 100644 index 00000000..13596956 --- /dev/null +++ b/example/dataset_creation.ipynb @@ -0,0 +1,118 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.pretty_validation_errors import (\n", + " enable_pretty_validation_errors_in_ipynb,\n", + ")\n", + "\n", + "enable_pretty_validation_errors_in_ipynb()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bioimageio.spec.dataset.v0_3 import (\n", + " Author,\n", + " CiteEntry,\n", + " Dataset,\n", + " HttpUrl,\n", + " RelativeFilePath,\n", + ")\n", + "\n", + "nuclei_broad_data = Dataset(\n", + " name=\"Kaggle 2018 Data Science Bowl\",\n", + " description=\"This image data set contains a large number of segmented nuclei images and was created for the Kaggle \"\n", + " \"2018 Data Science Bowl sponsored by Booz Allen Hamilton with cash prizes. The image set was a testing ground \"\n", + " \"for the application of novel and cutting edge approaches in computer vision and machine learning to the \"\n", + " \"segmentation of the nuclei belonging to cells from a breadth of biological contexts.\",\n", + " documentation=RelativeFilePath(\"README.md\"),\n", + " covers=(\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage1.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage2.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage3.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage4.png\"\n", + " ),\n", + " HttpUrl(\n", + " \"https://data.broadinstitute.org/bbbc/BBBC038/BBBC038exampleimage5.png\"\n", + " ),\n", + " ),\n", + " authors=(\n", + " Author(\n", + " name=\"Fynn Beuttenmueller\",\n", + " affiliation=\"EMBL\",\n", + " github_user=\"fynnbe\",\n", + " orcid=\"0000-0002-8567-6389\",\n", + " ),\n", + " ),\n", + " source=HttpUrl(\"https://bbbc.lbroadinstitute.org/BBBC038/\"),\n", + " cite=(\n", + " CiteEntry(\n", + " text=\"Caicedo, J.C., Goodman, A., Karhohs, K.W. et al. Nucleus segmentation across imaging experiments: \"\n", + " \"the 2018 Data Science Bowl. Nat Methods 16, 1247–1253 (2019).\",\n", + " url=\"10.1038/s41592-019-0612-7\",\n", + " ),\n", + " CiteEntry(\n", + " text=\"Allen Goodman, Anne Carpenter, Elizabeth Park, jlefman-nvidia, Josette_BoozAllen, Kyle, Maggie, \"\n", + " \"Nilofer, Peter Sedivec, Will Cukierski. (2018). 2018 Data Science Bowl . Kaggle.\",\n", + " url=\"https://kaggle.com/competitions/data-science-bowl-2018\",\n", + " ),\n", + " ),\n", + " license=\"CC0-1.0\",\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nuclei_broad_data.source" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "bio38", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.17" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/dataset_statistics_demo.ipynb b/example/dataset_statistics_demo.ipynb index a8f30526..e59f4fd3 100644 --- a/example/dataset_statistics_demo.ipynb +++ b/example/dataset_statistics_demo.ipynb @@ -98,27 +98,41 @@ "source": [ "from bioimageio.core.prediction import get_tiling\n", "\n", - "tile_shape = dict(zip(\n", - " model_resource.inputs[0].axes, \n", - " np.asarray(model_resource.inputs[0].shape.min) + np.asarray(model_resource.inputs[0].shape.step)\n", - "))\n", - "\n", - "tiles = list(get_tiling(\n", - " shape=input_image.shape,\n", - " tile_shape=tile_shape,\n", - " halo=dict(zip(model_resource.inputs[0].axes, model_resource.outputs[0].halo)),\n", - " input_axes=model_resource.inputs[0].axes\n", - "))\n", + "tile_shape = dict(\n", + " zip(\n", + " model_resource.inputs[0].axes,\n", + " np.asarray(model_resource.inputs[0].shape.min)\n", + " + np.asarray(model_resource.inputs[0].shape.step),\n", + " )\n", + ")\n", + "\n", + "tiles = list(\n", + " get_tiling(\n", + " shape=input_image.shape,\n", + " tile_shape=tile_shape,\n", + " halo=dict(zip(model_resource.inputs[0].axes, model_resource.outputs[0].halo)),\n", + " input_axes=model_resource.inputs[0].axes,\n", + " )\n", + ")\n", + "\n", "\n", "def add_tile_box(ax, t):\n", " x = t[\"x\"].start\n", " w = t[\"x\"].stop - x\n", " y = t[\"y\"].start\n", " h = t[\"y\"].stop - y\n", - " \n", - " box = Rectangle((x, y), w, h, linewidth=1, edgecolor=np.random.choice(list(\"rgbcmykw\")), facecolor=\"none\")\n", + "\n", + " box = Rectangle(\n", + " (x, y),\n", + " w,\n", + " h,\n", + " linewidth=1,\n", + " edgecolor=np.random.choice(list(\"rgbcmykw\")),\n", + " facecolor=\"none\",\n", + " )\n", " ax.add_patch(box)\n", "\n", + "\n", "fig, ax = plt.subplots(1, 2)\n", "fig.suptitle(\"'samples' of test image 'dataset'\")\n", "ax[0].set_title(\"input (outer) tiles\")\n", @@ -175,19 +189,28 @@ "def process_dataset(pp, dataset):\n", " stats = pp._ipt_stats.compute_measures()[\"per_dataset\"]\n", " print(f\"initial stats:\")\n", - " pprint(None if not stats else {k: f\"{v.item():.2f}\" for k, v in stats[\"input0\"].items()})\n", + " pprint(\n", + " None\n", + " if not stats\n", + " else {k: f\"{v.item():.2f}\" for k, v in stats[\"input0\"].items()}\n", + " )\n", " stats = {}\n", " sample_dataset = [{\"input0\": s} for s in dataset]\n", " [pp.apply_preprocessing(s, stats) for s in sample_dataset]\n", " print(f\"final stats:\")\n", - " pprint(None if not stats else {k: f\"{v.item():.2f}\" for k, v in stats[\"per_dataset\"][\"input0\"].items()})\n", + " pprint(\n", + " None\n", + " if not stats\n", + " else {k: f\"{v.item():.2f}\" for k, v in stats[\"per_dataset\"][\"input0\"].items()}\n", + " )\n", " return [s[\"input0\"] for s in sample_dataset]\n", "\n", + "\n", "# accumulate dataset statistics exclusively while processing samples (no initial dataset statistics are computed)\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " # dataset_for_initial_statistics=tuple(), # an empty dataset is the default\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " # dataset_for_initial_statistics=tuple(), # an empty dataset is the default\n", + ") as pp:\n", " wo_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -220,9 +243,9 @@ "source": [ "# accumulate dataset statistics exclusively while processing samples for a limited number of samples\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " \tupdate_dataset_stats_for_n_samples=len(dataset) // 2,\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " update_dataset_stats_for_n_samples=len(dataset) // 2,\n", + ") as pp:\n", " wo_init_dataset_stats_limit = process_dataset(pp, dataset)" ] }, @@ -256,14 +279,14 @@ ], "source": [ "# initialize dataset statistics with first n samples and keep update dataset statistics after the n sample\n", - "# this assumes that the n samples present in 'dataset_for_initial_statistics' are those that will be processed \n", - "# by the prediction pipeline and thus should not update the dataset statistics. \n", + "# this assumes that the n samples present in 'dataset_for_initial_statistics' are those that will be processed\n", + "# by the prediction pipeline and thus should not update the dataset statistics.\n", "# Use 'update_dataset_stats_after_n_samples=0' if that is not your use case.\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " dataset_for_initial_statistics=dataset[:len(dataset) // 2],\n", - " # update_dataset_stats_after_n_samples=None, # defaults to len(dataset_for_initial_statistics)\n", - " ) as pp:\n", + " bioimageio_model=model_resource,\n", + " dataset_for_initial_statistics=dataset[: len(dataset) // 2],\n", + " # update_dataset_stats_after_n_samples=None, # defaults to len(dataset_for_initial_statistics)\n", + ") as pp:\n", " partial_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -299,11 +322,11 @@ "# compute dataset statistics on all samples\n", "# (in this case we should really use the non-overlapping tiles as samples in dataset_for_initial_statistics)\n", "with create_prediction_pipeline(\n", - " bioimageio_model=model_resource,\n", - " dataset_for_initial_statistics=dataset,\n", - " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset) \n", + " bioimageio_model=model_resource,\n", + " dataset_for_initial_statistics=dataset,\n", + " update_dataset_stats_for_n_samples=0, # if you call the prediciton pipeline more then len(dataset)\n", " # times you might want to set this to zero to avoid further updates to the dataset statistics\n", - " ) as pp:\n", + ") as pp:\n", " only_init_dataset_stats = process_dataset(pp, dataset)" ] }, @@ -323,15 +346,32 @@ "outputs": [], "source": [ "def untile(outputs):\n", - " untiled = xr.DataArray(np.empty((1, 2, *input_image.squeeze().shape)), dims=model_resource.outputs[0].axes)\n", + " untiled = xr.DataArray(\n", + " np.empty((1, 2, *input_image.squeeze().shape)),\n", + " dims=model_resource.outputs[0].axes,\n", + " )\n", " for out, t in zip(outputs, tiles):\n", " untiled[t.inner] = out[0]\n", "\n", " return untiled.data.squeeze()\n", "\n", + "\n", "# prepare image comparisons\n", - "titles = [\"wo_init_dataset_stats\", \"wo_init_dataset_stats_limit\", \"partial_init_dataset_stats\", \"only_init_dataset_stats\"]\n", - "images = [untile(out)[0] for out in [wo_init_dataset_stats, wo_init_dataset_stats_limit, partial_init_dataset_stats, only_init_dataset_stats]]" + "titles = [\n", + " \"wo_init_dataset_stats\",\n", + " \"wo_init_dataset_stats_limit\",\n", + " \"partial_init_dataset_stats\",\n", + " \"only_init_dataset_stats\",\n", + "]\n", + "images = [\n", + " untile(out)[0]\n", + " for out in [\n", + " wo_init_dataset_stats,\n", + " wo_init_dataset_stats_limit,\n", + " partial_init_dataset_stats,\n", + " only_init_dataset_stats,\n", + " ]\n", + "]" ] }, { @@ -364,17 +404,22 @@ "source": [ "fig, axes = plt.subplots(2, 4, figsize=(30, 15))\n", "for ax in axes[1]:\n", - " ax.set_axis_off()\n", + " ax.set_axis_off()\n", "\n", "zoom_roi = np.s_[20:60, 50:90]\n", + "\n", + "\n", "def get_box():\n", " return Rectangle(\n", " (zoom_roi[1].start, zoom_roi[0].start),\n", " zoom_roi[1].stop - zoom_roi[1].start,\n", " zoom_roi[0].stop - zoom_roi[0].start,\n", - " linewidth=1, edgecolor='r', facecolor='none'\n", + " linewidth=1,\n", + " edgecolor=\"r\",\n", + " facecolor=\"none\",\n", " )\n", "\n", + "\n", "vmin = min([img.min() for img in images])\n", "vmax = max([img.max() for img in images])\n", "zoom_vmin = min([img[zoom_roi].min() for img in images])\n", @@ -385,7 +430,9 @@ " axes[0, i].add_patch(get_box())\n", " axes[0, i].set_title(f\"{title} (min: {img.min():.2f} max: {img.max():.2f})\")\n", " axes[1, i].imshow(img[zoom_roi], vmin=zoom_vmin, vmax=zoom_vmax)\n", - " axes[1, i].set_title(f\"zooom in (min: {img[zoom_roi].min():.2f} max: {img[zoom_roi].max():.2f})\")\n", + " axes[1, i].set_title(\n", + " f\"zooom in (min: {img[zoom_roi].min():.2f} max: {img[zoom_roi].max():.2f})\"\n", + " )\n", "\n", "plt.show()" ] @@ -421,7 +468,7 @@ "fig, ax = plt.subplots(4, 4, figsize=(20, 20))\n", "for ai, (atitle, a) in enumerate(zip(titles, images)):\n", " for bi, (btitle, b) in enumerate(zip(titles, images)):\n", - " ax[ai, bi].imshow(np.abs(a-b))\n", + " ax[ai, bi].imshow(np.abs(a - b))\n", " if ai == 0:\n", " ax[ai, bi].set_title(btitle)\n", " if bi == 0:\n", diff --git a/example/model_creation.ipynb b/example/model_creation.ipynb.needs_update similarity index 99% rename from example/model_creation.ipynb rename to example/model_creation.ipynb.needs_update index dc81c52e..e45714e7 100644 --- a/example/model_creation.ipynb +++ b/example/model_creation.ipynb.needs_update @@ -162,7 +162,7 @@ "# it will output a list of dictionaries. each dict gives the status of a different test that is being run\n", "# if all of them contain \"status\": \"passed\" then all tests were successful\n", "from bioimageio.core.resource_tests import test_model\n", - "my_model = bioimageio.core.load_resource_description(\"my-model/model.zip\") \n", + "my_model = bioimageio.core.load_resource_description(\"my-model/model.zip\")\n", "test_model(my_model)" ] }, @@ -272,7 +272,7 @@ "zip_path = os.path.join(model_root, f\"{name}.zip\")\n", "\n", "# `build_model` needs some additional information about the model, like citation information\n", - "# all this additional information is passed as plain python types and will be converted into the bioimageio representation internally \n", + "# all this additional information is passed as plain python types and will be converted into the bioimageio representation internally\n", "# for more informantion, check out the function signature\n", "# https://github.com/bioimage-io/core-bioimage-io-python/blob/main/bioimageio/core/build_spec/build_model.py#L252\n", "cite = [{\"text\": cite_entry.text, \"url\": cite_entry.url} for cite_entry in model_resource.cite]\n", @@ -388,7 +388,7 @@ "# the path to save the new model with torchscript weights\n", "temp_zip_path = f\"{model_root}/new_model3.zip\"\n", "\n", - "_ = build_model( \n", + "_ = build_model(\n", " weight_uri=weight_file,\n", " weight_type=\"pytorch_state_dict\",\n", " architecture=model_source,\n", @@ -494,7 +494,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.8" + "version": "3.8.17" } }, "nbformat": 4, diff --git a/example/model_usage.ipynb b/example/model_usage.ipynb index 7e7d7a8a..5f7c835d 100644 --- a/example/model_usage.ipynb +++ b/example/model_usage.ipynb @@ -37,7 +37,7 @@ "# helper function for showing multiple images in napari\n", "def show_images(*images, names=None):\n", " v = napari.Viewer()\n", - " for i, im in enumerate(images):\n", + " for i, im in enumerate(images):\n", " name = None if names is None else names[i]\n", " if isinstance(im, str):\n", " im = imageio.imread(im)\n", @@ -77,7 +77,9 @@ "# - go to https://bioimage.io/#/?id=10.5281%2Fzenodo.5764892%2F5764893\n", "# - click the download icon\n", "# - select \"ilastik\" weight format\n", - "rdf_path = \"/home/pape/Downloads/nuclei-segmentation-boundarymodel_pytorch_state_dict.zip\"" + "rdf_path = (\n", + " \"/home/pape/Downloads/nuclei-segmentation-boundarymodel_pytorch_state_dict.zip\"\n", + ")" ] }, { @@ -126,7 +128,10 @@ "# we can e.g. check what weight formats are available in the model (pytorch_state_dict for the model used here)\n", "print(\"Available weight formats for this model:\", model_resource.weights.keys())\n", "# or where the (downloaded) weight files are stored\n", - "print(\"Pytorch state dict weights are stored at:\", model_resource.weights[\"pytorch_state_dict\"].source)\n", + "print(\n", + " \"Pytorch state dict weights are stored at:\",\n", + " model_resource.weights[\"pytorch_state_dict\"].source,\n", + ")\n", "print()\n", "# or what inputs the model expects\n", "print(\"The model requires as inputs:\")\n", @@ -151,15 +156,16 @@ "# before using a model, it is recommended to check that it properly works with this function\n", "# 'test_model' returns a dict with 'status'='passed'/'failed' and more detailed information\n", "from bioimageio.core.resource_tests import test_model\n", - "for test_result in test_model(model_resource):\n", - " if test_result[\"status\"] == \"failed\":\n", - " print(\"model test:\", test_result[\"name\"])\n", - " print(\"The model test failed with:\", test_result[\"error\"])\n", - " print(\"with the traceback:\")\n", - " print(\"\".join(test_result[\"traceback\"]))\n", - " else:\n", - " test_result[\"status\"] == \"passed\"\n", - " print(\"The model passed all tests\")" + "\n", + "test_result = test_model(model_resource)\n", + "if test_result[\"status\"] == \"failed\":\n", + " print(\"model test:\", test_result[\"name\"])\n", + " print(\"The model test failed with:\", test_result[\"error\"])\n", + " print(\"with the traceback:\")\n", + " print(\"\".join(test_result[\"traceback\"]))\n", + "else:\n", + " test_result[\"status\"] == \"passed\"\n", + " print(\"The model passed all tests\")" ] }, { @@ -235,7 +241,9 @@ "# The prediction pipeline always returns a tuple (even if the model only has a single output tensor).\n", "# So we access the first element of the prediction to get the predicted tensor.\n", "prediction = prediction_pipeline(input_array)[0]\n", - "show_images(input_image, prediction, names=[\"image\", \"prediction\"]) # show the prediction result" + "show_images(\n", + " input_image, prediction, names=[\"image\", \"prediction\"]\n", + ") # show the prediction result" ] }, { @@ -273,7 +281,9 @@ "source": [ "# Instead, we can use the function `predict_with_padding`, which will pad the image to a shape that fits the model.\n", "prediction = bioimageio.core.predict_with_padding(prediction_pipeline, cropped_array)\n", - "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"]) # show the prediction result" + "show_images(\n", + " cropped_image, prediction, names=[\"image\", \"prediction\"]\n", + ") # show the prediction result" ] }, { @@ -290,11 +300,16 @@ "# that is cropped in order to reduce boundary artifacts.\n", "# Alternatively, `tiling` can also be set to `True`, than the tile size and halo will be deduced from the model config\n", "# (this is also the default behavior when the `tiling` parameter is not passed).\n", - "tiling = {\"tile\": {\"x\": 128, \"y\": 128}, \"halo\": {\"x\": 16, \"y\": 16}} # use a tile size of 128x128 and crop a halo of 16 pixels\n", + "tiling = {\n", + " \"tile\": {\"x\": 128, \"y\": 128},\n", + " \"halo\": {\"x\": 16, \"y\": 16},\n", + "} # use a tile size of 128x128 and crop a halo of 16 pixels\n", "\n", - "# if `verbose` is set to True a progress bar will be printed \n", - "prediction = bioimageio.core.predict_with_tiling(prediction_pipeline, cropped_array, tiling=tiling, verbose=True)\n", - "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"]) " + "# if `verbose` is set to True a progress bar will be printed\n", + "prediction = bioimageio.core.predict_with_tiling(\n", + " prediction_pipeline, cropped_array, tiling=tiling, verbose=True\n", + ")\n", + "show_images(cropped_image, prediction, names=[\"image\", \"prediction\"])" ] }, { @@ -321,15 +336,18 @@ "\n", "# The filepath where the output should be stored; supports most common image formats as well as npy fileformat.\n", "outputs = [\"prediction.tif\"]\n", - "predict_image(\n", - " model_resource, model_resource.test_inputs, outputs\n", - ")\n", + "predict_image(model_resource, model_resource.test_inputs, outputs)\n", "\n", "# The output tensor contains 2 channels, which is not supported by normal tif.\n", "# Thus, these 2 channels are stored as 2 separate images.\n", "fg_pred = imageio.imread(\"prediction-c0.tif\")\n", "bd_pred = imageio.imread(\"prediction-c1.tif\")\n", - "show_images(input_image, fg_pred, bd_pred, names=[\"image\", \"foreground-prediction\", \"boundary-prediction\"])" + "show_images(\n", + " input_image,\n", + " fg_pred,\n", + " bd_pred,\n", + " names=[\"image\", \"foreground-prediction\", \"boundary-prediction\"],\n", + ")" ] }, { @@ -349,6 +367,7 @@ "\n", "# Get all paths to the images in the \"example-images\" folder.\n", "from glob import glob\n", + "\n", "inputs = glob(\"./example-images/*.png\")\n", "\n", "# Create an output folder and specify the output path for each image.\n", @@ -374,12 +393,14 @@ "# `{\"x\": 512, \"y\": 512, \"mode\": \"fixed\"}` will always pad to a size of 512x512.\n", "# The padding is cropped again after the prediction to restore the input shape.\n", "padding = {\"x\": 16, \"y\": 16, \"mode\": \"dynamic\"}\n", - "predict_images(\n", - " model_resource, inputs, outputs, padding=padding, verbose=True\n", - ")\n", + "predict_images(model_resource, inputs, outputs, padding=padding, verbose=True)\n", "\n", "# check the first input/output\n", - "show_images(inputs[0], outputs[0].replace(\".png\", \"-c0.png\"), outputs[0].replace(\".png\", \"-c1.png\"))" + "show_images(\n", + " inputs[0],\n", + " outputs[0].replace(\".png\", \"-c0.png\"),\n", + " outputs[0].replace(\".png\", \"-c1.png\"),\n", + ")" ] }, { @@ -395,12 +416,14 @@ " \"tile\": {\"x\": 256, \"y\": 256},\n", " \"halo\": {\"x\": 16, \"y\": 16},\n", "}\n", - "predict_images(\n", - " model_resource, inputs, outputs, tiling=tiling, verbose=True\n", - ")\n", + "predict_images(model_resource, inputs, outputs, tiling=tiling, verbose=True)\n", "\n", "# Check the first input output pair.\n", - "show_images(inputs[0], outputs[0].replace(\".png\", \"-c0.png\"), outputs[0].replace(\".png\", \"-c1.png\"))" + "show_images(\n", + " inputs[0],\n", + " outputs[0].replace(\".png\", \"-c0.png\"),\n", + " outputs[0].replace(\".png\", \"-c1.png\"),\n", + ")" ] } ], diff --git a/pyproject.toml b/pyproject.toml index 4e42573b..083aaf84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,38 @@ [tool.black] -line-length = 120 -target-version = ['py38'] +line-length = 88 +target-version = ["py38", "py39", "py310", "py311", "py312"] + +[tool.pyright] +exclude = ["**/node_modules", "**/__pycache__", "tests/old_*"] +include = ["bioimageio", "scripts", "tests"] +pythonPlatform = "All" +pythonVersion = "3.8" +reportDuplicateImport = "error" +reportImplicitStringConcatenation = "error" +reportIncompatibleMethodOverride = true +reportMatchNotExhaustive = "error" +reportMissingSuperCall = "error" +reportMissingTypeArgument = true +reportMissingTypeStubs = "warning" +reportPropertyTypeMismatch = "error" +reportUninitializedInstanceVariable = "error" +reportUnknownMemberType = false +reportUnnecessaryIsInstance = false +reportUnnecessaryTypeIgnoreComment = "error" +reportUnsupportedDunderAll = "error" +reportUnusedCallResult = "error" +reportUnusedClass = "error" +reportUnusedExpression = "error" +reportUnusedFunction = "error" +reportUnusedVariable = "error" +reportWildcardImportFromLibrary = "error" +typeCheckingMode = "strict" +useLibraryCodeForTypes = true + +[tool.pytest.ini_options] +addopts = " -n auto --capture=no --doctest-modules --failed-first" + +[tool.ruff] +line-length = 88 +include = ["*.py", "*.pyi", "**/pyproject.toml", "*.ipynb"] +target-version = "py38" diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index c2c366b2..00000000 --- a/pytest.ini +++ /dev/null @@ -1,5 +0,0 @@ -[pytest] -add_opts = -s --doctest-modules -testpaths = tests -#log_format = %(asctime)s.%(msecs)03d %(levelname)s %(message)s -#log_date_format = %M:%S.%f diff --git a/scripts/setup_dev_env.py b/scripts/setup_dev_env.py new file mode 100644 index 00000000..b1df230f --- /dev/null +++ b/scripts/setup_dev_env.py @@ -0,0 +1,25 @@ +# untested draft! +import subprocess +from os import chdir +from pathlib import Path + + +def run(prompt: str): + _ = subprocess.run(prompt, check=True, capture_output=True) + + +if __name__ == "__main__": + repo_dir = Path(__file__).parent.parent.parent + cur_dir = Path().resolve() + chdir(str(repo_dir)) + try: + run("mamba env create --file core-bioimage-io/dev/env.yaml") + run( + "pip install --no-deps --config-settings editable_mode=compat -e spec-bioimage-io" + ) + run( + "pip install --no-deps --config-settings editable_mode=compat -e core-bioimage-io" + ) + except Exception: + chdir(cur_dir) + raise diff --git a/scripts/show_diff.py b/scripts/show_diff.py new file mode 100644 index 00000000..1b0163bb --- /dev/null +++ b/scripts/show_diff.py @@ -0,0 +1,26 @@ +import subprocess +from pathlib import Path +from tempfile import TemporaryDirectory + +import pooch + +from bioimageio.core import load_description, save_bioimageio_yaml_only + +if __name__ == "__main__": + rdf_source = "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/unet2d_nuclei_broad/v0_4_9.bioimageio.yaml" + + local_source = Path(pooch.retrieve(rdf_source, None)) # type: ignore + model_as_is = load_description(rdf_source, format_version="discover") + model_latest = load_description(rdf_source, format_version="latest") + print(model_latest.validation_summary) + + with TemporaryDirectory() as tmp: + as_is = Path(tmp) / "as_is.bioimageio.yaml" + + save_bioimageio_yaml_only( + model_as_is, file=as_is + ) # write out as is to avoid sorting diff + latest = Path(tmp) / "latest.bioimageio.yaml" + save_bioimageio_yaml_only(model_latest, file=latest) + + _ = subprocess.run(f"git diff --no-index --ignore-all-space {as_is} {latest}") diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index d8020326..00000000 --- a/setup.cfg +++ /dev/null @@ -1,10 +0,0 @@ -[tool:isort] -line_length = 120 -multi_line_output = 3 -include_trailing_comma = true - -[flake8] -max-line-length = 120 - -[pylint] -max-line-length = 120 diff --git a/setup.py b/setup.py index 3fc27a76..72a4bce4 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,7 @@ VERSION = json.loads(VERSION_FILE.read_text())["version"] -setup( +_ = setup( name="bioimageio.core", version=VERSION, description="Python functionality for the bioimage model zoo", @@ -18,33 +18,56 @@ long_description_content_type="text/markdown", url="https://github.com/bioimage-io/core-bioimage-io-python", author="Bioimage Team", - classifiers=[ # Optional + classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Developers", - "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", ], - packages=find_namespace_packages(exclude=["tests"]), # Required + packages=find_namespace_packages(exclude=["tests"]), install_requires=[ - "bioimageio.spec==0.4.9.*", + "bioimageio.spec==0.5.2.*", + "fire", "imageio>=2.5", + "loguru", "numpy", - "ruamel.yaml", + "pydantic-settings", + "pydantic", + "python-dotenv", + "ruyaml", "tqdm", + "typing-extensions", "xarray", - "tifffile", ], include_package_data=True, extras_require={ - "test": ["pytest", "black", "mypy"], - "dev": ["pre-commit"], - "pytorch": ["torch>=1.6", "torchvision"], - "tensorflow": ["tensorflow"], + "pytorch": ["torch>=1.6", "torchvision", "keras>=3.0"], + "tensorflow": ["tensorflow", "keras>=2.15"], "onnx": ["onnxruntime"], + "dev": [ + "black", + # "crick", # currently requires python<=3.9 + "filelock", + "jupyter", + "jupyter-black", + "keras>=3.0", + "onnxruntime", + "packaging>=17.0", + "pre-commit", + "psutil", # parallel pytest with 'pytest -n auto' + "pyright", + "pytest-xdist", # parallel pytest + "pytest", + "torch>=1.6", + "torchvision", + ], }, - project_urls={ # Optional + project_urls={ "Bug Reports": "https://github.com/bioimage-io/core-bioimage-io-python/issues", "Source": "https://github.com/bioimage-io/core-bioimage-io-python", }, - entry_points={"console_scripts": ["bioimageio = bioimageio.core.__main__:app"]}, + entry_points={"console_scripts": ["bioimageio = bioimageio.core.__main__:main"]}, ) diff --git a/tests/build_spec/test_add_weights.py b/tests/build_spec/test_add_weights.py deleted file mode 100644 index 2f8300b0..00000000 --- a/tests/build_spec/test_add_weights.py +++ /dev/null @@ -1,49 +0,0 @@ -import os -from bioimageio.core import export_resource_package, load_raw_resource_description, load_resource_description -from bioimageio.core.resource_tests import test_model as _test_model - - -def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): - from bioimageio.core.build_spec import add_weights - - rdf = load_raw_resource_description(model) - assert base_weights in rdf.weights - assert added_weights in rdf.weights - - weight_path = load_resource_description(model).weights[added_weights].source - assert weight_path.exists() - - drop_weights = set(rdf.weights.keys()) - {base_weights} - for drop in drop_weights: - rdf.weights.pop(drop) - assert tuple(rdf.weights.keys()) == (base_weights,) - - in_path = tmp_path / "model1.zip" - export_resource_package(rdf, output_path=in_path) - - out_path = tmp_path / "model2.zip" - add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs) - - assert out_path.exists() - new_rdf = load_resource_description(out_path) - assert set(new_rdf.weights.keys()) == {base_weights, added_weights} - for weight in new_rdf.weights.values(): - assert weight.source.exists() - - test_res = _test_model(out_path, added_weights) - failed = [s for s in test_res if s["status"] != "passed"] - assert not failed, failed - test_res = _test_model(out_path) - failed = [s for s in test_res if s["status"] != "passed"] - assert not failed, failed - - # make sure the weights were cleaned from the cwd - assert not os.path.exists(os.path.split(weight_path)[1]) - - -def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path): - _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript") - - -def test_add_onnx(unet2d_nuclei_broad_model, tmp_path): - _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12) diff --git a/tests/build_spec/test_build_spec.py b/tests/build_spec/test_build_spec.py deleted file mode 100644 index 8edd9436..00000000 --- a/tests/build_spec/test_build_spec.py +++ /dev/null @@ -1,238 +0,0 @@ -from typing import Optional - -from marshmallow import missing - -import bioimageio.spec as spec -from bioimageio.core import load_raw_resource_description, load_resource_description -from bioimageio.core.resource_io import nodes -from bioimageio.core.resource_io.utils import resolve_source -from bioimageio.core.resource_tests import test_model as _test_model - -try: - import tensorflow -except ImportError: - tf_version = None -else: - tf_version: Optional[str] = ".".join(tensorflow.__version__.split(".")[:2]) - - -def _test_build_spec( - spec_path, - out_path, - weight_type, - tensorflow_version=None, - opset_version=None, - use_implicit_output_shape=False, - add_deepimagej_config=False, - use_original_covers=False, - training_data=None, - parent=None, -): - from bioimageio.core.build_spec import build_model - - model_spec = load_raw_resource_description(spec_path, update_to_format="latest") - root = model_spec.root_path - assert isinstance(model_spec, spec.model.raw_nodes.Model) - weight_source = model_spec.weights[weight_type].source - - cite = [] - for entry in model_spec.cite: - entry_ = {"text": entry.text} - has_url = entry.url is not missing - has_doi = entry.doi is not missing - assert has_url != has_doi - if has_doi: - entry_["doi"] = entry.doi - else: - entry_["url"] = entry.url - cite.append(entry_) - - weight_spec = model_spec.weights[weight_type] - dep_file = None if weight_spec.dependencies is missing else resolve_source(weight_spec.dependencies.file, root) - if weight_type == "pytorch_state_dict": - model_kwargs = None if weight_spec.kwargs is missing else weight_spec.kwargs - architecture = str(weight_spec.architecture) - weight_type_ = None # the weight type can be auto-detected - elif weight_type == "torchscript": - architecture = None - model_kwargs = None - weight_type_ = "torchscript" # the weight type CANNOT be auto-detected - else: - architecture = None - model_kwargs = None - weight_type_ = None # the weight type can be auto-detected - - authors = [{"name": auth.name, "affiliation": auth.affiliation} for auth in model_spec.authors] - - input_axes = [input_.axes for input_ in model_spec.inputs] - output_axes = [output.axes for output in model_spec.outputs] - preprocessing = [ - None - if input_.preprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in input_.preprocessing] - for input_ in model_spec.inputs - ] - postprocessing = [ - None - if output.postprocessing is missing - else [{"name": preproc.name, "kwargs": preproc.kwargs} for preproc in output.preprocessing] - for output in model_spec.outputs - ] - - kwargs = dict( - weight_uri=weight_source, - test_inputs=resolve_source(model_spec.test_inputs, root), - test_outputs=resolve_source(model_spec.test_outputs, root), - name=model_spec.name, - description=model_spec.description, - authors=authors, - tags=model_spec.tags, - license=model_spec.license, - documentation=model_spec.documentation, - dependencies=dep_file, - cite=cite, - root=model_spec.root_path, - weight_type=weight_type_, - input_axes=input_axes, - output_axes=output_axes, - preprocessing=preprocessing, - postprocessing=postprocessing, - output_path=out_path, - add_deepimagej_config=add_deepimagej_config, - maintainers=[{"github_user": "jane_doe"}], - input_names=[inp.name for inp in model_spec.inputs], - output_names=[out.name for out in model_spec.outputs], - ) - if architecture is not None: - kwargs["architecture"] = architecture - if model_kwargs is not None: - kwargs["model_kwargs"] = model_kwargs - if tensorflow_version is not None: - kwargs["tensorflow_version"] = tensorflow_version - if opset_version is not None: - kwargs["opset_version"] = opset_version - if use_implicit_output_shape: - kwargs["input_names"] = ["input"] - kwargs["output_reference"] = ["input"] - kwargs["output_scale"] = [[1.0, 1.0, 1.0, 1.0]] - kwargs["output_offset"] = [[0.0, 0.0, 0.0, 0.0]] - if add_deepimagej_config: - kwargs["pixel_sizes"] = [{"x": 5.0, "y": 5.0}] - if use_original_covers: - kwargs["covers"] = resolve_source(model_spec.covers, root) - if training_data is not None: - kwargs["training_data"] = training_data - if parent is not None: - kwargs["parent"] = parent - - build_model(**kwargs) - assert out_path.exists() - loaded_model = load_resource_description(out_path) - assert isinstance(loaded_model, nodes.Model) - if add_deepimagej_config: - loaded_config = loaded_model.config - assert "deepimagej" in loaded_config - - if loaded_model.sample_inputs is not missing: - for sample in loaded_model.sample_inputs: - assert sample.exists() - if loaded_model.sample_outputs is not missing: - for sample in loaded_model.sample_outputs: - assert sample.exists() - - assert loaded_model.maintainers[0].github_user == "jane_doe" - - attachments = loaded_model.attachments - if attachments is not missing and attachments.files is not missing: - for attached_file in attachments.files: - assert attached_file.exists() - - # make sure there is one attachment per pre/post-processing - if add_deepimagej_config: - preproc, postproc = preprocessing[0], postprocessing[0] - n_processing = 0 - if preproc is not None: - n_processing += len(preproc) - if postproc is not None: - n_processing += len(postproc) - if n_processing > 0: - assert attachments.files is not missing - assert n_processing == len(attachments.files) - - # test inference for the model to ensure that the weights were written correctly - test_res = _test_model(out_path) - assert all([s["status"] == "passed" for s in test_res]) - - -def test_build_spec_pytorch(any_torch_model, tmp_path): - _test_build_spec(any_torch_model, tmp_path / "model.zip", "pytorch_state_dict") - - -def test_build_spec_implicit_output_shape(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec( - unet2d_nuclei_broad_model, tmp_path / "model.zip", "pytorch_state_dict", use_implicit_output_shape=True - ) - - -def test_build_spec_torchscript(any_torchscript_model, tmp_path): - _test_build_spec(any_torchscript_model, tmp_path / "model.zip", "torchscript") - - -def test_build_spec_onnx(any_onnx_model, tmp_path): - _test_build_spec(any_onnx_model, tmp_path / "model.zip", "onnx", opset_version=12) - - -def test_build_spec_keras(any_keras_model, tmp_path): - _test_build_spec( - any_keras_model, tmp_path / "model.zip", "keras_hdf5", tensorflow_version=tf_version - ) # todo: keras for tf 2?? - - -def test_build_spec_tf(any_tensorflow_model, tmp_path): - _test_build_spec( - any_tensorflow_model, tmp_path / "model.zip", "tensorflow_saved_model_bundle", tensorflow_version=tf_version - ) # check tf version - - -def test_build_spec_tfjs(any_tensorflow_js_model, tmp_path): - _test_build_spec(any_tensorflow_js_model, tmp_path / "model.zip", "tensorflow_js", tensorflow_version=tf_version) - - -def test_build_spec_deepimagej(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", add_deepimagej_config=True) - - -def test_build_spec_training_data1(unet2d_nuclei_broad_model, tmp_path): - training_data = {"id": "ilastik/stradist_dsb_training_data"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data) - - -def test_build_spec_training_data2(unet2d_nuclei_broad_model, tmp_path): - training_data = { - "type": "dataset", - "name": "nucleus-training-data", - "description": "stardist nucleus training data", - "source": "https://github.com/stardist/stardist/releases/download/0.1.0/dsb2018.zip", - } - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", training_data=training_data) - - -def test_build_spec_parent1(unet2d_nuclei_broad_model, tmp_path): - parent = {"uri": "https://doi.org/10.5281/zenodo.5764892"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", parent=parent) - - -def test_build_spec_parent2(unet2d_nuclei_broad_model, tmp_path): - parent = {"id": "10.5281/zenodo.5764892"} - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", parent=parent) - - -def test_build_spec_deepimagej_keras(unet2d_keras, tmp_path): - _test_build_spec( - unet2d_keras, tmp_path / "model.zip", "keras_hdf5", add_deepimagej_config=True, tensorflow_version=tf_version - ) - - -# test with original covers -def test_build_spec_with_original_covers(unet2d_nuclei_broad_model, tmp_path): - _test_build_spec(unet2d_nuclei_broad_model, tmp_path / "model.zip", "torchscript", use_original_covers=True) diff --git a/tests/conftest.py b/tests/conftest.py index 10dfe53a..c4fa5ff7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,85 +1,15 @@ -import logging -import os +from __future__ import annotations + import subprocess import warnings +from itertools import chain +from typing import Dict, List -import pytest +from loguru import logger +from pytest import FixtureRequest, fixture -os.environ["BIOIMAGEIO_COUNT_RDF_DOWNLOADS"] = "false" # disable tracking before bioimageio imports -from bioimageio.core import export_resource_package from bioimageio.spec import __version__ as bioimageio_spec_version - -logger = logging.getLogger(__name__) -warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") - -# test models for various frameworks -torch_models = [ - "unet2d_fixed_shape", - "unet2d_multi_tensor", - "unet2d_nuclei_broad_model", - "unet2d_diff_output_shape", - "shape_change", -] -torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"] -onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"] -tensorflow1_models = ["stardist"] -tensorflow2_models = ["unet2d_keras_tf2"] -keras_tf1_models = ["unet2d_keras"] -keras_tf2_models = ["unet2d_keras_tf2"] -tensorflow_js_models = [] - - -model_sources = { - "unet2d_keras": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_keras_tf/rdf.yaml" - ), - "unet2d_keras_tf2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_keras_tf2/rdf.yaml" - ), - "unet2d_nuclei_broad_model": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_nuclei_broad/rdf.yaml" - ), - "unet2d_expand_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_nuclei_broad/rdf_expand_output_shape.yaml" - ), - "unet2d_fixed_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_fixed_shape/rdf.yaml" - ), - "unet2d_multi_tensor": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_multi_tensor/rdf.yaml" - ), - "unet2d_diff_output_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "unet2d_diff_output_shape/rdf.yaml" - ), - "hpa_densenet": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml" - ), - "stardist": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models" - "/stardist_example_model/rdf.yaml" - ), - "stardist_wrong_shape": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "stardist_example_model/rdf_wrong_shape.yaml" - ), - "stardist_wrong_shape2": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "stardist_example_model/rdf_wrong_shape2.yaml" - ), - "shape_change": ( - "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/" - "upsample_test_model/rdf.yaml" - ), -} - try: import torch @@ -91,58 +21,158 @@ skip_torch = torch is None try: - import onnxruntime + import onnxruntime # type: ignore except ImportError: onnxruntime = None skip_onnx = onnxruntime is None try: - import tensorflow + import tensorflow # type: ignore - tf_major_version = int(tensorflow.__version__.split(".")[0]) + tf_major_version = int(tensorflow.__version__.split(".")[0]) # type: ignore except ImportError: tensorflow = None tf_major_version = None -skip_tensorflow = tensorflow is None -skip_tensorflow_js = True # TODO: add a tensorflow_js example model - -# load all model packages we need for testing -load_model_packages = set() -if not skip_torch: - load_model_packages |= set(torch_models + torchscript_models) -if not skip_onnx: - load_model_packages |= set(onnx_models) - -if not skip_tensorflow: - load_model_packages |= set(tensorflow_js_models) - if tf_major_version == 1: - load_model_packages |= set(keras_tf1_models) - load_model_packages |= set(tensorflow1_models) - load_model_packages.add("stardist_wrong_shape") - load_model_packages.add("stardist_wrong_shape2") - elif tf_major_version == 2: - load_model_packages |= set(keras_tf2_models) - load_model_packages |= set(tensorflow2_models) +try: + import keras # type: ignore +except ImportError: + keras = None +skip_tensorflow = tensorflow is None -def pytest_configure(): - # explicit skip flags needed for some tests - pytest.skip_torch = skip_torch - pytest.skip_onnx = skip_onnx +warnings.warn(f"testing with bioimageio.spec {bioimageio_spec_version}") - # load all model packages used in tests - pytest.model_packages = {name: export_resource_package(model_sources[name]) for name in load_model_packages} +# TODO: use models from new collection on S3 +MODEL_SOURCES: Dict[str, str] = { + "hpa_densenet": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/hpa-densenet/rdf.yaml" + ), + "stardist": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models" + "/stardist_example_model/v0_4.bioimageio.yaml" + ), + "shape_change": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "upsample_test_model/v0_4.bioimageio.yaml" + ), + "stardist_wrong_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape.yaml" + ), + "stardist_wrong_shape2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "stardist_example_model/rdf_wrong_shape2_v0_4.yaml" + ), + "unet2d_diff_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_diff_output_shape/v0_4.bioimageio.yaml" + ), + "unet2d_expand_output_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/expand_output_shape_v0_4.bioimageio.yaml" + ), + "unet2d_fixed_shape": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_fixed_shape/v0_4.bioimageio.yaml" + ), + "unet2d_keras_tf2": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf2/v0_4.bioimageio.yaml" + ), + "unet2d_keras": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_keras_tf/v0_4.bioimageio.yaml" + ), + "unet2d_multi_tensor": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_multi_tensor/v0_4.bioimageio.yaml" + ), + "unet2d_nuclei_broad_model": ( + "https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_descriptions/models/" + "unet2d_nuclei_broad/bioimageio.yaml" + ), +} - pytest.mamba_cmd = "micromamba" +# test models for various frameworks +TORCH_MODELS = ( + [] + if torch is None + else [ + "shape_change", + "unet2d_diff_output_shape", + "unet2d_expand_output_shape", + "unet2d_fixed_shape", + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +TORCHSCRIPT_MODELS = ( + [] + if torch is None + else [ + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +ONNX_MODELS = ( + [] + if onnxruntime is None + else [ + "hpa_densenet", + "unet2d_multi_tensor", + "unet2d_nuclei_broad_model", + ] +) +TENSORFLOW_MODELS = ( + [] + if tensorflow is None + else ( + [ + "hpa_densenet", + "stardist", + ] + if tf_major_version == 1 + else [ + "unet2d_keras_tf2", + ] + ) +) +KERAS_MODELS = ( + [] + if keras is None + else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] +) +TENSORFLOW_JS_MODELS: List[str] = [] # TODO: add a tensorflow_js example model + +ALL_MODELS = sorted( + { + m + for m in chain( + TORCH_MODELS, + TORCHSCRIPT_MODELS, + ONNX_MODELS, + TENSORFLOW_MODELS, + KERAS_MODELS, + TENSORFLOW_JS_MODELS, + ) + } +) + + +@fixture(scope="session") +def mamba_cmd(): + mamba_cmd = "micromamba" try: - subprocess.run(["which", pytest.mamba_cmd], check=True) + _ = subprocess.run(["which", mamba_cmd], check=True) except (subprocess.CalledProcessError, FileNotFoundError): - pytest.mamba_cmd = "mamba" + mamba_cmd = "mamba" try: - subprocess.run(["which", pytest.mamba_cmd], check=True) + _ = subprocess.run(["which", mamba_cmd], check=True) except (subprocess.CalledProcessError, FileNotFoundError): - pytest.mamba_cmd = None + mamba_cmd = None + + return mamba_cmd # @@ -150,42 +180,41 @@ def pytest_configure(): # -@pytest.fixture(params=[] if skip_torch else torch_models) -def any_torch_model(request): - return pytest.model_packages[request.param] +@fixture(params=TORCH_MODELS) +def any_torch_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_torch else torchscript_models) -def any_torchscript_model(request): - return pytest.model_packages[request.param] +@fixture(params=TORCHSCRIPT_MODELS) +def any_torchscript_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_onnx else onnx_models) -def any_onnx_model(request): - return pytest.model_packages[request.param] +@fixture(params=ONNX_MODELS) +def any_onnx_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_tensorflow else tensorflow1_models if tf_major_version == 1 else tensorflow2_models) -def any_tensorflow_model(request): - return pytest.model_packages[request.param] +@fixture(params=TENSORFLOW_MODELS) +def any_tensorflow_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_tensorflow else keras_tf1_models if tf_major_version == 1 else keras_tf2_models) -def any_keras_model(request): - return pytest.model_packages[request.param] +@fixture(params=KERAS_MODELS) +def any_keras_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_tensorflow_js else tensorflow_js_models) -def any_tensorflow_js_model(request): - return pytest.model_packages[request.param] +@fixture(params=TENSORFLOW_JS_MODELS) +def any_tensorflow_js_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # fixture to test with all models that should run in the current environment -# we exclude stardist_wrong_shape here because it is not a valid model -# and included only to test that validation for this model fails -@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"}) -def any_model(request): - return pytest.model_packages[request.param] +# we exclude any 'wrong' model here +@fixture(params=sorted({m for m in ALL_MODELS if "wrong" not in m})) +def any_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # TODO it would be nice to just generate fixtures for all the individual models dynamically @@ -195,64 +224,78 @@ def any_model(request): # -@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]) -def unet2d_fixed_shape_or_not(request): - return pytest.model_packages[request.param] +@fixture( + params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"] +) +def unet2d_fixed_shape_or_not(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_onnx or skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]) -def convert_to_onnx(request): - return pytest.model_packages[request.param] +@fixture( + params=( + [] + if skip_onnx or skip_torch + else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"] + ) +) +def convert_to_onnx(request: FixtureRequest): + return MODEL_SOURCES[request.param] -@pytest.fixture(params=[] if skip_tensorflow else ["unet2d_keras" if tf_major_version == 1 else "unet2d_keras_tf2"]) -def unet2d_keras(request): - return pytest.model_packages[request.param] +@fixture( + params=( + [] + if tf_major_version is None + else ["unet2d_keras"] if tf_major_version == 1 else ["unet2d_keras_tf2"] + ) +) +def unet2d_keras(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) -def unet2d_nuclei_broad_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"]) +def unet2d_nuclei_broad_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) -def unet2d_diff_output_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_diff_output_shape"]) +def unet2d_diff_output_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) -def unet2d_expand_output_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_expand_output_shape"]) +def unet2d_expand_output_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) -def unet2d_fixed_shape(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["unet2d_fixed_shape"]) +def unet2d_fixed_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing torch -@pytest.fixture(params=[] if skip_torch else ["shape_change"]) -def shape_change_model(request): - return pytest.model_packages[request.param] +@fixture(params=[] if skip_torch else ["shape_change"]) +def shape_change_model(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"]) -def stardist_wrong_shape(request): - return pytest.model_packages[request.param] +@fixture(params=["stardist_wrong_shape"] if tf_major_version == 1 else []) +def stardist_wrong_shape(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"]) -def stardist_wrong_shape2(request): - return pytest.model_packages[request.param] +@fixture(params=["stardist_wrong_shape2"] if tf_major_version == 1 else []) +def stardist_wrong_shape2(request: FixtureRequest): + return MODEL_SOURCES[request.param] # written as model group to automatically skip on missing tensorflow 1 -@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"]) -def stardist(request): - return pytest.model_packages[request.param] +@fixture(params=["stardist"] if tf_major_version == 1 else []) +def stardist(request: FixtureRequest): + return MODEL_SOURCES[request.param] diff --git a/tests/prediction_pipeline/test_combined_processing.py b/tests/prediction_pipeline/test_combined_processing.py deleted file mode 100644 index 744e236c..00000000 --- a/tests/prediction_pipeline/test_combined_processing.py +++ /dev/null @@ -1,34 +0,0 @@ -import numpy as np -import xarray as xr - -from bioimageio.core.resource_io import nodes - - -def test_postprocessing_dtype(): - from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing - - shape = [3, 32, 32] - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - threshold = 0.5 - exp = xr.DataArray(np_data > threshold, dims=axes) - - for dtype in ("float32", "float64", "uint8", "uint16"): - outputs = [ - nodes.OutputTensor( - "out1", - data_type=dtype, - axes=axes, - shape=shape, - postprocessing=[nodes.Postprocessing("binarize", dict(threshold=threshold))], - ) - ] - com_proc = CombinedProcessing.from_tensor_specs(outputs) - - sample = {"out1": data} - com_proc.apply(sample, {}) - res = sample["out1"] - assert np.dtype(res.dtype) == np.dtype(dtype) - xr.testing.assert_allclose(res, exp.astype(dtype)) diff --git a/tests/prediction_pipeline/test_device_management.py b/tests/prediction_pipeline/test_device_management.py deleted file mode 100644 index 98ab13c2..00000000 --- a/tests/prediction_pipeline/test_device_management.py +++ /dev/null @@ -1,79 +0,0 @@ -import numpy as np -import pytest -import xarray as xr -from numpy.testing import assert_array_almost_equal - -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io.nodes import Model -from bioimageio.core.utils import skip_on - - -class TooFewDevicesException(Exception): - pass - - -def _test_device_management(model_package, weight_format): - import torch - - if torch.cuda.device_count() == 0: - raise TooFewDevicesException("Need at least one cuda device for this test") - - from bioimageio.core.prediction_pipeline import create_prediction_pipeline - - bio_model = load_resource_description(model_package) - assert isinstance(bio_model, Model) - pred_pipe = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"]) - - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] - with pred_pipe as pp: - outputs = pp.forward(*inputs) - - assert isinstance(outputs, list) - - expected_outputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_outputs, bio_model.outputs) - ] - - assert len(outputs) == len(expected_outputs) - for out, exp in zip(outputs, expected_outputs): - assert_array_almost_equal(out, exp, decimal=4) - - # repeat inference with context manager to test load/unload/load/forward - with pred_pipe as pp: - outputs = pp.forward(*inputs) - - assert len(outputs) == len(expected_outputs) - for out, exp in zip(outputs, expected_outputs): - assert_array_almost_equal(out, exp, decimal=4) - - -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torch(any_torch_model): - _test_device_management(any_torch_model, "pytorch_state_dict") - - -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_torchscript(any_torchscript_model): - _test_device_management(any_torchscript_model, "torchscript") - - -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_onnx(any_onnx_model): - _test_device_management(any_onnx_model, "onnx") - - -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_tensorflow(any_tensorflow_model): - _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") - - -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch for device discovery") -@skip_on(TooFewDevicesException, reason="Too few devices") -def test_device_management_keras(any_keras_model): - _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/prediction_pipeline/test_measures.py b/tests/prediction_pipeline/test_measures.py deleted file mode 100644 index 6eeaa7fd..00000000 --- a/tests/prediction_pipeline/test_measures.py +++ /dev/null @@ -1,103 +0,0 @@ -import dataclasses -from itertools import product - -import numpy as np -import numpy.testing -import pytest -import xarray as xr - -from bioimageio.core import statistical_measures -from bioimageio.core.prediction_pipeline._measure_groups import get_measure_groups -from bioimageio.core.prediction_pipeline._utils import PER_DATASET, PER_SAMPLE -from bioimageio.core.statistical_measures import Mean, Percentile, Std, Var - - -@pytest.mark.parametrize("name_axes", product(["mean", "var", "std"], [None, ("x", "y")])) -def test_individual_normal_measure(name_axes): - name, axes = name_axes - measure = getattr(statistical_measures, name.title())(axes=axes) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) - - expected = getattr(data, name)(dim=axes) - actual = measure.compute(data) - xr.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize("axes_n", product([None, ("x", "y")], [0, 10, 50, 100])) -def test_individual_percentile_measure(axes_n): - axes, n = axes_n - measure = statistical_measures.Percentile(axes=axes, n=n) - data = xr.DataArray(np.random.random((5, 6, 3)), dims=("x", "y", "c")) - - expected = data.quantile(q=n / 100, dim=axes) - actual = measure.compute(data) - xr.testing.assert_allclose(expected, actual) - - -@pytest.mark.parametrize( - "measures_mode", - product( - [ - {"t1": {Mean()}, "t2": {Mean(), Std()}}, - {"t1": {Mean(), Var(), Std()}, "t2": {Std(axes=("x", "y"))}}, - {"t1": {Mean(axes=("x", "y"))}, "t2": {Mean(), Std(axes=("x", "y"))}}, - { - "t1": {Percentile(n=10), Percentile(n=35), Percentile(n=10, axes=("x", "y"))}, - "t2": {Percentile(n=10, axes=("x", "y")), Percentile(n=35, axes=("x", "y")), Percentile(n=10)}, - }, - ], - [PER_SAMPLE, PER_DATASET], - ), -) -def test_measure_groups(measures_mode): - measures, mode = measures_mode - - def get_sample(): - return { - "t1": xr.DataArray(np.random.random((2, 500, 600, 3)), dims=("b", "x", "y", "c")), - "t2": xr.DataArray(np.random.random((1, 500, 600)), dims=("c", "x", "y")), - } - - sample = get_sample() - dataset_seq = [sample, get_sample()] - dataset_full = {tn: xr.concat([s[tn] for s in dataset_seq], dim="dataset") for tn in sample.keys()} - - # compute independently - expected = {} - for tn, ms in measures.items(): - for m in ms: - if mode == PER_SAMPLE: - expected[(tn, m)] = m.compute(sample[tn]) - elif mode == PER_DATASET: - if m.axes is None: - m_d = m - else: - m_d = dataclasses.replace(m, axes=("dataset",) + m.axes) - - expected[(tn, m)] = m_d.compute(dataset_full[tn]) - else: - raise NotImplementedError(mode) - - groups = get_measure_groups({mode: measures})[mode] - actual = {} - for g in groups: - if mode == PER_SAMPLE: - res = g.compute(sample) - elif mode == PER_DATASET: - for s in dataset_seq: - g.update_with_sample(s) - - res = g.finalize() - else: - raise NotImplementedError(mode) - - for tn, vs in res.items(): - for m, v in vs.items(): - actual[(tn, m)] = v - - # discard additionally computed measures by groups - actual = {k: v for k, v in actual.items() if k in expected} - - for k in expected.keys(): - assert k in actual - numpy.testing.assert_array_almost_equal(expected[k].data, actual[k].data, decimal=2) diff --git a/tests/prediction_pipeline/test_postprocessing.py b/tests/prediction_pipeline/test_postprocessing.py deleted file mode 100644 index 52c3e151..00000000 --- a/tests/prediction_pipeline/test_postprocessing.py +++ /dev/null @@ -1,70 +0,0 @@ -import numpy as np -import pytest -import xarray as xr - -from bioimageio.core.prediction_pipeline._measure_groups import compute_measures - - -def test_binarize(): - from bioimageio.core.prediction_pipeline._processing import Binarize - - shape = (3, 32, 32) - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - threshold = 0.5 - exp = xr.DataArray(np_data > threshold, dims=axes) - - binarize = Binarize("data_name", threshold=threshold) - res = binarize(data) - xr.testing.assert_allclose(res, exp) - - -@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("cyx"), tuple("x")]) -def test_scale_mean_variance(axes): - from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance - - shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - ipt_data = xr.DataArray(np_data, dims=ipt_axes) - ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) - - scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes) - required = scale_mean_variance.get_required_measures() - computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data}) - scale_mean_variance.set_computed_measures(computed) - - res = scale_mean_variance(ipt_data) - xr.testing.assert_allclose(res, ref_data) - - -@pytest.mark.parametrize("axes", [None, tuple("cy"), tuple("y"), tuple("yx")]) -def test_scale_mean_variance_per_channel(axes): - from bioimageio.core.prediction_pipeline._processing import ScaleMeanVariance - - shape = (3, 32, 46) - ipt_axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - ipt_data = xr.DataArray(np_data, dims=ipt_axes) - - # set different mean, std per channel - np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)]) - print(np_ref_data.shape) - ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) - - scale_mean_variance = ScaleMeanVariance("data_name", reference_tensor="ref_name", axes=axes) - required = scale_mean_variance.get_required_measures() - computed = compute_measures(required, sample={"data_name": ipt_data, "ref_name": ref_data}) - scale_mean_variance.set_computed_measures(computed) - - res = scale_mean_variance(ipt_data) - - if axes is not None and "c" not in axes: - # mean,std per channel should match exactly - xr.testing.assert_allclose(res, ref_data) - else: - # mean,std across channels should not match - with pytest.raises(AssertionError): - xr.testing.assert_allclose(res, ref_data) diff --git a/tests/prediction_pipeline/test_prediction_pipeline.py b/tests/prediction_pipeline/test_prediction_pipeline.py deleted file mode 100644 index ac3c6a65..00000000 --- a/tests/prediction_pipeline/test_prediction_pipeline.py +++ /dev/null @@ -1,50 +0,0 @@ -import numpy as np -import xarray as xr -from numpy.testing import assert_array_almost_equal - -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io.nodes import Model - - -def _test_prediction_pipeline(model_package, weight_format): - from bioimageio.core.prediction_pipeline import create_prediction_pipeline - - bio_model = load_resource_description(model_package) - assert isinstance(bio_model, Model) - pp = create_prediction_pipeline(bioimageio_model=bio_model, weight_format=weight_format) - - inputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_inputs, bio_model.inputs) - ] - outputs = pp.forward(*inputs) - assert isinstance(outputs, list) - - expected_outputs = [ - xr.DataArray(np.load(str(test_tensor)), dims=tuple(spec.axes)) - for test_tensor, spec in zip(bio_model.test_outputs, bio_model.outputs) - ] - assert len(outputs) == len(expected_outputs) - - for out, exp in zip(outputs, expected_outputs): - assert_array_almost_equal(out, exp, decimal=4) - - -def test_prediction_pipeline_torch(any_torch_model): - _test_prediction_pipeline(any_torch_model, "pytorch_state_dict") - - -def test_prediction_pipeline_torchscript(any_torchscript_model): - _test_prediction_pipeline(any_torchscript_model, "torchscript") - - -def test_prediction_pipeline_onnx(any_onnx_model): - _test_prediction_pipeline(any_onnx_model, "onnx") - - -def test_prediction_pipeline_tensorflow(any_tensorflow_model): - _test_prediction_pipeline(any_tensorflow_model, "tensorflow_saved_model_bundle") - - -def test_prediction_pipeline_keras(any_keras_model): - _test_prediction_pipeline(any_keras_model, "keras_hdf5") diff --git a/tests/prediction_pipeline/test_preprocessing.py b/tests/prediction_pipeline/test_preprocessing.py deleted file mode 100644 index fb8efa06..00000000 --- a/tests/prediction_pipeline/test_preprocessing.py +++ /dev/null @@ -1,212 +0,0 @@ -import numpy as np -import xarray as xr - -from bioimageio.core.prediction_pipeline._measure_groups import compute_measures -from bioimageio.core.prediction_pipeline._utils import PER_SAMPLE - - -def test_scale_linear(): - from bioimageio.core.prediction_pipeline._processing import ScaleLinear - - preprocessing = ScaleLinear("data_name", offset=[1, 2, 42], gain=[1, 2, 3], axes="yx") - data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) - expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) - result = preprocessing.apply(data) - xr.testing.assert_allclose(expected, result) - - -def test_scale_linear_no_channel(): - from bioimageio.core.prediction_pipeline._processing import ScaleLinear - - preprocessing = ScaleLinear("data_name", offset=1, gain=2, axes="yx") - data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) - expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) - result = preprocessing.apply(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_variance_preprocessing(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - - preprocessing = ZeroMeanUnitVariance("data_name", mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_variance_preprocessing_fixed(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - preprocessing = ZeroMeanUnitVariance( - "data_name", mode="fixed", axes=["y"], mean=[1, 4, 7], std=[0.81650, 0.81650, 0.81650] - ) - data = xr.DataArray(np.arange(9).reshape((1, 1, 3, 3)), dims=("b", "c", "x", "y")) - expected = xr.DataArray( - np.array([[-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743], [-1.224743, 0.0, 1.224743]])[None, None], - dims=("b", "c", "x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_zero_mean_unit_across_axes(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - - axes = ("x", "y") - preprocessing = ZeroMeanUnitVariance("data_name", axes=axes, mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result[dict(c=0)]) - - -def test_zero_mean_unit_variance_fixed(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - np_data = np.arange(9).reshape(3, 3) - mean = np_data.mean() - std = np_data.mean() - eps = 1.0e-7 - preprocessing = ZeroMeanUnitVariance("data_name", mode="fixed", mean=mean, std=std, eps=eps) - - data = xr.DataArray(np_data, dims=("x", "y")) - expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_binarize(): - from bioimageio.core.prediction_pipeline._processing import Binarize - - preprocessing = Binarize("data_name", threshold=14) - data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) - expected = xr.zeros_like(data) - expected[{"x": slice(1, None)}] = 1 - result = preprocessing(data) - xr.testing.assert_allclose(expected, result) - - -def test_clip_preprocessing(): - from bioimageio.core.prediction_pipeline._processing import Clip - - preprocessing = Clip("data_name", min=3, max=5) - data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) - expected = xr.DataArray(np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y")) - result = preprocessing(data) - xr.testing.assert_equal(expected, result) - - -def test_combination_of_preprocessing_steps_with_dims_specified(): - from bioimageio.core.prediction_pipeline._processing import ZeroMeanUnitVariance - - data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) - axes = ("x", "y") - preprocessing = ZeroMeanUnitVariance("data_name", axes=axes, mode=PER_SAMPLE) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - expected = xr.DataArray( - np.array( - [ - [-1.54919274, -1.16189455, -0.77459637], - [-0.38729818, 0.0, 0.38729818], - [0.77459637, 1.16189455, 1.54919274], - ] - ), - dims=("x", "y"), - ) - - result = preprocessing(data) - xr.testing.assert_allclose(expected, result[dict(c=0)]) - - -def test_scale_range(): - from bioimageio.core.prediction_pipeline._processing import ScaleRange - - preprocessing = ScaleRange("data_name") - np_data = np.arange(9).reshape(3, 3).astype("float32") - data = xr.DataArray(np_data, dims=("x", "y")) - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - eps = 1.0e-6 - mi, ma = np_data.min(), np_data.max() - exp_data = (np_data - mi) / (ma - mi + eps) - expected = xr.DataArray(exp_data, dims=("x", "y")) - - result = preprocessing(data) - # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, result) - - -def test_scale_range_axes(): - from bioimageio.core.prediction_pipeline._processing import ScaleRange - - min_percentile = 1.0 - max_percentile = 99.0 - preprocessing = ScaleRange( - "data_name", axes=("x", "y"), min_percentile=min_percentile, max_percentile=max_percentile - ) - - np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") - data = xr.DataArray(np_data, dims=("c", "x", "y")) - - required = preprocessing.get_required_measures() - computed = compute_measures(required, sample={"data_name": data}) - preprocessing.set_computed_measures(computed) - - eps = 1.0e-6 - p_low = np.percentile(np_data, min_percentile, axis=(1, 2), keepdims=True) - p_up = np.percentile(np_data, max_percentile, axis=(1, 2), keepdims=True) - exp_data = (np_data - p_low) / (p_up - p_low + eps) - expected = xr.DataArray(exp_data, dims=("c", "x", "y")) - - result = preprocessing(data) - # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct - np.testing.assert_allclose(expected, result) - - -def test_sigmoid(): - from bioimageio.core.prediction_pipeline._processing import Sigmoid - - shape = (3, 32, 32) - axes = ("c", "y", "x") - np_data = np.random.rand(*shape) - data = xr.DataArray(np_data, dims=axes) - - sigmoid = Sigmoid("data_name") - res = sigmoid(data) - - exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes) - xr.testing.assert_allclose(res, exp) diff --git a/tests/prediction_pipeline/test_processing.py b/tests/prediction_pipeline/test_processing.py deleted file mode 100644 index b693bcfc..00000000 --- a/tests/prediction_pipeline/test_processing.py +++ /dev/null @@ -1,55 +0,0 @@ -import dataclasses - -import numpy as np -import pytest -import xarray as xr - -from bioimageio.core.prediction_pipeline._processing import KNOWN_PROCESSING -from bioimageio.core.prediction_pipeline._utils import FIXED - -try: - from typing import get_args -except ImportError: - from typing_extensions import get_args # type: ignore - - -def test_assert_dtype(): - from bioimageio.core.prediction_pipeline._processing import AssertDtype - - proc = AssertDtype("test_tensor", dtype="uint8") - tensor = xr.DataArray(np.zeros((1,), dtype="uint8"), dims=tuple("c")) - out = proc(tensor) - assert out is tensor - - tensor = tensor.astype("uint16") - with pytest.raises(AssertionError): - out = proc(tensor) - assert out is tensor - - -@pytest.mark.parametrize( - "proc", - list(KNOWN_PROCESSING["pre"].values()) + list(KNOWN_PROCESSING["post"].values()), -) -def test_no_req_measures_for_mode_fixed(proc): - # check if mode=fixed is valid for this proc - for f in dataclasses.fields(proc): - if f.name == "mode": - break - else: - raise AttributeError("Processing is missing mode attribute") - # mode is always annotated as literals (or literals of literals) - valid_modes = get_args(f.type) - for inner in get_args(f.type): - valid_modes += get_args(inner) - - if FIXED not in valid_modes: - return - - # we might be missing required kwargs. These have marshmallow.missing value as default - # and raise a TypeError is in __post_init__() - proc.__post_init__ = lambda self: None # ignore missing kwargs - - proc_instance = proc(tensor_name="tensor_name", mode=FIXED) - req_measures = proc_instance.get_required_measures() - assert not req_measures diff --git a/tests/resource_io/test_load_rdf.py b/tests/resource_io/test_load_rdf.py deleted file mode 100644 index ca86750a..00000000 --- a/tests/resource_io/test_load_rdf.py +++ /dev/null @@ -1,96 +0,0 @@ -import os.path -import pathlib -from pathlib import Path - -import pytest - -from bioimageio.core.resource_io.utils import resolve_source - - -def test_load_non_existing_rdf(): - from bioimageio.core import load_resource_description - - spec_path = Path("some/none/existing/path/to/spec.model.yaml") - - with pytest.raises(FileNotFoundError): - load_resource_description(spec_path) - - -def test_load_raw_model(any_model): - from bioimageio.core import load_raw_resource_description - - raw_model = load_raw_resource_description(any_model) - assert raw_model - - -def test_load_model(any_model): - from bioimageio.core import load_resource_description - - model = load_resource_description(any_model) - assert model - - -def test_load_model_with_abs_path_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = (raw_rd.root_path / "rdf.yaml").absolute() - assert path_source.is_absolute() - model = load_resource_description(path_source) - assert model - - -def test_load_model_with_rel_path_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) - assert not path_source.is_absolute() - model = load_resource_description(path_source) - assert model - - -def test_load_model_with_abs_str_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = (raw_rd.root_path / "rdf.yaml").absolute() - assert path_source.is_absolute() - model = load_resource_description(str(path_source)) - assert model - - -def test_load_model_with_rel_str_source(unet2d_nuclei_broad_model): - from bioimageio.core.resource_io import load_raw_resource_description, load_resource_description - - raw_rd = load_raw_resource_description(unet2d_nuclei_broad_model) - path_source = pathlib.Path(os.path.relpath(raw_rd.root_path / "rdf.yaml", os.curdir)) - assert not path_source.is_absolute() - model = load_resource_description(str(path_source)) - assert model - - -@pytest.mark.skipif(pytest.skip_torch, reason="remote model is a pytorch model") -def test_load_remote_rdf(): - from bioimageio.core import load_resource_description - from bioimageio.core.resource_io import nodes - - remote_rdf = "https://zenodo.org/api/files/63b44f05-a187-4fc9-81c8-c4568535531b/rdf.yaml" - model = load_resource_description(remote_rdf) - assert isinstance(model, nodes.Model) - - -@pytest.mark.skipif(True, reason="No suitable test model available yet") -def test_load_remote_rdf_with_folders(): - from bioimageio.core import load_resource_description, load_raw_resource_description - from bioimageio.core.resource_io import nodes - from bioimageio.spec.model import raw_nodes - - rdf_doi = "" - raw_model = load_raw_resource_description(rdf_doi, update_to_format="latest") - assert isinstance(raw_model, raw_nodes.Model) - model = load_resource_description(rdf_doi) - assert isinstance(model, nodes.Model) - - # test for field value with folder, e.g. - assert resolve_source(raw_model.documentation) == model.documentation diff --git a/tests/resource_io/test_utils.py b/tests/resource_io/test_utils.py deleted file mode 100644 index 30889a1d..00000000 --- a/tests/resource_io/test_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -from pathlib import Path - -from bioimageio.core.resource_io import nodes, utils -from bioimageio.spec.shared import raw_nodes - - -def test_resolve_import_path(tmpdir): - tmpdir = Path(tmpdir) - manifest_path = tmpdir / "manifest.yaml" - manifest_path.touch() - source_file = Path("my_mod.py") - (tmpdir / str(source_file)).write_text("class Foo: pass", encoding="utf8") - node = raw_nodes.ImportableSourceFile(source_file=source_file, callable_name="Foo") - uri_transformed = utils.UriNodeTransformer(root_path=tmpdir).transform(node) - source_transformed = utils.SourceNodeTransformer().transform(uri_transformed) - assert isinstance(source_transformed, nodes.ImportedSource) - Foo = source_transformed.factory - assert Foo.__name__ == "Foo" - assert isinstance(Foo, type) - - -def test_resolve_directory_uri(tmpdir): - node = raw_nodes.URI(Path(tmpdir).as_uri()) - uri_transformed = utils.UriNodeTransformer(root_path=Path(tmpdir)).transform(node) - assert uri_transformed == Path(tmpdir) - - -def test_uri_available(): - pass # todo - - -def test_all_uris_available(): - from bioimageio.core.resource_io.utils import all_sources_available - - not_available = { - "uri": raw_nodes.URI(scheme="file", path="non_existing_file_in/non_existing_dir/ftw"), - "uri_exists": raw_nodes.URI(scheme="file", path="."), - } - assert not all_sources_available(not_available) - - -def test_uri_node_transformer_is_ok_with_abs_path(): - from bioimageio.core.resource_io.utils import UriNodeTransformer - - # note: the call of .absolute() is required to add the drive letter for windows paths, which are relative otherwise - tree = {"rel_path": Path("something/relative"), "abs_path": Path("/something/absolute").absolute()} - assert not tree["rel_path"].is_absolute() - assert tree["abs_path"].is_absolute() - - root = Path("/root").absolute() - print(root) - - tree = UriNodeTransformer(root_path=root).transform(tree) - assert tree["rel_path"].is_absolute() - assert tree["rel_path"] == Path("/root/something/relative").absolute() - assert tree["abs_path"].is_absolute() - assert tree["abs_path"] == Path("/something/absolute").absolute() diff --git a/tests/test_any_model_fixture.py b/tests/test_any_model_fixture.py new file mode 100644 index 00000000..a4cc1bce --- /dev/null +++ b/tests/test_any_model_fixture.py @@ -0,0 +1,6 @@ +from bioimageio.spec import load_description_and_validate_format_only + + +def test_model(any_model: str): + summary = load_description_and_validate_format_only(any_model) + assert summary.status == "passed", summary.format() diff --git a/tests/test_bioimageio_spec_version.py b/tests/test_bioimageio_spec_version.py index fcd95582..719796ef 100644 --- a/tests/test_bioimageio_spec_version.py +++ b/tests/test_bioimageio_spec_version.py @@ -1,19 +1,22 @@ import json import subprocess -import sys +from typing import Optional import pytest from packaging.version import Version -@pytest.mark.skipif(sys.version_info < (3, 8), reason="requires python 3.8") -@pytest.mark.skipif(pytest.mamba_cmd is None, reason="requires mamba") -def test_bioimageio_spec_version(): +def test_bioimageio_spec_version(mamba_cmd: Optional[str]): + if mamba_cmd is None: + pytest.skip("requires mamba") + from importlib.metadata import metadata # get latest released bioimageio.spec version mamba_repoquery = subprocess.run( - f"{pytest.mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split(" "), + f"{mamba_cmd} repoquery search -c conda-forge --json bioimageio.spec".split( + " " + ), encoding="utf-8", capture_output=True, check=True, @@ -26,7 +29,6 @@ def test_bioimageio_spec_version(): # get currently pinned bioimageio.spec version meta = metadata("bioimageio.core") req = meta["Requires-Dist"] - print(req) assert req.startswith("bioimageio.spec ==") spec_ver = req[len("bioimageio.spec ==") :] assert spec_ver.count(".") == 3 diff --git a/tests/test_cli.py b/tests/test_cli.py index c0de99d4..b9a8246f 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,141 +1,137 @@ -import os import subprocess -from typing import Sequence +from typing import Any, List, Sequence -import numpy as np import pytest - -from bioimageio.core import load_resource_description - - -def run_subprocess(commands: Sequence[str], **kwargs) -> subprocess.CompletedProcess: - return subprocess.run(commands, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, encoding="utf-8", **kwargs) - - -def test_validate_model(unet2d_nuclei_broad_model): - ret = run_subprocess(["bioimageio", "validate", unet2d_nuclei_broad_model]) - assert ret.returncode == 0, ret.stdout - - -def test_cli_package(unet2d_nuclei_broad_model, tmp_path): - out_path = tmp_path / "model.zip" - ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model, str(out_path)]) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() - - -def test_cli_package_wo_cache(unet2d_nuclei_broad_model): - env = os.environ.copy() - env["BIOIMAGEIO_USE_CACHE"] = "false" - ret = run_subprocess(["bioimageio", "package", unet2d_nuclei_broad_model], env=env) - assert ret.returncode == 0, ret.stdout - - -def test_cli_test_model(unet2d_nuclei_broad_model): - ret = run_subprocess(["bioimageio", "test-model", unet2d_nuclei_broad_model]) - assert ret.returncode == 0, ret.stdout - - -def test_cli_test_model_fail(stardist_wrong_shape): - ret = run_subprocess(["bioimageio", "test-model", stardist_wrong_shape]) - assert ret.returncode == 1 - - -def test_cli_test_model_with_weight_format(unet2d_nuclei_broad_model): - ret = run_subprocess( - ["bioimageio", "test-model", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] +from pydantic import FilePath + + +def run_subprocess( + commands: Sequence[str], **kwargs: Any +) -> "subprocess.CompletedProcess[str]": + return subprocess.run( + commands, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + encoding="utf-8", + **kwargs, ) - assert ret.returncode == 0, ret.stdout -def test_cli_test_resource(unet2d_nuclei_broad_model): - ret = run_subprocess(["bioimageio", "test-resource", unet2d_nuclei_broad_model]) +@pytest.mark.parametrize( + "args", + [ + [ + "package", + "unet2d_nuclei_broad_model", + "--weight-format", + "pytorch_state_dict", + ], + ["package", "unet2d_nuclei_broad_model"], + [ + "test", + "unet2d_nuclei_broad_model", + "--weight-format", + "pytorch_state_dict", + ], + ["test", "unet2d_nuclei_broad_model"], + ], +) +def test_cli(args: List[str], unet2d_nuclei_broad_model: str): + resolved_args = [ + str(unet2d_nuclei_broad_model) if arg == "unet2d_nuclei_broad_model" else arg + for arg in args + ] + ret = run_subprocess(["bioimageio", *resolved_args]) assert ret.returncode == 0, ret.stdout -def test_cli_test_resource_with_weight_format(unet2d_nuclei_broad_model): - ret = run_subprocess( - ["bioimageio", "test-resource", unet2d_nuclei_broad_model, "--weight-format", "pytorch_state_dict"] - ) - assert ret.returncode == 0, ret.stdout +@pytest.mark.parametrize("args", [["test", "stardist_wrong_shape"]]) +def test_cli_fails(args: List[str], stardist_wrong_shape: FilePath): + resolved_args = [ + str(stardist_wrong_shape) if arg == "stardist_wrong_shape" else arg + for arg in args + ] + ret = run_subprocess(["bioimageio", *resolved_args]) + assert ret.returncode == 1, ret.stdout -def _test_cli_predict_image(model, tmp_path, extra_kwargs=None): - spec = load_resource_description(model) - in_path = spec.test_inputs[0] - out_path = tmp_path.with_suffix(".npy") - cmd = ["bioimageio", "predict-image", model, "--inputs", str(in_path), "--outputs", str(out_path)] - if extra_kwargs is not None: - cmd.extend(extra_kwargs) - ret = run_subprocess(cmd) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() +# TODO: update CLI test +# def _test_cli_predict_image(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): +# spec = load_description(model) +# in_path = spec.test_inputs[0] +# out_path = tmp_path.with_suffix(".npy") +# cmd = ["bioimageio", "predict-image", model, "--input", str(in_path), "--output", str(out_path)] +# if extra_cmd_args is not None: +# cmd.extend(extra_cmd_args) +# ret = run_subprocess(cmd) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() -def test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path): - _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path) +# def test_cli_predict_image(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path) -def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model, tmp_path): - _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) +# def test_cli_predict_image_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_image(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) -def _test_cli_predict_images(model, tmp_path, extra_kwargs=None): - n_images = 3 - shape = (1, 1, 128, 128) - expected_shape = (1, 1, 128, 128) - in_folder = tmp_path / "inputs" - in_folder.mkdir() - out_folder = tmp_path / "outputs" - out_folder.mkdir() +# def _test_cli_predict_images(model: Path, tmp_path: Path, extra_cmd_args: Optional[List[str]] = None): +# n_images = 3 +# shape = (1, 1, 128, 128) +# expected_shape = (1, 1, 128, 128) - expected_outputs = [] - for i in range(n_images): - path = in_folder / f"im-{i}.npy" - im = np.random.randint(0, 255, size=shape).astype("uint8") - np.save(path, im) - expected_outputs.append(out_folder / f"im-{i}.npy") +# in_folder = tmp_path / "inputs" +# in_folder.mkdir() +# out_folder = tmp_path / "outputs" +# out_folder.mkdir() - input_pattern = str(in_folder / "*.npy") - cmd = ["bioimageio", "predict-images", model, input_pattern, str(out_folder)] - if extra_kwargs is not None: - cmd.extend(extra_kwargs) - ret = run_subprocess(cmd) - assert ret.returncode == 0, ret.stdout +# expected_outputs: List[Path] = [] +# for i in range(n_images): +# path = in_folder / f"im-{i}.npy" +# im = np.random.randint(0, 255, size=shape).astype("uint8") +# np.save(path, im) +# expected_outputs.append(out_folder / f"im-{i}.npy") - for out_path in expected_outputs: - assert out_path.exists() - assert np.load(out_path).shape == expected_shape +# input_pattern = str(in_folder / "*.npy") +# cmd = ["bioimageio", "predict-images", str(model), input_pattern, str(out_folder)] +# if extra_cmd_args is not None: +# cmd.extend(extra_cmd_args) +# ret = run_subprocess(cmd) +# assert ret.returncode == 0, ret.stdout +# for out_path in expected_outputs: +# assert out_path.exists() +# assert np.load(out_path).shape == expected_shape -def test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path): - _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path) +# def test_cli_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path) -def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model, tmp_path): - _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) +# def test_cli_predict_images_with_weight_format(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_cli_predict_images(unet2d_nuclei_broad_model, tmp_path, ["--weight-format", "pytorch_state_dict"]) -def test_torch_to_torchscript(unet2d_nuclei_broad_model, tmp_path): - out_path = tmp_path.with_suffix(".pt") - ret = run_subprocess( - ["bioimageio", "convert-torch-weights-to-torchscript", str(unet2d_nuclei_broad_model), str(out_path)] - ) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() +# def test_torch_to_torchscript(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# out_path = tmp_path.with_suffix(".pt") +# ret = run_subprocess( +# ["bioimageio", "convert-torch-weights-to-torchscript", str(unet2d_nuclei_broad_model), str(out_path)] +# ) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() -@pytest.mark.skipif(pytest.skip_onnx, reason="requires torch and onnx") -def test_torch_to_onnx(unet2d_nuclei_broad_model, tmp_path): - out_path = tmp_path.with_suffix(".onnx") - ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(unet2d_nuclei_broad_model), str(out_path)]) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() +# def test_torch_to_onnx(convert_to_onnx: Path, tmp_path: Path): +# out_path = tmp_path.with_suffix(".onnx") +# ret = run_subprocess(["bioimageio", "convert-torch-weights-to-onnx", str(convert_to_onnx), str(out_path)]) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() -def test_keras_to_tf(unet2d_keras, tmp_path): - out_path = tmp_path / "weights.zip" - ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)]) - assert ret.returncode == 0, ret.stdout - assert out_path.exists() + +# def test_keras_to_tf(unet2d_keras: Path, tmp_path: Path): +# out_path = tmp_path / "weights.zip" +# ret = run_subprocess(["bioimageio", "convert-keras-weights-to-tensorflow", str(unet2d_keras), str(out_path)]) +# assert ret.returncode == 0, ret.stdout +# assert out_path.exists() diff --git a/tests/test_digest_spec.py b/tests/test_digest_spec.py new file mode 100644 index 00000000..08022ab2 --- /dev/null +++ b/tests/test_digest_spec.py @@ -0,0 +1,42 @@ +import pytest + +from bioimageio.spec import load_description +from bioimageio.spec.model import v0_5 + + +# TODO: don't just test with unet2d_nuclei_broad_model +@pytest.mark.skip("get_io_sample_block_metas needs improvements") +def test_get_block_transform(unet2d_nuclei_broad_model: str): + from bioimageio.core.axis import AxisId + from bioimageio.core.common import MemberId + from bioimageio.core.digest_spec import ( + get_block_transform, + get_io_sample_block_metas, + ) + + model = load_description(unet2d_nuclei_broad_model) + assert isinstance(model, v0_5.ModelDescr) + block_transform = get_block_transform(model) + + ns = { + (ipt.id, a.id): 1 + for ipt in model.inputs + for a in ipt.axes + if isinstance(a.size, v0_5.ParameterizedSize) + } + + _, blocks = get_io_sample_block_metas( + model, + input_sample_shape={ + MemberId("raw"): { + AxisId("batch"): 3, + AxisId("channel"): 1, + AxisId("x"): 4000, + AxisId("y"): 3000, + } + }, + ns=ns, + ) + for ipt_block, out_block in blocks: + trf_block = ipt_block.get_transformed(block_transform) + assert out_block == trf_block diff --git a/tests/test_export_package.py b/tests/test_export_package.py deleted file mode 100644 index a04aa994..00000000 --- a/tests/test_export_package.py +++ /dev/null @@ -1,60 +0,0 @@ -import shutil -from pathlib import Path -from tempfile import TemporaryDirectory -from zipfile import ZipFile - -from marshmallow import missing - -from bioimageio.spec.model import raw_nodes - - -def test_export_package(any_onnx_model): - from bioimageio.core import export_resource_package, load_raw_resource_description - - package_path = export_resource_package(any_onnx_model, weights_priority_order=["onnx"]) - assert isinstance(package_path, Path), package_path - assert package_path.exists(), package_path - - raw_model = load_raw_resource_description(package_path) - assert isinstance(raw_model, raw_nodes.Model) - - -def test_package_with_folder(unet2d_nuclei_broad_model): - from bioimageio.core import export_resource_package, load_raw_resource_description - - with TemporaryDirectory() as tmp_dir: - tmp_dir = Path(tmp_dir) - - # extract package (to not cache to BIOIMAGEIO_CACHE) - package_folder = tmp_dir / "package" - with ZipFile(unet2d_nuclei_broad_model) as zf: - zf.extractall(package_folder) - - # load package - model = load_raw_resource_description(package_folder / "rdf.yaml") - assert isinstance(model, raw_nodes.Model) - - # alter package to have its documentation in a nested folder - doc = model.documentation - assert doc is not missing - doc = doc.relative_to(model.root_path) - assert not doc.is_absolute() - new_doc = Path("nested") / "folder" / doc - (package_folder / new_doc).parent.mkdir(parents=True) - shutil.move(package_folder / doc, package_folder / new_doc) - model.documentation = new_doc - - # export altered package - altered_package = tmp_dir / "altered_package.zip" - altered_package = export_resource_package(model, output_path=altered_package, weights_priority_order=["onnx"]) - - # extract altered package (to not cache to BIOIMAGEIO_CACHE) - altered_package_folder = tmp_dir / "altered_package" - with ZipFile(altered_package) as zf: - zf.extractall(altered_package_folder) - - # load altered package - reloaded_model = load_raw_resource_description(altered_package_folder / "rdf.yaml") - assert isinstance(reloaded_model, raw_nodes.Model) - assert reloaded_model.documentation.as_posix().endswith(new_doc.as_posix()) - assert reloaded_model.documentation.exists() diff --git a/tests/test_image_helper.py b/tests/test_image_helper.py deleted file mode 100644 index 9c495de1..00000000 --- a/tests/test_image_helper.py +++ /dev/null @@ -1,29 +0,0 @@ -import numpy as np - - -def test_transform_input_image(): - from bioimageio.core.image_helper import transform_input_image - - ax_list = ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - im = np.random.rand(256, 256) - for axes in ax_list: - inp = transform_input_image(im, axes) - assert inp.ndim == len(axes) - - ax_list = ["zyx", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"] - vol = np.random.rand(64, 64, 64) - for axes in ax_list: - inp = transform_input_image(vol, axes) - assert inp.ndim == len(axes) - - -def test_transform_output_tensor(): - from bioimageio.core.image_helper import transform_output_tensor - - tensor = np.random.rand(1, 3, 64, 64, 64) - tensor_axes = "bczyx" - - out_ax_list = ["bczyx", "cyx", "xyc", "byxc", "zyx", "xyz"] - for out_axes in out_ax_list: - out = transform_output_tensor(tensor, tensor_axes, out_axes) - assert out.ndim == len(out_axes) diff --git a/tests/test_package.py b/tests/test_package.py new file mode 100644 index 00000000..bb087375 --- /dev/null +++ b/tests/test_package.py @@ -0,0 +1,7 @@ +def test_save_bioimageio_package(unet2d_nuclei_broad_model: str): + from bioimageio.spec._package import save_bioimageio_package + + _ = save_bioimageio_package( + unet2d_nuclei_broad_model, + weights_priority_order=("pytorch_state_dict",), + ) diff --git a/tests/test_prediction.py b/tests/test_prediction.py index 7992f9f5..de8b8062 100644 --- a/tests/test_prediction.py +++ b/tests/test_prediction.py @@ -1,208 +1,228 @@ -from pathlib import Path - -import imageio -import numpy as np -from numpy.testing import assert_array_almost_equal - -from bioimageio.core import load_resource_description -from bioimageio.core.resource_io.nodes import Model - - -def test_predict_image(any_model, tmpdir): - from bioimageio.core.prediction import predict_image - - spec = load_resource_description(any_model) - assert isinstance(spec, Model) - inputs = spec.test_inputs - - outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] - predict_image(any_model, inputs, outputs) - for out_path in outputs: - assert out_path.exists() - - result = [np.load(str(p)) for p in outputs] - expected = [np.load(str(p)) for p in spec.test_outputs] - for res, exp in zip(result, expected): - assert_array_almost_equal(res, exp, decimal=4) - - -def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir): - from bioimageio.core.prediction import predict_image - - spec = load_resource_description(unet2d_fixed_shape_or_not) - assert isinstance(spec, Model) - inputs = spec.test_inputs - - outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] - predict_image(unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict") - for out_path in outputs: - assert out_path.exists() - - result = [np.load(str(p)) for p in outputs] - expected = [np.load(str(p)) for p in spec.test_outputs] - for res, exp in zip(result, expected): - assert_array_almost_equal(res, exp, decimal=4) +# TODO: update +# from pathlib import Path + +# import imageio +# import numpy as np +# from numpy.testing import assert_array_almost_equal + +# from bioimageio.spec import load_description +# from bioimageio.spec.model.v0_4 import InputTensorDescr as InputTensorDescr_v0_4 +# from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr_v0_4 +# from bioimageio.spec.model.v0_5 import ModelDescr + + +# def test_predict_image(any_model: Path, tmpdir: Path): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(any_model) +# assert isinstance(spec, ModelDescr) +# inputs = spec.test_inputs + +# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] +# predict_image(any_model, inputs, outputs) +# for out_path in outputs: +# assert out_path.exists() + +# result = [np.load(str(p)) for p in outputs] +# expected = [np.load(str(p)) for p in spec.test_outputs] +# for res, exp in zip(result, expected): +# assert_array_almost_equal(res, exp, decimal=4) + + +# def test_predict_image_with_weight_format( +# unet2d_fixed_shape_or_not: Path, tmpdir: Path +# ): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(unet2d_fixed_shape_or_not) +# assert isinstance(spec, Model) +# inputs = spec.test_inputs + +# outputs = [Path(tmpdir) / f"out{i}.npy" for i in range(len(spec.test_outputs))] +# predict_image( +# unet2d_fixed_shape_or_not, inputs, outputs, weight_format="pytorch_state_dict" +# ) +# for out_path in outputs: +# assert out_path.exists() + +# result = [np.load(str(p)) for p in outputs] +# expected = [np.load(str(p)) for p in spec.test_outputs] +# for res, exp in zip(result, expected): +# assert_array_almost_equal(res, exp, decimal=4) + + +# def _test_predict_with_padding(any_model: Path, tmp_path: Path): +# from bioimageio.core.digest_spec import get_test_inputs +# from bioimageio.core.prediction import predict_image + +# model = load_description(any_model) +# assert isinstance(model, (ModelDescr_v0_4, ModelDescr)) + +# input_spec, output_spec = model.inputs[0], model.outputs[0] +# channel_axis = ( +# "c" +# if isinstance(input_spec, InputTensorDescr_v0_4) +# else [a.id for a in input_spec.axes][0] +# ) +# channel_first = channel_axis == 1 + +# # TODO: check more tensors +# image = get_test_inputs(model)[0] + +# if isinstance(output_spec.shape, list): +# n_channels = output_spec.shape[channel_axis] +# else: +# scale = output_spec.shape.scale[channel_axis] +# offset = output_spec.shape.offset[channel_axis] +# in_channels = 1 +# n_channels = int(2 * offset + scale * in_channels) + +# # write the padded image +# image = image[3:-2, 1:-12] +# in_path = tmp_path / "in.tif" +# out_path = tmp_path / "out.tif" +# imageio.imwrite(in_path, image) + +# if hasattr(output_spec.shape, "scale"): +# scale = dict(zip(output_spec.axes, output_spec.shape.scale)) +# offset = dict(zip(output_spec.axes, output_spec.shape.offset)) +# spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] +# network_resizes = any( +# sc != 1 for ax, sc in scale.items() if ax in spatial_axes +# ) or any(off != 0 for ax, off in offset.items() if ax in spatial_axes) +# else: +# network_resizes = False + +# if network_resizes: +# exp_shape = tuple( +# int(sh * scale[ax] + 2 * offset[ax]) +# for sh, ax in zip(image.shape, spatial_axes) +# ) +# else: +# exp_shape = image.shape + +# def check_result(): +# if n_channels == 1: +# assert out_path.exists() +# res = imageio.imread(out_path) +# assert res.shape == exp_shape +# else: +# path = str(out_path) +# for c in range(n_channels): +# channel_out_path = Path(path.replace(".tif", f"-c{c}.tif")) +# assert channel_out_path.exists() +# res = imageio.imread(channel_out_path) +# assert res.shape == exp_shape + +# # test with dynamic padding +# predict_image( +# any_model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"} +# ) +# check_result() + +# # test with fixed padding +# predict_image( +# any_model, +# in_path, +# out_path, +# padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}, +# ) +# check_result() + +# # test with automated padding +# predict_image(any_model, in_path, out_path, padding=True) +# check_result() + + +# # prediction with padding with the parameters above may not be suited for any model +# # so we only run it for the pytorch unet2d here +# def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): +# _test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path) + + +# # and with different output shape +# def test_predict_image_with_padding_diff_output_shape( +# unet2d_diff_output_shape, tmp_path +# ): +# _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) + + +# def test_predict_image_with_padding_channel_last(stardist, tmp_path): +# _test_predict_with_padding(stardist, tmp_path) + + +# def _test_predict_image_with_tiling(model: Path, tmp_path: Path, exp_mean_deviation): +# from bioimageio.core.prediction import predict_image + +# spec = load_description(model) +# assert isinstance(spec, Model) +# inputs = spec.test_inputs +# assert len(inputs) == 1 +# exp = np.load(str(spec.test_outputs[0])) + +# out_path = tmp_path.with_suffix(".npy") + +# def check_result(): +# assert out_path.exists() +# res = np.load(out_path) +# assert res.shape == exp.shape +# # check that the mean deviation is smaller than the expected value +# # note that we can't use array_almost_equal here, because the numerical differences +# # between tiled and normal prediction are too large +# mean_deviation = np.abs(res - exp).mean() +# assert mean_deviation <= exp_mean_deviation + +# # with tiling config +# tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}} +# predict_image(model, inputs, [out_path], tiling=tiling) +# check_result() + +# # with tiling determined from spec +# predict_image(model, inputs, [out_path], tiling=True) +# check_result() + + +# # prediction with tiling with the parameters above may not be suited for any model +# # so we only run it for the pytorch unet2d here +# def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) + + +# def test_predict_image_with_tiling_2(unet2d_diff_output_shape: Path, tmp_path: Path): +# _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) + + +# def test_predict_image_with_tiling_3(shape_change_model: Path, tmp_path: Path): +# _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) -def _test_predict_with_padding(model, tmp_path): - from bioimageio.core.prediction import predict_image - - spec = load_resource_description(model) - assert isinstance(spec, Model) - - input_spec, output_spec = spec.inputs[0], spec.outputs[0] - channel_axis = input_spec.axes.index("c") - channel_first = channel_axis == 1 - - image = np.load(str(spec.test_inputs[0])) - assert image.shape[channel_axis] == 1 - if channel_first: - image = image[0, 0] - else: - image = image[0, ..., 0] - original_shape = image.shape - assert image.ndim == 2 - - if isinstance(output_spec.shape, list): - n_channels = output_spec.shape[channel_axis] - else: - scale = output_spec.shape.scale[channel_axis] - offset = output_spec.shape.offset[channel_axis] - in_channels = 1 - n_channels = int(2 * offset + scale * in_channels) - - # write the padded image - image = image[3:-2, 1:-12] - in_path = tmp_path / "in.tif" - out_path = tmp_path / "out.tif" - imageio.imwrite(in_path, image) +# def test_predict_image_with_tiling_channel_last(stardist: Path, tmp_path: Path): +# _test_predict_image_with_tiling(stardist, tmp_path, 0.13) - if hasattr(output_spec.shape, "scale"): - scale = dict(zip(output_spec.axes, output_spec.shape.scale)) - offset = dict(zip(output_spec.axes, output_spec.shape.offset)) - spatial_axes = [ax for ax in output_spec.axes if ax in "xyz"] - network_resizes = any(sc != 1 for ax, sc in scale.items() if ax in spatial_axes) or any( - off != 0 for ax, off in offset.items() if ax in spatial_axes - ) - else: - network_resizes = False - if network_resizes: - exp_shape = tuple(int(sh * scale[ax] + 2 * offset[ax]) for sh, ax in zip(image.shape, spatial_axes)) - else: - exp_shape = image.shape +# def test_predict_image_with_tiling_fixed_output_shape( +# unet2d_fixed_shape: Path, tmp_path: Path +# ): +# _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) - def check_result(): - if n_channels == 1: - assert out_path.exists() - res = imageio.imread(out_path) - assert res.shape == exp_shape - else: - path = str(out_path) - for c in range(n_channels): - channel_out_path = Path(path.replace(".tif", f"-c{c}.tif")) - assert channel_out_path.exists() - res = imageio.imread(channel_out_path) - assert res.shape == exp_shape - # test with dynamic padding - predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"}) - check_result() +# def test_predict_images(unet2d_nuclei_broad_model: Path, tmp_path: Path): +# from bioimageio.core.prediction import predict_images - # test with fixed padding - predict_image(model, in_path, out_path, padding={"x": original_shape[0], "y": original_shape[1], "mode": "fixed"}) - check_result() +# n_images = 5 +# shape = (256, 256) - # test with automated padding - predict_image(model, in_path, out_path, padding=True) - check_result() +# in_paths = [] +# out_paths = [] +# for i in range(n_images): +# in_path = tmp_path / f"in{i}.tif" +# im = np.random.randint(0, 255, size=shape).astype("uint8") +# imageio.imwrite(in_path, im) +# in_paths.append(in_path) +# out_paths.append(tmp_path / f"out{i}.tif") +# predict_images(unet2d_nuclei_broad_model, in_paths, out_paths) - -# prediction with padding with the parameters above may not be suited for any model -# so we only run it for the pytorch unet2d here -def test_predict_image_with_padding(unet2d_fixed_shape_or_not, tmp_path): - _test_predict_with_padding(unet2d_fixed_shape_or_not, tmp_path) - - -# and with different output shape -def test_predict_image_with_padding_diff_output_shape(unet2d_diff_output_shape, tmp_path): - _test_predict_with_padding(unet2d_diff_output_shape, tmp_path) - - -def test_predict_image_with_padding_channel_last(stardist, tmp_path): - _test_predict_with_padding(stardist, tmp_path) - - -def _test_predict_image_with_tiling(model, tmp_path: Path, exp_mean_deviation): - from bioimageio.core.prediction import predict_image - - spec = load_resource_description(model) - assert isinstance(spec, Model) - inputs = spec.test_inputs - assert len(inputs) == 1 - exp = np.load(str(spec.test_outputs[0])) - - out_path = tmp_path.with_suffix(".npy") - - def check_result(): - assert out_path.exists() - res = np.load(out_path) - assert res.shape == exp.shape - # check that the mean deviation is smaller than the expected value - # note that we can't use array_almost_equal here, because the numerical differences - # between tiled and normal prediction are too large - mean_deviation = np.abs(res - exp).mean() - assert mean_deviation <= exp_mean_deviation - - # with tiling config - tiling = {"halo": {"x": 32, "y": 32}, "tile": {"x": 256, "y": 256}} - predict_image(model, inputs, [out_path], tiling=tiling) - check_result() - - # with tiling determined from spec - predict_image(model, inputs, [out_path], tiling=True) - check_result() - - -# prediction with tiling with the parameters above may not be suited for any model -# so we only run it for the pytorch unet2d here -def test_predict_image_with_tiling_1(unet2d_nuclei_broad_model, tmp_path: Path): - _test_predict_image_with_tiling(unet2d_nuclei_broad_model, tmp_path, 0.012) - - -def test_predict_image_with_tiling_2(unet2d_diff_output_shape, tmp_path: Path): - _test_predict_image_with_tiling(unet2d_diff_output_shape, tmp_path, 0.06) - - -def test_predict_image_with_tiling_3(shape_change_model, tmp_path: Path): - _test_predict_image_with_tiling(shape_change_model, tmp_path, 0.012) - - -def test_predict_image_with_tiling_channel_last(stardist, tmp_path: Path): - _test_predict_image_with_tiling(stardist, tmp_path, 0.13) - - -def test_predict_image_with_tiling_fixed_output_shape(unet2d_fixed_shape, tmp_path: Path): - _test_predict_image_with_tiling(unet2d_fixed_shape, tmp_path, 0.025) - - -def test_predict_images(unet2d_nuclei_broad_model, tmp_path: Path): - from bioimageio.core.prediction import predict_images - - n_images = 5 - shape = (256, 256) - - in_paths = [] - out_paths = [] - for i in range(n_images): - in_path = tmp_path / f"in{i}.tif" - im = np.random.randint(0, 255, size=shape).astype("uint8") - imageio.imwrite(in_path, im) - in_paths.append(in_path) - out_paths.append(tmp_path / f"out{i}.tif") - predict_images(unet2d_nuclei_broad_model, in_paths, out_paths) - - for outp in out_paths: - assert outp.exists() - out = imageio.imread(outp) - assert out.shape == shape +# for outp in out_paths: +# assert outp.exists() +# out = imageio.imread(outp) +# assert out.shape == shape diff --git a/tests/test_prediction_pipeline.py b/tests/test_prediction_pipeline.py new file mode 100644 index 00000000..a0a85f5d --- /dev/null +++ b/tests/test_prediction_pipeline.py @@ -0,0 +1,51 @@ +from pathlib import Path + +from numpy.testing import assert_array_almost_equal + +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 +from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat + + +def _test_prediction_pipeline(model_package: Path, weights_format: WeightsFormat): + from bioimageio.core._prediction_pipeline import create_prediction_pipeline + from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs + + bio_model = load_description(model_package) + assert isinstance( + bio_model, (ModelDescr, ModelDescr04) + ), bio_model.validation_summary.format() + pp = create_prediction_pipeline( + bioimageio_model=bio_model, weight_format=weights_format + ) + + inputs = get_test_inputs(bio_model) + outputs = pp.predict_sample_without_blocking(inputs) + + expected_outputs = get_test_outputs(bio_model) + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data + assert out is not None + exp = expected_outputs.members[m].data + assert_array_almost_equal(out, exp, decimal=4) + + +def test_prediction_pipeline_torch(any_torch_model: Path): + _test_prediction_pipeline(any_torch_model, "pytorch_state_dict") + + +def test_prediction_pipeline_torchscript(any_torchscript_model: Path): + _test_prediction_pipeline(any_torchscript_model, "torchscript") + + +def test_prediction_pipeline_onnx(any_onnx_model: Path): + _test_prediction_pipeline(any_onnx_model, "onnx") + + +def test_prediction_pipeline_tensorflow(any_tensorflow_model: Path): + _test_prediction_pipeline(any_tensorflow_model, "tensorflow_saved_model_bundle") + + +def test_prediction_pipeline_keras(any_keras_model: Path): + _test_prediction_pipeline(any_keras_model, "keras_hdf5") diff --git a/tests/test_prediction_pipeline_device_management.py b/tests/test_prediction_pipeline_device_management.py new file mode 100644 index 00000000..447eb698 --- /dev/null +++ b/tests/test_prediction_pipeline_device_management.py @@ -0,0 +1,87 @@ +from pathlib import Path + +from numpy.testing import assert_array_almost_equal + +from bioimageio.core.utils.testing import skip_on +from bioimageio.spec.model.v0_4 import ModelDescr as ModelDescr04 +from bioimageio.spec.model.v0_5 import ModelDescr, WeightsFormat + + +class TooFewDevicesException(Exception): + pass + + +def _test_device_management(model_package: Path, weight_format: WeightsFormat): + import torch + + from bioimageio.core import load_description + from bioimageio.core._prediction_pipeline import create_prediction_pipeline + from bioimageio.core.digest_spec import get_test_inputs, get_test_outputs + + if not hasattr(torch, "cuda") or torch.cuda.device_count() == 0: + raise TooFewDevicesException("Need at least one cuda device for this test") + + bio_model = load_description(model_package) + assert isinstance(bio_model, (ModelDescr, ModelDescr04)) + pred_pipe = create_prediction_pipeline( + bioimageio_model=bio_model, weight_format=weight_format, devices=["cuda:0"] + ) + + inputs = get_test_inputs(bio_model) + with pred_pipe as pp: + outputs = pp.predict_sample_without_blocking(inputs) + + expected_outputs = get_test_outputs(bio_model) + + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data + assert out is not None + exp = expected_outputs.members[m].data + assert_array_almost_equal(out, exp, decimal=4) + + # repeat inference with context manager to test load/predict/unload/load/predict + with pred_pipe as pp: + outputs = pp.predict_sample_without_blocking(inputs) + + assert len(outputs.shape) == len(expected_outputs.shape) + for m in expected_outputs.members: + out = outputs.members[m].data + assert out is not None + exp = expected_outputs.members[m].data + assert_array_almost_equal(out, exp, decimal=4) + + +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] +def test_device_management_torch(any_torch_model: Path): + _test_device_management(any_torch_model, "pytorch_state_dict") + + +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] +def test_device_management_torchscript(any_torchscript_model: Path): + _test_device_management(any_torchscript_model, "torchscript") + + +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] +def test_device_management_onnx(any_onnx_model: Path): + _test_device_management(any_onnx_model, "onnx") + + +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] +def test_device_management_tensorflow(any_tensorflow_model: Path): + _test_device_management(any_tensorflow_model, "tensorflow_saved_model_bundle") + + +@skip_on( + TooFewDevicesException, reason="Too few devices" +) # pyright: ignore[reportArgumentType] +def test_device_management_keras(any_keras_model: Path): + _test_device_management(any_keras_model, "keras_hdf5") diff --git a/tests/test_proc_ops.py b/tests/test_proc_ops.py new file mode 100644 index 00000000..7ef1a8fe --- /dev/null +++ b/tests/test_proc_ops.py @@ -0,0 +1,366 @@ +from typing import Iterable, Optional, Tuple, Type, TypeVar + +import numpy as np +import pytest +import xarray as xr +from typing_extensions import TypeGuard + +from bioimageio.core.axis import AxisId +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample +from bioimageio.core.stat_calculators import compute_measures +from bioimageio.core.stat_measures import SampleMean, SampleQuantile, SampleStd +from bioimageio.core.tensor import Tensor + + +@pytest.fixture(scope="module") +def tid(): + return MemberId("data123") + + +def test_scale_linear(tid: MemberId): + from bioimageio.core.proc_ops import ScaleLinear + + offset = xr.DataArray([1, 2, 42], dims=("c")) + gain = xr.DataArray([1, 2, 3], dims=("c")) + data = xr.DataArray(np.arange(6).reshape((1, 2, 3)), dims=("x", "y", "c")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + + op = ScaleLinear(input=tid, output=tid, offset=offset, gain=gain) + op(sample) + + expected = xr.DataArray(np.array([[[1, 4, 48], [4, 10, 57]]]), dims=("x", "y", "c")) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_scale_linear_no_channel(tid: MemberId): + from bioimageio.core.proc_ops import ScaleLinear + + op = ScaleLinear(tid, tid, offset=1, gain=2) + data = xr.DataArray(np.arange(6).reshape(2, 3), dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + op(sample) + + expected = xr.DataArray(np.array([[1, 3, 5], [7, 9, 11]]), dims=("x", "y")) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +T = TypeVar("T") + + +def is_iterable(val: Iterable[T], inner: Type[T]) -> TypeGuard[Iterable[T]]: + """Determines whether all objects in the list are strings""" + return all(isinstance(x, inner) for x in val) + + +def test_zero_mean_unit_variance(tid: MemberId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + m = SampleMean(tid) + std = SampleStd(tid) + op = ZeroMeanUnitVariance(tid, tid, m, std) + req = op.required_measures + sample.stat = compute_measures(req, [sample]) + op(sample) + + expected = xr.DataArray( + np.array( + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ] + ), + dims=("x", "y"), + ) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_zero_mean_unit_variance_fixed(tid: MemberId): + from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance + + op = FixedZeroMeanUnitVariance( + tid, + tid, + mean=xr.DataArray([3, 4, 5], dims=("c")), + std=xr.DataArray([2.44948974, 2.44948974, 2.44948974], dims=("c")), + ) + data = xr.DataArray(np.arange(9).reshape((1, 3, 3)), dims=("b", "c", "x")) + expected = xr.DataArray( + np.array( + [ + [ + [-1.22474487, -0.81649658, -0.40824829], + [-0.40824829, 0.0, 0.40824829], + [0.40824829, 0.81649658, 1.22474487], + ] + ] + ), + dims=("b", "c", "x"), + ) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_zero_mean_unit_across_axes(tid: MemberId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + + op = ZeroMeanUnitVariance( + tid, + tid, + SampleMean(tid, (AxisId("x"), AxisId("y"))), + SampleStd(tid, (AxisId("x"), AxisId("y"))), + ) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + sample.stat = compute_measures(op.required_measures, [sample]) + + expected = xr.concat( + [(data[i : i + 1] - data[i].mean()) / data[i].std() for i in range(2)], dim="c" + ) + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_zero_mean_unit_variance_fixed2(tid: MemberId): + from bioimageio.core.proc_ops import FixedZeroMeanUnitVariance + + np_data = np.arange(9).reshape(3, 3) + mean = float(np_data.mean()) + std = float(np_data.mean()) + eps = 1.0e-7 + op = FixedZeroMeanUnitVariance(tid, tid, mean=mean, std=std, eps=eps) + + data = xr.DataArray(np_data, dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + expected = xr.DataArray((np_data - mean) / (std + eps), dims=("x", "y")) + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_binarize(tid: MemberId): + from bioimageio.core.proc_ops import Binarize + + op = Binarize(tid, tid, threshold=14) + data = xr.DataArray(np.arange(30).reshape((2, 3, 5)), dims=("x", "y", "c")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + expected = xr.zeros_like(data) + expected[{"x": slice(1, None)}] = 1 + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_binarize2(tid: MemberId): + from bioimageio.core.proc_ops import Binarize + + shape = (3, 32, 32) + axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + data = xr.DataArray(np_data, dims=axes) + + threshold = 0.5 + exp = xr.DataArray(np_data > threshold, dims=axes) + + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + binarize = Binarize(tid, tid, threshold=threshold) + binarize(sample) + xr.testing.assert_allclose(exp, sample.members[tid].data) + + +def test_clip(tid: MemberId): + from bioimageio.core.proc_ops import Clip + + op = Clip(tid, tid, min=3, max=5) + data = xr.DataArray(np.arange(9).reshape(3, 3), dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + + expected = xr.DataArray( + np.array([[3, 3, 3], [3, 4, 5], [5, 5, 5]]), dims=("x", "y") + ) + op(sample) + xr.testing.assert_equal(expected, sample.members[tid].data) + + +def test_combination_of_op_steps_with_dims_specified(tid: MemberId): + from bioimageio.core.proc_ops import ZeroMeanUnitVariance + + data = xr.DataArray(np.arange(18).reshape((2, 3, 3)), dims=("c", "x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + op = ZeroMeanUnitVariance( + tid, + tid, + SampleMean( + tid, + (AxisId("x"), AxisId("y")), + ), + SampleStd( + tid, + (AxisId("x"), AxisId("y")), + ), + ) + sample.stat = compute_measures(op.required_measures, [sample]) + + expected = xr.DataArray( + np.array( + [ + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ], + [ + [-1.54919274, -1.16189455, -0.77459637], + [-0.38729818, 0.0, 0.38729818], + [0.77459637, 1.16189455, 1.54919274], + ], + ] + ), + dims=("c", "x", "y"), + ) + + op(sample) + xr.testing.assert_allclose(expected, sample.members[tid].data) + + +@pytest.mark.parametrize( + "axes", + [ + None, + tuple(map(AxisId, "cy")), + tuple(map(AxisId, "cyx")), + tuple(map(AxisId, "x")), + ], +) +def test_scale_mean_variance(tid: MemberId, axes: Optional[Tuple[AxisId, ...]]): + from bioimageio.core.proc_ops import ScaleMeanVariance + + shape = (3, 32, 46) + ipt_axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + ipt_data = xr.DataArray(np_data, dims=ipt_axes) + ref_data = xr.DataArray((np_data * 2) + 3, dims=ipt_axes) + + op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) + sample = Sample( + members={ + tid: Tensor.from_xarray(ipt_data), + MemberId("ref_name"): Tensor.from_xarray(ref_data), + }, + stat={}, + id=None, + ) + sample.stat = compute_measures(op.required_measures, [sample]) + op(sample) + xr.testing.assert_allclose(ref_data, sample.members[tid].data) + + +@pytest.mark.parametrize( + "axes_str", + [None, "cy", "y", "yx"], +) +def test_scale_mean_variance_per_channel(tid: MemberId, axes_str: Optional[str]): + from bioimageio.core.proc_ops import ScaleMeanVariance + + axes = None if axes_str is None else tuple(map(AxisId, axes_str)) + + shape = (3, 32, 46) + ipt_axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + ipt_data = xr.DataArray(np_data, dims=ipt_axes) + + # set different mean, std per channel + np_ref_data = np.stack([d * i + i for i, d in enumerate(np_data, start=2)]) + ref_data = xr.DataArray(np_ref_data, dims=ipt_axes) + + op = ScaleMeanVariance(tid, tid, reference_tensor=MemberId("ref_name"), axes=axes) + sample = Sample( + members={ + tid: Tensor.from_xarray(ipt_data), + MemberId("ref_name"): Tensor.from_xarray(ref_data), + }, + stat={}, + id=None, + ) + sample.stat = compute_measures(op.required_measures, [sample]) + op(sample) + + if axes is not None and AxisId("c") not in axes: + # mean,std per channel should match exactly + xr.testing.assert_allclose(ref_data, sample.members[tid].data) + else: + # mean,std across channels should not match + with pytest.raises(AssertionError): + xr.testing.assert_allclose(ref_data, sample.members[tid].data) + + +def test_scale_range(tid: MemberId): + from bioimageio.core.proc_ops import ScaleRange + + op = ScaleRange(tid, tid) + np_data = np.arange(9).reshape(3, 3).astype("float32") + data = xr.DataArray(np_data, dims=("x", "y")) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + sample.stat = compute_measures(op.required_measures, [sample]) + + eps = 1.0e-6 + mi, ma = np_data.min(), np_data.max() + exp_data = (np_data - mi) / (ma - mi + eps) + expected = xr.DataArray(exp_data, dims=("x", "y")) + + op(sample) + # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct + np.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_scale_range_axes(tid: MemberId): + from bioimageio.core.proc_ops import ScaleRange + + eps = 1.0e-6 + + lower_quantile = SampleQuantile(tid, 0.1, axes=(AxisId("x"), AxisId("y"))) + upper_quantile = SampleQuantile(tid, 0.9, axes=(AxisId("x"), AxisId("y"))) + op = ScaleRange(tid, tid, lower_quantile, upper_quantile, eps=eps) + + np_data = np.arange(18).reshape((2, 3, 3)).astype("float32") + data = Tensor.from_xarray(xr.DataArray(np_data, dims=("c", "x", "y"))) + sample = Sample(members={tid: data}, stat={}, id=None) + + p_low_direct = lower_quantile.compute(sample) + p_up_direct = upper_quantile.compute(sample) + + p_low_expected = np.quantile(np_data, lower_quantile.q, axis=(1, 2), keepdims=True) + p_up_expected = np.quantile(np_data, upper_quantile.q, axis=(1, 2), keepdims=True) + + np.testing.assert_allclose(p_low_expected.squeeze(), p_low_direct) + np.testing.assert_allclose(p_up_expected.squeeze(), p_up_direct) + + sample.stat = compute_measures(op.required_measures, [sample]) + + np.testing.assert_allclose(p_low_expected.squeeze(), sample.stat[lower_quantile]) + np.testing.assert_allclose(p_up_expected.squeeze(), sample.stat[upper_quantile]) + + exp_data = (np_data - p_low_expected) / (p_up_expected - p_low_expected + eps) + expected = xr.DataArray(exp_data, dims=("c", "x", "y")) + + op(sample) + # NOTE xarray.testing.assert_allclose compares irrelavant properties here and fails although the result is correct + np.testing.assert_allclose(expected, sample.members[tid].data) + + +def test_sigmoid(tid: MemberId): + from bioimageio.core.proc_ops import Sigmoid + + shape = (3, 32, 32) + axes = ("c", "y", "x") + np_data = np.random.rand(*shape) + data = xr.DataArray(np_data, dims=axes) + sample = Sample(members={tid: Tensor.from_xarray(data)}, stat={}, id=None) + sigmoid = Sigmoid(tid, tid) + sigmoid(sample) + + exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes) + xr.testing.assert_allclose(exp, sample.members[tid].data) diff --git a/tests/test_resource_tests.py b/tests/test_resource_tests.py new file mode 100644 index 00000000..b9d3cf66 --- /dev/null +++ b/tests/test_resource_tests.py @@ -0,0 +1,49 @@ +from bioimageio.spec import InvalidDescr + + +def test_error_for_wrong_shape(stardist_wrong_shape: str): + from bioimageio.core._resource_tests import test_model + + summary = test_model(stardist_wrong_shape) + expected_error_message = ( + "Shape (1, 512, 512, 33) of test output 0 'output' does not match output shape description: " + "ImplicitOutputShape(reference_tensor='input', " + "scale=[1.0, 1.0, 1.0, 0.0], offset=[1.0, 1.0, 1.0, 33.0])." + ) + assert summary.details[0].errors[0].msg == expected_error_message + + +def test_error_for_wrong_shape2(stardist_wrong_shape2: str): + from bioimageio.core._resource_tests import test_model + + summary = test_model(stardist_wrong_shape2) + expected_error_message = ( + "Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: " + "ParameterizedInputShape(min=[1, 80, 80, 1], step=[0, 17, 17, 0])." + ) + assert summary.details[0].errors[0].msg == expected_error_message + + +def test_test_model(any_model: str): + from bioimageio.core._resource_tests import test_model + + summary = test_model(any_model) + assert summary.status == "passed", summary.format() + + +def test_test_resource(any_model: str): + from bioimageio.core._resource_tests import test_description + + summary = test_description(any_model) + assert summary.status == "passed", summary.format() + + +def test_loading_description_multiple_times(unet2d_nuclei_broad_model: str): + from bioimageio.core import load_description + + model_descr = load_description(unet2d_nuclei_broad_model) + assert not isinstance(model_descr, InvalidDescr) + + # load again, which some users might end up doing + model_descr = load_description(model_descr) # pyright: ignore[reportArgumentType] + assert not isinstance(model_descr, InvalidDescr) diff --git a/tests/test_resource_tests/test_test_model.py b/tests/test_resource_tests/test_test_model.py deleted file mode 100644 index 1d168472..00000000 --- a/tests/test_resource_tests/test_test_model.py +++ /dev/null @@ -1,72 +0,0 @@ -import pathlib - -import pytest - - -def test_error_for_wrong_shape(stardist_wrong_shape): - from bioimageio.core.resource_tests import test_model - - summary = test_model(stardist_wrong_shape)[-1] - expected_error_message = ( - "Shape (1, 512, 512, 33) of test output 0 'output' does not match output shape description: " - "ImplicitOutputShape(reference_tensor='input', " - "scale=[1.0, 1.0, 1.0, 0.0], offset=[1.0, 1.0, 1.0, 33.0])." - ) - assert summary["error"] == expected_error_message - - -def test_error_for_wrong_shape2(stardist_wrong_shape2): - from bioimageio.core.resource_tests import test_model - - summary = test_model(stardist_wrong_shape2)[-1] - expected_error_message = ( - "Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: " - "ParametrizedInputShape(min=[1, 80, 80, 1], step=[0, 17, 17, 0])." - ) - assert summary["error"] == expected_error_message - - -def test_test_model(any_model): - from bioimageio.core.resource_tests import test_model - - summary = test_model(any_model) - assert all([s["status"] for s in summary]) - - -def test_test_resource(any_model): - from bioimageio.core.resource_tests import test_resource - - summary = test_resource(any_model) - assert all([s["status"] for s in summary]) - - -def test_validation_section_warning(unet2d_nuclei_broad_model, tmp_path: pathlib.Path): - from bioimageio.core.resource_tests import test_resource - from bioimageio.core import load_resource_description - - model = load_resource_description(unet2d_nuclei_broad_model) - - summary = test_resource(model)[2] - assert summary["name"] == "Test documentation completeness." - assert summary["warnings"] == {"documentation": "No '# Validation' (sub)section found."} - assert summary["status"] == "passed" - - doc_with_validation = tmp_path / "doc.md" - doc_with_validation.write_text("# Validation\nThis is a section about how to validate the model on new data") - model.documentation = doc_with_validation - summary = test_resource(model)[2] - assert summary["name"] == "Test documentation completeness." - assert summary["warnings"] == {} - assert summary["status"] == "passed" - - -@pytest.mark.skipif(pytest.skip_torch, reason="requires torch") -def test_issue289(): - """test for failure case from https://github.com/bioimage-io/core-bioimage-io-python/issues/289""" - import bioimageio.core - from bioimageio.core.resource_tests import test_model - - doi = "10.5281/zenodo.6287342" - model_resource = bioimageio.core.load_resource_description(doi) - test_result = test_model(model_resource) - assert all([t["status"] == "passed" for t in test_result]) diff --git a/tests/test_stat_calculators.py b/tests/test_stat_calculators.py new file mode 100644 index 00000000..115b8556 --- /dev/null +++ b/tests/test_stat_calculators.py @@ -0,0 +1,66 @@ +from typing import Tuple, Union + +import numpy as np +import pytest +from xarray.testing import assert_allclose # pyright: ignore[reportUnknownVariableType] + +from bioimageio.core.axis import AxisId +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample +from bioimageio.core.stat_calculators import MeanVarStdCalculator +from bioimageio.core.stat_measures import ( + DatasetMean, + DatasetStd, + DatasetVar, +) +from bioimageio.core.tensor import Tensor + + +def create_random_dataset(tid: MemberId, axes: Tuple[AxisId, ...]): + n = 3 + sizes = list(range(n, len(axes) + n)) + data = np.asarray(np.random.rand(*sizes)) + ds = [ + Sample(members={tid: Tensor(data[i : i + 1], dims=axes)}, stat={}, id=None) + for i in range(n) + ] + return Tensor(data, dims=axes), ds + + +@pytest.mark.parametrize( + "axes", + [ + None, + ("x", "y"), + ("channel", "y"), + ], +) +def test_mean_var_std_calculator(axes: Union[None, str, Tuple[str, ...]]): + tid = MemberId("tensor") + axes = tuple(map(AxisId, ("batch", "channel", "x", "y"))) + data, ds = create_random_dataset(tid, axes) + expected_mean = data.mean(axes) + expected_var = data.var(axes) + expected_std = data.std(axes) + + calc = MeanVarStdCalculator(tid, axes=axes) + for s in ds: + calc.update(s) + + actual = calc.finalize() + actual_mean = actual[DatasetMean(tid, axes=axes)] + actual_var = actual[DatasetVar(tid, axes=axes)] + actual_std = actual[DatasetStd(tid, axes=axes)] + + assert_allclose( + actual_mean if isinstance(actual_mean, (int, float)) else actual_mean.data, + expected_mean.data, + ) + assert_allclose( + actual_var if isinstance(actual_var, (int, float)) else actual_var.data, + expected_var.data, + ) + assert_allclose( + actual_std if isinstance(actual_std, (int, float)) else actual_std.data, + expected_std.data, + ) diff --git a/tests/test_stat_measures.py b/tests/test_stat_measures.py new file mode 100644 index 00000000..49c87609 --- /dev/null +++ b/tests/test_stat_measures.py @@ -0,0 +1,66 @@ +from itertools import product +from typing import Optional, Tuple + +import numpy as np +import pytest +import xarray as xr + +from bioimageio.core import stat_measures +from bioimageio.core.axis import AxisId +from bioimageio.core.common import MemberId +from bioimageio.core.sample import Sample +from bioimageio.core.stat_calculators import ( + SamplePercentilesCalculator, + get_measure_calculators, +) +from bioimageio.core.stat_measures import SampleQuantile +from bioimageio.core.tensor import Tensor + + +@pytest.mark.parametrize( + "name,axes", + product( + ["mean", "var", "std"], + [None, (AxisId("c"),), (AxisId("x"), AxisId("y"))], + ), +) +def test_individual_normal_measure( + name: str, + axes: Optional[Tuple[AxisId, AxisId]], +): + data_id = MemberId("test_data") + measure = getattr(stat_measures, "Sample" + name.title())( + axes=axes, member_id=data_id + ) + data = Tensor( + np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) + ) + + expected = getattr(data, name)(dim=axes) + sample = Sample(members={data_id: data}, stat={}, id=None) + actual = measure.compute(sample) + xr.testing.assert_allclose(expected.data, actual.data) + + +@pytest.mark.parametrize("axes", [None, (AxisId("x"), AxisId("y"))]) +def test_individual_percentile_measure(axes: Optional[Tuple[AxisId, ...]]): + qs = [0, 0.1, 0.5, 1.0] + tid = MemberId("tensor") + + measures = [SampleQuantile(member_id=tid, axes=axes, q=q) for q in qs] + calcs, _ = get_measure_calculators(measures) + assert len(calcs) == 1 + calc = calcs[0] + assert isinstance(calc, SamplePercentilesCalculator) + + data = Tensor( + np.random.random((5, 6, 3)), dims=(AxisId("x"), AxisId("y"), AxisId("c")) + ) + actual = calc.compute(Sample(members={tid: data}, stat={}, id=None)) + for m in measures: + expected = data.quantile(q=m.q, dim=m.axes) + actual_data = actual[m] + if isinstance(actual_data, Tensor): + actual_data = actual_data.data + + xr.testing.assert_allclose(expected.data, actual_data) diff --git a/tests/test_tensor.py b/tests/test_tensor.py new file mode 100644 index 00000000..33163077 --- /dev/null +++ b/tests/test_tensor.py @@ -0,0 +1,41 @@ +import numpy as np +import pytest +import xarray as xr +from xarray.testing import assert_equal # pyright: ignore[reportUnknownVariableType] + +from bioimageio.core import AxisId, Tensor + + +@pytest.mark.parametrize( + "axes", + ["yx", "xy", "cyx", "yxc", "bczyx", "xyz", "xyzc", "bzyxc"], +) +def test_transpose_tensor_2d(axes: str): + + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +@pytest.mark.parametrize( + "axes", + ["zyx", "cyzx", "yzixc", "bczyx", "xyz", "xyzc", "bzyxtc"], +) +def test_transpose_tensor_3d(axes: str): + tensor = Tensor.from_numpy(np.random.rand(64, 64, 64), dims=None) + transposed = tensor.transpose([AxisId(a) for a in axes]) + assert transposed.ndim == len(axes) + + +def test_crop_and_pad(): + tensor = Tensor.from_xarray( + xr.DataArray(np.random.rand(10, 20), dims=("x", "y"), name="id") + ) + padded = tensor.pad({AxisId("x"): 7, AxisId("y"): (3, 3)}) + cropped = padded.crop_to(tensor.sizes) + assert_equal(tensor.data, cropped.data) + + +def test_some_magic_ops(): + tensor = Tensor.from_numpy(np.random.rand(256, 256), dims=None) + assert tensor + 2 == 2 + tensor diff --git a/tests/weight_converter/keras/test_tensorflow.py b/tests/weight_converter/keras/test_tensorflow.py index 712263fa..65c93f60 100644 --- a/tests/weight_converter/keras/test_tensorflow.py +++ b/tests/weight_converter/keras/test_tensorflow.py @@ -1,24 +1,49 @@ +# type: ignore # TODO enable type checking import zipfile +from pathlib import Path +import pytest -def test_tensorflow_converter(any_keras_model, tmp_path): - from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle +from bioimageio.spec import load_description +from bioimageio.spec.model.v0_5 import ModelDescr + + +@pytest.mark.skip( + "tensorflow converter not updated yet" +) # TODO: test tensorflow converter +def test_tensorflow_converter(any_keras_model: Path, tmp_path: Path): + from bioimageio.core.weight_converter.keras import ( + convert_weights_to_tensorflow_saved_model_bundle, + ) out_path = tmp_path / "weights" - ret_val = convert_weights_to_tensorflow_saved_model_bundle(any_keras_model, out_path) + model = load_description(any_keras_model) + assert isinstance(model, ModelDescr), model.validation_summary.format() + ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) assert out_path.exists() assert (out_path / "variables").exists() assert (out_path / "saved_model.pb").exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes -def test_tensorflow_converter_zipped(any_keras_model, tmp_path): - from bioimageio.core.weight_converter.keras import convert_weights_to_tensorflow_saved_model_bundle +@pytest.mark.skip( + "tensorflow converter not updated yet" +) # TODO: test tensorflow converter +def test_tensorflow_converter_zipped(any_keras_model: Path, tmp_path: Path): + from bioimageio.core.weight_converter.keras import ( + convert_weights_to_tensorflow_saved_model_bundle, + ) out_path = tmp_path / "weights.zip" - ret_val = convert_weights_to_tensorflow_saved_model_bundle(any_keras_model, out_path) + model = load_description(any_keras_model) + assert isinstance(model, ModelDescr), model.validation_summary.format() + ret_val = convert_weights_to_tensorflow_saved_model_bundle(model, out_path) assert out_path.exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes # make sure that the zip package was created correctly expected_names = {"saved_model.pb", "variables/variables.index"} diff --git a/tests/weight_converter/test_add_weights.py b/tests/weight_converter/test_add_weights.py new file mode 100644 index 00000000..836353c7 --- /dev/null +++ b/tests/weight_converter/test_add_weights.py @@ -0,0 +1,48 @@ +# TODO: update add weights tests +# import os + + +# def _test_add_weights(model, tmp_path, base_weights, added_weights, **kwargs): +# from bioimageio.core.build_spec import add_weights + +# rdf = load_raw_resource_description(model) +# assert base_weights in rdf.weights +# assert added_weights in rdf.weights + +# weight_path = load_description(model).weights[added_weights].source +# assert weight_path.exists() + +# drop_weights = set(rdf.weights.keys()) - {base_weights} +# for drop in drop_weights: +# rdf.weights.pop(drop) +# assert tuple(rdf.weights.keys()) == (base_weights,) + +# in_path = tmp_path / "model1.zip" +# export_resource_package(rdf, output_path=in_path) + +# out_path = tmp_path / "model2.zip" +# add_weights(in_path, weight_path, weight_type=added_weights, output_path=out_path, **kwargs) + +# assert out_path.exists() +# new_rdf = load_description(out_path) +# assert set(new_rdf.weights.keys()) == {base_weights, added_weights} +# for weight in new_rdf.weights.values(): +# assert weight.source.exists() + +# test_res = _test_model(out_path, added_weights) +# failed = [s for s in test_res if s["status"] != "passed"] +# assert not failed, failed +# test_res = _test_model(out_path) +# failed = [s for s in test_res if s["status"] != "passed"] +# assert not failed, failed + +# # make sure the weights were cleaned from the cwd +# assert not os.path.exists(os.path.split(weight_path)[1]) + + +# def test_add_torchscript(unet2d_nuclei_broad_model, tmp_path): +# _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "torchscript") + + +# def test_add_onnx(unet2d_nuclei_broad_model, tmp_path): +# _test_add_weights(unet2d_nuclei_broad_model, tmp_path, "pytorch_state_dict", "onnx", opset_version=12) diff --git a/tests/weight_converter/torch/test_onnx.py b/tests/weight_converter/torch/test_onnx.py index 5a26c916..54f2cdf4 100644 --- a/tests/weight_converter/torch/test_onnx.py +++ b/tests/weight_converter/torch/test_onnx.py @@ -1,13 +1,18 @@ +# type: ignore # TODO enable type checking import os +from pathlib import Path + import pytest -# todo: test with 'any_torch_model' -def test_onnx_converter(convert_to_onnx, tmp_path): - from bioimageio.core.weight_converter.torch.onnx import convert_weights_to_onnx +@pytest.mark.skip("onnx converter not updated yet") # TODO: test onnx converter +def test_onnx_converter(convert_to_onnx: Path, tmp_path: Path): + from bioimageio.core.weight_converter.torch._onnx import convert_weights_to_onnx out_path = tmp_path / "weights.onnx" ret_val = convert_weights_to_onnx(convert_to_onnx, out_path, test_decimal=3) assert os.path.exists(out_path) if not pytest.skip_onnx: - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes diff --git a/tests/weight_converter/torch/test_torchscript.py b/tests/weight_converter/torch/test_torchscript.py index 5c879577..e0cee3d8 100644 --- a/tests/weight_converter/torch/test_torchscript.py +++ b/tests/weight_converter/torch/test_torchscript.py @@ -1,7 +1,22 @@ -def test_torchscript_converter(any_torch_model, tmp_path): +# type: ignore # TODO enable type checking +from pathlib import Path + +import pytest + +from bioimageio.spec.model import v0_4, v0_5 + + +@pytest.mark.skip( + "torchscript converter not updated yet" +) # TODO: test torchscript converter +def test_torchscript_converter( + any_torch_model: "v0_4.ModelDescr | v0_5.ModelDescr", tmp_path: Path +): from bioimageio.core.weight_converter.torch import convert_weights_to_torchscript out_path = tmp_path / "weights.pt" ret_val = convert_weights_to_torchscript(any_torch_model, out_path) assert out_path.exists() - assert ret_val == 0 # check for correctness is done in converter and returns 0 if it passes + assert ( + ret_val == 0 + ) # check for correctness is done in converter and returns 0 if it passes diff --git a/tox.ini b/tox.ini deleted file mode 100644 index 7fea938f..00000000 --- a/tox.ini +++ /dev/null @@ -1,8 +0,0 @@ -[tox] -envlist = py38 - -[testenv] -deps = - pytest -commands = - pytest