diff --git a/environment-dev.yaml b/environment-dev.yaml new file mode 100644 index 00000000..c3d45c33 --- /dev/null +++ b/environment-dev.yaml @@ -0,0 +1,22 @@ +name: plant-seg-dev +channels: + - pytorch + - nvidia + - conda-forge +dependencies: + - python + - h5py + - zarr + - requests + - pyyaml + - scikit-image + - tifffile + - vigra + - cudnn + - pytorch + - pytorch-cuda=12.1 + - python-elf + - pyqt + - napari + - python-graphviz +# then `pip install -e .` diff --git a/examples/config_advanced.yaml b/examples/config_advanced.yaml index 4b684d2f..1356d040 100644 --- a/examples/config_advanced.yaml +++ b/examples/config_advanced.yaml @@ -36,7 +36,7 @@ cnn_prediction: # channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True channel: Null # Trained model name, more info on available models and custom models in the README - model_name: 'RIKEN20230824v160' + model_name: 'generic_plant_nuclei_3D' # If a CUDA capable gpu is available and corrected setup use "cuda", if not you can use "cpu" for cpu only inference (slower) device: 'cuda' # how many subprocesses to use for data loading diff --git a/plantseg/dataprocessing/functional/dataprocessing.py b/plantseg/dataprocessing/functional/dataprocessing.py index ce2b5ba5..a08b4d8f 100644 --- a/plantseg/dataprocessing/functional/dataprocessing.py +++ b/plantseg/dataprocessing/functional/dataprocessing.py @@ -77,7 +77,15 @@ def image_crop(image: np.array, crop_str: str) -> np.array: return image[slices] -def fix_input_shape(data: np.array) -> np.array: +def fix_input_shape(data: np.array, ndim=3) -> np.array: + assert ndim in [3, 4] + if ndim == 3: + return fix_input_shape_to_3D(data) + else: + return fix_input_shape_to_4D(data) + + +def fix_input_shape_to_3D(data: np.array) -> np.array: """ fix array ndim to be always 3 """ @@ -94,6 +102,25 @@ def fix_input_shape(data: np.array) -> np.array: raise RuntimeError(f"Expected input data to be 2d, 3d or 4d, but got {data.ndim}d input") +def fix_input_shape_to_4D(data: np.array) -> np.array: + """ + Fix array ndim to be 4 and return it in (C x Z x Y x X) e.g. 2 x 1 x 512 x 512 + + Only used for multi-channel network output, e.g. 2-channel 3D pmaps. + """ + if data.ndim == 4: + return data + + elif data.ndim == 3: + return data.reshape(data.shape[0], 1, data.shape[1], data.shape[2]) + + # elif data.ndim == 2: + # return data.reshape(1, 1, data.shape[0], data.shape[1]) + + else: + raise RuntimeError(f"Expected input data to be 3d or 4d, but got {data.ndim}d input") + + def normalize_01(data: np.array) -> np.array: """ normalize a numpy array between 0 and 1 diff --git a/plantseg/legacy_gui/gui_widgets.py b/plantseg/legacy_gui/gui_widgets.py index b59e22da..bfe4d40a 100644 --- a/plantseg/legacy_gui/gui_widgets.py +++ b/plantseg/legacy_gui/gui_widgets.py @@ -9,8 +9,8 @@ class ModuleFramePrototype: """ Prototype for the main keys field. - Every process is in the pipeline is represented by a single instance of it. - """ + Every process is in the pipeline is represented by a single instance of it. + """ def __init__(self, frame, module_name="processing", font=None): self.frame = frame @@ -24,13 +24,19 @@ def __init__(self, frame, module_name="processing", font=None): self.place_module(module_name=module_name) def place_module(self, module_name): - self.checkbox = tkinter.Checkbutton(self.frame, bg=convert_rgb((208, 240, 192)), - text=module_name, font=self.font) - self.checkbox.grid(column=0, - row=0, - padx=self.style["padx"], - pady=self.style["pady"], - sticky=stick_all) + self.checkbox = tkinter.Checkbutton( + self.frame, + bg=convert_rgb((208, 240, 192)), + text=module_name, + font=self.font, + ) + self.checkbox.grid( + column=0, + row=0, + padx=self.style["padx"], + pady=self.style["pady"], + sticky=stick_all, + ) def _show_options(self, config, module): if self.show.get(): @@ -103,25 +109,26 @@ def show_options(self): class PreprocessingFrame(ModuleFramePrototype): def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True): self.preprocessing_frame = tkinter.Frame(frame) - self.preprocessing_style = {"bg": "white", - "padx": 10, - "pady": 10, - "row_weights": [2, 1, 1, 1, 1], - "columns_weights": [1], - "height": 4, - } + self.preprocessing_style = { + "bg": "white", + "padx": 10, + "pady": 10, + "row_weights": [2, 1, 1, 1, 1], + "columns_weights": [1], + "height": 4, + } self.preprocessing_frame["bg"] = self.preprocessing_style["bg"] - self.preprocessing_frame.grid(column=col, - row=0, - padx=self.preprocessing_style["padx"], - pady=self.preprocessing_style["pady"], - sticky=stick_new) + self.preprocessing_frame.grid( + column=col, + row=0, + padx=self.preprocessing_style["padx"], + pady=self.preprocessing_style["pady"], + sticky=stick_new, + ) - [tkinter.Grid.rowconfigure(self.preprocessing_frame, i, weight=w) - for i, w in enumerate(self.preprocessing_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.preprocessing_frame, i, weight=w) - for i, w in enumerate(self.preprocessing_style["columns_weights"])] + [tkinter.Grid.rowconfigure(self.preprocessing_frame, i, weight=w) for i, w in enumerate(self.preprocessing_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.preprocessing_frame, i, weight=w) for i, w in enumerate(self.preprocessing_style["columns_weights"])] super().__init__(self.preprocessing_frame, module_name, font=font) self.module = "preprocessing" @@ -137,36 +144,63 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, self.checkbox["command"] = self.show_options self.obj_collection = [] - self.custom_key = {"save_directory": SimpleEntry(self.preprocessing_frame, - text="Save Directory: ", - row=1, - column=0, - _type=str, - _font=font), - "factor": RescaleEntry(self.preprocessing_frame, - text="Rescaling (z,x,y):", - row=2, - column=0, - font=font), - "order": MenuEntry(self.preprocessing_frame, - text="Interpolation: ", - row=3, - column=0, - menu=[0, 1, 2], - default=2, - font=font), - "filter": FilterEntry(self.preprocessing_frame, - text="Filter (Optional): ", - row=4, - column=0, - font=font), - "crop_volume": SimpleEntry(self.preprocessing_frame, - text="Crop Volume: ", - row=5, - column=0, - _type=str, - _font=font), - } + self.custom_key = { + "key": SimpleEntry( + self.preprocessing_frame, + text="Key (HDF5 only): ", + row=1, + column=0, + _type=str, + _font=font, + ), + "channel": SimpleEntry( + self.preprocessing_frame, + text="Channel (HDF5 only): ", + row=2, + column=0, + _type=int, + _font=font, + ), + "save_directory": SimpleEntry( + self.preprocessing_frame, + text="Save Directory: ", + row=3, + column=0, + _type=str, + _font=font, + ), + "factor": RescaleEntry( + self.preprocessing_frame, + text="Rescaling (z,x,y):", + row=4, + column=0, + font=font, + ), + "order": MenuEntry( + self.preprocessing_frame, + text="Interpolation: ", + row=5, + column=0, + menu=[0, 1, 2], + default=2, + font=font, + ), + "filter": FilterEntry( + self.preprocessing_frame, + text="Filter (Optional): ", + row=6, + column=0, + font=font, + ), + "crop_volume": SimpleEntry( + self.preprocessing_frame, + text="Crop Volume: ", + row=7, + column=0, + _type=str, + _font=font, + ), + } self.show_options() @@ -174,25 +208,26 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, class UnetPredictionFrame(ModuleFramePrototype): def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, show_all=True): self.prediction_frame = tkinter.Frame(frame) - self.prediction_style = {"bg": "white", - "padx": 10, - "pady": 10, - "row_weights": [2, 1, 1, 1, 1], - "columns_weights": [1], - "height": 4, - } + self.prediction_style = { + "bg": "white", + "padx": 10, + "pady": 10, + "row_weights": [2, 1, 1, 1, 1], + "columns_weights": [1], + "height": 4, + } self.prediction_frame["bg"] = self.prediction_style["bg"] - self.prediction_frame.grid(column=col, - row=0, - padx=self.prediction_style["padx"], - pady=self.prediction_style["pady"], - sticky=stick_new) + self.prediction_frame.grid( + column=col, + row=0, + padx=self.prediction_style["padx"], + pady=self.prediction_style["pady"], + sticky=stick_new, + ) - [tkinter.Grid.rowconfigure(self.prediction_frame, i, weight=w) - for i, w in enumerate(self.prediction_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.prediction_frame, i, weight=w) - for i, w in enumerate(self.prediction_style["columns_weights"])] + [tkinter.Grid.rowconfigure(self.prediction_frame, i, weight=w) for i, w in enumerate(self.prediction_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.prediction_frame, i, weight=w) for i, w in enumerate(self.prediction_style["columns_weights"])] super().__init__(self.prediction_frame, module_name, font=font) self.module = "cnn_prediction" @@ -208,28 +243,51 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, self.checkbox["command"] = self.show_options self.obj_collection = [] - self.custom_key = {"model_name": MenuEntry(self.prediction_frame, - text="Model Name: ", - row=1, - column=0, - menu=list_models(), - default=config[self.module]["model_name"], - is_model=True, - font=font), - "patch": ListEntry(self.prediction_frame, - text="Patch Size: ", - row=2, - column=0, - type=int, - font=font), - "device": MenuEntry(self.prediction_frame, - text="Device Type: ", - row=4, - column=0, - menu=["cuda", "cpu"], - default=config[self.module]["device"], - font=font), - } + self.custom_key = { + "key": SimpleEntry( + self.prediction_frame, + text="Key (HDF5 only): ", + row=1, + column=0, + _type=str, + _font=font, + ), + "channel": SimpleEntry( + self.prediction_frame, + text="Channel (HDF5 only): ", + row=2, + column=0, + _type=int, + _font=font, + ), + "model_name": MenuEntry( + self.prediction_frame, + text="Model Name: ", + row=3, + column=0, + menu=list_models(), + default=config[self.module]["model_name"], + is_model=True, + font=font, + ), + "patch": ListEntry( + self.prediction_frame, + text="Patch Size: ", + row=4, + column=0, + type=int, + font=font, + ), + "device": MenuEntry( + self.prediction_frame, + text="Device Type: ", + row=5, + column=0, + menu=["cuda", "cpu"], + default=config[self.module]["device"], + font=font, + ), + } self.show_options() @@ -237,25 +295,26 @@ def __init__(self, frame, config, col=0, module_name="preprocessing", font=None, class SegmentationFrame(ModuleFramePrototype): def __init__(self, frame, config, col=0, module_name="segmentation", font=None, show_all=True): self.segmentation_frame = tkinter.Frame(frame) - self.segmentation_style = {"bg": "white", - "padx": 10, - "pady": 10, - "row_weights": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1], - "columns_weights": [1], - "height": 4, - } + self.segmentation_style = { + "bg": "white", + "padx": 10, + "pady": 10, + "row_weights": [2, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "columns_weights": [1], + "height": 4, + } self.segmentation_frame["bg"] = self.segmentation_style["bg"] - self.segmentation_frame.grid(column=col, - row=0, - padx=self.segmentation_style["padx"], - pady=self.segmentation_style["pady"], - sticky=stick_new) + self.segmentation_frame.grid( + column=col, + row=0, + padx=self.segmentation_style["padx"], + pady=self.segmentation_style["pady"], + sticky=stick_new, + ) - [tkinter.Grid.rowconfigure(self.segmentation_frame, i, weight=w) - for i, w in enumerate(self.segmentation_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.segmentation_frame, i, weight=w) - for i, w in enumerate(self.segmentation_style["columns_weights"])] + [tkinter.Grid.rowconfigure(self.segmentation_frame, i, weight=w) for i, w in enumerate(self.segmentation_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.segmentation_frame, i, weight=w) for i, w in enumerate(self.segmentation_style["columns_weights"])] super().__init__(self.segmentation_frame, module_name, font) self.module = "segmentation" @@ -271,70 +330,136 @@ def __init__(self, frame, config, col=0, module_name="segmentation", font=None, self.checkbox["command"] = self.show_options self.obj_collection = [] - self.custom_key = {"name": MenuEntry(self.segmentation_frame, - text="Algorithm: ", - row=1, - column=0, - menu={"MultiCut", "GASP", "MutexWS", "DtWatershed", "SimpleITK"}, - is_segmentation=True, - default=config[self.module]["name"], - font=font), - "save_directory": SimpleEntry(self.segmentation_frame, - text="Save Directory: ", - row=2, - column=0, - _type=str, - _font=font), - "beta": SliderEntry(self.segmentation_frame, - text="Under-/Over-segmentation factor: ", - row=3, - column=0, - is_not_in_dtws=True, - _type=float, - _font=font), - "ws_2D": MenuEntry(self.segmentation_frame, - text="Run Watershed in 2D: ", - row=4, - column=0, - menu={"True", "False"}, - default=config[self.module]["ws_2D"], - font=font), - - "ws_threshold": SliderEntry(self.segmentation_frame, - text="CNN Predictions Threshold: ", - row=5, - column=0, - data_range=(0, 1, 0.001), - _type=float, - _font=font), - - "ws_sigma": SliderEntry(self.segmentation_frame, - text="Watershed Seeds Sigma: ", - row=6, - column=0, - data_range=(0, 5, 0.2), - _type=float, - _font=font), - "ws_w_sigma": SliderEntry(self.segmentation_frame, - text="Watershed Boundary Sigma: ", - row=7, - column=0, - data_range=(0, 5, 0.2), - _type=float, - _font=font), - "ws_minsize": SimpleEntry(self.segmentation_frame, - text="Superpixels Minimum Size (voxels): ", - row=8, - column=0, - _type=int, - _font=font), - "post_minsize": SimpleEntry(self.segmentation_frame, - text="Cell Minimum Size (voxels): ", - row=9, - column=0, - _type=int, - _font=font), - } + self.custom_key = { + "key": SimpleEntry( + self.segmentation_frame, + text="Key (HDF5 only): ", + row=1, + column=0, + _type=str, + _font=font, + ), + "channel": SimpleEntry( + self.segmentation_frame, + text="Channel (HDF5 only): ", + row=2, + column=0, + _type=int, + _font=font, + ), + "name": MenuEntry( + self.segmentation_frame, + text="Algorithm: ", + row=3, + column=0, + menu={"MultiCut", "LiftedMulticut", "GASP", "MutexWS", "DtWatershed", "SimpleITK"}, + is_segmentation=True, + default=config[self.module]["name"], + font=font, + ), + "save_directory": SimpleEntry( + self.segmentation_frame, + text="Save Directory: ", + row=4, + column=0, + _type=str, + _font=font, + ), + "beta": SliderEntry( + self.segmentation_frame, + text="Under-/Over-segmentation factor: ", + row=5, + column=0, + is_not_in_dtws=True, + _type=float, + _font=font, + ), + "ws_2D": MenuEntry( + self.segmentation_frame, + text="Run Watershed in 2D: ", + row=6, + column=0, + menu={"True", "False"}, + default=config[self.module]["ws_2D"], + font=font, + ), + "ws_threshold": SliderEntry( + self.segmentation_frame, + text="CNN Predictions Threshold: ", + row=7, + column=0, + data_range=(0, 1, 0.001), + _type=float, + _font=font, + ), + "ws_sigma": SliderEntry( + self.segmentation_frame, + text="Watershed Seeds Sigma: ", + row=8, + column=0, + data_range=(0, 5, 0.2), + _type=float, + _font=font, + ), + "ws_w_sigma": SliderEntry( + self.segmentation_frame, + text="Watershed Boundary Sigma: ", + row=9, + column=0, + data_range=(0, 5, 0.2), + _type=float, + _font=font, + ), + "ws_minsize": SimpleEntry( + self.segmentation_frame, + text="Superpixels Minimum Size (voxels): ", + row=10, + column=0, + _type=int, + _font=font, + ), + "post_minsize": SimpleEntry( + self.segmentation_frame, + text="Cell Minimum Size (voxels): ", + row=11, + column=0, + _type=int, + _font=font, + ), + "nuclei_predictions_path": SimpleEntry( + self.segmentation_frame, + text="Nuclei Path: ", + row=12, + column=0, + _type=str, + _font=font, + ), + "key_nuclei": SimpleEntry( + self.segmentation_frame, + text="Nuclei Key (HDF5 only): ", + row=13, + column=0, + _type=str, + _font=font, + ), + "channel_nuclei": SimpleEntry( + self.segmentation_frame, + text="Nuclei Channel (HDF5 only): ", + row=14, + column=0, + _type=int, + _font=font, + ), + "is_segmentation": MenuEntry( + self.segmentation_frame, + text="Nuclei are Labels (or Probability): ", + row=15, + column=0, + menu={"True", "False"}, + default=config[self.module]["is_segmentation"], + font=font, + ), + } self.show_options() @@ -342,26 +467,21 @@ def __init__(self, frame, config, col=0, module_name="segmentation", font=None, class PostSegmentationFrame(ModuleFramePrototype): def __init__(self, frame, config, row=0, module_name="Segmentation Post Processing", font=None, show_all=True): self.post_frame = tkinter.Frame(frame) - self.post_style = {"bg": "white", - "padx": 0, - "pady": 0, - "row_weights": [1, 1, 1, 1], - "columns_weights": [1], - "height": 4, - } + self.post_style = { + "bg": "white", + "padx": 0, + "pady": 0, + "row_weights": [1, 1, 1, 1], + "columns_weights": [1], + "height": 4, + } self.post_frame["bg"] = self.post_style["bg"] self.font = font - self.post_frame.grid(column=0, - row=row, - padx=self.post_style["padx"], - pady=self.post_style["pady"], - sticky=stick_all) + self.post_frame.grid(column=0, row=row, padx=self.post_style["padx"], pady=self.post_style["pady"], sticky=stick_all) - [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["columns_weights"])] + [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["columns_weights"])] super().__init__(self.post_frame, module_name, font=font) self.module = "segmentation_postprocessing" @@ -377,21 +497,42 @@ def __init__(self, frame, config, row=0, module_name="Segmentation Post Processi self.checkbox["command"] = self.show_options self.obj_collection = [] - self.custom_key = {"tiff": MenuEntry(self.post_frame, - text="Convert to tiff: ", - row=1, - column=0, - menu=["True", "False"], - default=self.config[self.module]["tiff"], - font=font), - "save_raw": MenuEntry(self.post_frame, - text="Save raw data: ", - row=4, - column=0, - menu=["True", "False"], - default=self.config[self.module].get("save_raw", "False"), - font=font), - } + self.custom_key = { + "key": SimpleEntry( + self.post_frame, + text="Key (HDF5 only): ", + row=1, + column=0, + _type=str, + _font=font, + ), + "channel": SimpleEntry( + self.post_frame, + text="Channel (HDF5 only): ", + row=2, + column=0, + _type=int, + _font=font, + ), + "tiff": MenuEntry( + self.post_frame, + text="Convert to tiff: ", + row=3, + column=0, + menu=["True", "False"], + default=self.config[self.module]["tiff"], + font=font, + ), + "save_raw": MenuEntry( + self.post_frame, + text="Save raw data: ", + row=4, + column=0, + menu=["True", "False"], + default=self.config[self.module].get("save_raw", "False"), + font=font, + ), + } self.show_options() @@ -399,27 +540,28 @@ def __init__(self, frame, config, row=0, module_name="Segmentation Post Processi class PostPredictionsFrame(ModuleFramePrototype): def __init__(self, frame, config, row=0, module_name="Prediction Post Processing", font=None, show_all=True): self.post_frame = tkinter.Frame(frame) - self.post_style = {"bg": "white", - "padx": 0, - "pady": 0, - "row_weights": [1, 1, 1, 1], - "columns_weights": [1], - "height": 4, - } + self.post_style = { + "bg": "white", + "padx": 0, + "pady": 0, + "row_weights": [1, 1, 1, 1], + "columns_weights": [1], + "height": 4, + } self.post_frame["bg"] = self.post_style["bg"] self.font = font - self.post_frame.grid(column=0, - row=row, - padx=self.post_style["padx"], - pady=self.post_style["pady"], - sticky=stick_new) + self.post_frame.grid( + column=0, + row=row, + padx=self.post_style["padx"], + pady=self.post_style["pady"], + sticky=stick_new, + ) - [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["columns_weights"])] + [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["columns_weights"])] super().__init__(self.post_frame, module_name, font=font) self.module = "cnn_postprocessing" @@ -435,28 +577,51 @@ def __init__(self, frame, config, row=0, module_name="Prediction Post Processing self.checkbox["command"] = self.show_options self.obj_collection = [] - self.custom_key = {"tiff": MenuEntry(self.post_frame, - text="Convert to tiff: ", - row=1, - column=0, - menu=["True", "False"], - default=self.config[self.module]["tiff"], - font=font), - "output_type": MenuEntry(self.post_frame, - text="Cast Predictions: ", - row=4, - column=0, - menu=["data_uint8", "data_float32"], - default=config[self.module]["output_type"], - font=font), - "save_raw": MenuEntry(self.post_frame, - text="Save raw data: ", - row=7, - column=0, - menu=["True", "False"], - default=self.config[self.module].get("save_raw", "False"), - font=font), - } + self.custom_key = { + "key": SimpleEntry( + self.post_frame, + text="Key (HDF5 only): ", + row=1, + column=0, + _type=str, + _font=font, + ), + "channel": SimpleEntry( + self.post_frame, + text="Channel (HDF5 only): ", + row=2, + column=0, + _type=int, + _font=font, + ), + "tiff": MenuEntry( + self.post_frame, + text="Convert to tiff: ", + row=3, + column=0, + menu=["True", "False"], + default=self.config[self.module]["tiff"], + font=font, + ), + "output_type": MenuEntry( + self.post_frame, + text="Cast Predictions: ", + row=4, + column=0, + menu=["data_uint8", "data_float32"], + default=config[self.module]["output_type"], + font=font, + ), + "save_raw": MenuEntry( + self.post_frame, + text="Save raw data: ", + row=5, + column=0, + menu=["True", "False"], + default=self.config[self.module].get("save_raw", "False"), + font=font, + ), + } self.show_options() @@ -464,26 +629,21 @@ def __init__(self, frame, config, row=0, module_name="Prediction Post Processing class PostFrame: def __init__(self, frame, config, col=0, font=None, show_all=True): self.post_frame = tkinter.Frame(frame) - self.post_style = {"bg": "white", - "padx": 10, - "pady": 10, - "row_weights": [1, 1], - "columns_weights": [1], - "height": 4, - } + self.post_style = { + "bg": "white", + "padx": 10, + "pady": 10, + "row_weights": [1, 1], + "columns_weights": [1], + "height": 4, + } self.post_frame["bg"] = self.post_style["bg"] self.font = font - self.post_frame.grid(column=col, - row=0, - padx=self.post_style["padx"], - pady=self.post_style["pady"], - sticky=stick_new) - - [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["row_weights"])] - [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) - for i, w in enumerate(self.post_style["columns_weights"])] + self.post_frame.grid(column=col, row=0, padx=self.post_style["padx"], pady=self.post_style["pady"], sticky=stick_new) + + [tkinter.Grid.rowconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["row_weights"])] + [tkinter.Grid.columnconfigure(self.post_frame, i, weight=w) for i, w in enumerate(self.post_style["columns_weights"])] # init frames self.post_pred_obj = PostPredictionsFrame(self.post_frame, config, row=0, font=font, show_all=True) diff --git a/plantseg/legacy_gui/plantsegapp.py b/plantseg/legacy_gui/plantsegapp.py index b8692707..8dcd814b 100644 --- a/plantseg/legacy_gui/plantsegapp.py +++ b/plantseg/legacy_gui/plantsegapp.py @@ -38,7 +38,7 @@ def __init__(self): # Load app config self.app_config = self.load_app_config() - self.plant_config_path, self.plantseg_config = self.load_config() + self.plant_config_path, self.plantseg_config = self.load_core_config() # Init main app and configure self.plant_segapp = tkinter.Tk() @@ -312,8 +312,8 @@ def get_icon_path(name="FOR2581_Logo_FINAL_no_text.png"): icon_path = os.path.join(plantseg_global_path, RESOURCES_DIR, name) return icon_path - def load_config(self, name="config_gui_last.yaml"): - """Load the last (or if not possible a standard) config""" + def load_core_config(self, name="config_gui_last.yaml"): + """Load the last (or if not possible a standard) config. Used only once at startup.""" plant_config_path = self.get_last_config_path(name) if os.path.exists(plant_config_path): @@ -396,28 +396,28 @@ def config_row_column(frame, config): @staticmethod def open_documentation_index(): - """Open git page on the default browser""" - webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki") + """Open documentation homepage on the default browser""" + webbrowser.open("https://hci-unihd.github.io/plant-seg/intro.html") @staticmethod def open_documentation_preprocessing(): - """Open git page on the default browser""" - webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki/Classic-Data-Processing") + """Open Classic Data Processing documentation on the default browser""" + webbrowser.open("https://hci-unihd.github.io/plant-seg/chapters/plantseg_classic_gui/data_processing.html") @staticmethod def open_documentation_3dunet(): - """Open git page on the default browser""" - webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki/CNN-Predictions") + """Open CNN Predictions documentation on the default browser""" + webbrowser.open("https://hci-unihd.github.io/plant-seg/chapters/plantseg_classic_gui/cnn_predictions.html") @staticmethod def open_documentation_segmentation(): - """Open git page on the default browser""" - webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki/Segmentation") + """Open Segmentation documentation on the default browser""" + webbrowser.open("https://hci-unihd.github.io/plant-seg/chapters/plantseg_classic_gui/segmentation.html") @staticmethod def open_postprocessing(): - """Open git page on the default browser""" - webbrowser.open("https://github.com/hci-unihd/plant-seg/wiki/Classic-Data-Processing") + """Open Classic Data Processing documentation on the default browser""" + webbrowser.open("https://hci-unihd.github.io/plant-seg/chapters/plantseg_classic_gui/data_processing.html") def size_up(self): """ adjust font size in the main widget""" diff --git a/plantseg/pipeline/raw2seg.py b/plantseg/pipeline/raw2seg.py index b0b46a3d..6d0b9ee4 100644 --- a/plantseg/pipeline/raw2seg.py +++ b/plantseg/pipeline/raw2seg.py @@ -34,6 +34,7 @@ def configure_preprocessing_step(input_paths, config): def configure_cnn_step(input_paths, config): input_key = config.get('key', None) input_channel = config.get('channel', None) + model_name = config['model_name'] patch = config.get('patch', (80, 160, 160)) stride_ratio = config.get('stride_ratio', 0.75) @@ -55,6 +56,7 @@ def configure_segmentation_postprocessing_step(input_paths, config): def _create_postprocessing_step(input_paths, input_type, config): input_key = config.get('key', None) input_channel = config.get('channel', None) + output_type = config.get('output_type', input_type) save_directory = config.get('save_directory', 'PostProcessing') factor = config.get('factor', [1, 1, 1]) @@ -71,6 +73,7 @@ def _create_postprocessing_step(input_paths, input_type, config): def _validate_cnn_postprocessing_rescaling(input_paths, config): input_key = config["preprocessing"].get('key', None) input_channel = config["preprocessing"].get('channel', None) + input_shapes = [load_shape(input_path, key=input_key) for input_path in input_paths] if input_channel is not None: input_shapes = [input_shape[input_channel] for input_shape in input_shapes] @@ -92,10 +95,20 @@ def raw2seg(config): ('cnn_prediction', configure_cnn_step), ('cnn_postprocessing', configure_cnn_postprocessing_step), ('segmentation', configure_segmentation_step), - ('segmentation_postprocessing', configure_segmentation_postprocessing_step) + ('segmentation_postprocessing', configure_segmentation_postprocessing_step), ] - for pipeline_step_name, pipeline_step_setup in all_pipeline_steps: + for pipeline_step_name, pipeline_step_setup in all_pipeline_steps: # Common section for all steps + # In Tk GUI, entries have fixed types. All steps are fixed here including LMC. TODO: better solution? + if config[pipeline_step_name].get('key', None) == 'None': # in Tk GUI key is str + config[pipeline_step_name]['key'] = None + if config[pipeline_step_name].get('channel', None) == -1: # in Tk GUI channel is int + config[pipeline_step_name]['channel'] = None + if config[pipeline_step_name].get('key_nuclei', None) == 'None': # in Tk GUI key is str + config[pipeline_step_name]['key_nuclei'] = None + if config[pipeline_step_name].get('channel_nuclei', None) == -1: # in Tk GUI channel is int + config[pipeline_step_name]['channel_nuclei'] = None + if pipeline_step_name == 'preprocessing': _validate_cnn_postprocessing_rescaling(input_paths, config) @@ -108,4 +121,4 @@ def raw2seg(config): if not isinstance(pipeline_step, DataPostProcessing3D): input_paths = output_paths - gui_logger.info(f"Pipeline execution finished!") + gui_logger.info("Pipeline execution finished!") diff --git a/plantseg/predictions/functional/predictions.py b/plantseg/predictions/functional/predictions.py index f61d9bce..457c4b14 100644 --- a/plantseg/predictions/functional/predictions.py +++ b/plantseg/predictions/functional/predictions.py @@ -3,8 +3,9 @@ import numpy as np import torch +from plantseg.viewer.logging import napari_formatted_logging from plantseg.augment.transforms import get_test_augmentations -from plantseg.dataprocessing.functional.dataprocessing import fix_input_shape +from plantseg.dataprocessing.functional.dataprocessing import fix_input_shape_to_3D, fix_input_shape_to_4D from plantseg.predictions.functional.array_dataset import ArrayDataset from plantseg.predictions.functional.array_predictor import ArrayPredictor from plantseg.predictions.functional.slice_builder import SliceBuilder @@ -14,9 +15,11 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] = (80, 160, 160), single_batch_mode: bool = True, device: str = 'cuda', model_update: bool = False, - disable_tqdm: bool = False, **kwargs) -> np.array: + disable_tqdm: bool = False, handle_multichannel = False, **kwargs) -> np.array: """ Predict boundaries predictions from raw data using a 3D U-Net model. + If the model has single-channel output, then return a 3D array of shape (Z, Y, X). + If the model has multi-channel output, then return a 4D array of shape (C, Z, Y, X). Args: raw (np.array): raw data, must be a 3D array of shape (Z, Y, X) normalized between 0 and 1. @@ -27,11 +30,13 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] Defaults to 'cuda'. model_update (bool, optional): if True will update the model to the latest version. Defaults to False. disable_tqdm (bool, optional): if True will disable tqdm progress bar. Defaults to False. + output_ndim (int, optional): output ndim, must be one of [3, 4]. Only use `4` if network output is + multi-channel 3D pmap. Now `4` only used in `widget_unet_predictions()`. Returns: - np.array: predictions, 3D array of shape (Z, Y, X) with values between 0 and 1. + np.array: predictions, 4D array of shape (C, Z, Y, X) or 3D array of shape (Z, Y, X) with values between 0 and 1. + if `out_channels` in model config is greater than 1, then output will be 4D array. :param single_batch_mode: - """ model, model_config, model_path = get_model_config(model_name, model_update=model_update) state = torch.load(model_path, map_location='cpu') @@ -46,13 +51,20 @@ def unet_predictions(raw: np.array, model_name: str, patch: Tuple[int, int, int] patch_halo=patch_halo, single_batch_mode=single_batch_mode, headless=False, verbose_logging=False, disable_tqdm=disable_tqdm) - raw = fix_input_shape(raw) + raw = fix_input_shape_to_3D(raw) raw = raw.astype('float32') stride = get_stride_shape(patch) augs = get_test_augmentations(raw) slice_builder = SliceBuilder(raw, label_dataset=None, patch_shape=patch, stride_shape=stride) test_dataset = ArrayDataset(raw, slice_builder, augs, verbose_logging=False) - pmaps = predictor(test_dataset) - pmaps = fix_input_shape(pmaps[0]) + pmaps = predictor(test_dataset) # pmaps either (C, Z, Y, X) or (C, Y, X) + out_channel = int(model_config['out_channels']) + + if out_channel > 1 and handle_multichannel: # if multi-channel output and who called this function can handle it + napari_formatted_logging(f'`unet_predictions()` has `handle_multichannel`={handle_multichannel}', + thread="unet_predictions", level='warning') + pmaps = fix_input_shape_to_4D(pmaps) # then make (C, Y, X) to (C, 1, Y, X) and keep (C, Z, Y, X) unchanged + else: # otherwise use old mechanism + pmaps = fix_input_shape_to_3D(pmaps[0]) return pmaps diff --git a/plantseg/resources/config_gui_template.yaml b/plantseg/resources/config_gui_template.yaml index a8e395a0..f0426044 100644 --- a/plantseg/resources/config_gui_template.yaml +++ b/plantseg/resources/config_gui_template.yaml @@ -4,6 +4,10 @@ path: preprocessing: # enable/disable preprocessing state: True + # key for H5 or ZARR, can be set to null if only one key exists in each file + key: Null + # channel to use if input image has shape CZYX or CYX, otherwise set to null + channel: -1 # create a new sub folder where all results will be stored save_directory: "PreProcessing" # rescaling the volume is essential for the generalization of the networks. The rescaling factor can be computed as the resolution @@ -27,6 +31,10 @@ preprocessing: cnn_prediction: # enable/disable UNet prediction state: True + # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True + key: Null + # channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True + channel: -1 # Trained model name, more info on available models and custom models in the README model_name: "generic_confocal_3D_unet" # If a CUDA capable gpu is available and corrected setup use "cuda", if not you can use "cpu" for cpu only inference (slower) @@ -45,6 +53,10 @@ cnn_prediction: cnn_postprocessing: # enable/disable cnn post processing state: True + # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True + key: Null + # channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True + channel: -1 # if True convert to result to tiff tiff: False output_type: "data_float32" @@ -55,12 +67,26 @@ cnn_postprocessing: # save raw input in the output prediction file h5 file save_raw: False - segmentation: # enable/disable segmentation state: True + # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True + key: Null + # channel to use if prediction has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True + channel: -1 # Name of the algorithm to use for inferences name: "GASP" + + # For 'LiftedMulticut' algorithm only + # path to the directory containing the nuclei predictions (either probability maps or segmentation) + nuclei_predictions_path: + # whether the `nuclei_predictions_path` contains the probability maps (is_segmentation=False) or segmentation (is_segmentation=True) + is_segmentation: True + # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True + key_nuclei: Null + # channel to use if prediction has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True + channel_nuclei: -1 + # Segmentation specific parameters here # balance under-/over-segmentation; 0 - aim for undersegmentation, 1 - aim for oversegmentation beta: 0.6 @@ -84,6 +110,10 @@ segmentation: segmentation_postprocessing: # enable/disable segmentation post processing state: True + # key for H5 or ZARR, can be set to null if only one key exists in each file; null is recommended if the previous steps has state True + key: Null + # channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True + channel: -1 # if True convert to result to tiff tiff: False # rescaling factor diff --git a/plantseg/viewer/containers.py b/plantseg/viewer/containers.py index 3c1bb1b5..8d798f33 100644 --- a/plantseg/viewer/containers.py +++ b/plantseg/viewer/containers.py @@ -38,7 +38,7 @@ def get_main(): widget_filter_segmentation, ], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Napari-Main') + container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/index.html') return container @@ -49,7 +49,7 @@ def get_preprocessing_workflow(): widget_add_layers, widget_label_processing], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Data-Processing') + container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/data_processing.html') return container @@ -58,7 +58,7 @@ def get_gasp_workflow(): widget_simple_dt_ws, widget_agglomeration], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/UNet-GASP-Workflow') + container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/unet_gasp_workflow.html') return container @@ -67,7 +67,7 @@ def get_extra_seg(): widget_lifted_multicut, widget_fix_over_under_segmentation_from_nuclei], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Extra-Seg') + container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/extra_seg.html') return container @@ -76,5 +76,5 @@ def get_extra_pred(): widget_iterative_unet_predictions, widget_add_custom_model], labels=False) - container = setup_menu(container, path='https://github.com/hci-unihd/plant-seg/wiki/Extra-Pred') + container = setup_menu(container, path='https://hci-unihd.github.io/plant-seg/chapters/plantseg_interactive_napari/extra_pred.html') return container diff --git a/plantseg/viewer/widget/predictions.py b/plantseg/viewer/widget/predictions.py index d5dbdea1..96ede2cc 100644 --- a/plantseg/viewer/widget/predictions.py +++ b/plantseg/viewer/widget/predictions.py @@ -18,7 +18,7 @@ from plantseg.viewer.widget.proofreading.proofreading import widget_split_and_merge_from_scribbles from plantseg.viewer.widget.segmentation import widget_agglomeration, widget_lifted_multicut, widget_simple_dt_ws from plantseg.viewer.widget.utils import return_value_if_widget -from plantseg.viewer.widget.utils import start_threading_process, create_layer_name, layer_properties +from plantseg.viewer.widget.utils import start_threading_process, start_prediction_threading_process, create_layer_name, layer_properties ALL_CUDA_DEVICES = [f'cuda:{i}' for i in range(torch.cuda.device_count())] MPS = ['mps'] if torch.backends.mps.is_available() else [] @@ -87,8 +87,10 @@ def widget_unet_predictions(viewer: Viewer, layer_type = 'image' step_kwargs = dict(model_name=model_name, patch=patch_size, single_batch_mode=single_patch) - return start_threading_process(unet_predictions_wrapper, - runtime_kwargs={'raw': image.data, 'device': device}, + return start_prediction_threading_process(unet_predictions_wrapper, + runtime_kwargs={'raw': image.data, + 'device': device, + 'handle_multichannel': True}, statics_kwargs=step_kwargs, out_name=out_name, input_keys=inputs_names, diff --git a/plantseg/viewer/widget/utils.py b/plantseg/viewer/widget/utils.py index 40d4fdd1..896bfa35 100644 --- a/plantseg/viewer/widget/utils.py +++ b/plantseg/viewer/widget/utils.py @@ -70,6 +70,54 @@ def on_done(result): return future +def start_prediction_process(func: Callable, + runtime_kwargs: dict, + statics_kwargs: dict, + out_name: str, + input_keys: Tuple[str, ...], + layer_kwarg: dict, + layer_type: str = 'image', + step_name: str = '', + skip_dag: bool = False, + viewer: Viewer = None, + widgets_to_update: list = None) -> Future: + runtime_kwargs.update(statics_kwargs) + thread_func = thread_worker(partial(func, **runtime_kwargs)) + future = Future() + timer_start = timeit.default_timer() + + def on_done(result): + timer = timeit.default_timer() - timer_start + napari_formatted_logging(f'Widget {step_name} computation complete in {timer:.2f}s', thread=step_name) + _func = func if not skip_dag else identity + dag_manager.add_step(_func, input_keys=input_keys, + output_key=out_name, + static_params=statics_kwargs, + step_name=step_name) + if result.ndim == 4: # then we have a 2-channel output + pmap_layers = [] + for i, pmap in enumerate(result): + temp_layer_kwarg = layer_kwarg.copy() + temp_layer_kwarg['name'] = layer_kwarg['name'] + f'_{i}' + pmap_layers.append((pmap, temp_layer_kwarg, layer_type)) + result = pmap_layers + + # Only widget_unet_predictions() invokes and handles 4D UNet output for now, but headless mode can also invoke this part, thus warn: + napari_formatted_logging(f'Widget {step_name}: Headless mode is not supported for 2-channel output predictions', thread=step_name, level='warning') + else: # then we have a 1-channel output + result = result, layer_kwarg, layer_type + future.set_result(result) + + if viewer is not None and widgets_to_update is not None: + setup_layers_suggestions(viewer, out_name, widgets_to_update) + + worker = thread_func() + worker.returned.connect(on_done) + worker.start() + napari_formatted_logging(f'Widget {step_name} computation started', thread=step_name) + return future + + def layer_properties(name, scale, metadata: dict = None): keys_to_save = {'original_voxel_size', 'voxel_size_unit', 'root_name'} if metadata is not None: