From b84d33b321b1f4e755aebe900f0bd74ed74eb012 Mon Sep 17 00:00:00 2001 From: fynnbe Date: Wed, 24 Apr 2024 01:16:22 +0200 Subject: [PATCH] torch is optional dep --- bioimageio/core/weight_converter/torch/_onnx.py | 7 ++++++- .../core/weight_converter/torch/_torchscript.py | 10 ++++++++-- bioimageio/core/weight_converter/torch/_utils.py | 12 +++++++----- 3 files changed, 21 insertions(+), 8 deletions(-) diff --git a/bioimageio/core/weight_converter/torch/_onnx.py b/bioimageio/core/weight_converter/torch/_onnx.py index 1e1e68ae..3935e1d1 100644 --- a/bioimageio/core/weight_converter/torch/_onnx.py +++ b/bioimageio/core/weight_converter/torch/_onnx.py @@ -4,7 +4,6 @@ from typing import Any, List, Sequence, cast import numpy as np -import torch from numpy.testing import assert_array_almost_equal from bioimageio.spec import load_description @@ -14,6 +13,11 @@ from ...digest_spec import get_member_id, get_test_inputs from ...weight_converter.torch._utils import load_torch_model +try: + import torch +except ImportError: + torch = None + def add_onnx_weights( model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr", @@ -48,6 +52,7 @@ def add_onnx_weights( "The provided model does not have weights in the pytorch state dict format" ) + assert torch is not None with torch.no_grad(): sample = get_test_inputs(model_spec) diff --git a/bioimageio/core/weight_converter/torch/_torchscript.py b/bioimageio/core/weight_converter/torch/_torchscript.py index 0d226563..5ca16069 100644 --- a/bioimageio/core/weight_converter/torch/_torchscript.py +++ b/bioimageio/core/weight_converter/torch/_torchscript.py @@ -3,7 +3,6 @@ from typing import List, Sequence, Union import numpy as np -import torch from numpy.testing import assert_array_almost_equal from typing_extensions import Any, assert_never @@ -12,14 +11,21 @@ from ._utils import load_torch_model +try: + import torch +except ImportError: + torch = None + # FIXME: remove Any def _check_predictions( model: Any, scripted_model: Any, model_spec: "v0_4.ModelDescr | v0_5.ModelDescr", - input_data: Sequence[torch.Tensor], + input_data: Sequence["torch.Tensor"], ): + assert torch is not None + def _check(input_: Sequence[torch.Tensor]) -> None: expected_tensors = model(*input_) if isinstance(expected_tensors, torch.Tensor): diff --git a/bioimageio/core/weight_converter/torch/_utils.py b/bioimageio/core/weight_converter/torch/_utils.py index d3908f61..01df0747 100644 --- a/bioimageio/core/weight_converter/torch/_utils.py +++ b/bioimageio/core/weight_converter/torch/_utils.py @@ -1,22 +1,24 @@ from typing import Union -import torch - from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter from bioimageio.spec.model import v0_4, v0_5 from bioimageio.spec.utils import download +try: + import torch +except ImportError: + torch = None + # additional convenience for pytorch state dict, eventually we want this in python-bioimageio too # and for each weight format def load_torch_model( # pyright: ignore[reportUnknownParameterType] node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr], ): + assert torch is not None model = ( # pyright: ignore[reportUnknownVariableType] PytorchModelAdapter.get_network(node) ) - state = torch.load( # pyright: ignore[reportUnknownVariableType] - download(node.source).path, map_location="cpu" - ) + state = torch.load(download(node.source).path, map_location="cpu") model.load_state_dict(state) # FIXME: check incompatible keys? return model.eval() # pyright: ignore[reportUnknownVariableType]