Skip to content

Commit

Permalink
Merge pull request #412 from bioimage-io/update_parent_field
Browse files Browse the repository at this point in the history
update parent field
  • Loading branch information
FynnBe authored Mar 2, 2022
2 parents 884d778 + e1ac519 commit 300814e
Show file tree
Hide file tree
Showing 9 changed files with 138 additions and 33 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ bioimageio update-format <MY-MODEL-SOURCE> <OUTPUT-PATH>
### RDF Format Versions
#### model RDF 0.4.5
- Breaking changes that are fully auto-convertible
- `parent` field changed to hold a string that is a BioImage.IO ID, a URL or a local relative path (and not subfields `uri` and `sha256`)
#### model RDF 0.4.4
- Non-breaking changes
- new optional field `training_data`
Expand Down
23 changes: 20 additions & 3 deletions bioimageio/spec/model/v0_4/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def convert_model_from_v0_3_to_0_4_0(data: Dict[str, Any]) -> Dict[str, Any]:
return data


def convert_model_from_v0_4_0_to_0_4_4(data: Dict[str, Any]) -> Dict[str, Any]:
def convert_model_from_v0_4_0_to_0_4_1(data: Dict[str, Any]) -> Dict[str, Any]:
data = dict(data)

# move dependencies from root to pytorch_state_dict weights entry
Expand All @@ -49,7 +49,18 @@ def convert_model_from_v0_4_0_to_0_4_4(data: Dict[str, Any]) -> Dict[str, Any]:
if entry and isinstance(entry, dict):
entry["dependencies"] = deps

data["format_version"] = "0.4.4"
data["format_version"] = "0.4.1"
return data


def convert_model_from_v0_4_4_to_0_4_5(data: Dict[str, Any]) -> Dict[str, Any]:
data = dict(data)

parent = data.pop("parent", None)
if parent and "uri" in parent:
data["parent"] = parent["uri"]

data["format_version"] = "0.4.5"
return data


Expand All @@ -60,7 +71,13 @@ def maybe_convert(data: Dict[str, Any]) -> Dict[str, Any]:
data = convert_model_from_v0_3_to_0_4_0(data)

if data["format_version"] == "0.4.0":
data = convert_model_from_v0_4_0_to_0_4_4(data)
data = convert_model_from_v0_4_0_to_0_4_1(data)

if data["format_version"] in ("0.4.1", "0.4.2", "0.4.3"):
data["format_version"] = "0.4.4"

if data["format_version"] == "0.4.4":
data = convert_model_from_v0_4_4_to_0_4_5(data)

# remove 'future' from config if no other than the used future entries exist
config = data.get("config", {})
Expand Down
13 changes: 10 additions & 3 deletions bioimageio/spec/model/v0_4/raw_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from bioimageio.spec.model.v0_3.raw_nodes import (
InputTensor,
KerasHdf5WeightsEntry as KerasHdf5WeightsEntry03,
ModelParent,
OnnxWeightsEntry as OnnxWeightsEntry03,
OutputTensor,
Postprocessing,
Expand All @@ -30,6 +29,7 @@
ImportableModule,
ImportableSourceFile,
ParametrizedInputShape,
RawNode,
URI,
)

Expand All @@ -48,7 +48,7 @@
PreprocessingName = PreprocessingName

FormatVersion = Literal[
"0.4.0", "0.4.1", "0.4.2", "0.4.3", "0.4.4"
"0.4.0", "0.4.1", "0.4.2", "0.4.3", "0.4.4", "0.4.5"
] # newest format needs to be last (used in __init__.py)
WeightsFormat = Literal[
"pytorch_state_dict", "torchscript", "keras_hdf5", "tensorflow_js", "tensorflow_saved_model_bundle", "onnx"
Expand Down Expand Up @@ -108,10 +108,17 @@ class TorchscriptWeightsEntry(_WeightsEntryBase):


@dataclass
class LinkedDataset:
class LinkedDataset(RawNode):
id: str


@dataclass
class ModelParent(RawNode):
id: Union[_Missing, str] = missing
uri: Union[_Missing, URI, Path] = missing
sha256: Union[_Missing, str] = missing


@dataclass
class Model(_RDF):
_include_in_package = ("covers", "documentation", "test_inputs", "test_outputs", "sample_inputs", "sample_outputs")
Expand Down
18 changes: 14 additions & 4 deletions bioimageio/spec/model/v0_4/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from bioimageio.spec.dataset.v0_2.schema import Dataset as _Dataset
from bioimageio.spec.model.v0_3.schema import (
KerasHdf5WeightsEntry as KerasHdf5WeightsEntry03,
ModelParent,
OnnxWeightsEntry as OnnxWeightsEntry03,
Postprocessing as Postprocessing03,
Preprocessing as Preprocessing03,
Expand Down Expand Up @@ -281,6 +280,19 @@ class LinkedDataset(_BioImageIOSchema):
id = fields.String(bioimageio_description="dataset id")


class ModelParent(_BioImageIOSchema):
id = fields.BioImageIO_ID(resource_type="model")
uri = fields.Union(
[fields.URI(), fields.RelativeLocalPath()], bioimageio_description="URL or local relative path of a model RDF"
)
sha256 = fields.SHA256(bioimageio_description="Hash of the parent model RDF. Note: the hash is not validated")

@validates_schema()
def id_xor_uri(self, data, **kwargs):
if ("id" in data) == ("uri" in data):
raise ValidationError("Either 'id' or 'uri' are required (not both).")


class Model(rdf.schema.RDF):
raw_nodes = raw_nodes

Expand Down Expand Up @@ -481,9 +493,7 @@ def get_min_shape(t) -> numpy.ndarray:

parent = fields.Nested(
ModelParent(),
bioimageio_description="Parent model from which the trained weights of this model have been derived, e.g. by "
"finetuning the weights of this model on a different dataset. For format changes of the same trained model "
"checkpoint, see `weights`.",
bioimageio_description="The model from which this model is derived, e.g. by fine-tuning the weights.",
)

run_mode = fields.Nested(
Expand Down
31 changes: 23 additions & 8 deletions bioimageio/spec/shared/_resolve_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,7 @@ def resolve_rdf_source(
if isinstance(source, str):
# source might be bioimageio id, doi, url or file path -> resolve to pathlib.Path

if BIOIMAGEIO_COLLECTION is None:
bioimageio_rdf_source = None
else:
bioimageio_collection = {
c.get("id", f"missind_id_{i}"): c.get("rdf_source")
for i, c in enumerate(BIOIMAGEIO_COLLECTION.get("collection", []))
}
bioimageio_rdf_source = bioimageio_collection.get(source) or bioimageio_collection.get(source + "/latest")
bioimageio_rdf_source: typing.Optional[str] = (BIOIMAGEIO_COLLECTION_ENTRIES or {}).get(source, (None, None))[1]

if bioimageio_rdf_source is not None:
# source is bioimageio id
Expand Down Expand Up @@ -412,3 +405,25 @@ def _resolve_json_from_url(

BIOIMAGEIO_SITE_CONFIG, BIOIMAGEIO_SITE_CONFIG_ERROR = _resolve_json_from_url(BIOIMAGEIO_SITE_CONFIG_URL)
BIOIMAGEIO_COLLECTION, BIOIMAGEIO_COLLECTION_ERROR = _resolve_json_from_url(BIOIMAGEIO_COLLECTION_URL)
if BIOIMAGEIO_COLLECTION is None:
BIOIMAGEIO_COLLECTION_ENTRIES: typing.Optional[typing.Dict[str, typing.Tuple[str, str]]] = None
else:
BIOIMAGEIO_COLLECTION_ENTRIES = {
cr["id"]: (cr["type"], cr["rdf_source"])
for cr in BIOIMAGEIO_COLLECTION.get("collection", [])
if "id" in cr and "rdf_source" in cr and "type" in cr
}
# add resource versions explicitly
BIOIMAGEIO_COLLECTION_ENTRIES.update(
{
f"{cr['id']}/{cv}": (
cr["type"],
cr["rdf_source"].replace(
f"/{cr['versions'][0]}", f"/{cv}"
), # todo: improve this replace-version-monkeypatch
)
for cr in BIOIMAGEIO_COLLECTION.get("collection", [])
for cv in cr.get("versions", [])
if "id" in cr and "rdf_source" in cr and "type" in cr
}
)
54 changes: 44 additions & 10 deletions bioimageio/spec/shared/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,6 @@ def __init__(
super().__init__(*super_args, **super_kwargs) # type: ignore


#################################################
# fields directly derived from marshmallow fields
#################################################


class Array(DocumentedField, marshmallow_fields.Field):
def __init__(self, inner: marshmallow_fields.Field, **kwargs):
self.inner = inner
Expand Down Expand Up @@ -266,11 +261,6 @@ def _deserialize(self, value, attr=None, data=None, **kwargs):
raise ValidationError(message=messages, field_name=attr) from e


#########################
# more specialized fields
#########################


class Axes(String):
def _deserialize(self, *args, **kwargs) -> str:
axes_str = super()._deserialize(*args, **kwargs)
Expand Down Expand Up @@ -436,6 +426,50 @@ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]:
return super()._serialize(value, attr, obj, **kwargs)


class BioImageIO_ID(String):
def __init__(
self,
*super_args,
bioimageio_description: typing.Union[
str, typing.Callable[[], str]
] = "ID as shown on resource card on bioimage.io",
resource_type: typing.Optional[str] = None,
validate: typing.Optional[
typing.Union[
typing.Callable[[typing.Any], typing.Any], typing.Iterable[typing.Callable[[typing.Any], typing.Any]]
]
] = None,
**super_kwargs,
):
from ._resolve_source import BIOIMAGEIO_COLLECTION_ENTRIES

if validate is None:
validate = []

if isinstance(validate, typing.Iterable):
validate = list(validate)
else:
validate = [validate]

if BIOIMAGEIO_COLLECTION_ENTRIES is not None:
error_msg = "'{input}' is not a valid BioImage.IO ID"
if resource_type is not None:
error_msg += f" of type {resource_type}"

validate.append(
field_validators.OneOf(
{
k
for k, (v_type, _) in BIOIMAGEIO_COLLECTION_ENTRIES.items()
if resource_type is None or resource_type == v_type
},
error=error_msg,
)
)

super().__init__(*super_args, bioimageio_description=bioimageio_description, **super_kwargs)


class ProcMode(String):
all_modes = ("fixed", "per_dataset", "per_sample")
explanations = {
Expand Down
2 changes: 1 addition & 1 deletion example_specs/models/unet2d_nuclei_broad/rdf.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# TODO physical scale of the data
format_version: 0.4.4
format_version: 0.4.5

name: UNet 2D Nuclei Broad
description: A 2d U-Net trained on the nuclei broad dataset.
Expand Down
8 changes: 4 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def unet2d_nuclei_broad_base_path():


def get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -> dict:
if request.param == "v0_4_4":
if request.param == "v0_4_5":
v = ""
else:
v = f"_{request.param}"
Expand All @@ -21,7 +21,7 @@ def get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -> dict:
return yaml.load(path)


@pytest.fixture(params=["v0_3_0", "v0_3_1", "v0_3_2", "v0_3_3", "v0_3_6", "v0_4_0", "v0_4_4"])
@pytest.fixture(params=["v0_3_0", "v0_3_1", "v0_3_2", "v0_3_3", "v0_3_6", "v0_4_0", "v0_4_5"])
def unet2d_nuclei_broad_any(unet2d_nuclei_broad_base_path, request):
yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request)

Expand All @@ -31,12 +31,12 @@ def unet2d_nuclei_broad_before_latest(unet2d_nuclei_broad_base_path, request):
yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request)


@pytest.fixture(params=["v0_4_4"])
@pytest.fixture(params=["v0_4_5"])
def unet2d_nuclei_broad_latest(unet2d_nuclei_broad_base_path, request):
yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request)


@pytest.fixture(params=["v0_3_6", "v0_4_4"])
@pytest.fixture(params=["v0_3_6", "v0_4_5"])
def unet2d_nuclei_broad_any_minor(unet2d_nuclei_broad_base_path, request):
yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request)

Expand Down
18 changes: 18 additions & 0 deletions tests/test_schema_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,21 @@ def test_output_ref_shape_too_small(model_dict):
assert e.value.messages == {
"_schema": ["Minimal shape [128. 256. 9.] of output output_1 is too small for halo [256, 128, 0]."]
}


def test_model_has_parent_with_uri(model_dict):
from bioimageio.spec.model.schema import Model

model_dict["parent"] = dict(uri="https://doi.org/10.5281/zenodo.5744489")

valid_data = Model().load(model_dict)
assert valid_data


def test_model_has_parent_with_id(model_dict):
from bioimageio.spec.model.schema import Model

model_dict["parent"] = dict(id="10.5281/zenodo.5764892")

valid_data = Model().load(model_dict)
assert valid_data

0 comments on commit 300814e

Please sign in to comment.