Skip to content

Commit

Permalink
Merge pull request #575 from bioimage-io/improve_unions
Browse files Browse the repository at this point in the history
Improve unions
  • Loading branch information
FynnBe authored Mar 22, 2024
2 parents 703cd8a + 710be14 commit b622563
Show file tree
Hide file tree
Showing 13 changed files with 171 additions and 36 deletions.
4 changes: 3 additions & 1 deletion bioimageio/spec/_internal/field_warning.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,11 @@ def issue_warning(
value: Any,
severity: WarningSeverity = WARNING,
msg_context: Optional[Dict[str, Any]] = None,
field: Optional[str] = None,
):
msg_context = {"value": value, "severity": severity, **(msg_context or {})}
if severity >= validation_context_var.get().warning_level:
raise PydanticCustomError("warning", msg, msg_context)
elif validation_context_var.get().log_warnings:
logger.log(severity, msg.format(**msg_context))
log_msg = (field + ": " if field else "") + (msg.format(**msg_context))
logger.log(severity, log_msg)
27 changes: 21 additions & 6 deletions bioimageio/spec/_internal/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pydantic import (
AnyUrl,
DirectoryPath,
Field,
FilePath,
GetCoreSchemaHandler,
PlainSerializer,
Expand Down Expand Up @@ -111,6 +112,9 @@ def model_post_init(self, __context: Any) -> None:
if self.root.is_absolute():
raise ValueError(f"{self.root} is an absolute path.")

if self.root.parts and self.root.parts[0] in ("http:", "https:"):
raise ValueError(f"{self.root} looks like an http url.")

self._absolute = ( # pyright: ignore[reportAttributeAccessIssue]
self.get_absolute(validation_context_var.get().root)
)
Expand Down Expand Up @@ -195,6 +199,12 @@ def get_absolute(


class RelativeFilePath(RelativePathBase[Union[AbsoluteFilePath, HttpUrl]], frozen=True):
def model_post_init(self, __context: Any) -> None:
if not self.root.parts: # an empty path can only be a directory
raise ValueError(f"{self.root} is not a valid file path.")

super().model_post_init(__context)

def get_absolute(
self, root: "RootHttpUrl | Path | AnyUrl"
) -> "AbsoluteFilePath | HttpUrl":
Expand Down Expand Up @@ -226,7 +236,10 @@ def get_absolute(
return absolute


FileSource = Union[FilePath, RelativeFilePath, HttpUrl, pydantic.HttpUrl]
FileSource = Annotated[
Union[FilePath, RelativeFilePath, HttpUrl, pydantic.HttpUrl],
Field(union_mode="left_to_right"),
]
PermissiveFileSource = Union[FileSource, str]

V_suffix = TypeVar("V_suffix", bound=FileSource)
Expand Down Expand Up @@ -503,18 +516,20 @@ class HashKwargs(TypedDict):
sha256: NotRequired[Optional[Sha256]]


StrictFileSource = Union[HttpUrl, FilePath, RelativeFilePath]
StrictFileSource = Annotated[
Union[HttpUrl, FilePath, RelativeFilePath], Field(union_mode="left_to_right")
]
_strict_file_source_adapter = TypeAdapter(StrictFileSource)


def interprete_file_source(file_source: PermissiveFileSource) -> StrictFileSource:
if isinstance(file_source, (HttpUrl, Path)):
return file_source

if isinstance(file_source, pydantic.AnyUrl):
file_source = str(file_source)

if isinstance(file_source, str):
return _strict_file_source_adapter.validate_python(file_source)
else:
return file_source
return _strict_file_source_adapter.validate_python(file_source)


def _get_known_hash(hash_kwargs: HashKwargs):
Expand Down
12 changes: 7 additions & 5 deletions bioimageio/spec/_package.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from ._internal.packaging_context import PackagingContext
from ._internal.url import HttpUrl
from ._internal.validation_context import validation_context_var
from ._internal.warning_levels import ERROR
from ._io import load_description
from .model.v0_4 import ModelDescr as ModelDescr04
from .model.v0_4 import WeightsFormat
Expand Down Expand Up @@ -253,11 +254,12 @@ def save_bioimageio_package(
compression=compression,
compression_level=compression_level,
)
if isinstance((exported := load_description(output_path)), InvalidDescr):
raise ValueError(
f"Exported package '{output_path}' is invalid:"
+ f" {exported.validation_summary}"
)
with validation_context_var.get().replace(warning_level=ERROR):
if isinstance((exported := load_description(output_path)), InvalidDescr):
raise ValueError(
f"Exported package '{output_path}' is invalid:"
+ f" {exported.validation_summary}"
)

return output_path

Expand Down
3 changes: 2 additions & 1 deletion bioimageio/spec/collection/v0_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,10 @@ def finalize_entries(self) -> Self:
if entry.rdf_source is not None:
if not context.perform_io_checks:
issue_warning(
"Skipping IO relying validation for collection[{i}]",
"Skipping IO dependent validation (perform_io_checks=False)",
value=entry.rdf_source,
msg_context=dict(i=i),
field=f"collection[{i}]",
)
continue

Expand Down
2 changes: 2 additions & 0 deletions bioimageio/spec/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from ._internal.io import FileDescr as FileDescr
from ._internal.io import Sha256 as Sha256
from ._internal.io import YamlValue as YamlValue
from ._internal.io_basics import AbsoluteDirectory as AbsoluteDirectory
from ._internal.io_basics import AbsoluteFilePath as AbsoluteFilePath
from ._internal.io_basics import FileName as FileName
from ._internal.root_url import RootHttpUrl as RootHttpUrl
from ._internal.types import FileSource as FileSource
Expand Down
36 changes: 35 additions & 1 deletion bioimageio/spec/generic/_v0_3_converter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,15 @@
import collections.abc
import string
from pathlib import Path

from .._internal.io import BioimageioYamlContent
import imageio
from loguru import logger

from .._internal.io import (
BioimageioYamlContent,
extract_file_name,
interprete_file_source,
)
from ._v0_2_converter import convert_from_older_format as convert_from_older_format_v0_2


Expand All @@ -19,6 +27,7 @@ def convert_from_older_format(data: BioimageioYamlContent) -> None:
convert_from_older_format_v0_2(data)

convert_attachments(data)
convert_cover_images(data)

_ = data.pop("download_url", None)
_ = data.pop("rdf_source", None)
Expand All @@ -36,3 +45,28 @@ def convert_attachments(data: BioimageioYamlContent) -> None:
a = data.get("attachments")
if isinstance(a, collections.abc.Mapping):
data["attachments"] = tuple({"source": file} for file in a.get("files", [])) # type: ignore


def convert_cover_images(data: BioimageioYamlContent) -> None:
covers = data.get("covers")
if not isinstance(covers, list):
return

for i in range(len(covers)):
c = covers[i]
if not isinstance(c, str):
continue

src = interprete_file_source(c)
fname = extract_file_name(src)

if not (fname.endswith(".tif") or fname.endswith(".tiff")):
continue

try:
image = imageio.imread(c)
c_path = (Path(".bioimageio_converter_cache") / fname).with_suffix(".png")
imageio.imwrite(c_path, image)
covers[i] = str(c_path.absolute())
except Exception as e:
logger.warning("failed to convert tif cover image: {}", e)
27 changes: 18 additions & 9 deletions bioimageio/spec/generic/v0_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,14 @@ class ResourceId(ValidatedString):
".jpg",
".png",
".svg",
".tif",
".tiff",
)

_WithImageSuffix = WithSuffix(VALID_COVER_IMAGE_EXTENSIONS, case_sensitive=False)
CoverImageSource = Annotated[
Union[HttpUrl, AbsoluteFilePath, RelativeFilePath],
Union[AbsoluteFilePath, RelativeFilePath, HttpUrl],
Field(union_mode="left_to_right"),
_WithImageSuffix,
include_in_package_serializer,
]
Expand Down Expand Up @@ -234,7 +237,7 @@ def accept_author_strings(cls, authors: Union[Any, Sequence[Any]]) -> Any:
authors = [{"name": a} if isinstance(a, str) else a for a in authors]

if not authors:
issue_warning("No author specified.", value=authors)
issue_warning("missing", value=authors, field="authors")

return authors

Expand All @@ -248,7 +251,7 @@ def accept_author_strings(cls, authors: Union[Any, Sequence[Any]]) -> Any:
@classmethod
def _warn_empty_cite(cls, value: Any):
if not value:
issue_warning("No cite entry specified.", value=value)
issue_warning("missing", value=value, field="cite")

return value

Expand Down Expand Up @@ -396,9 +399,7 @@ def _convert_from_older_format(

license: Annotated[
Union[LicenseId, DeprecatedLicenseId, str, None],
Field(
union_mode="left_to_right", examples=["CC0-1.0", "MIT", "BSD-2-Clause"]
),
Field(union_mode="left_to_right", examples=["CC0-1.0", "MIT", "BSD-2-Clause"]),
] = None
"""A [SPDX license identifier](https://spdx.org/licenses/).
We do not support custom license beyond the SPDX license list, if you need that please
Expand All @@ -413,11 +414,19 @@ def deprecated_spdx_license(
if isinstance(value, LicenseId):
pass
elif value is None:
issue_warning("missing license.", value=value)
issue_warning("missing", value=value, field="license")
elif isinstance(value, DeprecatedLicenseId):
issue_warning("'{value}' is a deprecated license identifier.", value=value)
issue_warning(
"'{value}' is a deprecated license identifier.",
value=value,
field="license",
)
elif isinstance(value, str):
issue_warning("'{value}' is an unknown license identifier.", value=value)
issue_warning(
"'{value}' is an unknown license identifier.",
value=value,
field="license",
)
else:
assert_never(value)

Expand Down
10 changes: 9 additions & 1 deletion bioimageio/spec/generic/v0_3.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@
from .._internal.version_type import Version as Version
from .._internal.warning_levels import ALERT, INFO
from ._v0_3_converter import convert_from_older_format
from .v0_2 import VALID_COVER_IMAGE_EXTENSIONS, CoverImageSource
from .v0_2 import Author as _Author_v0_2
from .v0_2 import BadgeDescr as BadgeDescr
from .v0_2 import CoverImageSource
from .v0_2 import Doi as Doi
from .v0_2 import Maintainer as _Maintainer_v0_2
from .v0_2 import OrcidId as OrcidId
Expand All @@ -72,6 +72,13 @@
"model",
"notebook",
)
VALID_COVER_IMAGE_EXTENSIONS = (
".gif",
".jpeg",
".jpg",
".png",
".svg",
)


class ResourceId(ValidatedString):
Expand All @@ -92,6 +99,7 @@ def _validate_md_suffix(value: V_suffix) -> V_suffix:

DocumentationSource = Annotated[
Union[AbsoluteFilePath, RelativeFilePath, HttpUrl],
Field(union_mode="left_to_right"),
AfterValidator(_validate_md_suffix),
include_in_package_serializer,
]
Expand Down
25 changes: 17 additions & 8 deletions bioimageio/spec/model/v0_4.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def _get_data(cls, valid_string_data: str):
class CallableFromFile(StringNode):
_pattern = r"^.+:.+$"
source_file: Annotated[
Union[HttpUrl, RelativeFilePath],
Union[RelativeFilePath, HttpUrl],
Field(union_mode="left_to_right"),
include_in_package_serializer,
]
"""∈📦 Python module that implements `callable_name`"""
Expand All @@ -144,7 +145,9 @@ def _get_data(cls, valid_string_data: str):
return dict(source_file=":".join(file_parts), callable_name=callname)


CustomCallable = Union[CallableFromFile, CallableFromDepencency]
CustomCallable = Annotated[
Union[CallableFromFile, CallableFromDepencency], Field(union_mode="left_to_right")
]


class Dependencies(StringNode):
Expand Down Expand Up @@ -268,10 +271,11 @@ class KerasHdf5WeightsDescr(WeightsEntryDescrBase):
def _tfv(cls, value: Any):
if value is None:
issue_warning(
"Missing TensorFlow version. Please specify the TensorFlow version"
"missing. Please specify the TensorFlow version"
+ " these weights were created with.",
value=value,
severity=ALERT,
field="tensorflow_version",
)
return value

Expand All @@ -292,6 +296,7 @@ def _ov(cls, value: Any):
+ " with.",
value=value,
severity=ALERT,
field="opset_version",
)
return value

Expand Down Expand Up @@ -348,10 +353,11 @@ def check_architecture_sha256(self) -> Self:
def _ptv(cls, value: Any):
if value is None:
issue_warning(
"Missing PyTorch version. Please specify the PyTorch version these"
"missing. Please specify the PyTorch version these"
+ " PyTorch state dict weights were created with.",
value=value,
severity=ALERT,
field="pytorch_version",
)
return value

Expand All @@ -367,10 +373,11 @@ class TorchscriptWeightsDescr(WeightsEntryDescrBase):
def _ptv(cls, value: Any):
if value is None:
issue_warning(
"Missing PyTorch version. Please specify the PyTorch version these"
"missing. Please specify the PyTorch version these"
+ " Torchscript weights were created with.",
value=value,
severity=ALERT,
field="pytorch_version",
)
return value

Expand All @@ -386,10 +393,11 @@ class TensorflowJsWeightsDescr(WeightsEntryDescrBase):
def _tfv(cls, value: Any):
if value is None:
issue_warning(
"Missing TensorFlow version. Please specify the TensorFlow version"
"missing. Please specify the TensorFlow version"
+ " these TensorflowJs weights were created with.",
value=value,
severity=ALERT,
field="tensorflow_version",
)
return value

Expand All @@ -409,10 +417,11 @@ class TensorflowSavedModelBundleWeightsDescr(WeightsEntryDescrBase):
def _tfv(cls, value: Any):
if value is None:
issue_warning(
"Missing TensorFlow version. Please specify the TensorFlow version"
"missing. Please specify the TensorFlow version"
+ " these Tensorflow saved model bundle weights were created with.",
value=value,
severity=ALERT,
field="tensorflow_version",
)
return value

Expand Down Expand Up @@ -444,7 +453,7 @@ class ImplicitOutputShape(Node):
reference_tensor: TensorName
"""Name of the reference tensor."""

scale: NotEmpty[List[Union[float, None]]]
scale: NotEmpty[List[Optional[float]]]
"""output_pix/input_pix for each dimension.
'null' values indicate new dimensions, whose length is defined by 2*`offset`"""

Expand Down
Loading

0 comments on commit b622563

Please sign in to comment.