Skip to content

Commit

Permalink
torch is optional dep
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Apr 23, 2024
1 parent 2be9913 commit b84d33b
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
7 changes: 6 additions & 1 deletion bioimageio/core/weight_converter/torch/_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, List, Sequence, cast

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal

from bioimageio.spec import load_description
Expand All @@ -14,6 +13,11 @@
from ...digest_spec import get_member_id, get_test_inputs
from ...weight_converter.torch._utils import load_torch_model

try:
import torch
except ImportError:
torch = None


def add_onnx_weights(
model_spec: "str | Path | v0_4.ModelDescr | v0_5.ModelDescr",
Expand Down Expand Up @@ -48,6 +52,7 @@ def add_onnx_weights(
"The provided model does not have weights in the pytorch state dict format"
)

assert torch is not None
with torch.no_grad():

sample = get_test_inputs(model_spec)
Expand Down
10 changes: 8 additions & 2 deletions bioimageio/core/weight_converter/torch/_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import List, Sequence, Union

import numpy as np
import torch
from numpy.testing import assert_array_almost_equal
from typing_extensions import Any, assert_never

Expand All @@ -12,14 +11,21 @@

from ._utils import load_torch_model

try:
import torch
except ImportError:
torch = None


# FIXME: remove Any
def _check_predictions(
model: Any,
scripted_model: Any,
model_spec: "v0_4.ModelDescr | v0_5.ModelDescr",
input_data: Sequence[torch.Tensor],
input_data: Sequence["torch.Tensor"],
):
assert torch is not None

def _check(input_: Sequence[torch.Tensor]) -> None:
expected_tensors = model(*input_)
if isinstance(expected_tensors, torch.Tensor):
Expand Down
12 changes: 7 additions & 5 deletions bioimageio/core/weight_converter/torch/_utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,24 @@
from typing import Union

import torch

from bioimageio.core.model_adapters._pytorch_model_adapter import PytorchModelAdapter
from bioimageio.spec.model import v0_4, v0_5
from bioimageio.spec.utils import download

try:
import torch
except ImportError:
torch = None


# additional convenience for pytorch state dict, eventually we want this in python-bioimageio too
# and for each weight format
def load_torch_model( # pyright: ignore[reportUnknownParameterType]
node: Union[v0_4.PytorchStateDictWeightsDescr, v0_5.PytorchStateDictWeightsDescr],
):
assert torch is not None
model = ( # pyright: ignore[reportUnknownVariableType]
PytorchModelAdapter.get_network(node)
)
state = torch.load( # pyright: ignore[reportUnknownVariableType]
download(node.source).path, map_location="cpu"
)
state = torch.load(download(node.source).path, map_location="cpu")
model.load_state_dict(state) # FIXME: check incompatible keys?
return model.eval() # pyright: ignore[reportUnknownVariableType]

0 comments on commit b84d33b

Please sign in to comment.