Skip to content

Commit

Permalink
Merge PR #184 from hci-unihd/fixflowGUI | fix #173, fix #183
Browse files Browse the repository at this point in the history
Fix Flow in GUI & Napari: key, channel, LMC in GUI; multi-channel UNet layers in Napari
  • Loading branch information
qin-yu authored Feb 5, 2024
2 parents 441b724 + e80dc25 commit 2e8d551
Show file tree
Hide file tree
Showing 11 changed files with 607 additions and 293 deletions.
22 changes: 22 additions & 0 deletions environment-dev.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: plant-seg-dev
channels:
- pytorch
- nvidia
- conda-forge
dependencies:
- python
- h5py
- zarr
- requests
- pyyaml
- scikit-image
- tifffile
- vigra
- cudnn
- pytorch
- pytorch-cuda=12.1
- python-elf
- pyqt
- napari
- python-graphviz
# then `pip install -e .`
2 changes: 1 addition & 1 deletion examples/config_advanced.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ cnn_prediction:
# channel to use if input image has shape CZYX or CYX, otherwise set to null; null is recommended if the previous steps has state True
channel: Null
# Trained model name, more info on available models and custom models in the README
model_name: 'RIKEN20230824v160'
model_name: 'generic_plant_nuclei_3D'
# If a CUDA capable gpu is available and corrected setup use "cuda", if not you can use "cpu" for cpu only inference (slower)
device: 'cuda'
# how many subprocesses to use for data loading
Expand Down
29 changes: 28 additions & 1 deletion plantseg/dataprocessing/functional/dataprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,15 @@ def image_crop(image: np.array, crop_str: str) -> np.array:
return image[slices]


def fix_input_shape(data: np.array) -> np.array:
def fix_input_shape(data: np.array, ndim=3) -> np.array:
assert ndim in [3, 4]
if ndim == 3:
return fix_input_shape_to_3D(data)
else:
return fix_input_shape_to_4D(data)


def fix_input_shape_to_3D(data: np.array) -> np.array:
"""
fix array ndim to be always 3
"""
Expand All @@ -94,6 +102,25 @@ def fix_input_shape(data: np.array) -> np.array:
raise RuntimeError(f"Expected input data to be 2d, 3d or 4d, but got {data.ndim}d input")


def fix_input_shape_to_4D(data: np.array) -> np.array:
"""
Fix array ndim to be 4 and return it in (C x Z x Y x X) e.g. 2 x 1 x 512 x 512
Only used for multi-channel network output, e.g. 2-channel 3D pmaps.
"""
if data.ndim == 4:
return data

elif data.ndim == 3:
return data.reshape(data.shape[0], 1, data.shape[1], data.shape[2])

# elif data.ndim == 2:
# return data.reshape(1, 1, data.shape[0], data.shape[1])

else:
raise RuntimeError(f"Expected input data to be 3d or 4d, but got {data.ndim}d input")


def normalize_01(data: np.array) -> np.array:
"""
normalize a numpy array between 0 and 1
Expand Down
Loading

0 comments on commit 2e8d551

Please sign in to comment.