Skip to content

Commit

Permalink
WIP add validation of numpy arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
FynnBe committed Nov 17, 2023
1 parent 698aa59 commit c4d1f79
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 21 deletions.
154 changes: 133 additions & 21 deletions bioimageio/spec/model/v0_5.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
from __future__ import annotations

import collections.abc
from abc import ABC
from typing import Any, ClassVar, Dict, FrozenSet, List, Literal, NewType, Optional, Sequence, Set, Tuple, Union
from typing import (
Any,
ClassVar,
Dict,
FrozenSet,
Generic,
List,
Literal,
Mapping,
NewType,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
)

import numpy as np
from annotated_types import Ge, Gt, Interval, MaxLen, MinLen, Predicate
from numpy.typing import NDArray
from pydantic import (
Field,
StringConstraints,
Expand All @@ -11,11 +31,12 @@
field_validator,
model_validator,
)
from typing_extensions import Annotated, LiteralString, Self
from typing_extensions import Annotated, LiteralString, Self, assert_never

from bioimageio.spec._internal.base_nodes import Node, NodeWithExplicitlySetFields
from bioimageio.spec._internal.constants import DTYPE_LIMITS, INFO
from bioimageio.spec._internal.field_warning import issue_warning, warn
from bioimageio.spec._internal.io_utils import download
from bioimageio.spec._internal.types import AbsoluteFilePath as AbsoluteFilePath
from bioimageio.spec._internal.types import BioimageioYamlContent as BioimageioYamlContent
from bioimageio.spec._internal.types import Datetime as Datetime
Expand All @@ -24,12 +45,11 @@
from bioimageio.spec._internal.types import HttpUrl as HttpUrl
from bioimageio.spec._internal.types import Identifier as Identifier
from bioimageio.spec._internal.types import LicenseId as LicenseId
from bioimageio.spec._internal.types import LowerCaseIdentifierStr
from bioimageio.spec._internal.types import LowerCaseIdentifierStr, SiUnit
from bioimageio.spec._internal.types import NotEmpty as NotEmpty
from bioimageio.spec._internal.types import RelativeFilePath as RelativeFilePath
from bioimageio.spec._internal.types import ResourceId as ResourceId
from bioimageio.spec._internal.types import Sha256 as Sha256
from bioimageio.spec._internal.types import Unit as Unit
from bioimageio.spec._internal.types import Version as Version
from bioimageio.spec._internal.types.field_validation import AfterValidator, WithSuffix
from bioimageio.spec._internal.validation_context import InternalValidationContext, get_internal_validation_context
Expand Down Expand Up @@ -125,6 +145,7 @@ def validate_tensor_axis_id(s: str):
return s


# TODO: make TensorAxisId a Node?
TensorAxisId = Annotated[
str, StringConstraints(min_length=1, max_length=33, pattern=r"^.+\..+$"), AfterValidator(validate_tensor_axis_id)
]
Expand Down Expand Up @@ -159,9 +180,17 @@ class ParametrizedSize(Node):
step_with: Optional[Union[AxisId, TensorAxisId]] = None
"""ID of another axis with parametrixed size to resize jointly,
i.e. `n=n_other` for `size = min + n*step`, `size_other = min_other + n_other*step_other`.
To step with an axis of another tensor, use `step_with = <tensor id>.<axis id>`.
To step jointly with an axis of another tensor, use `step_with = <tensor id>.<axis id>`.
"""

def validate_size(self, size: int) -> int:
if size < self.min:
raise ValueError(f"size {size} < {self.min}")
if size - self.min % self.step != 0:
raise ValueError(f"axis of size {size} is not parametrized by min + n*step = {self.min} + n*{self.step}")

return size


class SizeReference(Node):
"""A tensor axis size defined in relation to another reference tensor axis.
Expand Down Expand Up @@ -369,7 +398,7 @@ def validate_values_match_type(

return self

unit: Optional[Unit] = None
unit: Optional[Union[Literal["arbitrary unit"], SiUnit]] = None

@property
def range(self):
Expand Down Expand Up @@ -397,7 +426,7 @@ class IntervalOrRatioData(Node):
)
"""Tuple `(minimum, maximum)` specifying the allowed range of the data in this tensor.
`None` corresponds to min/max of what can be expressed by `data_type`."""
unit: Optional[Unit] = "arbitrary unit"
unit: Union[Literal["arbitrary unit"], SiUnit] = "arbitrary unit"
scale: float = 1.0
"""Scale for data on an interval (or ratio) scale."""
offset: Optional[float] = None
Expand Down Expand Up @@ -636,18 +665,22 @@ class ScaleMeanVariance(ProcessingBase):
Field(discriminator="id"),
]

AxisVar = TypeVar("AxisVar", InputAxis, OutputAxis)


class TensorBase(Node):
class TensorBase(Node, Generic[AxisVar]):
id: TensorId
"""Tensor id. No duplicates are allowed."""

description: Annotated[str, MaxLen(128)] = ""
"""free text description"""

# axes: List[AnyAxis] # TODO: make abstract
axes: NotEmpty[Sequence[AxisVar]]
"""tensor axes"""

@field_validator("axes", mode="after", check_fields=False)
@classmethod
def validate_axes(cls, axes: List[AnyAxis]) -> List[AnyAxis]:
def validate_axes(cls, axes: Sequence[AnyAxis]) -> Sequence[AnyAxis]:
seen_types: Set[str] = set()
duplicate_axes_types: Set[str] = set()
for a in axes:
Expand Down Expand Up @@ -681,15 +714,27 @@ def validate_axes(cls, axes: List[AnyAxis]) -> List[AnyAxis]:
The sample files primarily serve to inform a human user about an example use case
and are typically stored as HDF5, PNG or TIFF images."""

data: Union[TensorData, NotEmpty[List[TensorData]]] = IntervalOrRatioData()
data: Union[TensorData, NotEmpty[Sequence[TensorData]]] = IntervalOrRatioData()
"""Description of the tensor's data values, optionally per channel.
If specified per channel, the data `type` needs to match across channels."""

@property
def dtype(
self,
) -> Literal[
"float32", "float64", "uint8", "int8", "uint16", "int16", "uint32", "int32", "uint64", "int64", "bool"
]:
"""dtype as specified under `data.type` or `data[i].type`"""
if isinstance(self.data, collections.abc.Sequence):
return self.data[0].type
else:
return self.data.type

@field_validator("data", mode="after")
@classmethod
def check_data_type_across_channels(
cls, value: Union[TensorData, NotEmpty[List[TensorData]]]
) -> Union[TensorData, NotEmpty[List[TensorData]]]:
cls, value: Union[TensorData, NotEmpty[Sequence[TensorData]]]
) -> Union[TensorData, NotEmpty[Sequence[TensorData]]]:
if not isinstance(value, list):
return value

Expand All @@ -702,7 +747,7 @@ def check_data_type_across_channels(
return value

@model_validator(mode="after")
def check_data_matches_channelaxis(self) -> Union[Self, "InputTensor", "OutputTensor"]:
def check_data_matches_channelaxis(self) -> Union[Self, InputTensor, OutputTensor]:
if not isinstance(self.data, list) or not isinstance(self, (InputTensor, OutputTensor)):
return self

Expand All @@ -721,14 +766,63 @@ def check_data_matches_channelaxis(self) -> Union[Self, "InputTensor", "OutputTe

return self


class InputTensor(TensorBase):
def get_axis_sizes(self, array: NDArray[Any]) -> Dict[AxisId, int]:
if len(array.shape) != len(self.axes):
raise ValueError(
f"Dimension mismatch: array shape {array.shape} (#{len(array.shape)}) "
f"incompatible with {len(self.axes)} axes."
)
return {a.id: array.shape[i] for i, a in enumerate(self.axes)}

def validate_array(
self, array: NDArray[Any], *, other_known_tensor_sizes: Optional[Mapping[TensorId, Mapping[AxisId, int]]] = None
) -> NDArray[Any]:
known_tensor_sizes = dict(other_known_tensor_sizes or {})
known_tensor_sizes[self.id] = self.get_axis_sizes(array)

if array.dtype.name != self.dtype:
raise ValueError("tensor with dtype {data.dtype.name} does not match specified dtype {self.dtype}")

shape = list(array.shape)
for i, a in enumerate(self.axes):
if isinstance(a.size, int):
if shape[i] != a.size:
raise ValueError(f"incompatible shape: array.shape[{i}] = {shape[i]} != {a.size}")
elif isinstance(a.size, ParametrizedSize):
_ = a.size.validate_size(shape[i])
# TODO: remove step_with
elif isinstance(a.size, str):
if "." in a.size:
assert a.size.count(".") == 1
_tensor_id, _axis_id = a.size.split(".")
tensor_id = TensorId(_tensor_id)
axis_id = AxisId(_axis_id)
else:
tensor_id = self.id
axis_id = AxisId(a.size)

if tensor_id not in known_tensor_sizes:
raise ValueError(f"tensor sizes of '{tensor_id}' are unknown.")

other_size = known_tensor_sizes[tensor_id].get(axis_id)
if other_size is None:
raise ValueError(f"axis size '{a.size}' is unknown.")

if shape[i] != other_size:
raise ValueError(
f"axis size mismatch: array axis {i} of size " f"{shape[i]} != {other_size} given by {a.size}."
)
else:
assert_never(a.size)

return array


class InputTensor(TensorBase[InputAxis]):
id: TensorId = TensorId("input")
"""Input tensor id.
No duplicates are allowed across all inputs and outputs."""

axes: NotEmpty[List[InputAxis]]

preprocessing: List[Preprocessing] = Field(default_factory=list)
"""Description of how this input should be preprocessed."""

Expand All @@ -746,13 +840,11 @@ def validate_preprocessing_kwargs(self) -> Self:
return self


class OutputTensor(TensorBase):
class OutputTensor(TensorBase[OutputAxis]):
id: TensorId = TensorId("output")
"""Output tensor id.
No duplicates are allowed across all inputs and outputs."""

axes: List[OutputAxis]

postprocessing: List[Postprocessing] = Field(default_factory=list)
"""Description of how this output should be postprocessed."""

Expand Down Expand Up @@ -1092,6 +1184,26 @@ def _validate_axis(
]
axis.channel_names = generated_channel_names

@model_validator(mode="after")
def validate_test_tensors(self, info: ValidationInfo) -> Self:
context = get_internal_validation_context(info.context)
if not context["perform_io_checks"]:
return self

ipt_test_arrays = [np.load(download(ipt.test_tensor).path) for ipt in self.inputs]
known_sizes = {ipt.id: ipt.get_axis_sizes(ta) for ipt, ta in zip(self.inputs, ipt_test_arrays)}

for i, ipt in enumerate(self.inputs):
_ = ipt.validate_array(ipt_test_arrays[i], other_known_tensor_sizes=known_sizes)

out_test_arrays = [np.load(download(out.test_tensor).path) for out in self.outputs]
known_sizes.update({out.id: out.get_axis_sizes(ta) for out, ta in zip(self.outputs, out_test_arrays)})

for i, out in enumerate(self.outputs):
_ = out.validate_array(out_test_arrays[i], other_known_tensor_sizes=known_sizes)

return self

license: Annotated[
Union[LicenseId, DeprecatedLicenseId],
warn(LicenseId, "{value} is deprecated, see https://spdx.org/licenses/{value}.html"),
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
install_requires=[
"annotated-types>=0.5.0",
"email_validator",
"numpy>=1.21",
"packaging>=17.0",
"pooch",
"pydantic[email]>=2.0.1",
Expand Down

0 comments on commit c4d1f79

Please sign in to comment.