diff --git a/plantseg/core/zoo.py b/plantseg/core/zoo.py index 69e3fa50..554c3588 100644 --- a/plantseg/core/zoo.py +++ b/plantseg/core/zoo.py @@ -457,23 +457,26 @@ def _is_plantseg_model(self, collection_entry: dict) -> bool: normalized_tags = ["".join(filter(str.isalnum, tag.lower())) for tag in tags] return 'plantseg' in normalized_tags - def get_bioimageio_zoo_plantseg_model_names(self) -> list[str]: - """Return a list of model names in the BioImage.IO Model Zoo tagged with 'plantseg'.""" + def get_bioimageio_zoo_all_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo.""" if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() - return sorted(model_zoo.models_bioimageio[model_zoo.models_bioimageio["supported"]].index.to_list()) + id_name = self.models_bioimageio[['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) - def get_bioimageio_zoo_all_model_names(self) -> list[str]: - """Return a list of all model names in the BioImage.IO Model Zoo.""" + def get_bioimageio_zoo_plantseg_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo tagged with 'plantseg'.""" if not hasattr(self, 'models_bioimageio'): self.refresh_bioimageio_zoo_urls() - return sorted(model_zoo.models_bioimageio.index.to_list()) + id_name = self.models_bioimageio[self.models_bioimageio["supported"]][['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) - def get_bioimageio_zoo_other_model_names(self) -> list[str]: - """Return a list of model names in the BioImage.IO Model Zoo not tagged with 'plantseg'.""" - return sorted( - list(set(self.get_bioimageio_zoo_all_model_names()) - set(self.get_bioimageio_zoo_plantseg_model_names())) - ) + def get_bioimageio_zoo_other_model_names(self) -> list[tuple[str, str]]: + """Return a list of (model id, model display name) in the BioImage.IO Model Zoo not tagged with 'plantseg'.""" + if not hasattr(self, 'models_bioimageio'): + self.refresh_bioimageio_zoo_urls() + id_name = self.models_bioimageio[~self.models_bioimageio["supported"]][['name_display']] + return sorted([(name, id) for id, name in id_name.itertuples()]) def _flatten_module(self, module: Module) -> list[Module]: """Recursively flatten a PyTorch nn.Module into a list of its elemental layers.""" diff --git a/plantseg/viewer_napari/widgets/prediction.py b/plantseg/viewer_napari/widgets/prediction.py index 28e5164f..311dd5fa 100644 --- a/plantseg/viewer_napari/widgets/prediction.py +++ b/plantseg/viewer_napari/widgets/prediction.py @@ -106,6 +106,7 @@ def to_choices(cls): 'label': 'BioImage.IO model', 'tooltip': 'Select a model from BioImage.IO model zoo.', 'choices': model_zoo.get_bioimageio_zoo_plantseg_model_names(), + 'value': model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1], }, advanced={ 'label': 'Show advanced parameters', @@ -131,7 +132,7 @@ def widget_unet_prediction( mode: UNetPredictionMode = UNetPredictionMode.PLANTSEG, plantseg_filter: bool = True, model_name: Optional[str] = None, - model_id: Optional[str] = None, + model_id: Optional[str] = model_zoo.get_bioimageio_zoo_plantseg_model_names()[0][1], device: str = ALL_DEVICES[0], advanced: bool = False, patch_size: tuple[int, int, int] = (128, 128, 128), @@ -139,14 +140,15 @@ def widget_unet_prediction( single_patch: bool = False, ) -> None: ps_image = PlantSegImage.from_napari_layer(image) - widgets_to_update = [ - widget_dt_ws.image, - widget_agglomeration.image, - widget_split_and_merge_from_scribbles.image, - ] + if mode is UNetPredictionMode.PLANTSEG: suffix = model_name model_id = None + widgets_to_update = [ + widget_dt_ws.image, + widget_agglomeration.image, + widget_split_and_merge_from_scribbles.image, + ] return schedule_task( unet_prediction_task, task_kwargs={ @@ -164,6 +166,10 @@ def widget_unet_prediction( elif mode is UNetPredictionMode.BIOIMAGEIO: suffix = model_id model_name = None + widgets_to_update = [ + # BioImage.IO models may output multi-channel 3D image or even multi-channel scalar in CZYX format. + # So PlantSeg widgets, which all take ZYX or YX, are better not to be updated. + ] return schedule_task( biio_prediction_task, task_kwargs={ @@ -210,17 +216,11 @@ def update_halo(): widget_unet_prediction.patch_size[0].enabled = True widget_unet_prediction.patch_halo[0].enabled = True elif widget_unet_prediction.mode.value is UNetPredictionMode.BIOIMAGEIO: - widget_unet_prediction.patch_halo.value = model_zoo.compute_3D_halo_for_bioimageio_models( - widget_unet_prediction.model_id.value + log( + 'Automatic halo not implemented for BioImage.IO models yet because they are handled by BioImage.IO Core.', + thread='BioImage.IO Core prediction', + level='info', ) - if model_zoo.is_2D_bioimageio_model(widget_unet_prediction.model_id.value): - widget_unet_prediction.patch_size[0].value = 0 - widget_unet_prediction.patch_size[0].enabled = False - widget_unet_prediction.patch_halo[0].enabled = False - else: - widget_unet_prediction.patch_size[0].value = widget_unet_prediction.patch_size[1].value - widget_unet_prediction.patch_size[0].enabled = True - widget_unet_prediction.patch_halo[0].enabled = True else: raise NotImplementedError(f'Automatic halo not implemented for {widget_unet_prediction.mode.value} mode.') @@ -270,7 +270,7 @@ def _on_widget_unet_prediction_plantseg_filter_change(plantseg_filter: bool): else: widget_unet_prediction.model_id.choices = ( model_zoo.get_bioimageio_zoo_plantseg_model_names() - + [Separator] + + [('', Separator)] # `[('', Separator)]` for list[tuple[str, str]], [Separator] for list[str] + model_zoo.get_bioimageio_zoo_other_model_names() )