Skip to content

Commit

Permalink
Merge pull request #355 from kreshuklab/proofreading_improvements
Browse files Browse the repository at this point in the history
Proofreading improvements
  • Loading branch information
lorenzocerrone authored Oct 21, 2024
2 parents 2195417 + f8bfe4b commit fd7ff08
Show file tree
Hide file tree
Showing 16 changed files with 699 additions and 375 deletions.
86 changes: 70 additions & 16 deletions plantseg/core/image.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from enum import Enum
from pathlib import Path
from typing import Literal
from uuid import UUID, uuid4

import h5py
import numpy as np
from napari.layers import Image, Labels
from napari.types import LayerDataTuple
from pydantic import BaseModel

import plantseg.functionals.dataprocessing as dp
from plantseg.io.h5 import create_h5
from plantseg.io.h5 import H5_EXTENSIONS, create_h5
from plantseg.io.io import smart_load_with_vs
from plantseg.io.tiff import create_tiff
from plantseg.io.voxelsize import VoxelSize
Expand Down Expand Up @@ -110,7 +112,7 @@ class ImageProperties(BaseModel):
voxel_size: VoxelSize
image_layout: ImageLayout
original_voxel_size: VoxelSize
original_name: str | None = None
source_file_name: str | None = None

@property
def dimensionality(self) -> ImageDimensionality:
Expand Down Expand Up @@ -264,7 +266,7 @@ def from_napari_layer(cls, layer: Image | Labels) -> "PlantSegImage":
raise ValueError("Voxel size not found in metadata")
new_voxel_size = VoxelSize(**metadata["voxel_size"])

original_name = metadata.get("original_name", None)
source_file_name = metadata.get("source_file_name", None)

# Loading from napari layer, the id needs to be present in the metadata
# If not present, the layer is corrupted
Expand All @@ -279,13 +281,13 @@ def from_napari_layer(cls, layer: Image | Labels) -> "PlantSegImage":
voxel_size=new_voxel_size,
image_layout=image_layout,
original_voxel_size=original_voxel_size,
original_name=original_name,
source_file_name=source_file_name,
)

if image_type != properties.image_type:
raise ValueError(f"Image type {image_type} does not match semantic type {properties.semantic_type}")

ps_image = cls(layer.data, properties)
ps_image = cls(layer.data, properties) # type: ignore
ps_image._id = id
return ps_image

Expand Down Expand Up @@ -317,7 +319,57 @@ def to_napari_layer_tuple(self) -> LayerDataTuple:

return LayerDataTuple(layer_data_tuple)

def _check_ndim(self, data: np.ndarray) -> None:
def to_h5(self, path: Path | str, key: str | None, mode: Literal["a", "w", "w-"] = "a") -> None:
"""Save the image with all metadata to an h5 file.
Args:
path (Path): Path to the h5 file
key (str): Key to save the data in the h5 file
mode (str): Mode to open the h5 file ['a', 'w', 'w-']
"""

if isinstance(path, str):
path = Path(path)

if path.suffix not in H5_EXTENSIONS:
raise ValueError(f"File format {path.suffix} not supported, should be one of {H5_EXTENSIONS}")

key = key if key is not None else self.name

data = self._data
voxel_size = self.voxel_size
metadata = self.properties.model_dump_json()

with h5py.File(path, mode=mode) as f:
f.create_dataset(key, data=data)
if voxel_size.voxels_size is not None:
f[key].attrs["element_size_um"] = voxel_size.voxels_size
f[key].attrs["plantseg_image_metadata_json"] = metadata

@classmethod
def from_h5(cls, path: Path | str, key: str) -> "PlantSegImage":
"""Build an instance of PlantSegImage from an h5 file."""

if isinstance(path, str):
path = Path(path)

if not path.exists():
raise ValueError(f"File {path} not found")

with h5py.File(path, "r") as f:
if key not in f:
raise ValueError(f"Key {key} not found in the h5 file")

data: np.ndarray = f[key][...] # type: ignore
metadata = f[key].attrs.get("plantseg_image_metadata_json", None)

if metadata is None:
raise ValueError("PlantSeg metadata not found in the h5 file")

properties = ImageProperties.model_validate_json(metadata)
return cls(data, properties)

def _check_ndim(self, data: np.ndarray) -> np.ndarray:
if self.image_layout in (ImageLayout.CYX, ImageLayout.ZYX):
if data.ndim != 3:
raise ValueError(
Expand All @@ -341,7 +393,7 @@ def _check_ndim(self, data: np.ndarray) -> None:

return data

def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> None:
def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> tuple[np.ndarray, ImageProperties]:
if self.image_layout == ImageLayout.ZYX:
if data.shape[0] == 1:
logger.warning("Image layout is ZYX but data has only one z slice, casting to YX")
Expand All @@ -362,7 +414,7 @@ def _check_shape(self, data: np.ndarray, properties: ImageProperties) -> None:
elif data.shape[0] > 1 and data.shape[1] == 1:
logger.warning("Image layout is CZYX but data has only one z slice, casting to CYX")
properties.image_layout = ImageLayout.CYX
return data[:, 0], properties.image_layout
return data[:, 0], properties

elif self.image_layout == ImageLayout.ZCYX:
raise ValueError(f"Image layout {self.image_layout} not supported, should have been converted to CZYX")
Expand All @@ -386,12 +438,14 @@ def _get_data_channel_layout(self, channel: int | None = None, normalize_01: boo
if channel is None:
data = self._data
if normalize_01:
assert self.channel_axis is not None
data = dp.normalize_01_channel_wise(data, self.channel_axis)
return data

if channel < 0:
raise ValueError(f"Channel should be a positive integer, but got {channel}")

assert self.channel_axis is not None
if channel > self._data.shape[self.channel_axis]:
raise ValueError(
f"Channel {channel} is out of bounds, the image has {self._data.shape[self.channel_axis]} channels"
Expand Down Expand Up @@ -463,8 +517,8 @@ def original_voxel_size(self) -> VoxelSize:
return self._properties.original_voxel_size

@property
def original_name(self) -> str | None:
return self._properties.original_name
def source_file_name(self) -> str | None:
return self._properties.source_file_name

@property
def name(self) -> str:
Expand Down Expand Up @@ -553,7 +607,7 @@ def import_image(
voxel_size=voxel_size,
image_layout=image_layout,
original_voxel_size=voxel_size,
original_name=image_name,
source_file_name=path.stem,
)

return PlantSegImage(data=data, properties=image_properties)
Expand All @@ -565,8 +619,8 @@ def _image_postprocessing(
if scale_to_origin and image.requires_scaling:
data = dp.scale_image_to_voxelsize(
image.get_data(),
input_voxel_size=image.voxel_size,
output_voxel_size=image.original_voxel_size,
input_voxel_size=image.voxel_size.as_tuple(),
output_voxel_size=image.original_voxel_size.as_tuple(),
order=image.interpolation_order(),
)
new_voxel_size = image.original_voxel_size
Expand Down Expand Up @@ -608,7 +662,7 @@ def save_image(
Args:
image (PlantSegImage): input image to be saved to disk
export_directory (Path): output directory path where the image will be saved
name_pattern (str): output file name pattern, can contain the {image_name} or {original_name} tokens
name_pattern (str): output file name pattern, can contain the {image_name} or {file_name} tokens
to be replaced in the final file name.
key (str | None): key for the image (used only for h5 and zarr formats).
scale_to_origin (bool): scale the voxel size to the original one
Expand All @@ -623,8 +677,8 @@ def save_image(

name_pattern = name_pattern.replace("{image_name}", image.name)

if image.original_name is not None:
name_pattern = name_pattern.replace("{original_name}", image.original_name)
if image.source_file_name is not None:
name_pattern = name_pattern.replace("{file_name}", image.source_file_name)

if export_format == "tiff":
file_path_name = directory / f"{name_pattern}.tiff"
Expand Down
2 changes: 1 addition & 1 deletion plantseg/io/h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def create_h5(
del f[key]
f.create_dataset(key, data=stack, compression="gzip")
# save voxel_size
if voxel_size is not None:
if voxel_size is not None and voxel_size.voxels_size is not None:
f[key].attrs["element_size_um"] = voxel_size.voxels_size


Expand Down
29 changes: 26 additions & 3 deletions plantseg/tasks/dataprocessing_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

from plantseg.core.image import ImageLayout, PlantSegImage, SemanticType
from plantseg.core.image import ImageDimensionality, ImageLayout, PlantSegImage, SemanticType
from plantseg.functionals.dataprocessing import (
fix_over_under_segmentation_from_nuclei,
image_gaussian_smoothing,
Expand Down Expand Up @@ -34,14 +34,17 @@ def gaussian_smoothing_task(image: PlantSegImage, sigma: float) -> PlantSegImage
return new_image


def _compute_slices(rectangle, crop_z: tuple[int, int], shape):
def _compute_slices_3d(rectangle, crop_z: tuple[int, int], shape):
"""
Compute slices for cropping based on a given rectangle and z-slices.
"""
z_slice = slice(*crop_z)
if rectangle is None:
return z_slice, slice(0, shape[1]), slice(0, shape[2])

if (rectangle[2, 0] - rectangle[0, 0]) > 0:
raise ValueError("Invalid crop, the rextangle must be drawn in the XY plane.")

x_start = max(rectangle[0, 1], 0)
x_end = min(rectangle[2, 1], shape[1])
x_slice = slice(x_start, x_end)
Expand All @@ -52,6 +55,23 @@ def _compute_slices(rectangle, crop_z: tuple[int, int], shape):
return z_slice, x_slice, y_slice


def _compute_slices_2d(rectangle, shape):
"""
Compute slices for cropping based on a given rectangle.
"""
if rectangle is None:
return slice(0, shape[0]), slice(0, shape[1])

x_start = max(rectangle[0, 0], 0)
x_end = min(rectangle[2, 0], shape[0])
x_slice = slice(x_start, x_end)

y_start = max(rectangle[0, 1], 0)
y_end = min(rectangle[2, 1], shape[1])
y_slice = slice(y_start, y_end)
return x_slice, y_slice


def _cropping(data, crop_slices):
"""
Apply cropping on the provided data based on the computed slices.
Expand All @@ -75,7 +95,10 @@ def image_cropping_task(image: PlantSegImage, rectangle=None, crop_z: tuple[int,
data = image.get_data()

# Compute crop slices
crop_slices = _compute_slices(rectangle, crop_z, data.shape)
if image.dimensionality == ImageDimensionality.TWO:
crop_slices = _compute_slices_2d(rectangle, data.shape)
else:
crop_slices = _compute_slices_3d(rectangle, crop_z, data.shape)

# Perform cropping on the data
cropped_data = _cropping(data, crop_slices)
Expand Down
4 changes: 2 additions & 2 deletions plantseg/tasks/io_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def import_image_task(
def export_image_task(
image: PlantSegImage,
export_directory: Path,
name_pattern: str = "{original_name}_export",
name_pattern: str = "{file_name}_export",
key: str | None = None,
scale_to_origin: bool = True,
export_format: str = "tiff",
Expand All @@ -85,7 +85,7 @@ def export_image_task(
Args:
image (PlantSegImage): input image to be saved to disk
export_directory (Path): output directory path where the image will be saved
name_pattern (str): output file name pattern, can contain the {image_name} or {original_name} tokens
name_pattern (str): output file name pattern, can contain the {image_name} or {file_name} tokens
to be replaced in the final file name.
key (str | None): key for the image (used only for h5 and zarr formats).
scale_to_origin (bool): scale the voxel size to the original one
Expand Down
15 changes: 5 additions & 10 deletions plantseg/viewer_napari/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from plantseg.viewer_napari.widgets import (
widget_add_custom_model,
widget_add_custom_model_toggl,
widget_agglomeration,
widget_clean_scribble,
widget_cropping,
Expand All @@ -21,6 +22,7 @@
widget_rescaling,
widget_save_state,
widget_set_biggest_instance_to_zero,
widget_set_voxel_size,
widget_show_info,
widget_split_and_merge_from_scribbles,
widget_undo,
Expand All @@ -37,6 +39,7 @@ def get_data_io_tab():
widget_open_file,
widget_export_image,
widget_export_headless_workflow,
widget_set_voxel_size,
widget_show_info,
widget_infos,
],
Expand All @@ -62,6 +65,8 @@ def get_segmentation_tab():
container = Container(
widgets=[
widget_unet_prediction,
widget_add_custom_model,
widget_add_custom_model_toggl,
widget_dt_ws,
widget_agglomeration,
],
Expand All @@ -83,16 +88,6 @@ def get_postprocessing_tab():
return container


def get_extras_tab():
container = Container(
widgets=[
widget_add_custom_model,
],
labels=False,
)
return container


def get_proofreading_tab():
widget_fix_over_under_segmentation_from_nuclei.threshold.native.setStyleSheet(STYLE_SLIDER)
widget_fix_over_under_segmentation_from_nuclei.quantile.native.setStyleSheet(STYLE_SLIDER)
Expand Down
8 changes: 3 additions & 5 deletions plantseg/viewer_napari/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from plantseg.viewer_napari import log
from plantseg.viewer_napari.containers import (
get_data_io_tab,
get_extras_tab,
get_postprocessing_tab,
get_preprocessing_tab,
get_proofreading_tab,
Expand All @@ -14,16 +13,15 @@

def run_viewer():
viewer = napari.Viewer(title='PlantSeg v2')
setup_proofreading_keybindings(viewer)
setup_proofreading_keybindings(viewer=viewer)

# Create and add tabs
for _containers, name in [
(get_data_io_tab(), 'Input/Output'),
(get_preprocessing_tab(), 'Image Processing'),
(get_preprocessing_tab(), 'Preprocessing'),
(get_segmentation_tab(), 'Segmentation'),
(get_postprocessing_tab(), 'Label Processing'),
(get_postprocessing_tab(), 'Postprocessing'),
(get_proofreading_tab(), 'Proofreading'),
(get_extras_tab(), 'Models'),
]:
viewer.window.add_dock_widget(_containers, name=name, tabify=True)

Expand Down
9 changes: 8 additions & 1 deletion plantseg/viewer_napari/widgets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,14 @@
widget_export_image,
widget_infos,
widget_open_file,
widget_set_voxel_size,
widget_show_info,
)
from plantseg.viewer_napari.widgets.prediction import widget_add_custom_model, widget_unet_prediction
from plantseg.viewer_napari.widgets.prediction import (
widget_add_custom_model,
widget_add_custom_model_toggl,
widget_unet_prediction,
)
from plantseg.viewer_napari.widgets.proofreading import (
widget_add_label_to_corrected,
widget_clean_scribble,
Expand All @@ -41,13 +46,15 @@
"widget_export_headless_workflow",
"widget_show_info",
"widget_infos",
"widget_set_voxel_size",
# Main - Prediction
"widget_unet_prediction",
# Main - Segmentation
"widget_dt_ws",
"widget_agglomeration",
# Extra
"widget_add_custom_model",
"widget_add_custom_model_toggl",
"widget_relabel",
"widget_set_biggest_instance_to_zero",
# Proofreading
Expand Down
Loading

0 comments on commit fd7ff08

Please sign in to comment.