diff --git a/README.md b/README.md index b3cbe209a..919bd82a1 100644 --- a/README.md +++ b/README.md @@ -156,6 +156,10 @@ bioimageio update-format ### RDF Format Versions +#### model RDF 0.4.5 +- Breaking changes that are fully auto-convertible + - `parent` field changed to hold a string that is a BioImage.IO ID, a URL or a local relative path (and not subfields `uri` and `sha256`) + #### model RDF 0.4.4 - Non-breaking changes - new optional field `training_data` diff --git a/bioimageio/spec/model/v0_4/converters.py b/bioimageio/spec/model/v0_4/converters.py index 0ac85dfb7..afa373ecb 100644 --- a/bioimageio/spec/model/v0_4/converters.py +++ b/bioimageio/spec/model/v0_4/converters.py @@ -38,7 +38,7 @@ def convert_model_from_v0_3_to_0_4_0(data: Dict[str, Any]) -> Dict[str, Any]: return data -def convert_model_from_v0_4_0_to_0_4_4(data: Dict[str, Any]) -> Dict[str, Any]: +def convert_model_from_v0_4_0_to_0_4_1(data: Dict[str, Any]) -> Dict[str, Any]: data = dict(data) # move dependencies from root to pytorch_state_dict weights entry @@ -49,7 +49,18 @@ def convert_model_from_v0_4_0_to_0_4_4(data: Dict[str, Any]) -> Dict[str, Any]: if entry and isinstance(entry, dict): entry["dependencies"] = deps - data["format_version"] = "0.4.4" + data["format_version"] = "0.4.1" + return data + + +def convert_model_from_v0_4_4_to_0_4_5(data: Dict[str, Any]) -> Dict[str, Any]: + data = dict(data) + + parent = data.pop("parent", None) + if parent and "uri" in parent: + data["parent"] = parent["uri"] + + data["format_version"] = "0.4.5" return data @@ -60,7 +71,13 @@ def maybe_convert(data: Dict[str, Any]) -> Dict[str, Any]: data = convert_model_from_v0_3_to_0_4_0(data) if data["format_version"] == "0.4.0": - data = convert_model_from_v0_4_0_to_0_4_4(data) + data = convert_model_from_v0_4_0_to_0_4_1(data) + + if data["format_version"] in ("0.4.1", "0.4.2", "0.4.3"): + data["format_version"] = "0.4.4" + + if data["format_version"] == "0.4.4": + data = convert_model_from_v0_4_4_to_0_4_5(data) # remove 'future' from config if no other than the used future entries exist config = data.get("config", {}) diff --git a/bioimageio/spec/model/v0_4/raw_nodes.py b/bioimageio/spec/model/v0_4/raw_nodes.py index 06ae14887..a5634a802 100644 --- a/bioimageio/spec/model/v0_4/raw_nodes.py +++ b/bioimageio/spec/model/v0_4/raw_nodes.py @@ -11,7 +11,6 @@ from bioimageio.spec.model.v0_3.raw_nodes import ( InputTensor, KerasHdf5WeightsEntry as KerasHdf5WeightsEntry03, - ModelParent, OnnxWeightsEntry as OnnxWeightsEntry03, OutputTensor, Postprocessing, @@ -30,6 +29,7 @@ ImportableModule, ImportableSourceFile, ParametrizedInputShape, + RawNode, URI, ) @@ -48,7 +48,7 @@ PreprocessingName = PreprocessingName FormatVersion = Literal[ - "0.4.0", "0.4.1", "0.4.2", "0.4.3", "0.4.4" + "0.4.0", "0.4.1", "0.4.2", "0.4.3", "0.4.4", "0.4.5" ] # newest format needs to be last (used in __init__.py) WeightsFormat = Literal[ "pytorch_state_dict", "torchscript", "keras_hdf5", "tensorflow_js", "tensorflow_saved_model_bundle", "onnx" @@ -108,10 +108,17 @@ class TorchscriptWeightsEntry(_WeightsEntryBase): @dataclass -class LinkedDataset: +class LinkedDataset(RawNode): id: str +@dataclass +class ModelParent(RawNode): + id: Union[_Missing, str] = missing + uri: Union[_Missing, URI, Path] = missing + sha256: Union[_Missing, str] = missing + + @dataclass class Model(_RDF): _include_in_package = ("covers", "documentation", "test_inputs", "test_outputs", "sample_inputs", "sample_outputs") diff --git a/bioimageio/spec/model/v0_4/schema.py b/bioimageio/spec/model/v0_4/schema.py index e4c23c661..9c09a9aab 100644 --- a/bioimageio/spec/model/v0_4/schema.py +++ b/bioimageio/spec/model/v0_4/schema.py @@ -7,7 +7,6 @@ from bioimageio.spec.dataset.v0_2.schema import Dataset as _Dataset from bioimageio.spec.model.v0_3.schema import ( KerasHdf5WeightsEntry as KerasHdf5WeightsEntry03, - ModelParent, OnnxWeightsEntry as OnnxWeightsEntry03, Postprocessing as Postprocessing03, Preprocessing as Preprocessing03, @@ -281,6 +280,19 @@ class LinkedDataset(_BioImageIOSchema): id = fields.String(bioimageio_description="dataset id") +class ModelParent(_BioImageIOSchema): + id = fields.BioImageIO_ID(resource_type="model") + uri = fields.Union( + [fields.URI(), fields.RelativeLocalPath()], bioimageio_description="URL or local relative path of a model RDF" + ) + sha256 = fields.SHA256(bioimageio_description="Hash of the parent model RDF. Note: the hash is not validated") + + @validates_schema() + def id_xor_uri(self, data, **kwargs): + if ("id" in data) == ("uri" in data): + raise ValidationError("Either 'id' or 'uri' are required (not both).") + + class Model(rdf.schema.RDF): raw_nodes = raw_nodes @@ -481,9 +493,7 @@ def get_min_shape(t) -> numpy.ndarray: parent = fields.Nested( ModelParent(), - bioimageio_description="Parent model from which the trained weights of this model have been derived, e.g. by " - "finetuning the weights of this model on a different dataset. For format changes of the same trained model " - "checkpoint, see `weights`.", + bioimageio_description="The model from which this model is derived, e.g. by fine-tuning the weights.", ) run_mode = fields.Nested( diff --git a/bioimageio/spec/shared/_resolve_source.py b/bioimageio/spec/shared/_resolve_source.py index e785885d2..925570934 100644 --- a/bioimageio/spec/shared/_resolve_source.py +++ b/bioimageio/spec/shared/_resolve_source.py @@ -67,14 +67,7 @@ def resolve_rdf_source( if isinstance(source, str): # source might be bioimageio id, doi, url or file path -> resolve to pathlib.Path - if BIOIMAGEIO_COLLECTION is None: - bioimageio_rdf_source = None - else: - bioimageio_collection = { - c.get("id", f"missind_id_{i}"): c.get("rdf_source") - for i, c in enumerate(BIOIMAGEIO_COLLECTION.get("collection", [])) - } - bioimageio_rdf_source = bioimageio_collection.get(source) or bioimageio_collection.get(source + "/latest") + bioimageio_rdf_source: typing.Optional[str] = (BIOIMAGEIO_COLLECTION_ENTRIES or {}).get(source, (None, None))[1] if bioimageio_rdf_source is not None: # source is bioimageio id @@ -412,3 +405,25 @@ def _resolve_json_from_url( BIOIMAGEIO_SITE_CONFIG, BIOIMAGEIO_SITE_CONFIG_ERROR = _resolve_json_from_url(BIOIMAGEIO_SITE_CONFIG_URL) BIOIMAGEIO_COLLECTION, BIOIMAGEIO_COLLECTION_ERROR = _resolve_json_from_url(BIOIMAGEIO_COLLECTION_URL) +if BIOIMAGEIO_COLLECTION is None: + BIOIMAGEIO_COLLECTION_ENTRIES: typing.Optional[typing.Dict[str, typing.Tuple[str, str]]] = None +else: + BIOIMAGEIO_COLLECTION_ENTRIES = { + cr["id"]: (cr["type"], cr["rdf_source"]) + for cr in BIOIMAGEIO_COLLECTION.get("collection", []) + if "id" in cr and "rdf_source" in cr and "type" in cr + } + # add resource versions explicitly + BIOIMAGEIO_COLLECTION_ENTRIES.update( + { + f"{cr['id']}/{cv}": ( + cr["type"], + cr["rdf_source"].replace( + f"/{cr['versions'][0]}", f"/{cv}" + ), # todo: improve this replace-version-monkeypatch + ) + for cr in BIOIMAGEIO_COLLECTION.get("collection", []) + for cv in cr.get("versions", []) + if "id" in cr and "rdf_source" in cr and "type" in cr + } + ) diff --git a/bioimageio/spec/shared/fields.py b/bioimageio/spec/shared/fields.py index 4c1be48c1..346cbca63 100644 --- a/bioimageio/spec/shared/fields.py +++ b/bioimageio/spec/shared/fields.py @@ -48,11 +48,6 @@ def __init__( super().__init__(*super_args, **super_kwargs) # type: ignore -################################################# -# fields directly derived from marshmallow fields -################################################# - - class Array(DocumentedField, marshmallow_fields.Field): def __init__(self, inner: marshmallow_fields.Field, **kwargs): self.inner = inner @@ -266,11 +261,6 @@ def _deserialize(self, value, attr=None, data=None, **kwargs): raise ValidationError(message=messages, field_name=attr) from e -######################### -# more specialized fields -######################### - - class Axes(String): def _deserialize(self, *args, **kwargs) -> str: axes_str = super()._deserialize(*args, **kwargs) @@ -436,6 +426,50 @@ def _serialize(self, value, attr, obj, **kwargs) -> typing.Optional[str]: return super()._serialize(value, attr, obj, **kwargs) +class BioImageIO_ID(String): + def __init__( + self, + *super_args, + bioimageio_description: typing.Union[ + str, typing.Callable[[], str] + ] = "ID as shown on resource card on bioimage.io", + resource_type: typing.Optional[str] = None, + validate: typing.Optional[ + typing.Union[ + typing.Callable[[typing.Any], typing.Any], typing.Iterable[typing.Callable[[typing.Any], typing.Any]] + ] + ] = None, + **super_kwargs, + ): + from ._resolve_source import BIOIMAGEIO_COLLECTION_ENTRIES + + if validate is None: + validate = [] + + if isinstance(validate, typing.Iterable): + validate = list(validate) + else: + validate = [validate] + + if BIOIMAGEIO_COLLECTION_ENTRIES is not None: + error_msg = "'{input}' is not a valid BioImage.IO ID" + if resource_type is not None: + error_msg += f" of type {resource_type}" + + validate.append( + field_validators.OneOf( + { + k + for k, (v_type, _) in BIOIMAGEIO_COLLECTION_ENTRIES.items() + if resource_type is None or resource_type == v_type + }, + error=error_msg, + ) + ) + + super().__init__(*super_args, bioimageio_description=bioimageio_description, **super_kwargs) + + class ProcMode(String): all_modes = ("fixed", "per_dataset", "per_sample") explanations = { diff --git a/example_specs/models/unet2d_nuclei_broad/rdf.yaml b/example_specs/models/unet2d_nuclei_broad/rdf.yaml index f25a0bfaa..c2e0c50f0 100644 --- a/example_specs/models/unet2d_nuclei_broad/rdf.yaml +++ b/example_specs/models/unet2d_nuclei_broad/rdf.yaml @@ -1,5 +1,5 @@ # TODO physical scale of the data -format_version: 0.4.4 +format_version: 0.4.5 name: UNet 2D Nuclei Broad description: A 2d U-Net trained on the nuclei broad dataset. diff --git a/tests/conftest.py b/tests/conftest.py index 7f643c785..177a3dfc8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,7 +11,7 @@ def unet2d_nuclei_broad_base_path(): def get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -> dict: - if request.param == "v0_4_4": + if request.param == "v0_4_5": v = "" else: v = f"_{request.param}" @@ -21,7 +21,7 @@ def get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -> dict: return yaml.load(path) -@pytest.fixture(params=["v0_3_0", "v0_3_1", "v0_3_2", "v0_3_3", "v0_3_6", "v0_4_0", "v0_4_4"]) +@pytest.fixture(params=["v0_3_0", "v0_3_1", "v0_3_2", "v0_3_3", "v0_3_6", "v0_4_0", "v0_4_5"]) def unet2d_nuclei_broad_any(unet2d_nuclei_broad_base_path, request): yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) @@ -31,12 +31,12 @@ def unet2d_nuclei_broad_before_latest(unet2d_nuclei_broad_base_path, request): yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -@pytest.fixture(params=["v0_4_4"]) +@pytest.fixture(params=["v0_4_5"]) def unet2d_nuclei_broad_latest(unet2d_nuclei_broad_base_path, request): yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) -@pytest.fixture(params=["v0_3_6", "v0_4_4"]) +@pytest.fixture(params=["v0_3_6", "v0_4_5"]) def unet2d_nuclei_broad_any_minor(unet2d_nuclei_broad_base_path, request): yield get_unet2d_nuclei_broad(unet2d_nuclei_broad_base_path, request) diff --git a/tests/test_schema_model.py b/tests/test_schema_model.py index 92a68b2e4..da76958e9 100644 --- a/tests/test_schema_model.py +++ b/tests/test_schema_model.py @@ -174,3 +174,21 @@ def test_output_ref_shape_too_small(model_dict): assert e.value.messages == { "_schema": ["Minimal shape [128. 256. 9.] of output output_1 is too small for halo [256, 128, 0]."] } + + +def test_model_has_parent_with_uri(model_dict): + from bioimageio.spec.model.schema import Model + + model_dict["parent"] = dict(uri="https://doi.org/10.5281/zenodo.5744489") + + valid_data = Model().load(model_dict) + assert valid_data + + +def test_model_has_parent_with_id(model_dict): + from bioimageio.spec.model.schema import Model + + model_dict["parent"] = dict(id="10.5281/zenodo.5764892") + + valid_data = Model().load(model_dict) + assert valid_data