From d3309f5fa2c9ed1ea3cd2b60ea00e8f23c34e125 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 20 Jun 2024 15:08:20 -0700 Subject: [PATCH 01/42] first pass on VSCyto2D demo --- .../demos/VSCyto2d_a549cells/demo_vscyto2d.py | 113 ++++++++++++++++++ 1 file changed, 113 insertions(+) create mode 100644 examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py diff --git a/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py b/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py new file mode 100644 index 00000000..4eca98f0 --- /dev/null +++ b/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py @@ -0,0 +1,113 @@ +# %% [markdown] +""" +# 2D Virtual Staining of A549 Cells +--- +This example shows how to virtually stain A549 cells using the _VSCyto2D_ model. + +First we import the necessary libraries and set the random seed for reproducibility. +""" +# %% Imports and paths +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import torch +import torchview +import torchvision +from iohub import open_ome_zarr +from lightning.pytorch import seed_everything + +# from rich.pretty import pprint #TODO: add pretty print(?) + +from napari.utils.notebook_display import nbscreenshot +import napari + +# %% Imports and paths +from viscy.data.hcs import HCSDataModule + +# Trainer class and UNet. +from viscy.light.engine import FcmaeUNet +from viscy.light.trainer import VSTrainer +from viscy.transforms import NormalizeSampled +from viscy.light.predict_writer import HCSPredictionWriter +from viscy.data.hcs import HCSDataModule + +# %% [markdown] +""" +## Prediction using the 2D U-Net model to predict nuclei and membrane from phase. + +### Construct a 2D U-Net +See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details. +""" + +# %% +input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr/0/0/0" +model_ckpt_path = "/hpc/projects/comp.micro/virtual_staining/models/hek-a549-bj5a-20x/lightning_logs/tiny-2x2-finetune-e2e-amp-hek-a549-bj5a-nucleus-membrane-400ep/checkpoints/last.ckpt" +output_path = "./test_a549_demo.zarr" + +# %% +# Create a the VSCyto2D + +GPU_ID = 0 +BATCH_SIZE = 10 +YX_PATCH_SIZE = (384, 384) +phase_channel_name = "Phase3D" + + +# %% +# Setup the data module. +data_module = HCSDataModule( + data_path=input_data_path, + source_channel=phase_channel_name, + target_channel=["Membrane", "Nuclei"], + z_window_size=1, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + architecture="2D", + yx_patch_size=YX_PATCH_SIZE, + normalizations=[ + NormalizeSampled( + [phase_channel_name], + level="fov_statistics", + subtrahend="median", + divisor="iqr", + ) + ], +) +data_module.prepare_data() +data_module.setup(stage="predict") +# %% +# Setup the model. +# Dictionary that specifies key parameters of the model. +config_VSCyto2D = { + "in_channels": 1, + "out_channels": 2, + "encoder_blocks": [3, 3, 9, 3], + "dims": [96, 192, 384, 768], + "decoder_conv_blocks": 2, + "stem_kernel_size": [1, 2, 2], + "in_stack_depth": 1, + "pretraining": False, +} + +model_VSCyto2D = FcmaeUNet.load_from_checkpoint( + model_ckpt_path, model_config=config_VSCyto2D +) +model_VSCyto2D.eval() + +# %% +trainer = VSTrainer( + accelerator="gpu", + callbacks=[HCSPredictionWriter(output_path)], +) + +# Start the predictions +trainer.predict( + model=model_VSCyto2D, + datamodule=data_module, + return_predictions=False, +) + +# %% From 6a106408c7d349a011015a266b902ee17808cc1d Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 20 Jun 2024 15:15:30 -0700 Subject: [PATCH 02/42] updating notebooks --- examples/demo_dlmbl/solution.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 2c81aa6f..335bd9ca 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -3,26 +3,37 @@ # Image translation --- -Written by Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. +### Overview +In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. -In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. +### Goal +- Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). +- The goal is to learn a mapping from the source domain to the target domain. We will use a _purely convolutional architecture_ that draws on the design principles of transformer models. +- Here we will use a UNeXt2, an efficient image translation architecture inspired by ConvNeXt v2, SparK. UNeXt2. +- We will perform the preprocessing, training, prediction, evaluation, and deployment steps that are unified in a computer vision pipeline for single-cell analysis in our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). -Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). The goal is to learn a mapping from the source domain to the target domain. We will use a deep convolutional neural network (CNN), specifically, a U-Net model with residual connections to learn the mapping. The preprocessing, training, prediction, evaluation, and deployment steps are unified in a computer vision pipeline for single-cell analysis that we call [VisCy](https://github.com/mehta-lab/VisCy). +We will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels. +![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true) + +# Extra information +--- +Written by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. + VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy. ![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg) +## References +--- +[Liu,Z. and Hirata-Miyasaki,E. et al.(2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v1.full.pdf) + [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning -. eLife](https://elifesciences.org/articles/55502). +. eLife](https://elifesciences.org/articles/55502). VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). -""" -# %% [markdown] -""" -Today, we will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels. -![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true) """ + # %% [markdown] """
From 454e5fa91f91f2ca649775be6a5c75e1e718c695 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 3 Jul 2024 18:45:00 -0700 Subject: [PATCH 03/42] draft modifications updating 2DUnet to UNeXT2 --- examples/demo_dlmbl/solution.py | 389 +++++++++++++++++++++----------- 1 file changed, 256 insertions(+), 133 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 335bd9ca..fac6ce1b 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -1,64 +1,59 @@ # %% [markdown] """ -# Image translation +# Image translation (Virtual Staining) --- ### Overview In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. +Additionally, we will apply the inverse process of predicting a phase image from a fluorescence membrane label. -### Goal +### Goals - Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). - The goal is to learn a mapping from the source domain to the target domain. We will use a _purely convolutional architecture_ that draws on the design principles of transformer models. -- Here we will use a UNeXt2, an efficient image translation architecture inspired by ConvNeXt v2, SparK. UNeXt2. +- Here we will use a UNeXt2, an efficient image translation architecture inspired by ConvNeXt v2, SparK. - We will perform the preprocessing, training, prediction, evaluation, and deployment steps that are unified in a computer vision pipeline for single-cell analysis in our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). +- We will train a 2D image translation model using a 2D-Unet with residual connections. We will use a dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask.![](https://github.com/mehta-lab/VisCy/blob/main/docs/figures/svideo_1.png) +- The last exercise is trying out the inverse going from fluorescence to label-free. +![](https://github.com/mehta-lab/VisCy/blob/main/docs/figures/svideo_1.png) -We will train a 2D image translation model using a 2D U-Net with residual connections. We will use a dataset of 301 fields of view (FOVs) of Human Embryonic Kidney (HEK) cells, each FOV has 3 channels (phase, membrane, and nuclei). The cells were labeled with CRISPR editing. Intrestingly, not all cells during this experiment were labeled due to the stochastic nature of CRISPR editing. In such situations, virtual staining rescues missing labels. -![HEK](https://github.com/mehta-lab/VisCy/blob/dlmbl2023/docs/figures/phase_to_nuclei_membrane.svg?raw=true) +📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing) 📖. -# Extra information ---- -Written by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. - -VisCy evolved from our previous work on virtual staining of cellular components from their density and anisotropy. -![](https://iiif.elifesciences.org/lax/55502%2Felife-55502-fig1-v2.tif/full/1500,/0/default.jpg) +Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. -## References ---- -[Liu,Z. and Hirata-Miyasaki,E. et al.(2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v1.full.pdf) +The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. -[Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning -. eLife](https://elifesciences.org/articles/55502). +Checkout VisCy!! Our deep learning pipeline for training and deploying computer vision models for image-based phenotyping in including the robust virtual staining of landmark organelles. VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). -""" +## References +--- +- [Liu,Z. and Hirata-Miyasaki,E. et al.(2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) -# %% [markdown] -""" -
-The exercise is organized in 3 parts. +- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning +. eLife](https://elifesciences.org/articles/55502). + + +Written by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. -* **Part 1** - Explore the data using tensorboard. Launch the training before lunch. -* Lunch break - The model will continue training during lunch. -* **Part 2** - Evaluate the training with tensorboard. Train another model. -* **Part 3** - Tune the models to improve performance. -
""" # %% [markdown] """ -📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) 📖. - +
+The exercise is organized in 3 parts + Extra part. -Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~20 min on a typical AWS node. -We will discuss your observations on google doc after checkpoints 2 and 3. +* **Part 1** - Explore the data using tensorboard.Explore augmentations. Train a phase to fluroesence model. +* **Part 2** - Evaluate the training with tensorboard. Evaluate the trained model. +* **Part 3** - Train the fluorescence to phase model. +* **Extra task** - Tune the models to improve performance. -The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. +
""" # %% [markdown] """
-Set your python kernel to 04_image_translation +Set your python kernel to 06_image_translation
""" # %% @@ -103,21 +98,23 @@ RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, - RandWeightedCropd, + RandWeightedCropd, + NormalizeSampled ) # Trainer class and UNet. -from viscy.light.engine import VSUNet +from viscy.light.engine import VSUNet,MixedLoss from viscy.light.trainer import VSTrainer seed_everything(42, workers=True) # Paths to data and log directory -data_path = Path( - Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/") -).expanduser() +# data_path = Path( +# Path("~/data/06_image_translation/HEK_nuclei_membrane_pyramid.zarr/") +# ).expanduser() -log_dir = Path("~/data/04_image_translation/logs/").expanduser() +data_path = Path("/hpc/projects/comp.micro/virtual_staining/datasets/training/a549-hoechst-cellmask-20x/a549_hoechst_cellmask_train_val.zarr") +log_dir = Path("~/mydata/tmp/06_image_translation/logs/").expanduser() # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) @@ -150,7 +147,6 @@ (OME-NGFF). The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x. -Notice that labelling of nuclei channel is not complete - some cells are not expressing the fluorescent protein. """ # %% @@ -194,7 +190,7 @@ # # ### Task 1.1 # -# Look at a couple different fields of view by changing the value in the cell above. See if you notice any missing or inconsistent staining. +# Look at a couple different fields of view by changing the value in the cell above. Check the cell density, the cell morphologies, and fluorescence signal. #
# %% [markdown] @@ -314,25 +310,26 @@ def log_batch_jupyter(batch): # %% - # Initialize the data module. -BATCH_SIZE = 4 -# 42 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything. +BATCH_SIZE = 6 + +# 6 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything. # More seriously, batch size does not have to be a power of 2. # See: https://sebastianraschka.com/blog/2022/batch-size-2.html data_module = HCSDataModule( data_path, - source_channel="Phase", - target_channel=["Membrane", "Nuclei"], z_window_size=1, + architecture='fcmae', + source_channel=["Phase3D"], + target_channel=['Nucl','Mem'], split_ratio=0.8, batch_size=BATCH_SIZE, num_workers=8, - architecture="2D", - yx_patch_size=(512, 512), # larger patch size makes it easy to see augmentations. - augmentations=None, # Turn off augmentation for now. + yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations. + augmentations=[], # Turn off augmentation for now. + normalizations=[], #Turn off normalization for now. ) data_module.setup("fit") @@ -363,25 +360,45 @@ def log_batch_jupyter(batch): """ # %% # Here we turn on data augmentation and rerun setup +source_channel= ['Phase3D'] +target_channel=['Nucl','Mem'] + augmentations = [ RandWeightedCropd( - keys=["Phase", "Membrane", "Nuclei"], w_key="Nuclei", spatial_size=[512, 512] + keys=source_channel + target_channel, + spatial_size=(1, 384, 384), + num_samples=2, + w_key=target_channel[0], ), RandAffined( - keys=["Phase", "Membrane", "Nuclei"], - prob=0.5, + keys=source_channel + target_channel, rotate_range=[3.14, 0.0, 0.0], - shear_range=[0.0, 0.05, 0.05], scale_range=[0.0, 0.3, 0.3], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], ), - RandAdjustContrastd(keys=["Phase"], prob=0.3, gamma=[0.5, 1.5]), - RandScaleIntensityd(keys=["Phase"], prob=0.5, factors=0.5), - RandGaussianNoised(keys=["Phase"], prob=0.5, std=1), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), RandGaussianSmoothd( - keys=["Phase"], prob=0.5, sigma_x=[0.25, 1.5], sigma_y=[0.25, 1.5] + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, ), ] +normalizations=[ + NormalizeSampled( + keys=source_channel+target_channel, + level="fov_statistics", + subtrahend="median", + divisor="std", + ) +] + data_module.augmentations = augmentations data_module.setup("fit") @@ -405,7 +422,8 @@ def log_batch_jupyter(batch): # # ### Task 1.3 # Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? -# +# Why are these augmentations important? How do can they make the model more robust? + # Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). # @@ -420,27 +438,51 @@ def log_batch_jupyter(batch): # Create a 2D UNet. GPU_ID = 0 BATCH_SIZE = 10 -YX_PATCH_SIZE = (512, 512) +YX_PATCH_SIZE = (256, 256) # Dictionary that specifies key parameters of the model. -phase2fluor_config = { - "architecture": "2D", - "num_filters": [24, 48, 96, 192, 384], - "in_channels": 1, - "out_channels": 2, - "residual": True, - "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data. - "task": "reg", # reg = regression task. -} +# phase2fluor_config=dict( +# in_channels=1, +# out_channels=2, +# encoder_blocks=[3, 3, 9, 3], +# dims=[96, 192, 384, 768], +# decoder_conv_blocks=2, +# stem_kernel_size=(1, 2, 2), +# in_stack_depth=1, +# pretraining=False, +# ) + +# phase2fluor_model = VSUNet( +# architecture='fcmae', +# model_config=phase2fluor_config.copy(), +# loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), +# schedule="WarmupCosine", +# lr=2e-4, +# log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. +# freeze_encoder=False, +# ) +phase2fluor_config=dict( + in_channels=1, + out_channels=2, + in_stack_depth=1, + backbone="convnextv2_tiny", + pretrained=False, + stem_kernel_size=[1, 4, 4], + decoder_mode="pixelshuffle", + decoder_conv_blocks=2, + head_pool=True, + head_expansion_ratio=4, + drop_path_rate=0.0, + ) phase2fluor_model = VSUNet( + architecture='UNeXt2', model_config=phase2fluor_config.copy(), - batch_size=BATCH_SIZE, - loss_function=torch.nn.functional.l1_loss, + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", - log_num_samples=5, # Number of samples from each batch to log to tensorboard. - example_input_yx_shape=YX_PATCH_SIZE, + lr=2e-4, + log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. ) @@ -452,15 +494,16 @@ def log_batch_jupyter(batch): # Setup the data module. phase2fluor_data = HCSDataModule( data_path, - source_channel="Phase", - target_channel=["Membrane", "Nuclei"], + architecture="UNeXt2", + source_channel=["Phase3D"], + target_channel=["Nucl", "Mem"], z_window_size=1, split_ratio=0.8, batch_size=BATCH_SIZE, num_workers=8, - architecture="2D", yx_patch_size=YX_PATCH_SIZE, augmentations=augmentations, + normalizations=normalizations ) phase2fluor_data.setup("fit") # fast_dev_run runs a single batch of data through the model to check for errors. @@ -486,9 +529,11 @@ def log_batch_jupyter(batch): # visualize graph of phase2fluor model as image. model_graph_phase2fluor = torchview.draw_graph( phase2fluor_model, - phase2fluor_data.train_dataset[0]["source"], - depth=2, # adjust depth to zoom in. + phase2fluor_data.train_dataset[0]["source"][0].unsqueeze(dim=0), + roll=True, + depth=3, # adjust depth to zoom in. device="cpu", + expand_nested=True, ) # Print the image of the model. model_graph_phase2fluor.visual_graph @@ -575,18 +620,20 @@ def log_batch_jupyter(batch): # - Structural similarity: # %% Compute metrics directly and plot here. -test_data_path = Path( - "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr" -).expanduser() +# test_data_path = Path( +# "~/data/06_image_translation/HEK_nuclei_membrane_test.zarr" +# ).expanduser() +#TODO: replace this path with relative_test_dataset uncommenting above. +test_data_path = Path("/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr") test_data = HCSDataModule( test_data_path, - source_channel="Phase", - target_channel=["Membrane", "Nuclei"], + source_channel=source_channel, + target_channel=target_channel, z_window_size=1, batch_size=1, num_workers=8, - architecture="2D", + architecture="UNeXt2", ) test_data.setup("test") @@ -633,7 +680,37 @@ def min_max_scale(input): column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"], rot=30, ) - +#%% +#Plot the predicted image +channel_titles=['Phase','VS Nuclei','VS Membrane'] +fig,axes = plt.subplots(1,3,figsize=(12,4)) + +channel_image=phase_image[0,0] +p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) +channel_image = np.clip(channel_image, p_low, p_high) +axes[0].imshow(channel_image) +axes[0].axis("off") +axes[0].set_title(channel_titles[0]) + +for i in range(predicted_image.shape[-4]): + # Adjust contrast to 0.5th and 99.5th percentile of pixel values. + channel_image = predicted_image[i,0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, 0, p_high) + axes[i+1].imshow(channel_image, cmap="gray") + axes[i+1].axis("off") + axes[i+1].set_title(channel_titles[i+1]) +plt.tight_layout() +#%% +#Plot the target images +fig,axes = plt.subplots(1,2,figsize=(10,10)) +for i in range(target_image.shape[-4]): + channel_image=target_image[i,0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[i].imshow(channel_image, cmap="gray") + axes[i].axis("off") + axes[i].set_title(dataset.channel_names[i]) # %% [markdown] tags=[] """ @@ -685,42 +762,86 @@ def min_max_scale(input): ########################## # The entire training loop is contained in this cell. +source_channel=['Mem'] # or 'Nuc' depending on choice +target_channel = ["Phase3D"] + +#Setup the new augmentations +augmentations = [ + RandWeightedCropd( + keys=source_channel + target_channel, + spatial_size=(1, 384, 384), + num_samples=2, + w_key=target_channel[0], + ), + RandAffined( + keys=source_channel + target_channel, + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], + ), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), +] +normalizations=[ + NormalizeSampled( + keys=source_channel+target_channel, + level="fov_statistics", + subtrahend="median", + divisor="std", + ) +] + +#Setup the dataloader fluor2phase_data = HCSDataModule( data_path, - source_channel="Membrane", - target_channel="Phase", + architecture="UNeXt2", + source_channel=source_channel, + target_channel=target_channel, z_window_size=1, split_ratio=0.8, batch_size=BATCH_SIZE, num_workers=8, - architecture="2D", yx_patch_size=YX_PATCH_SIZE, augmentations=augmentations, + normalizations=normalizations ) fluor2phase_data.setup("fit") # Dictionary that specifies key parameters of the model. -fluor2phase_config = { - "architecture": "2D", - "in_channels": 1, - "out_channels": 1, - "residual": True, - "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data. - "task": "reg", # reg = regression task. - "num_filters": [24, 48, 96, 192, 384], -} +fluor2phase_config = dict( + in_channels=1, + out_channels=2, + in_stack_depth=1, + backbone="convnextv2_tiny", + pretrained=False, + stem_kernel_size=[1, 4, 4], + decoder_mode="pixelshuffle", + decoder_conv_blocks=2, + head_pool=True, + head_expansion_ratio=4, + drop_path_rate=0.0, + ) fluor2phase_model = VSUNet( + architecture='UNeXt2', model_config=fluor2phase_config.copy(), - batch_size=BATCH_SIZE, - loss_function=torch.nn.functional.mse_loss, + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", - log_num_samples=5, - example_input_yx_shape=YX_PATCH_SIZE, + lr=2e-4, + log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. ) - trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], @@ -735,7 +856,7 @@ def min_max_scale(input): ) trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) - +#%% # Visualize the graph of fluor2phase model as image. model_graph_fluor2phase = torchview.draw_graph( fluor2phase_model, @@ -759,17 +880,17 @@ def min_max_scale(input): """ # %% test_data_path = Path( - "~/data/04_image_translation/HEK_nuclei_membrane_test.zarr" + "~/data/06_image_translation/HEK_nuclei_membrane_test.zarr" ).expanduser() test_data = HCSDataModule( test_data_path, - source_channel="Nuclei", # or Membrane, depending on your choice of source + source_channel="Mem", # or Nuc, depending on your choice of source target_channel="Phase", z_window_size=1, batch_size=1, num_workers=8, - architecture="2D", + architecture="UNeXt2", ) test_data.setup("test") @@ -816,7 +937,7 @@ def min_max_scale(input):
## Checkpoint 2 -When your model finishes training, please summarize hyperparameters and performance of your models in the [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) +When your model finishes training, please summarize hyperparameters and performance of your models in the [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing)
""" @@ -842,7 +963,7 @@ def min_max_scale(input): - Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop. - Add code to evaluate the model using Pearson Correlation and SSIM -As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z) +As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing) """ @@ -886,25 +1007,27 @@ def min_max_scale(input): ########################## ######## Solution ######## ########################## +phase2fluor_config=dict( + in_channels=1, + out_channels=2, + in_stack_depth=1, + backbone="convnextv2_base", + pretrained=False, + stem_kernel_size=[1, 4, 4], + decoder_mode="pixelshuffle", + decoder_conv_blocks=2, + head_pool=True, + head_expansion_ratio=4, + drop_path_rate=0.0, + ) -phase2fluor_wider_config = { - "architecture": "2D", - # double the number of filters at each stage - "num_filters": [48, 96, 192, 384, 768], - "in_channels": 1, - "out_channels": 2, - "residual": True, - "dropout": 0.1, - "task": "reg", -} - -phase2fluor_wider_model = VSUNet( - model_config=phase2fluor_wider_config.copy(), - batch_size=BATCH_SIZE, - loss_function=torch.nn.functional.l1_loss, +phase2fluor_model = VSUNet( + architecture='UNeXt2', + model_config=phase2fluor_config.copy(), + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", - log_num_samples=5, - example_input_yx_shape=YX_PATCH_SIZE, + lr=2e-4, + log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. ) @@ -916,28 +1039,26 @@ def min_max_scale(input): logger=TensorBoardLogger( save_dir=log_dir, name="phase2fluor", - version="wider", + version="base_model", log_graph=True, ), fast_dev_run=True, ) # Set fast_dev_run to False to train the model. -trainer.fit(phase2fluor_wider_model, datamodule=phase2fluor_data) +trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) # %% tags=["solution"] ########################## ######## Solution ######## ########################## - phase2fluor_slow_model = VSUNet( + architecture='UNeXt2', model_config=phase2fluor_config.copy(), - batch_size=BATCH_SIZE, - loss_function=torch.nn.functional.l1_loss, - # lower learning rate by 5 times - lr=2e-4, + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), + # lower learning rate by 10 times + lr=2e-5, schedule="WarmupCosine", - log_num_samples=5, - example_input_yx_shape=YX_PATCH_SIZE, + log_batches_per_epoch=5, ) trainer = VSTrainer( @@ -963,6 +1084,8 @@ def min_max_scale(input): ## Checkpoint 3 Congratulations! You have trained several image translation models now! -Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1hZWSVRvt9KJEdYu7ib-vFBqAVQRYL8cWaP_vFznu7D8/edit#heading=h.n5u485pmzv2z). We'll discuss our combined results as a group. +Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). We'll discuss our combined results as a group. """ + +# %% From c07136492abcf85b6ab638feda9939f756b92916 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 4 Jul 2024 19:10:02 -0700 Subject: [PATCH 04/42] editing solution --- examples/demo_dlmbl/README.md | 14 +- examples/demo_dlmbl/solution.py | 367 ++++++++++-------- .../demos/VSCyto2d_a549cells/demo_vscyto2d.py | 113 ------ 3 files changed, 204 insertions(+), 290 deletions(-) delete mode 100644 examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py diff --git a/examples/demo_dlmbl/README.md b/examples/demo_dlmbl/README.md index eb7917c8..d7a4e9f6 100644 --- a/examples/demo_dlmbl/README.md +++ b/examples/demo_dlmbl/README.md @@ -1,6 +1,6 @@ -# Exercise 4: Image translation +# Exercise 6: Image translation - Part 1 -This demo script was developed for the DL@MBL 2023 course by Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). +This demo script was developed for the DL@MBL 2023 course by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). @@ -9,10 +9,10 @@ This demo script was developed for the DL@MBL 2023 course by Ziwen Liu and Shali Make sure that you are inside of the `image_translation` folder by using the `cd` command to change directories if needed. -Make sure that you can use mamba to switch environments. +Make sure that you can use conda to switch environments. ```bash -mamba init +conda init ``` **Close your shell, and login again.** @@ -23,7 +23,7 @@ sh setup.sh ``` Activate your environment ```bash -mamba activate 04_image_translation +conda activate 06_image_translation ``` ## Use vscode @@ -42,7 +42,7 @@ jupyter notebook ...and continue with the instructions in the notebook. -If 04_image_translation is not available as a kernel in jupyter, run +If 06_image_translation is not available as a kernel in jupyter, run ``` -python -m ipykernel install --user --name=04_image_translation +python -m ipykernel install --user --name=06_image_translation ``` diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index fac6ce1b..701c0032 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -1,55 +1,76 @@ # %% [markdown] """ -# Image translation (Virtual Staining) ---- - -### Overview -In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will _virtually stain_ the nuclei and membrane visible in the phase image. -Additionally, we will apply the inverse process of predicting a phase image from a fluorescence membrane label. +
+

Image translation (Virtual Staining)

+
+

Overview

+

In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will virtually stain the nuclei and cell membrane visible in the phase image. We will apply a series of spatial and intensity augmentations to train robust models and evaluate their performance. Finally, we will apply the inverse process of predicting a phase image from a fluorescence membrane label.

+ +
+

+
+ + Virtual Staining + +
(click image to play)
+
+
+
+""" +# %% [markdown] +""" ### Goals -- Here, the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). -- The goal is to learn a mapping from the source domain to the target domain. We will use a _purely convolutional architecture_ that draws on the design principles of transformer models. -- Here we will use a UNeXt2, an efficient image translation architecture inspired by ConvNeXt v2, SparK. -- We will perform the preprocessing, training, prediction, evaluation, and deployment steps that are unified in a computer vision pipeline for single-cell analysis in our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). -- We will train a 2D image translation model using a 2D-Unet with residual connections. We will use a dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask.![](https://github.com/mehta-lab/VisCy/blob/main/docs/figures/svideo_1.png) -- The last exercise is trying out the inverse going from fluorescence to label-free. -![](https://github.com/mehta-lab/VisCy/blob/main/docs/figures/svideo_1.png) +**Part 1: Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard.** -📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing) 📖. + - Use a `ome-zarr` dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask. + - Explore the OME-zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) and the high-content-screen (HCS) format. + - Use [MONAI](https://monai.io/) to implement data augmentations. -Our guesstimate is that each of the three parts will take ~1.5 hours. A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. +**Part 2: Train a phase to fluorescence model using a UNeXt2 and evaluate it and vice versa.** -The focus of the exercise is on understanding information content of the data, how to train and evaluate 2D image translation model, and explore some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. + - Create a model for image translation mapping from source domain to target domain where the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). + - Use the UNeXt2 architecture, a _purely convolutional architecture_ that draws on the design principles of transformer models to complete this task. Here we will use a *UNeXt2*, an efficient image translation architecture inspired by ConvNeXt v2, SparK. + - We will perform the preprocessing, training, prediction, evaluation, and deployment steps that borrow from our computer vision pipeline for single-cell analysis in our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). + - Reuse the same architecture as above and create a similar model doing the inverse task (fluorescence to phase). + - Evaluate the model. +**(Extra) Play with the hyperparameters to improve the models or train a 3D UNeXt2** -Checkout VisCy!! Our deep learning pipeline for training and deploying computer vision models for image-based phenotyping in including the robust virtual staining of landmark organelles. -VisCy exploits recent advances in the data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). +Our guesstimate is that each of the three parts will take ~1-1.5 hours. A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. +The focus of the exercise is on understanding the information content of the data, how to train and evaluate 2D image translation models, and exploring some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. + +Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos), our deep learning pipeline for training and deploying computer vision models for image-based phenotyping including the robust virtual staining of landmark organelles. VisCy exploits recent advances in data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). ## References --- -- [Liu,Z. and Hirata-Miyasaki,E. et al.(2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) +- [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) +- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) -- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning -. eLife](https://elifesciences.org/articles/55502). +Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco. +""" -Written by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, CZ Biohub San Francisco. +#%% [markdown] +""" +📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this Google Doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). 📖 """ # %% [markdown] """
The exercise is organized in 3 parts + Extra part. -* **Part 1** - Explore the data using tensorboard.Explore augmentations. Train a phase to fluroesence model. -* **Part 2** - Evaluate the training with tensorboard. Evaluate the trained model. -* **Part 3** - Train the fluorescence to phase model. -* **Extra task** - Tune the models to improve performance. +
    +
  • Part 1 - Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard.
  • +
  • Part 2 - Train and evaluate the phase to fluorescence model and viceversa.
  • +
  • Extra task - Tune the models to improve performance.
  • +
""" + # %% [markdown] """
@@ -63,10 +84,10 @@ Learning goals: -- Load the OME-zarr dataset and examine the channels. +- Load the OME-zarr dataset and examine the channels (A54). - Configure and understand the data loader. - Log some patches to tensorboard. -- Initialize a 2D U-Net model for virtual staining +- Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase. - Start training the model to predict nuclei and membrane from phase. """ @@ -136,9 +157,9 @@ # %% [markdown] """ -## Load Dataset. +## Load OME-Zarr Dataset. -There should be 301 FOVs in the dataset (12 GB compressed). +There should be 34 FOVs in the dataset. Each FOV consists of 3 channels of 2048x2048 images, saved in the @@ -157,7 +178,7 @@ # Use the field and pyramid_level below to visualize data. row = 0 col = 0 -field = 23 # TODO: Change this to explore data. +field = 10 # TODO: Change this to explore data. # This dataset contains images at 3 resolutions. # '0' is the highest resolution @@ -190,7 +211,8 @@ # # ### Task 1.1 # -# Look at a couple different fields of view by changing the value in the cell above. Check the cell density, the cell morphologies, and fluorescence signal. +# Look at a couple different fields of view by changing the value in the cell above. +# Check the cell density, the cell morphologies, and fluorescence signal. #
# %% [markdown] @@ -355,9 +377,28 @@ def log_batch_jupyter(batch): log_batch_jupyter(batch) # %% [markdown] -""" -## View augmentations using tensorboard. -""" +#
+# +# ### Task 1.3 +# Add augmentations to the datamodule and rerun the setup. +# +# What kind of augmentations do you think are important for this task? +# +# How do they make the model more robust? +# +# Add augmentations to rotate about `pi` along z-axis, 30% scale in y,x, shearing of 10% and no padding with zeros with a probablity of 80%. +# +# Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. +# +# Hint: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). +# *Note* these are MONAI transforms that have been redefined for VisCy. + +# Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? + +# +# HINT:[Check your augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). +#
+ # %% # Here we turn on data augmentation and rerun setup source_channel= ['Phase3D'] @@ -366,21 +407,13 @@ def log_batch_jupyter(batch): augmentations = [ RandWeightedCropd( keys=source_channel + target_channel, - spatial_size=(1, 384, 384), + spatial_size=(1, 256, 256), num_samples=2, w_key=target_channel[0], ), - RandAffined( - keys=source_channel + target_channel, - rotate_range=[3.14, 0.0, 0.0], - scale_range=[0.0, 0.3, 0.3], - prob=0.8, - padding_mode="zeros", - shear_range=[0.0, 0.01, 0.01], - ), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), - RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), RandGaussianSmoothd( keys=source_channel, sigma_x=(0.25, 0.75), @@ -388,6 +421,11 @@ def log_batch_jupyter(batch): sigma_z=(0.0, 0.0), prob=0.5, ), + ##TODO: Add rotation agumentations + ## Write code below + + ## TODO: Add Random Gaussian Noise + ## Write code below ] normalizations=[ @@ -417,84 +455,51 @@ def log_batch_jupyter(batch): # %% log_batch_jupyter(augmented_batch) -# %% [markdown] -#
-# -# ### Task 1.3 -# Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? -# Why are these augmentations important? How do can they make the model more robust? - -# Check your answer using the source code [here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). -#
- # %% [markdown] """ ## Train a 2D U-Net model to predict nuclei and membrane from phase. -### Construct a 2D U-Net +### Construct a 2D UNeXt2 using VisCy See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details. """ # %% # Create a 2D UNet. GPU_ID = 0 -BATCH_SIZE = 10 +BATCH_SIZE = 2 YX_PATCH_SIZE = (256, 256) # Dictionary that specifies key parameters of the model. -# phase2fluor_config=dict( -# in_channels=1, -# out_channels=2, -# encoder_blocks=[3, 3, 9, 3], -# dims=[96, 192, 384, 768], -# decoder_conv_blocks=2, -# stem_kernel_size=(1, 2, 2), -# in_stack_depth=1, -# pretraining=False, -# ) - -# phase2fluor_model = VSUNet( -# architecture='fcmae', -# model_config=phase2fluor_config.copy(), -# loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), -# schedule="WarmupCosine", -# lr=2e-4, -# log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. -# freeze_encoder=False, -# ) phase2fluor_config=dict( - in_channels=1, - out_channels=2, - in_stack_depth=1, - backbone="convnextv2_tiny", - pretrained=False, - stem_kernel_size=[1, 4, 4], - decoder_mode="pixelshuffle", - decoder_conv_blocks=2, - head_pool=True, - head_expansion_ratio=4, - drop_path_rate=0.0, - ) + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) phase2fluor_model = VSUNet( - architecture='UNeXt2', + architecture='fcmae', #2D UNeXt2 architecture model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", lr=2e-4, log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. + freeze_encoder=False, ) - # %% [markdown] """ ### Instantiate data module and trainer, test that we are setup to launch training. """ # %% # Setup the data module. -phase2fluor_data = HCSDataModule( +phase2fluor_2D_data = HCSDataModule( data_path, - architecture="UNeXt2", + architecture="fcmae", source_channel=["Phase3D"], target_channel=["Nucl", "Mem"], z_window_size=1, @@ -505,12 +510,12 @@ def log_batch_jupyter(batch): augmentations=augmentations, normalizations=normalizations ) -phase2fluor_data.setup("fit") +phase2fluor_2D_data.setup("fit") # fast_dev_run runs a single batch of data through the model to check for errors. trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], fast_dev_run=True) # trainer class takes the model and the data module as inputs. -trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) +trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) # %% [markdown] @@ -529,15 +534,16 @@ def log_batch_jupyter(batch): # visualize graph of phase2fluor model as image. model_graph_phase2fluor = torchview.draw_graph( phase2fluor_model, - phase2fluor_data.train_dataset[0]["source"][0].unsqueeze(dim=0), + phase2fluor_2D_data.train_dataset[0]["source"][0].unsqueeze(dim=0), roll=True, depth=3, # adjust depth to zoom in. device="cpu", - expand_nested=True, + # expand_nested=True, ) # Print the image of the model. model_graph_phase2fluor.visual_graph + # %% [markdown] """
@@ -547,11 +553,9 @@ def log_batch_jupyter(batch):
""" - # %% - GPU_ID = 0 -n_samples = len(phase2fluor_data.train_dataset) +n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. n_epochs = 50 # Set this to 50 or the number of epochs you want to train for. @@ -569,7 +573,7 @@ def log_batch_jupyter(batch): ), ) # Launch training and check that loss and images are being logged on tensorboard. -trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) +trainer.fit(phase2fluor_model, datamodule=phase2fluor_2D_data) # %% [markdown] """ @@ -579,6 +583,7 @@ def log_batch_jupyter(batch): Now the training has started, we can come back after a while and evaluate the performance! + """ @@ -682,36 +687,41 @@ def min_max_scale(input): ) #%% #Plot the predicted image -channel_titles=['Phase','VS Nuclei','VS Membrane'] -fig,axes = plt.subplots(1,3,figsize=(12,4)) +channel_titles = ['Phase', 'Nuclei', 'Membrane'] +fig, axes = plt.subplots(2, 3, figsize=(30, 20)) -channel_image=phase_image[0,0] +# Plot the phase image +channel_image = phase_image[0, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) -axes[0].imshow(channel_image) -axes[0].axis("off") -axes[0].set_title(channel_titles[0]) +axes[0, 0].imshow(channel_image) +axes[0, 0].axis("off") +axes[0, 0].set_title(channel_titles[0]) +# Plot the predicted images for i in range(predicted_image.shape[-4]): - # Adjust contrast to 0.5th and 99.5th percentile of pixel values. - channel_image = predicted_image[i,0] + channel_image = predicted_image[i, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) - channel_image = np.clip(channel_image, 0, p_high) - axes[i+1].imshow(channel_image, cmap="gray") - axes[i+1].axis("off") - axes[i+1].set_title(channel_titles[i+1]) -plt.tight_layout() -#%% -#Plot the target images -fig,axes = plt.subplots(1,2,figsize=(10,10)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[0, i + 1].imshow(channel_image, cmap="gray") + axes[0, i + 1].axis("off") + axes[0, i + 1].set_title(f"VS {channel_titles[i + 1]}") + +# Plot the target images for i in range(target_image.shape[-4]): - channel_image=target_image[i,0] + channel_image = target_image[i, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) - axes[i].imshow(channel_image, cmap="gray") - axes[i].axis("off") - axes[i].set_title(dataset.channel_names[i]) + axes[1, i+1].imshow(channel_image, cmap="gray") + axes[1, i+1].axis("off") + axes[1, i+1].set_title(f"Target {dataset.channel_names[i+1]}") +# Remove any unused subplots +for j in range(i + 1, 3): + fig.delaxes(axes[1, j]) + +plt.tight_layout() +plt.show() # %% [markdown] tags=[] """
@@ -805,7 +815,7 @@ def min_max_scale(input): #Setup the dataloader fluor2phase_data = HCSDataModule( data_path, - architecture="UNeXt2", + architecture="fcmae", source_channel=source_channel, target_channel=target_channel, z_window_size=1, @@ -819,27 +829,25 @@ def min_max_scale(input): fluor2phase_data.setup("fit") # Dictionary that specifies key parameters of the model. -fluor2phase_config = dict( - in_channels=1, - out_channels=2, - in_stack_depth=1, - backbone="convnextv2_tiny", - pretrained=False, - stem_kernel_size=[1, 4, 4], - decoder_mode="pixelshuffle", - decoder_conv_blocks=2, - head_pool=True, - head_expansion_ratio=4, - drop_path_rate=0.0, - ) +fluor2phase_config=dict( + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) fluor2phase_model = VSUNet( - architecture='UNeXt2', - model_config=fluor2phase_config.copy(), + architecture='fcmae', + model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", lr=2e-4, log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. + freeze_encoder=False, ) trainer = VSTrainer( @@ -936,18 +944,20 @@ def min_max_scale(input): """ """ -# %% tags=[] + +# %% tags=[] """ -# Part 3: Tune the models. +# (Extra)Tune the models and explore other architectures from [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos) -------------------------------------------------- - -Learning goals: Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model. +Learning goals: +- Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model. +- How can we scale it up from 2D to 3D training and predictions? """ @@ -955,7 +965,7 @@ def min_max_scale(input): """
-### Task 3.1 +### Extra Part - Choose a model you want to train (phase2fluor or fluor2phase). - Set up a configuration that you think will improve the performance of the model @@ -1008,29 +1018,25 @@ def min_max_scale(input): ######## Solution ######## ########################## phase2fluor_config=dict( - in_channels=1, - out_channels=2, - in_stack_depth=1, - backbone="convnextv2_base", - pretrained=False, - stem_kernel_size=[1, 4, 4], - decoder_mode="pixelshuffle", - decoder_conv_blocks=2, - head_pool=True, - head_expansion_ratio=4, - drop_path_rate=0.0, - ) + in_channels=1, + out_channels=2, + encoder_blocks=[3, 3, 9, 3], + dims=[96, 192, 384, 768], + decoder_conv_blocks=2, + stem_kernel_size=(1, 2, 2), + in_stack_depth=1, + pretraining=False, +) -phase2fluor_model = VSUNet( - architecture='UNeXt2', +phase2fluor_model_low_lr= VSUNet( + architecture='fcmae', model_config=phase2fluor_config.copy(), - loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), + loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), #Changed the loss function to MixedLoss L1 and MS-SSIM schedule="WarmupCosine", - lr=2e-4, + lr=2e-5, #lower learning rate by factor of 10 log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. ) - trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], @@ -1039,24 +1045,45 @@ def min_max_scale(input): logger=TensorBoardLogger( save_dir=log_dir, name="phase2fluor", - version="base_model", + version="phase2fluor_low_lr", log_graph=True, ), fast_dev_run=True, ) # Set fast_dev_run to False to train the model. -trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) +trainer.fit(phase2fluor_model_low_lr, datamodule=phase2fluor_2D_data) # %% tags=["solution"] ########################## ######## Solution ######## ########################## -phase2fluor_slow_model = VSUNet( +data_path = Path() #TODO: Point to a 3D dataset (HEK, Neuromast) +BATCH_SIZE = 4 +YX_PATCH_SIZE = (384, 384) + +## For 3D training - VSCyto3D +source_channel=['Phase3D'] +target_channel=['Nucl','Mem'] + +phase2fluor_3D_data = HCSDataModule( + data_path, + architecture="UNeXt2", + source_channel=source_channel, + target_channel=target_channel, + z_window_size=5, + split_ratio=0.8, + batch_size=BATCH_SIZE, + num_workers=8, + yx_patch_size=YX_PATCH_SIZE, + augmentations=augmentations, + normalizations=normalizations +) + +phase2fluor_3D = VSUNet( architecture='UNeXt2', model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), - # lower learning rate by 10 times - lr=2e-5, + lr=2e-4, schedule="WarmupCosine", log_batches_per_epoch=5, ) @@ -1069,22 +1096,22 @@ def min_max_scale(input): logger=TensorBoardLogger( save_dir=log_dir, name="phase2fluor", - version="low_lr", + version="3D_UNeXt2", log_graph=True, ), fast_dev_run=True, ) -trainer.fit(phase2fluor_slow_model, datamodule=phase2fluor_data) +trainer.fit(phase2fluor_3D, datamodule=phase2fluor_3D_data) # %% [markdown] tags=[] """
-## Checkpoint 3 +## 🎉 The end of the notebook 🎉 Congratulations! You have trained several image translation models now! -Please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). We'll discuss our combined results as a group. +Please remember to document the hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). We'll discuss our combined results as a group.
""" diff --git a/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py b/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py deleted file mode 100644 index 4eca98f0..00000000 --- a/examples/demos/VSCyto2d_a549cells/demo_vscyto2d.py +++ /dev/null @@ -1,113 +0,0 @@ -# %% [markdown] -""" -# 2D Virtual Staining of A549 Cells ---- -This example shows how to virtually stain A549 cells using the _VSCyto2D_ model. - -First we import the necessary libraries and set the random seed for reproducibility. -""" -# %% Imports and paths -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import torchview -import torchvision -from iohub import open_ome_zarr -from lightning.pytorch import seed_everything - -# from rich.pretty import pprint #TODO: add pretty print(?) - -from napari.utils.notebook_display import nbscreenshot -import napari - -# %% Imports and paths -from viscy.data.hcs import HCSDataModule - -# Trainer class and UNet. -from viscy.light.engine import FcmaeUNet -from viscy.light.trainer import VSTrainer -from viscy.transforms import NormalizeSampled -from viscy.light.predict_writer import HCSPredictionWriter -from viscy.data.hcs import HCSDataModule - -# %% [markdown] -""" -## Prediction using the 2D U-Net model to predict nuclei and membrane from phase. - -### Construct a 2D U-Net -See ``viscy.unet.networks.Unet2D.Unet2d`` ([source code](https://github.com/mehta-lab/VisCy/blob/7c5e4c1d68e70163cf514d22c475da8ea7dc3a88/viscy/unet/networks/Unet2D.py#L7)) for configuration details. -""" - -# %% -input_data_path = "/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr/0/0/0" -model_ckpt_path = "/hpc/projects/comp.micro/virtual_staining/models/hek-a549-bj5a-20x/lightning_logs/tiny-2x2-finetune-e2e-amp-hek-a549-bj5a-nucleus-membrane-400ep/checkpoints/last.ckpt" -output_path = "./test_a549_demo.zarr" - -# %% -# Create a the VSCyto2D - -GPU_ID = 0 -BATCH_SIZE = 10 -YX_PATCH_SIZE = (384, 384) -phase_channel_name = "Phase3D" - - -# %% -# Setup the data module. -data_module = HCSDataModule( - data_path=input_data_path, - source_channel=phase_channel_name, - target_channel=["Membrane", "Nuclei"], - z_window_size=1, - split_ratio=0.8, - batch_size=BATCH_SIZE, - num_workers=8, - architecture="2D", - yx_patch_size=YX_PATCH_SIZE, - normalizations=[ - NormalizeSampled( - [phase_channel_name], - level="fov_statistics", - subtrahend="median", - divisor="iqr", - ) - ], -) -data_module.prepare_data() -data_module.setup(stage="predict") -# %% -# Setup the model. -# Dictionary that specifies key parameters of the model. -config_VSCyto2D = { - "in_channels": 1, - "out_channels": 2, - "encoder_blocks": [3, 3, 9, 3], - "dims": [96, 192, 384, 768], - "decoder_conv_blocks": 2, - "stem_kernel_size": [1, 2, 2], - "in_stack_depth": 1, - "pretraining": False, -} - -model_VSCyto2D = FcmaeUNet.load_from_checkpoint( - model_ckpt_path, model_config=config_VSCyto2D -) -model_VSCyto2D.eval() - -# %% -trainer = VSTrainer( - accelerator="gpu", - callbacks=[HCSPredictionWriter(output_path)], -) - -# Start the predictions -trainer.predict( - model=model_VSCyto2D, - datamodule=data_module, - return_predictions=False, -) - -# %% From 3ea1df6756b852dd5f10ded497c7207b36e8692b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 5 Jul 2024 19:26:40 -0700 Subject: [PATCH 05/42] modifying to relative paths based on setup.sh download paths. --- examples/demo_dlmbl/setup.sh | 33 +++++++++++-------- examples/demo_dlmbl/solution.py | 58 +++++++++++++++++++++++---------- 2 files changed, 59 insertions(+), 32 deletions(-) diff --git a/examples/demo_dlmbl/setup.sh b/examples/demo_dlmbl/setup.sh index e502ceee..1ad21d9c 100644 --- a/examples/demo_dlmbl/setup.sh +++ b/examples/demo_dlmbl/setup.sh @@ -2,31 +2,36 @@ START_DIR=$(pwd) -# Create mamba environment -mamba create -y --name 04_image_translation python=3.10 +# Create conda environment +conda create -y --name 06_image_translation python=3.10 # Install ipykernel in the environment. -mamba install -y ipykernel nbformat nbconvert black jupytext ipywidgets --name 04_image_translation +conda install -y ipykernel nbformat nbconvert black jupytext ipywidgets --name 06_image_translation # Specifying the environment explicitly. -# mamba activate sometimes doesn't work from within shell scripts. +# conda activate sometimes doesn't work from within shell scripts. # install viscy and its dependencies`s in the environment using pip. mkdir -p ~/code/ cd ~/code/ git clone https://github.com/mehta-lab/viscy.git cd viscy -git checkout 7c5e4c1d68e70163cf514d22c475da8ea7dc3a88 # Exercise is tested with this commit of viscy -# Find path to the environment - mamba activate doesn't work from within shell scripts. -ENV_PATH=$(conda info --envs | grep 04_image_translation | awk '{print $NF}') +git checkout main # Exercise is tested with this commit of viscy +# Find path to the environment - conda activate doesn't work from within shell scripts. +ENV_PATH=$(conda info --envs | grep 06_image_translation | awk '{print $NF}') $ENV_PATH/bin/pip install ".[metrics]" -# Create data directory -mkdir -p ~/data/04_image_translation -cd ~/data/04_image_translation -wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_data_pyramid.tar.gz -wget https://dl-at-mbl-2023-data.s3.us-east-2.amazonaws.com/DLMBL2023_image_translation_test.tar.gz -tar -xzf DLMBL2023_image_translation_data_pyramid.tar.gz -tar -xzf DLMBL2023_image_translation_test.tar.gz +# Create the directory structure +mkdir -p ~/data/06_image_translation/training +mkdir -p ~/data/06_image_translation/test + +# Change to the target directory +cd ~/data/06_image_translation/training +# Download the Zarr dataset recursively (if the server supports it) +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto2D/training/a549_hoechst_cellmask_train_val.zarr/" + +cd ~/data/06_image_translation/test +# Download the Zarr dataset recursively (if the server supports it) +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto2D/test/a549_hoechst_cellmask_test.zarr/" # Change back to the starting directory cd $START_DIR diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 701c0032..a13ca02d 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -101,6 +101,7 @@ import torchview import torchvision from iohub import open_ome_zarr +from iohub.reader import print_info from lightning.pytorch import seed_everything from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger from skimage import metrics # for metrics. @@ -130,12 +131,11 @@ seed_everything(42, workers=True) # Paths to data and log directory -# data_path = Path( -# Path("~/data/06_image_translation/HEK_nuclei_membrane_pyramid.zarr/") -# ).expanduser() +data_path = Path( + Path("~/data/06_image_translation/training/a549_hoechst_cellmask_train_val.zarr") +).expanduser() -data_path = Path("/hpc/projects/comp.micro/virtual_staining/datasets/training/a549-hoechst-cellmask-20x/a549_hoechst_cellmask_train_val.zarr") -log_dir = Path("~/mydata/tmp/06_image_translation/logs/").expanduser() +log_dir = Path("~/data/06_image_translation/logs/").expanduser() # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) @@ -170,15 +170,26 @@ The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x. """ -# %% -dataset = open_ome_zarr(data_path) +# %%[markdown] +""" +You can inspect the tree structure by using your terminal: +`iohub info -v ` -print(f"Number of positions: {len(list(dataset.positions()))}") +More info on the CLI: +`iohub info -h` to see the help menu. +""" +#%% +# This is the python function called by `iohub info` CLI command +print_info(data_path,verbose=True) + +# Open and inspect the dataset. +dataset = open_ome_zarr(data_path) +#%% # Use the field and pyramid_level below to visualize data. row = 0 col = 0 -field = 10 # TODO: Change this to explore data. +field = 9 # TODO: Change this to explore data. # This dataset contains images at 3 resolutions. # '0' is the highest resolution @@ -464,8 +475,8 @@ def log_batch_jupyter(batch): """ # %% # Create a 2D UNet. -GPU_ID = 0 -BATCH_SIZE = 2 +GPU_ID = 1 +BATCH_SIZE = 12 YX_PATCH_SIZE = (256, 256) @@ -554,7 +565,7 @@ def log_batch_jupyter(batch): """ # %% -GPU_ID = 0 +GPU_ID = 1 n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. n_epochs = 50 # Set this to 50 or the number of epochs you want to train for. @@ -625,12 +636,10 @@ def log_batch_jupyter(batch): # - Structural similarity: # %% Compute metrics directly and plot here. -# test_data_path = Path( -# "~/data/06_image_translation/HEK_nuclei_membrane_test.zarr" -# ).expanduser() +test_data_path = Path( + "~/data/06_image_translation/test/a549_hoechst_cellmask_test.zarr" +).expanduser() -#TODO: replace this path with relative_test_dataset uncommenting above. -test_data_path = Path("/hpc/projects/comp.micro/virtual_staining/datasets/test/cell_types_20x/a549_sliced/a549_hoechst_cellmask_test.zarr") test_data = HCSDataModule( test_data_path, source_channel=source_channel, @@ -888,7 +897,7 @@ def min_max_scale(input): """ # %% test_data_path = Path( - "~/data/06_image_translation/HEK_nuclei_membrane_test.zarr" + "~/data/06_image_translation/test/a549_hoechst_cellmask_test.zarr" ).expanduser() test_data = HCSDataModule( @@ -1057,6 +1066,19 @@ def min_max_scale(input): ########################## ######## Solution ######## ########################## +""" +You can download the file and place it in the data folder. +https://public.czbiohub.org/comp.micro/viscy/VSCyto3D/train/raw-and-reconstructed.zarr/ + +You can run the following shell script: +``` +cd ~/data/hek3d/training +# Download the Zarr dataset recursively (if the server supports it) +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto3D/train/raw-and-reconstructed.zarr/" +``` + +""" + data_path = Path() #TODO: Point to a 3D dataset (HEK, Neuromast) BATCH_SIZE = 4 YX_PATCH_SIZE = (384, 384) From a74662cdceb72b18f5b389d2a9f22218cd8bc8b3 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Sun, 7 Jul 2024 22:07:13 -0700 Subject: [PATCH 06/42] small edits --- examples/demo_dlmbl/solution.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index a13ca02d..0d9d65fd 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -4,7 +4,7 @@

Image translation (Virtual Staining)


Overview

-

In this exercise, we will solve an image translation task to predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells. In other words, we will virtually stain the nuclei and cell membrane visible in the phase image. We will apply a series of spatial and intensity augmentations to train robust models and evaluate their performance. Finally, we will apply the inverse process of predicting a phase image from a fluorescence membrane label.

+

In this exercise, we will predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells, i.e., we will virtually stain the nuclei and cell membrane visible in the phase image. This is an example of an image translation task. We will apply spatial and intensity augmentations to train robust models and evaluate their performance. Finally, we will explore the opposite process of predicting a phase image from a fluorescence membrane label.



@@ -22,13 +22,13 @@ """ ### Goals -**Part 1: Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard.** +#### Part 1: Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard. - Use a `ome-zarr` dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask. - Explore the OME-zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) and the high-content-screen (HCS) format. - Use [MONAI](https://monai.io/) to implement data augmentations. - -**Part 2: Train a phase to fluorescence model using a UNeXt2 and evaluate it and vice versa.** + +#### Part 2: Train a model that predicts fluorescence from phase, and vice versa, using the UNeXt2 architecture. - Create a model for image translation mapping from source domain to target domain where the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). - Use the UNeXt2 architecture, a _purely convolutional architecture_ that draws on the design principles of transformer models to complete this task. Here we will use a *UNeXt2*, an efficient image translation architecture inspired by ConvNeXt v2, SparK. @@ -36,7 +36,7 @@ - Reuse the same architecture as above and create a similar model doing the inverse task (fluorescence to phase). - Evaluate the model. -**(Extra) Play with the hyperparameters to improve the models or train a 3D UNeXt2** +#### (Extra) Play with the hyperparameters to improve the models or train a 3D UNeXt2** Our guesstimate is that each of the three parts will take ~1-1.5 hours. A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. The focus of the exercise is on understanding the information content of the data, how to train and evaluate 2D image translation models, and exploring some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. @@ -48,6 +48,7 @@ - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) + Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco. """ @@ -131,11 +132,9 @@ seed_everything(42, workers=True) # Paths to data and log directory -data_path = Path( - Path("~/data/06_image_translation/training/a549_hoechst_cellmask_train_val.zarr") -).expanduser() - -log_dir = Path("~/data/06_image_translation/logs/").expanduser() +top_dir = Path(f"/hpc/mydata/{os.environ['USER']},data/") +data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" +log_dir = top_dir / "06_image_translation/logs/" # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) @@ -636,9 +635,7 @@ def log_batch_jupyter(batch): # - Structural similarity: # %% Compute metrics directly and plot here. -test_data_path = Path( - "~/data/06_image_translation/test/a549_hoechst_cellmask_test.zarr" -).expanduser() +test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" test_data = HCSDataModule( test_data_path, From 0aa51bca12f1a3b917e0dce36c0b46cb156073a5 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Jul 2024 09:37:50 -0700 Subject: [PATCH 07/42] fixing video and adding todo comment for students to switch the top_dir. --- examples/demo_dlmbl/solution.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 0d9d65fd..96b2e4b8 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -9,7 +9,7 @@


- + Virtual Staining
(click image to play)
@@ -132,10 +132,13 @@ seed_everything(42, workers=True) # Paths to data and log directory -top_dir = Path(f"/hpc/mydata/{os.environ['USER']},data/") +top_dir = Path(f"/hpc/mydata/{os.environ['USER']},data/") #TODO: Change this to point to your data directory. data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" log_dir = top_dir / "06_image_translation/logs/" +if not data_path.exists(): + raise FileNotFoundError(f"Data not found at {data_path}. Please check the top_dir and data_path variables.") +#%% # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) From add21bfaee6aec7bfc64847f39c1d553665d190e Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Jul 2024 09:54:31 -0700 Subject: [PATCH 08/42] updating the readme and setup.sh --- examples/demo_dlmbl/README.md | 5 +++-- examples/demo_dlmbl/setup.sh | 8 ++++---- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/demo_dlmbl/README.md b/examples/demo_dlmbl/README.md index d7a4e9f6..b16378e5 100644 --- a/examples/demo_dlmbl/README.md +++ b/examples/demo_dlmbl/README.md @@ -1,6 +1,6 @@ # Exercise 6: Image translation - Part 1 -This demo script was developed for the DL@MBL 2023 course by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). +This demo script was developed for the DL@MBL 2024 course by Eduardo Hirata-Miyasaki, Ziwen Liu and Shalin Mehta, with many inputs and bugfixes by [Morgan Schwartz](https://github.com/msschwartz21), [Caroline Malin-Mayor](https://github.com/cmalinmayor), and [Peter Park](https://github.com/peterhpark). @@ -42,7 +42,8 @@ jupyter notebook ...and continue with the instructions in the notebook. -If 06_image_translation is not available as a kernel in jupyter, run +If `06_image_translation` is not available as a kernel in jupyter, run: + ``` python -m ipykernel install --user --name=06_image_translation ``` diff --git a/examples/demo_dlmbl/setup.sh b/examples/demo_dlmbl/setup.sh index 1ad21d9c..4b46f23f 100644 --- a/examples/demo_dlmbl/setup.sh +++ b/examples/demo_dlmbl/setup.sh @@ -15,7 +15,8 @@ mkdir -p ~/code/ cd ~/code/ git clone https://github.com/mehta-lab/viscy.git cd viscy -git checkout main # Exercise is tested with this commit of viscy +git checkout main #FIXME: change after merging this PR # Exercise is tested with this commit of viscy + # Find path to the environment - conda activate doesn't work from within shell scripts. ENV_PATH=$(conda info --envs | grep 06_image_translation | awk '{print $NF}') $ENV_PATH/bin/pip install ".[metrics]" @@ -26,11 +27,10 @@ mkdir -p ~/data/06_image_translation/test # Change to the target directory cd ~/data/06_image_translation/training -# Download the Zarr dataset recursively (if the server supports it) -wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto2D/training/a549_hoechst_cellmask_train_val.zarr/" +# Download the OME-Zarr dataset recursively +wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto2D/training/a549_hoechst_cellmask_train_val.zarr/" cd ~/data/06_image_translation/test -# Download the Zarr dataset recursively (if the server supports it) wget -m -np -nH --cut-dirs=4 -R "index.html*" "https://public.czbiohub.org/comp.micro/viscy/VSCyto2D/test/a549_hoechst_cellmask_test.zarr/" # Change back to the starting directory From 0b36a85a36c36e2b0a76ce580c30fa2c5aa3c254 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Mon, 8 Jul 2024 12:02:27 -0700 Subject: [PATCH 09/42] addressing some typos --- examples/demo_dlmbl/solution.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 96b2e4b8..27d6c04f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -85,7 +85,7 @@ Learning goals: -- Load the OME-zarr dataset and examine the channels (A54). +- Load the OME-zarr dataset and examine the channels (A549). - Configure and understand the data loader. - Log some patches to tensorboard. - Initialize a 2D UNeXt2 model for virtual staining of nuclei and membrane from phase. @@ -567,7 +567,13 @@ def log_batch_jupyter(batch): """ # %% -GPU_ID = 1 +# Check if GPU is available +if torch.cuda.is_available(): + # Get the GPU ID (you can change the logic to select the appropriate GPU if you have multiple) + GPU_ID = torch.cuda.current_device() +else: + raise ValueError("No GPU available") + n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. n_epochs = 50 # Set this to 50 or the number of epochs you want to train for. From 496578d45a0a0dad28d4b0d968e9168f568d903a Mon Sep 17 00:00:00 2001 From: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com> Date: Tue, 9 Jul 2024 09:18:36 -0700 Subject: [PATCH 10/42] html to md (#106) --- examples/demo_dlmbl/solution.py | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 27d6c04f..d43908ea 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -1,21 +1,21 @@ # %% [markdown] """ -
-

Image translation (Virtual Staining)

-
-

Overview

-

In this exercise, we will predict fluorescence images of nuclei and membrane markers from quantitative phase images of cells, i.e., we will virtually stain the nuclei and cell membrane visible in the phase image. This is an example of an image translation task. We will apply spatial and intensity augmentations to train robust models and evaluate their performance. Finally, we will explore the opposite process of predicting a phase image from a fluorescence membrane label.

- -
-

-
- - Virtual Staining - -
(click image to play)
-
-
-
+# Image translation (Virtual Staining) + +## Overview + +In this exercise, we will predict fluorescence images of +nuclei and plasma membrane markers from quantitative phase images of cells, +i.e., we will _virtually stain_ the nuclei and plasma membrane +visible in the phase image. +This is an example of an image translation task. +We will apply spatial and intensity augmentations to train robust models +and evaluate their performance. +Finally, we will explore the opposite process of predicting a phase image +from a fluorescence membrane label. + +[![HEK293T](https://raw.githubusercontent.com/mehta-lab/VisCy/main/docs/figures/svideo_1.png)](https://github.com/mehta-lab/VisCy/assets/67518483/d53a81eb-eb37-44f3-b522-8bd7bddc7755) +(Click on image to play video) """ # %% [markdown] From 5189378815dcd59465e4e79d7f1bd6e18ffa216a Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 9 Jul 2024 15:38:28 -0700 Subject: [PATCH 11/42] typographical --- examples/demo_dlmbl/solution.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index d43908ea..cf505130 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -64,8 +64,8 @@ The exercise is organized in 3 parts + Extra part.
    -
  • Part 1 - Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard.
  • -
  • Part 2 - Train and evaluate the phase to fluorescence model and viceversa.
  • +
  • Part 1 - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.
  • +
  • Part 2 - Train and evaluate the model to translate phase into fluorescence, and viceversa.
  • Extra task - Tune the models to improve performance.
@@ -82,7 +82,6 @@ """ # Part 1: Log training data to tensorboard, start training a model. --------- - Learning goals: - Load the OME-zarr dataset and examine the channels (A549). @@ -132,7 +131,7 @@ seed_everything(42, workers=True) # Paths to data and log directory -top_dir = Path(f"/hpc/mydata/{os.environ['USER']},data/") #TODO: Change this to point to your data directory. +top_dir = Path(f"/hpc/mydata/{os.environ['USER']}/data/") #TODO: Change this to point to your data directory. data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" log_dir = top_dir / "06_image_translation/logs/" From 091e9d2880656cfcc8c35ecb448349e4b093afdf Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 9 Jul 2024 16:31:01 -0700 Subject: [PATCH 12/42] nits --- examples/demo_dlmbl/solution.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index cf505130..dfb2e10d 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -22,7 +22,7 @@ """ ### Goals -#### Part 1: Familiarization with iohub (I/O library), VisCy dataloaders, and tensorboard. +#### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard. - Use a `ome-zarr` dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask. - Explore the OME-zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) and the high-content-screen (HCS) format. @@ -65,7 +65,7 @@
  • Part 1 - Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard.
  • -
  • Part 2 - Train and evaluate the model to translate phase into fluorescence, and viceversa.
  • +
  • Part 2 - Train and evaluate the model to translate phase into fluorescence, and vice versa.
  • Extra task - Tune the models to improve performance.
@@ -476,7 +476,12 @@ def log_batch_jupyter(batch): """ # %% # Create a 2D UNet. -GPU_ID = 1 +if torch.cuda.is_available(): + # Get the GPU ID (you can change the logic to select the appropriate GPU if you have multiple) + GPU_ID = torch.cuda.current_device() +else: + raise ValueError("No GPU available") + BATCH_SIZE = 12 YX_PATCH_SIZE = (256, 256) From 7f0e8aff5db390abd719e513b15316aac3f4deba Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 11 Jul 2024 09:44:50 -0700 Subject: [PATCH 13/42] -adding alias for UNeXt2_2D -fixing bug fluor-phase -addingtensorboard instructions to open on browser -fixing the markdowns to render properly on vscode jupyter --- examples/demo_dlmbl/solution.py | 225 +++++++++++++++++++++----------- viscy/light/engine.py | 3 +- 2 files changed, 153 insertions(+), 75 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index dfb2e10d..84991b16 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -78,7 +78,7 @@ Set your python kernel to 06_image_translation
""" -# %% +# %% [markdown] """ # Part 1: Log training data to tensorboard, start training a model. --------- @@ -143,41 +143,72 @@ # %% [markdown] tags=[] """ -The next cell starts tensorboard within the notebook. +The next cell starts tensorboard.
If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com. -You can also launch tensorboard in an independent tab (instead of in the notebook) by changing the `%` to `!`
""" # %% Imports and paths tags=[] -%reload_ext tensorboard -%tensorboard --logdir {log_dir} +# Function to find an available port +def find_free_port(): + import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(('', 0)) + return s.getsockname()[1] + +# Launch TensorBoard on the browser +def launch_tensorboard(log_dir): + import subprocess + port = find_free_port() + tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" + process = subprocess.Popen(tensorboard_cmd, shell=True) + print(f"TensorBoard started at http://localhost:{port}") + return process +# Launch tensorboard and click on the link to view the logs. +tensorboard_process = launch_tensorboard(log_dir) + +#%%[markdown] +""" +
+If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
+Take note of the port number was assigned in the previous cell.(i.e http://localhost:{port_number_assigned})
+ +Locate the your VSCode terminal and select the Ports tab
+
    +
  • Add a new port with the port_number_assigned +
  • Change the port to 4000 and ensure that the forwarded Adress: localhost:{port_number_assigned} +
+Click on the link to view the tensorboard and it should open in your browser. +
+""" # %% [markdown] """ -## Load OME-Zarr Dataset. +## Load OME-Zarr Dataset There should be 34 FOVs in the dataset. Each FOV consists of 3 channels of 2048x2048 images, -saved in the -High-Content Screening (HCS) layout +saved in the [High-Content Screening (HCS) layout](https://ngff.openmicroscopy.org/latest/#hcs-layout) specified by the Open Microscopy Environment Next Generation File Format (OME-NGFF). -The layout on the disk is: row/col/field/pyramid_level/timepoint/channel/z/y/x. +- The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.` """ # %%[markdown] """ +
You can inspect the tree structure by using your terminal: -`iohub info -v ` + iohub info -v "path-to-ome-zarr" +
More info on the CLI: -`iohub info -h` to see the help menu. +iohub info --help to see the help menu. +
""" #%% # This is the python function called by `iohub info` CLI command @@ -355,7 +386,7 @@ def log_batch_jupyter(batch): data_module = HCSDataModule( data_path, z_window_size=1, - architecture='fcmae', + architecture='UNeXt2_2D', source_channel=["Phase3D"], target_channel=['Nucl','Mem'], split_ratio=0.8, @@ -382,7 +413,7 @@ def log_batch_jupyter(batch): # %% [markdown] -# Visualize directly on Jupyter ☄️, if your tensorboard is causing issues. +# If your tensorboard is causing issues, you can visualize directly on Jupyter ☄️/VSCode # %% %matplotlib inline @@ -402,15 +433,15 @@ def log_batch_jupyter(batch): # # Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. # -# Hint: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). +# HINT: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). # *Note* these are MONAI transforms that have been redefined for VisCy. # Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? # -# HINT:[Check your augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). +# +# HINT:[Compare your choice of augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). #
- # %% # Here we turn on data augmentation and rerun setup source_channel= ['Phase3D'] @@ -476,17 +507,13 @@ def log_batch_jupyter(batch): """ # %% # Create a 2D UNet. -if torch.cuda.is_available(): - # Get the GPU ID (you can change the logic to select the appropriate GPU if you have multiple) - GPU_ID = torch.cuda.current_device() -else: - raise ValueError("No GPU available") +GPU_ID = 0 BATCH_SIZE = 12 YX_PATCH_SIZE = (256, 256) - # Dictionary that specifies key parameters of the model. + phase2fluor_config=dict( in_channels=1, out_channels=2, @@ -499,7 +526,7 @@ def log_batch_jupyter(batch): ) phase2fluor_model = VSUNet( - architecture='fcmae', #2D UNeXt2 architecture + architecture='UNeXt2_2D', #2D UNeXt2 architecture model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", @@ -513,12 +540,14 @@ def log_batch_jupyter(batch): ### Instantiate data module and trainer, test that we are setup to launch training. """ # %% +source_channel= ['Phase3D'] +target_channel=['Nucl','Mem'] # Setup the data module. -phase2fluor_2D_data = HCSDataModule( +phase2fluor_2D_data= HCSDataModule( data_path, - architecture="fcmae", - source_channel=["Phase3D"], - target_channel=["Nucl", "Mem"], + architecture="UNeXt2_2D", + source_channel=source_channel, + target_channel=target_channel, z_window_size=1, split_ratio=0.8, batch_size=BATCH_SIZE, @@ -572,11 +601,7 @@ def log_batch_jupyter(batch): # %% # Check if GPU is available -if torch.cuda.is_available(): - # Get the GPU ID (you can change the logic to select the appropriate GPU if you have multiple) - GPU_ID = torch.cuda.current_device() -else: - raise ValueError("No GPU available") +GPU_ID=0 n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. @@ -610,14 +635,11 @@ def log_batch_jupyter(batch):
""" -# %% +# %% [markdown] """ # Part 2: Assess previous model, train fluorescence to phase contrast translation model. -------------------------------------------------- -""" -# %% [markdown] -""" We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model: - [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient). - [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM). @@ -629,7 +651,7 @@ def log_batch_jupyter(batch): """
-### Task 2.1 Define metrics + Task 2.1 Define metrics
For each of the above metrics, write a brief definition of what they are and what they mean for this image translation task. @@ -649,7 +671,8 @@ def log_batch_jupyter(batch): # %% Compute metrics directly and plot here. test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" - +source_channel = ['Phase3D'] +target_channel = ['Nucl','Mem'] test_data = HCSDataModule( test_data_path, source_channel=source_channel, @@ -709,43 +732,52 @@ def min_max_scale(input): channel_titles = ['Phase', 'Nuclei', 'Membrane'] fig, axes = plt.subplots(2, 3, figsize=(30, 20)) -# Plot the phase image -channel_image = phase_image[0, 0] -p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) -channel_image = np.clip(channel_image, p_low, p_high) -axes[0, 0].imshow(channel_image) -axes[0, 0].axis("off") -axes[0, 0].set_title(channel_titles[0]) - -# Plot the predicted images -for i in range(predicted_image.shape[-4]): - channel_image = predicted_image[i, 0] +for i, sample in enumerate(test_data.test_dataloader()): + # Plot the phase image + phase_image = sample["source"] + channel_image = phase_image[0, 0,0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) - axes[0, i + 1].imshow(channel_image, cmap="gray") - axes[0, i + 1].axis("off") - axes[0, i + 1].set_title(f"VS {channel_titles[i + 1]}") + axes[0, 0].imshow(channel_image,cmap="gray") + axes[0, 0].axis("off") + axes[0, 0].set_title(channel_titles[0]) -# Plot the target images -for i in range(target_image.shape[-4]): - channel_image = target_image[i, 0] - p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) - channel_image = np.clip(channel_image, p_low, p_high) - axes[1, i+1].imshow(channel_image, cmap="gray") - axes[1, i+1].axis("off") - axes[1, i+1].set_title(f"Target {dataset.channel_names[i+1]}") + with torch.inference_mode(): # turn off gradient computation. + predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) + + target_image = ( + sample["target"].cpu().numpy().squeeze(0) + ) + # Plot the predicted images + for i in range(predicted_image.shape[-4]): + channel_image = predicted_image[i, 0] + p_low, p_high = np.percentile(channel_image, (0.1, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[0, i + 1].imshow(channel_image, cmap="gray") + axes[0, i + 1].axis("off") + axes[0, i + 1].set_title(f"VS {channel_titles[i + 1]}") -# Remove any unused subplots -for j in range(i + 1, 3): - fig.delaxes(axes[1, j]) + # Plot the target images + for i in range(target_image.shape[-4]): + channel_image = target_image[i, 0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[1, i].imshow(channel_image, cmap="gray") + axes[1, i].axis("off") + axes[1, i].set_title(f"Target {dataset.channel_names[i+1]}") -plt.tight_layout() -plt.show() + # Remove any unused subplots + for j in range(i + 1, 3): + fig.delaxes(axes[1, j]) + + plt.tight_layout() + plt.show() + break # %% [markdown] tags=[] """
-### Task 2.2 Train fluorescence to phase contrast translation model +Task 2.2 Train fluorescence to phase contrast translation model
Instantiate a data module, model, and trainer for fluorescence to phase contrast translation. Copy over the code from previous cells and update the parameters. Give the variables and paths a different name/suffix (fluor2phase) to avoid overwriting objects used to train phase2fluor models.
@@ -791,8 +823,12 @@ def min_max_scale(input): ########################## # The entire training loop is contained in this cell. -source_channel=['Mem'] # or 'Nuc' depending on choice +source_channel=["Mem"] # or 'Nuc' depending on choice target_channel = ["Phase3D"] +YX_PATCH_SIZE = (256,256) +BATCH_SIZE = 12 +n_epochs= 50 +steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. #Setup the new augmentations augmentations = [ @@ -834,7 +870,7 @@ def min_max_scale(input): #Setup the dataloader fluor2phase_data = HCSDataModule( data_path, - architecture="fcmae", + architecture="UNeXt2_2D", source_channel=source_channel, target_channel=target_channel, z_window_size=1, @@ -850,7 +886,7 @@ def min_max_scale(input): # Dictionary that specifies key parameters of the model. fluor2phase_config=dict( in_channels=1, - out_channels=2, + out_channels=1, encoder_blocks=[3, 3, 9, 3], dims=[96, 192, 384, 768], decoder_conv_blocks=2, @@ -860,8 +896,8 @@ def min_max_scale(input): ) fluor2phase_model = VSUNet( - architecture='fcmae', - model_config=phase2fluor_config.copy(), + architecture='UNeXt2_2D', + model_config=fluor2phase_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", lr=2e-4, @@ -913,7 +949,7 @@ def min_max_scale(input): test_data = HCSDataModule( test_data_path, source_channel="Mem", # or Nuc, depending on your choice of source - target_channel="Phase", + target_channel="Phase3D", z_window_size=1, batch_size=1, num_workers=8, @@ -958,6 +994,47 @@ def min_max_scale(input): column=["pearson_phase", "SSIM_phase"], rot=30, ) +#%% +#Plot the predicted image +channel_titles = ['Membrane','Target Phase','Predicted_Phase',] +fig, axes = plt.subplots(1, 3, figsize=(30, 20)) + +for i, sample in enumerate(test_data.test_dataloader()): + # Plot the phase image + mem_image = sample["source"] + channel_image = mem_image[0,0,0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[0].imshow(channel_image,cmap="gray") + axes[0].axis("off") + axes[0].set_title(channel_titles[0]) + + with torch.inference_mode(): # turn off gradient computation. + predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) + + target_image = ( + sample["target"].cpu().numpy().squeeze(0) + ) + # Plot the predicted images + channel_image = target_image[0,0] + p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[1].imshow(channel_image,cmap="gray") + axes[1].axis("off") + axes[1].set_title(channel_titles[1]) + + + channel_image = predicted_image[1, 0] + p_low, p_high = np.percentile(channel_image, (0.1, 99.5)) + channel_image = np.clip(channel_image, p_low, p_high) + axes[2].imshow(channel_image, cmap="gray") + axes[2].axis("off") + axes[2].set_title(f"VS {channel_titles[2]}") + + + plt.tight_layout() + plt.show() + break # %% [markdown] tags=[] """ @@ -1048,7 +1125,7 @@ def min_max_scale(input): ) phase2fluor_model_low_lr= VSUNet( - architecture='fcmae', + architecture='UNeXt2_2D', model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), #Changed the loss function to MixedLoss L1 and MS-SSIM schedule="WarmupCosine", @@ -1091,7 +1168,7 @@ def min_max_scale(input): data_path = Path() #TODO: Point to a 3D dataset (HEK, Neuromast) BATCH_SIZE = 4 -YX_PATCH_SIZE = (384, 384) +YX_PATCH_SIZE = (256, 256) ## For 3D training - VSCyto3D source_channel=['Phase3D'] diff --git a/viscy/light/engine.py b/viscy/light/engine.py index 33da3552..1758f9e9 100644 --- a/viscy/light/engine.py +++ b/viscy/light/engine.py @@ -43,6 +43,7 @@ "UNeXt2": UNeXt2, "2.5D": Unet25d, "fcmae": FullyConvolutionalMAE, + "UNeXt2_2D": FullyConvolutionalMAE, } @@ -117,7 +118,7 @@ class VSUNet(LightningModule): def __init__( self, - architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae"], + architecture: Literal["2D", "UNeXt2", "2.5D", "3D", "fcmae", "UNeXt2_2D"], model_config: dict = {}, loss_function: Union[nn.Module, MixedLoss] = None, lr: float = 1e-3, From e6aacc794593347b8c380ed2c22ebe84273ee49b Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 11 Jul 2024 09:48:04 -0700 Subject: [PATCH 14/42] ruff and black --- examples/demo_dlmbl/convert-solution.py | 16 +-- examples/demo_dlmbl/solution.py | 175 +++++++++++++----------- 2 files changed, 100 insertions(+), 91 deletions(-) diff --git a/examples/demo_dlmbl/convert-solution.py b/examples/demo_dlmbl/convert-solution.py index 279f7874..91d7e46c 100644 --- a/examples/demo_dlmbl/convert-solution.py +++ b/examples/demo_dlmbl/convert-solution.py @@ -1,15 +1,15 @@ import argparse -from traitlets.config import Config -import nbformat as nbf -from nbconvert.preprocessors import TagRemovePreprocessor, ClearOutputPreprocessor + from nbconvert.exporters import NotebookExporter +from nbconvert.preprocessors import ClearOutputPreprocessor, TagRemovePreprocessor +from traitlets.config import Config def get_arg_parser(): parser = argparse.ArgumentParser() - parser.add_argument('input_file') - parser.add_argument('output_file') + parser.add_argument("input_file") + parser.add_argument("output_file") return parser @@ -21,7 +21,7 @@ def convert(input_file, output_file): c.ClearOutputPreprocesser.enabled = True c.NotebookExporter.preprocessors = [ "nbconvert.preprocessors.TagRemovePreprocessor", - "nbconvert.preprocessors.ClearOutputPreprocessor" + "nbconvert.preprocessors.ClearOutputPreprocessor", ] exporter = NotebookExporter(config=c) @@ -29,7 +29,7 @@ def convert(input_file, output_file): exporter.register_preprocessor(ClearOutputPreprocessor(), True) output = NotebookExporter(config=c).from_filename(input_file) - with open(output_file, 'w') as f: + with open(output_file, "w") as f: f.write(output[0]) @@ -38,4 +38,4 @@ def convert(input_file, output_file): args = parser.parse_args() convert(args.input_file, args.output_file) - print(f'Converted {args.input_file} to {args.output_file}') + print(f"Converted {args.input_file} to {args.output_file}") diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 84991b16..9aca4d6a 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -53,8 +53,7 @@ """ - -#%% [markdown] +# %% [markdown] """ 📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this Google Doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). 📖 """ @@ -92,6 +91,7 @@ """ # %% Imports and paths +import os from pathlib import Path import matplotlib.pyplot as plt @@ -103,7 +103,7 @@ from iohub import open_ome_zarr from iohub.reader import print_info from lightning.pytorch import seed_everything -from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger +from lightning.pytorch.loggers import TensorBoardLogger from skimage import metrics # for metrics. # %% Imports and paths @@ -113,31 +113,37 @@ # HCSDataModule makes it easy to load data during training. from viscy.data.hcs import HCSDataModule +# Trainer class and UNet. +from viscy.light.engine import MixedLoss, VSUNet +from viscy.light.trainer import VSTrainer + # training augmentations from viscy.transforms import ( + NormalizeSampled, RandAdjustContrastd, RandAffined, RandGaussianNoised, RandGaussianSmoothd, RandScaleIntensityd, - RandWeightedCropd, - NormalizeSampled + RandWeightedCropd, ) -# Trainer class and UNet. -from viscy.light.engine import VSUNet,MixedLoss -from viscy.light.trainer import VSTrainer - seed_everything(42, workers=True) # Paths to data and log directory -top_dir = Path(f"/hpc/mydata/{os.environ['USER']}/data/") #TODO: Change this to point to your data directory. -data_path = top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" +top_dir = Path( + f"/hpc/mydata/{os.environ['USER']}/data/" +) # TODO: Change this to point to your data directory. +data_path = ( + top_dir / "06_image_translation/training/a549_hoechst_cellmask_train_val.zarr" +) log_dir = top_dir / "06_image_translation/logs/" if not data_path.exists(): - raise FileNotFoundError(f"Data not found at {data_path}. Please check the top_dir and data_path variables.") -#%% + raise FileNotFoundError( + f"Data not found at {data_path}. Please check the top_dir and data_path variables." + ) +# %% # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) @@ -153,25 +159,31 @@ # %% Imports and paths tags=[] + # Function to find an available port def find_free_port(): import socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: - s.bind(('', 0)) + s.bind(("", 0)) return s.getsockname()[1] -# Launch TensorBoard on the browser + +# Launch TensorBoard on the browser def launch_tensorboard(log_dir): import subprocess + port = find_free_port() tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" process = subprocess.Popen(tensorboard_cmd, shell=True) print(f"TensorBoard started at http://localhost:{port}") return process + + # Launch tensorboard and click on the link to view the logs. tensorboard_process = launch_tensorboard(log_dir) -#%%[markdown] +# %%[markdown] """
If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
@@ -210,14 +222,14 @@ def launch_tensorboard(log_dir): iohub info --help to see the help menu.
""" -#%% +# %% # This is the python function called by `iohub info` CLI command -print_info(data_path,verbose=True) +print_info(data_path, verbose=True) -# Open and inspect the dataset. +# Open and inspect the dataset. dataset = open_ome_zarr(data_path) -#%% +# %% # Use the field and pyramid_level below to visualize data. row = 0 col = 0 @@ -254,7 +266,7 @@ def launch_tensorboard(log_dir): # # ### Task 1.1 # -# Look at a couple different fields of view by changing the value in the cell above. +# Look at a couple different fields of view by changing the value in the cell above. # Check the cell density, the cell morphologies, and fluorescence signal. #
@@ -386,15 +398,15 @@ def log_batch_jupyter(batch): data_module = HCSDataModule( data_path, z_window_size=1, - architecture='UNeXt2_2D', + architecture="UNeXt2_2D", source_channel=["Phase3D"], - target_channel=['Nucl','Mem'], + target_channel=["Nucl", "Mem"], split_ratio=0.8, batch_size=BATCH_SIZE, num_workers=8, yx_patch_size=(256, 256), # larger patch size makes it easy to see augmentations. augmentations=[], # Turn off augmentation for now. - normalizations=[], #Turn off normalization for now. + normalizations=[], # Turn off normalization for now. ) data_module.setup("fit") @@ -416,7 +428,6 @@ def log_batch_jupyter(batch): # If your tensorboard is causing issues, you can visualize directly on Jupyter ☄️/VSCode # %% -%matplotlib inline log_batch_jupyter(batch) # %% [markdown] @@ -430,7 +441,7 @@ def log_batch_jupyter(batch): # How do they make the model more robust? # # Add augmentations to rotate about `pi` along z-axis, 30% scale in y,x, shearing of 10% and no padding with zeros with a probablity of 80%. -# +# # Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. # # HINT: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). @@ -444,8 +455,8 @@ def log_batch_jupyter(batch): #
# %% # Here we turn on data augmentation and rerun setup -source_channel= ['Phase3D'] -target_channel=['Nucl','Mem'] +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] augmentations = [ RandWeightedCropd( @@ -454,7 +465,6 @@ def log_batch_jupyter(batch): num_samples=2, w_key=target_channel[0], ), - RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), RandGaussianSmoothd( @@ -466,14 +476,13 @@ def log_batch_jupyter(batch): ), ##TODO: Add rotation agumentations ## Write code below - ## TODO: Add Random Gaussian Noise ## Write code below ] -normalizations=[ +normalizations = [ NormalizeSampled( - keys=source_channel+target_channel, + keys=source_channel + target_channel, level="fov_statistics", subtrahend="median", divisor="std", @@ -514,7 +523,7 @@ def log_batch_jupyter(batch): # Dictionary that specifies key parameters of the model. -phase2fluor_config=dict( +phase2fluor_config = dict( in_channels=1, out_channels=2, encoder_blocks=[3, 3, 9, 3], @@ -526,7 +535,7 @@ def log_batch_jupyter(batch): ) phase2fluor_model = VSUNet( - architecture='UNeXt2_2D', #2D UNeXt2 architecture + architecture="UNeXt2_2D", # 2D UNeXt2 architecture model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", @@ -540,10 +549,10 @@ def log_batch_jupyter(batch): ### Instantiate data module and trainer, test that we are setup to launch training. """ # %% -source_channel= ['Phase3D'] -target_channel=['Nucl','Mem'] +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] # Setup the data module. -phase2fluor_2D_data= HCSDataModule( +phase2fluor_2D_data = HCSDataModule( data_path, architecture="UNeXt2_2D", source_channel=source_channel, @@ -554,7 +563,7 @@ def log_batch_jupyter(batch): num_workers=8, yx_patch_size=YX_PATCH_SIZE, augmentations=augmentations, - normalizations=normalizations + normalizations=normalizations, ) phase2fluor_2D_data.setup("fit") # fast_dev_run runs a single batch of data through the model to check for errors. @@ -601,8 +610,8 @@ def log_batch_jupyter(batch): # %% # Check if GPU is available -GPU_ID=0 - +GPU_ID = 0 + n_samples = len(phase2fluor_2D_data.train_dataset) steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. n_epochs = 50 # Set this to 50 or the number of epochs you want to train for. @@ -671,8 +680,8 @@ def log_batch_jupyter(batch): # %% Compute metrics directly and plot here. test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" -source_channel = ['Phase3D'] -target_channel = ['Nucl','Mem'] +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] test_data = HCSDataModule( test_data_path, source_channel=source_channel, @@ -727,27 +736,25 @@ def min_max_scale(input): column=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"], rot=30, ) -#%% -#Plot the predicted image -channel_titles = ['Phase', 'Nuclei', 'Membrane'] +# %% +# Plot the predicted image +channel_titles = ["Phase", "Nuclei", "Membrane"] fig, axes = plt.subplots(2, 3, figsize=(30, 20)) for i, sample in enumerate(test_data.test_dataloader()): # Plot the phase image phase_image = sample["source"] - channel_image = phase_image[0, 0,0] + channel_image = phase_image[0, 0, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) - axes[0, 0].imshow(channel_image,cmap="gray") + axes[0, 0].imshow(channel_image, cmap="gray") axes[0, 0].axis("off") axes[0, 0].set_title(channel_titles[0]) with torch.inference_mode(): # turn off gradient computation. predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) - target_image = ( - sample["target"].cpu().numpy().squeeze(0) - ) + target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images for i in range(predicted_image.shape[-4]): channel_image = predicted_image[i, 0] @@ -823,14 +830,14 @@ def min_max_scale(input): ########################## # The entire training loop is contained in this cell. -source_channel=["Mem"] # or 'Nuc' depending on choice +source_channel = ["Mem"] # or 'Nuc' depending on choice target_channel = ["Phase3D"] -YX_PATCH_SIZE = (256,256) +YX_PATCH_SIZE = (256, 256) BATCH_SIZE = 12 -n_epochs= 50 +n_epochs = 50 steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. -#Setup the new augmentations +# Setup the new augmentations augmentations = [ RandWeightedCropd( keys=source_channel + target_channel, @@ -858,16 +865,16 @@ def min_max_scale(input): ), ] -normalizations=[ +normalizations = [ NormalizeSampled( - keys=source_channel+target_channel, + keys=source_channel + target_channel, level="fov_statistics", subtrahend="median", divisor="std", ) ] -#Setup the dataloader +# Setup the dataloader fluor2phase_data = HCSDataModule( data_path, architecture="UNeXt2_2D", @@ -879,12 +886,12 @@ def min_max_scale(input): num_workers=8, yx_patch_size=YX_PATCH_SIZE, augmentations=augmentations, - normalizations=normalizations + normalizations=normalizations, ) fluor2phase_data.setup("fit") # Dictionary that specifies key parameters of the model. -fluor2phase_config=dict( +fluor2phase_config = dict( in_channels=1, out_channels=1, encoder_blocks=[3, 3, 9, 3], @@ -896,7 +903,7 @@ def min_max_scale(input): ) fluor2phase_model = VSUNet( - architecture='UNeXt2_2D', + architecture="UNeXt2_2D", model_config=fluor2phase_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), schedule="WarmupCosine", @@ -919,7 +926,7 @@ def min_max_scale(input): ) trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) -#%% +# %% # Visualize the graph of fluor2phase model as image. model_graph_fluor2phase = torchview.draw_graph( fluor2phase_model, @@ -994,36 +1001,37 @@ def min_max_scale(input): column=["pearson_phase", "SSIM_phase"], rot=30, ) -#%% -#Plot the predicted image -channel_titles = ['Membrane','Target Phase','Predicted_Phase',] +# %% +# Plot the predicted image +channel_titles = [ + "Membrane", + "Target Phase", + "Predicted_Phase", +] fig, axes = plt.subplots(1, 3, figsize=(30, 20)) for i, sample in enumerate(test_data.test_dataloader()): # Plot the phase image mem_image = sample["source"] - channel_image = mem_image[0,0,0] + channel_image = mem_image[0, 0, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) - axes[0].imshow(channel_image,cmap="gray") + axes[0].imshow(channel_image, cmap="gray") axes[0].axis("off") axes[0].set_title(channel_titles[0]) with torch.inference_mode(): # turn off gradient computation. predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) - target_image = ( - sample["target"].cpu().numpy().squeeze(0) - ) + target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images - channel_image = target_image[0,0] + channel_image = target_image[0, 0] p_low, p_high = np.percentile(channel_image, (0.5, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) - axes[1].imshow(channel_image,cmap="gray") + axes[1].imshow(channel_image, cmap="gray") axes[1].axis("off") axes[1].set_title(channel_titles[1]) - channel_image = predicted_image[1, 0] p_low, p_high = np.percentile(channel_image, (0.1, 99.5)) channel_image = np.clip(channel_image, p_low, p_high) @@ -1031,7 +1039,6 @@ def min_max_scale(input): axes[2].axis("off") axes[2].set_title(f"VS {channel_titles[2]}") - plt.tight_layout() plt.show() break @@ -1113,7 +1120,7 @@ def min_max_scale(input): ########################## ######## Solution ######## ########################## -phase2fluor_config=dict( +phase2fluor_config = dict( in_channels=1, out_channels=2, encoder_blocks=[3, 3, 9, 3], @@ -1124,12 +1131,14 @@ def min_max_scale(input): pretraining=False, ) -phase2fluor_model_low_lr= VSUNet( - architecture='UNeXt2_2D', +phase2fluor_model_low_lr = VSUNet( + architecture="UNeXt2_2D", model_config=phase2fluor_config.copy(), - loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), #Changed the loss function to MixedLoss L1 and MS-SSIM + loss_function=MixedLoss( + l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5 + ), # Changed the loss function to MixedLoss L1 and MS-SSIM schedule="WarmupCosine", - lr=2e-5, #lower learning rate by factor of 10 + lr=2e-5, # lower learning rate by factor of 10 log_batches_per_epoch=5, # Number of samples from each batch to log to tensorboard. ) @@ -1166,13 +1175,13 @@ def min_max_scale(input): """ -data_path = Path() #TODO: Point to a 3D dataset (HEK, Neuromast) +data_path = Path() # TODO: Point to a 3D dataset (HEK, Neuromast) BATCH_SIZE = 4 YX_PATCH_SIZE = (256, 256) ## For 3D training - VSCyto3D -source_channel=['Phase3D'] -target_channel=['Nucl','Mem'] +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] phase2fluor_3D_data = HCSDataModule( data_path, @@ -1185,11 +1194,11 @@ def min_max_scale(input): num_workers=8, yx_patch_size=YX_PATCH_SIZE, augmentations=augmentations, - normalizations=normalizations + normalizations=normalizations, ) phase2fluor_3D = VSUNet( - architecture='UNeXt2', + architecture="UNeXt2", model_config=phase2fluor_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), lr=2e-4, From cda31bbb06dfe0d246f3d97e1f12c7b378fb62eb Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 16 Jul 2024 14:59:22 -0700 Subject: [PATCH 15/42] edit tensorboard instructions --- examples/demo_dlmbl/solution.py | 31 ++++++++++++++----------------- 1 file changed, 14 insertions(+), 17 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 9aca4d6a..59501999 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -151,9 +151,21 @@ """ The next cell starts tensorboard. -
+
If you launched jupyter lab from ssh terminal, add --host <your-server-name> to the tensorboard command below. <your-server-name> is the address of your compute node that ends in amazonaws.com. +
+ +
+If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
+Take note of the port number was assigned in the previous cell.(i.e http://localhost:{port_number_assigned})
+ +Locate the your VSCode terminal and select the Ports tab
+
    +
  • Add a new port with the port_number_assigned +
  • Change the port to 4000 and ensure that the forwarded Adress: localhost:{port_number_assigned} +
+Click on the link to view the tensorboard and it should open in your browser.
""" @@ -172,31 +184,16 @@ def find_free_port(): # Launch TensorBoard on the browser def launch_tensorboard(log_dir): import subprocess - port = find_free_port() tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" process = subprocess.Popen(tensorboard_cmd, shell=True) - print(f"TensorBoard started at http://localhost:{port}") + print(f"TensorBoard started at http://localhost:{port}. \n If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL.") return process # Launch tensorboard and click on the link to view the logs. tensorboard_process = launch_tensorboard(log_dir) -# %%[markdown] -""" -
-If you are using VSCode and a remote server, you will need to forward the port to view the tensorboard.
-Take note of the port number was assigned in the previous cell.(i.e http://localhost:{port_number_assigned})
- -Locate the your VSCode terminal and select the Ports tab
-
    -
  • Add a new port with the port_number_assigned -
  • Change the port to 4000 and ensure that the forwarded Adress: localhost:{port_number_assigned} -
-Click on the link to view the tensorboard and it should open in your browser. -
-""" # %% [markdown] """ ## Load OME-Zarr Dataset From 8c192217a709e017b13f148a4072237968dd014a Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 16 Jul 2024 15:59:50 -0700 Subject: [PATCH 16/42] clean up the order for Part 2: fluor2phase, training call follows model construction --- examples/demo_dlmbl/solution.py | 43 +++++++++++++++++++++------------ 1 file changed, 27 insertions(+), 16 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 59501999..5010b9e9 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -184,10 +184,13 @@ def find_free_port(): # Launch TensorBoard on the browser def launch_tensorboard(log_dir): import subprocess + port = find_free_port() tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" process = subprocess.Popen(tensorboard_cmd, shell=True) - print(f"TensorBoard started at http://localhost:{port}. \n If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL.") + print( + f"TensorBoard started at http://localhost:{port}. \n If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL." + ) return process @@ -805,12 +808,6 @@ def min_max_scale(input): # Your code here (copy from above and modify as needed) ) -trainer = VSTrainer( - # Your code here (copy from above and modify as needed) -) -trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) - - # Visualize the graph of fluor2phase model as image. model_graph_fluor2phase = torchview.draw_graph( fluor2phase_model, @@ -909,6 +906,27 @@ def min_max_scale(input): freeze_encoder=False, ) +# Visualize the graph of fluor2phase model as image. +model_graph_fluor2phase = torchview.draw_graph( + fluor2phase_model, + fluor2phase_data.train_dataset[0]["source"], + depth=2, # adjust depth to zoom in. + device="cpu", +) +model_graph_fluor2phase.visual_graph + +# %% tags=[] +########################## +######## TODO ######## +########################## + +trainer = VSTrainer( + # Your code here (copy from above and modify as needed) +) +trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) + + +# %% tags=["solution"] trainer = VSTrainer( accelerator="gpu", devices=[GPU_ID], @@ -924,20 +942,13 @@ def min_max_scale(input): trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) # %% -# Visualize the graph of fluor2phase model as image. -model_graph_fluor2phase = torchview.draw_graph( - fluor2phase_model, - fluor2phase_data.train_dataset[0]["source"], - depth=2, # adjust depth to zoom in. - device="cpu", -) -model_graph_fluor2phase.visual_graph + # %% [markdown] tags=[] """
-### Task 2.3 +Task 2.3
While your model is training, let's think about the following questions: - What is the information content of each channel in the dataset? From 6db7aab1146bc2b287a43a223352c0d8c65805e3 Mon Sep 17 00:00:00 2001 From: Shalin Mehta Date: Tue, 16 Jul 2024 17:26:20 -0700 Subject: [PATCH 17/42] restructured fluor2phase sections --- examples/demo_dlmbl/solution.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 5010b9e9..03106bc5 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -941,8 +941,6 @@ def min_max_scale(input): ) trainer.fit(fluor2phase_model, datamodule=fluor2phase_data) -# %% - # %% [markdown] tags=[] """ From 293ad171a1d18a5624dc21fd30b6e4948e380b80 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:11:35 -0700 Subject: [PATCH 18/42] Text edits --- examples/demo_dlmbl/solution.py | 44 ++++++++++++++++++++++----------- 1 file changed, 30 insertions(+), 14 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 03106bc5..c0f0452f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -2,6 +2,8 @@ """ # Image translation (Virtual Staining) +Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco. + ## Overview In this exercise, we will predict fluorescence images of @@ -22,34 +24,48 @@ """ ### Goals -#### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and tensorboard. +#### Part 1: Learn to use iohub (I/O library), VisCy dataloaders, and TensorBoard. - - Use a `ome-zarr` dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), each FOV has 3 channels (phase, nuclei, and cell membrane). The nuclei were stained with DAPI and the cell membrane with Cellmask. - - Explore the OME-zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) and the high-content-screen (HCS) format. + - Use a OME-Zarr dataset of 34 FOVs of adenocarcinomic human alveolar basal epithelial cells (A549), + each FOV has 3 channels (phase, nuclei, and cell membrane). + The nuclei were stained with DAPI and the cell membrane with Cellmask. + - Explore OME-Zarr using [iohub](https://czbiohub-sf.github.io/iohub/main/index.html) + and the high-content-screen (HCS) format. - Use [MONAI](https://monai.io/) to implement data augmentations. #### Part 2: Train a model that predicts fluorescence from phase, and vice versa, using the UNeXt2 architecture. - - Create a model for image translation mapping from source domain to target domain where the source domain is label-free microscopy (material density) and the target domain is fluorescence microscopy (fluorophore density). - - Use the UNeXt2 architecture, a _purely convolutional architecture_ that draws on the design principles of transformer models to complete this task. Here we will use a *UNeXt2*, an efficient image translation architecture inspired by ConvNeXt v2, SparK. - - We will perform the preprocessing, training, prediction, evaluation, and deployment steps that borrow from our computer vision pipeline for single-cell analysis in our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). + - Create a model for image translation mapping from source domain to target domain + where the source domain is label-free microscopy (material density) + and the target domain is fluorescence microscopy (fluorophore density). + - Use the UNeXt2 architecture, a _purely convolutional architecture_ + that draws on the design principles of transformer models to complete this task. + Here we will use a *UNeXt2*, an efficient image translation architecture inspired by ConvNeXt v2 and SparK. + - We will perform the preprocessing, training, prediction, evaluation, and deployment steps + that borrow from our computer vision pipeline for single-cell analysis in + our pipeline called [VisCy](https://github.com/mehta-lab/VisCy). - Reuse the same architecture as above and create a similar model doing the inverse task (fluorescence to phase). - Evaluate the model. -#### (Extra) Play with the hyperparameters to improve the models or train a 3D UNeXt2** +#### (Extra) Play with the hyperparameters to improve the models or train a 3D UNeXt2 -Our guesstimate is that each of the three parts will take ~1-1.5 hours. A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. -The focus of the exercise is on understanding the information content of the data, how to train and evaluate 2D image translation models, and exploring some hyperparameters of the model. If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. +Our guesstimate is that each of the three parts will take ~1-1.5 hours. +A reasonable 2D UNet can be trained in ~30 min on a typical AWS node. +The focus of the exercise is on understanding the information content of the data, +how to train and evaluate 2D image translation models, and exploring some hyperparameters of the model. +If you complete this exercise and have time to spare, try the bonus exercise on 3D image translation. -Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos), our deep learning pipeline for training and deploying computer vision models for image-based phenotyping including the robust virtual staining of landmark organelles. VisCy exploits recent advances in data and metadata formats ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). +Checkout [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos), +our deep learning pipeline for training and deploying computer vision models +for image-based phenotyping including the robust virtual staining of landmark organelles. +VisCy exploits recent advances in data and metadata formats +([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, +[PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). ## References --- - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) -- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) - - -Written by Eduardo Hirata-Miyasaki, Ziwen Liu, and Shalin Mehta, CZ Biohub San Francisco. +- [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) """ From c067d53370af232665749d0f15360013ad0fdb45 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:19:30 -0700 Subject: [PATCH 19/42] markdown and html fixes --- examples/demo_dlmbl/solution.py | 26 ++++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index c0f0452f..e9c6b6b8 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -62,7 +62,7 @@ ([OME-zarr](https://www.nature.com/articles/s41592-021-01326-w)) and DL frameworks, [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). -## References +### References --- - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) @@ -71,7 +71,9 @@ # %% [markdown] """ -📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) and their performance with everyone via [this Google Doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). 📖 +📖 As you work through parts 2 and 3, please share the layouts of your models (output of torchview) +and their performance with everyone via +[this Google Doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). 📖 """ # %% [markdown] """ @@ -95,7 +97,7 @@ """ # %% [markdown] """ -# Part 1: Log training data to tensorboard, start training a model. +## Part 1: Log training data to tensorboard, start training a model. --------- Learning goals: @@ -662,7 +664,7 @@ def log_batch_jupyter(batch): # %% [markdown] """ -# Part 2: Assess previous model, train fluorescence to phase contrast translation model. +## Part 2: Assess previous model, train fluorescence to phase contrast translation model. -------------------------------------------------- We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model: @@ -1078,7 +1080,7 @@ def min_max_scale(input): # %% tags=[] """ -# (Extra)Tune the models and explore other architectures from [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos) +## (Extra)Tune the models and explore other architectures from [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos) -------------------------------------------------- Learning goals: - Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model. @@ -1247,11 +1249,19 @@ def min_max_scale(input): # %% [markdown] tags=[] """
- -## 🎉 The end of the notebook 🎉 + +

+🎉 The end of the notebook 🎉 +

Congratulations! You have trained several image translation models now! -Please remember to document the hyperparameters, snapshots of predictions on validation set, and loss curves for your models and add the final perforance in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing). We'll discuss our combined results as a group. +
+Please remember to document the hyperparameters, +snapshots of predictions on validation set, +and loss curves for your models and add the final performance in + +this google doc. +We'll discuss our combined results as a group.
""" From 1920ffcb2218f2885d628cb82c3eb8a310ce50b9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:31:09 -0700 Subject: [PATCH 20/42] merge import blocks --- examples/demo_dlmbl/solution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index e9c6b6b8..fdbd1b5a 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -108,7 +108,7 @@ - Start training the model to predict nuclei and membrane from phase. """ -# %% Imports and paths +# %% Imports import os from pathlib import Path @@ -124,7 +124,6 @@ from lightning.pytorch.loggers import TensorBoardLogger from skimage import metrics # for metrics. -# %% Imports and paths # pytorch lightning wrapper for Tensorboard. from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard @@ -146,6 +145,8 @@ RandWeightedCropd, ) +# %% +# seed random number generators for reproducibility. seed_everything(42, workers=True) # Paths to data and log directory @@ -161,6 +162,7 @@ raise FileNotFoundError( f"Data not found at {data_path}. Please check the top_dir and data_path variables." ) + # %% # Create log directory if needed, and launch tensorboard log_dir.mkdir(parents=True, exist_ok=True) From db68932b082a69562f7d6f64af5dbe8f52001cb1 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:34:34 -0700 Subject: [PATCH 21/42] add line break --- examples/demo_dlmbl/solution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index fdbd1b5a..cecf24ac 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -209,7 +209,8 @@ def launch_tensorboard(log_dir): tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}" process = subprocess.Popen(tensorboard_cmd, shell=True) print( - f"TensorBoard started at http://localhost:{port}. \n If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL." + f"TensorBoard started at http://localhost:{port}. \n" + "If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL." ) return process From 2384767b7422235e78a2f157908e4dd9b6d377dd Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:48:53 -0700 Subject: [PATCH 22/42] tweak typesetting --- examples/demo_dlmbl/solution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index cecf24ac..2df3636c 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -466,7 +466,7 @@ def log_batch_jupyter(batch): # Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. # # HINT: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). -# *Note* these are MONAI transforms that have been redefined for VisCy. +# *Note these are MONAI transforms that have been redefined for VisCy.* # Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? From ed6c21ddfd6ef4f951af74d4e288d4442370645e Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:51:57 -0700 Subject: [PATCH 23/42] use the actual pi and more line breaks --- examples/demo_dlmbl/solution.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 2df3636c..867519f2 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -461,18 +461,18 @@ def log_batch_jupyter(batch): # # How do they make the model more robust? # -# Add augmentations to rotate about `pi` along z-axis, 30% scale in y,x, shearing of 10% and no padding with zeros with a probablity of 80%. +# Add augmentations to rotate about $\pi$ along z-axis, 30% scale in y,x, +# shearing of 10% and no padding with zeros with a probablity of 80%. # # Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. # -# HINT: `RandAffined()` and `RandGaussianNoised()` from `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). +# HINT: `RandAffined()` and `RandGaussianNoised()` are from +# `viscy.transforms` [here](https://github.com/mehta-lab/VisCy/blob/main/viscy/transforms.py). # *Note these are MONAI transforms that have been redefined for VisCy.* - # Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? - -# # -# HINT:[Compare your choice of augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). +# HINT: +# [Compare your choice of augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). #
# %% # Here we turn on data augmentation and rerun setup From 15199f4c172fff9d4d8a7bc965eb4e109bf10040 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:52:25 -0700 Subject: [PATCH 24/42] rotate 'around' --- examples/demo_dlmbl/solution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 867519f2..ed45f9c4 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -461,7 +461,7 @@ def log_batch_jupyter(batch): # # How do they make the model more robust? # -# Add augmentations to rotate about $\pi$ along z-axis, 30% scale in y,x, +# Add augmentations to rotate about $\pi$ around z-axis, 30% scale in y,x, # shearing of 10% and no padding with zeros with a probablity of 80%. # # Add a Gaussian noise with a mean of 0.0 and standard deviation of 0.3 with a probability of 50%. From 3fb5c21870ac5efeb520ce32325e3e1e6686c6c4 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 13:54:57 -0700 Subject: [PATCH 25/42] addaptive figure size --- examples/demo_dlmbl/solution.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index ed45f9c4..e02d0e83 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -393,7 +393,9 @@ def log_batch_jupyter(batch): batch_phase = np.clip((batch_phase - p1) / (p99 - p1), 0, 1) plt.figure() - fig, axes = plt.subplots(batch_size, n_channels, figsize=(10, 10)) + fig, axes = plt.subplots( + batch_size, n_channels, figsize=(n_channels * 2, batch_size * 2) + ) [N, C, H, W] = batch_phase.shape for sample_id in range(batch_size): axes[sample_id, 0].imshow(batch_phase[sample_id, 0]) From 9cc5933fdb6d3c0831d26ccd05317e3a406d871f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:15:29 -0700 Subject: [PATCH 26/42] remove division --- examples/demo_dlmbl/solution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index e02d0e83..127b6563 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -63,7 +63,7 @@ [PyTorch Lightning](https://lightning.ai/) and [MONAI](https://monai.io/). ### References ---- + - [Liu, Z. and Hirata-Miyasaki, E. et al. (2024) Robust Virtual Staining of Cellular Landmarks](https://www.biorxiv.org/content/10.1101/2024.05.31.596901v2.full.pdf) - [Guo et al. (2020) Revealing architectural order with quantitative label-free imaging and deep learning. eLife](https://elifesciences.org/articles/55502) """ From 46222a31cc00cc68862b4d2564a140546374972d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:19:42 -0700 Subject: [PATCH 27/42] FOVs -> samples --- examples/demo_dlmbl/solution.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 127b6563..764e7554 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -434,7 +434,8 @@ def log_batch_jupyter(batch): data_module.setup("fit") print( - f"FOVs in training set: {len(data_module.train_dataset)}, FOVs in validation set:{len(data_module.val_dataset)}" + f"Samples in training set: {len(data_module.train_dataset)}, " + f"samples in validation set:{len(data_module.val_dataset)}" ) train_dataloader = data_module.train_dataloader() From 8e86fa3e094c60ae6d8050ed5ac86b32f9cbce76 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:21:06 -0700 Subject: [PATCH 28/42] line breaks --- examples/demo_dlmbl/solution.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 764e7554..da1653a9 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -600,7 +600,15 @@ def log_batch_jupyter(batch): # %% [markdown] # ## View model graph. # -# PyTorch uses dynamic graphs under the hood. The graphs are constructed on the fly. This is in contrast to TensorFlow, where the graph is constructed before the training loop and remains static. In other words, the graph of the network can change with every forward pass. Therefore, we need to supply an input tensor to construct the graph. The input tensor can be a random tensor of the correct shape and type. We can also supply a real image from the dataset. The latter is more useful for debugging. +# PyTorch uses dynamic graphs under the hood. +# The graphs are constructed on the fly. +# This is in contrast to TensorFlow, +# where the graph is constructed before the training loop and remains static. +# In other words, the graph of the network can change with every forward pass. +# Therefore, we need to supply an input tensor to construct the graph. +# The input tensor can be a random tensor of the correct shape and type. +# We can also supply a real image from the dataset. +# The latter is more useful for debugging. # %% [markdown] #
From b6646f5acbdd8c1c2ceb20f5b50203994aa2aea4 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:41:16 -0700 Subject: [PATCH 29/42] remove duplicate function and more line breaks --- examples/demo_dlmbl/solution.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index da1653a9..c15ad1a4 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -614,7 +614,8 @@ def log_batch_jupyter(batch): #
# # ### Task 1.4 -# Run the next cell to generate a graph representation of the model architecture. Can you recognize the UNet structure and skip connections in this graph visualization? +# Run the next cell to generate a graph representation of the model architecture. +# Can you recognize the UNet structure and skip connections in this graph visualization? #
# %% @@ -679,13 +680,16 @@ def log_batch_jupyter(batch): # %% [markdown] """ ## Part 2: Assess previous model, train fluorescence to phase contrast translation model. --------------------------------------------------- -We now look at some metrics of performance of previous model. We typically evaluate the model performance on a held out test data. We will use the following metrics to evaluate the accuracy of regression of the model: +We now look at some metrics of performance of previous model. +We typically evaluate the model performance on a held out test data. +We will use the following metrics to evaluate the accuracy of regression of the model: + - [Person Correlation](https://en.wikipedia.org/wiki/Pearson_correlation_coefficient). - [Structural similarity](https://en.wikipedia.org/wiki/Structural_similarity) (SSIM). -You should also look at the validation samples on tensorboard (hint: the experimental data in nuclei channel is imperfect.) +You should also look at the validation samples on tensorboard +(hint: the experimental data in nuclei channel is imperfect.) """ # %% [markdown] @@ -694,7 +698,8 @@ def log_batch_jupyter(batch): Task 2.1 Define metrics
-For each of the above metrics, write a brief definition of what they are and what they mean for this image translation task. +For each of the above metrics, write a brief definition of what they are and what they mean +for this image translation task.
""" @@ -729,11 +734,6 @@ def log_batch_jupyter(batch): columns=["pearson_nuc", "SSIM_nuc", "pearson_mem", "SSIM_mem"] ) - -def min_max_scale(input): - return (input - np.min(input)) / (np.max(input) - np.min(input)) - - # %% Compute metrics directly and plot here. for i, sample in enumerate(test_data.test_dataloader()): phase_image = sample["source"] From 74172774b6ea3e86c4fae680fd117b6cf79f54c9 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:42:36 -0700 Subject: [PATCH 30/42] send data to device --- examples/demo_dlmbl/solution.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index c15ad1a4..f716e7aa 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -784,7 +784,12 @@ def log_batch_jupyter(batch): axes[0, 0].set_title(channel_titles[0]) with torch.inference_mode(): # turn off gradient computation. - predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) + predicted_image = ( + phase2fluor_model(phase_image.to(phase2fluor_model.device)) + .cpu() + .numpy() + .squeeze(0) + ) target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images From 33caa8716d0b63bedbcf9582b9aa965c2fdb047a Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:48:01 -0700 Subject: [PATCH 31/42] fix multisamping in example data for graphing --- examples/demo_dlmbl/solution.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index f716e7aa..9a8b042c 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -785,11 +785,7 @@ def log_batch_jupyter(batch): with torch.inference_mode(): # turn off gradient computation. predicted_image = ( - phase2fluor_model(phase_image.to(phase2fluor_model.device)) - .cpu() - .numpy() - .squeeze(0) - ) + phase2fluor_model(phase_image).cpu().numpy().squeeze(0) target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images @@ -946,7 +942,7 @@ def log_batch_jupyter(batch): # Visualize the graph of fluor2phase model as image. model_graph_fluor2phase = torchview.draw_graph( fluor2phase_model, - fluor2phase_data.train_dataset[0]["source"], + next(iter(fluor2phase_data.train_dataloader()))["source"], depth=2, # adjust depth to zoom in. device="cpu", ) From 818965ee92be65ddba9ce91b50eac8388c27526b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:52:29 -0700 Subject: [PATCH 32/42] fix definition location --- examples/demo_dlmbl/solution.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 9a8b042c..7616e06d 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -735,6 +735,12 @@ def log_batch_jupyter(batch): ) # %% Compute metrics directly and plot here. + +def min_max_scale(input): + return (input - np.min(input)) / (np.max(input) - np.min(input)) + + + for i, sample in enumerate(test_data.test_dataloader()): phase_image = sample["source"] with torch.inference_mode(): # turn off gradient computation. @@ -1006,10 +1012,6 @@ def log_batch_jupyter(batch): test_metrics = pd.DataFrame(columns=["pearson_phase", "SSIM_phase"]) -def min_max_scale(input): - return (input - np.min(input)) / (np.max(input) - np.min(input)) - - # %% for i, sample in enumerate(test_data.test_dataloader()): source_image = sample["source"] From 0a387f2136af76d11c339380fe34b9d72bca7f50 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 17 Jul 2024 14:59:44 -0700 Subject: [PATCH 33/42] fix device and markdown block --- examples/demo_dlmbl/solution.py | 44 +++++++++++++++++++-------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 7616e06d..1227eaa4 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -736,11 +736,11 @@ def log_batch_jupyter(batch): # %% Compute metrics directly and plot here. + def min_max_scale(input): return (input - np.min(input)) / (np.max(input) - np.min(input)) - for i, sample in enumerate(test_data.test_dataloader()): phase_image = sample["source"] with torch.inference_mode(): # turn off gradient computation. @@ -791,7 +791,11 @@ def min_max_scale(input): with torch.inference_mode(): # turn off gradient computation. predicted_image = ( - phase2fluor_model(phase_image).cpu().numpy().squeeze(0) + phase2fluor_model(phase_image.to(phase2fluor_model.device)) + .cpu() + .numpy() + .squeeze(0) + ) target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images @@ -1016,7 +1020,7 @@ def min_max_scale(input): for i, sample in enumerate(test_data.test_dataloader()): source_image = sample["source"] with torch.inference_mode(): # turn off gradient computation. - predicted_image = fluor2phase_model(source_image) + predicted_image = fluor2phase_model(source_image.to(fluor2phase_model.device)) target_image = ( sample["target"].cpu().numpy().squeeze(0) @@ -1062,7 +1066,12 @@ def min_max_scale(input): axes[0].set_title(channel_titles[0]) with torch.inference_mode(): # turn off gradient computation. - predicted_image = phase2fluor_model(phase_image).cpu().numpy().squeeze(0) + predicted_image = ( + phase2fluor_model(phase_image.to(phase2fluor_model.device)) + .cpu() + .numpy() + .squeeze(0) + ) target_image = sample["target"].cpu().numpy().squeeze(0) # Plot the predicted images @@ -1106,21 +1115,20 @@ def min_max_scale(input): # %% [markdown] tags=[] -""" -
- -### Extra Part - -- Choose a model you want to train (phase2fluor or fluor2phase). -- Set up a configuration that you think will improve the performance of the model -- Consider modifying the learning rate and see how it changes performance -- Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop. -- Add code to evaluate the model using Pearson Correlation and SSIM - -As your model is training, please document hyperparameters, snapshots of predictions on validation set, and loss curves for your models in [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing) +#
+# +# ### Extra Part +# +# - Choose a model you want to train (phase2fluor or fluor2phase). +# - Set up a configuration that you think will improve the performance of the model +# - Consider modifying the learning rate and see how it changes performance +# - Use training loop illustrated in previous cells to train phase2fluor and fluor2phase models to prototype your own training loop. +# - Add code to evaluate the model using Pearson Correlation and SSIM +# As your model is training, please document hyperparameters, snapshots of predictions on validation set, +# and loss curves for your models in +# [this google doc](https://docs.google.com/document/d/1Mq-yV8FTG02xE46Mii2vzPJVYSRNdeOXkeU-EKu-irE/edit?usp=sharing) +#
-
-""" # %% tags=[] ########################## ######## TODO ######## From 97937e96e9b5f45ec2c87cdf17feece8fee9024f Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Wed, 17 Jul 2024 16:47:56 -0700 Subject: [PATCH 34/42] reordering variable that prevented fluor2phase not to run if previous cells were not run --- examples/demo_dlmbl/solution.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 1227eaa4..bdadf8b2 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -872,7 +872,6 @@ def min_max_scale(input): YX_PATCH_SIZE = (256, 256) BATCH_SIZE = 12 n_epochs = 50 -steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. # Setup the new augmentations augmentations = [ @@ -927,6 +926,10 @@ def min_max_scale(input): ) fluor2phase_data.setup("fit") +n_samples = len(fluor2phase_data.train_dataset) + +steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. + # Dictionary that specifies key parameters of the model. fluor2phase_config = dict( in_channels=1, @@ -953,7 +956,7 @@ def min_max_scale(input): model_graph_fluor2phase = torchview.draw_graph( fluor2phase_model, next(iter(fluor2phase_data.train_dataloader()))["source"], - depth=2, # adjust depth to zoom in. + depth=3, # adjust depth to zoom in. device="cpu", ) model_graph_fluor2phase.visual_graph From 1efa4332bade7a24f9396b85b190bbb58c4bcf62 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Thu, 18 Jul 2024 11:21:02 -0700 Subject: [PATCH 35/42] fixing vscode and juypter readability --- examples/demo_dlmbl/solution.py | 167 ++++++++++++++++++++++++++------ 1 file changed, 136 insertions(+), 31 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index bdadf8b2..0867beef 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -183,7 +183,6 @@ Locate the your VSCode terminal and select the Ports tab
  • Add a new port with the port_number_assigned -
  • Change the port to 4000 and ensure that the forwarded Adress: localhost:{port_number_assigned}
Click on the link to view the tensorboard and it should open in your browser.
@@ -362,8 +361,6 @@ def log_batch_tensorboard(batch, batchno, writer, card_name): # %% # Define a function to visualize a batch on jupyter, in case tensorboard is finicky - - def log_batch_jupyter(batch): """ Logs a batch of images on jupyter using ipywidget. @@ -412,7 +409,7 @@ def log_batch_jupyter(batch): # %% # Initialize the data module. -BATCH_SIZE = 6 +BATCH_SIZE = 4 # 6 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything. # More seriously, batch size does not have to be a power of 2. @@ -449,12 +446,12 @@ def log_batch_jupyter(batch): # %% [markdown] -# If your tensorboard is causing issues, you can visualize directly on Jupyter ☄️/VSCode +# If your tensorboard is causing issues, you can visualize directly on Jupyter /VSCode # %% log_batch_jupyter(batch) -# %% [markdown] +# %% [markdown] tags=[] #
# # ### Task 1.3 @@ -498,6 +495,10 @@ def log_batch_jupyter(batch): sigma_z=(0.0, 0.0), prob=0.5, ), + + # ####################### + # ##### TODO ######## + # ####################### ##TODO: Add rotation agumentations ## Write code below ## TODO: Add Random Gaussian Noise @@ -508,7 +509,7 @@ def log_batch_jupyter(batch): NormalizeSampled( keys=source_channel + target_channel, level="fov_statistics", - subtrahend="median", + subtrahend="mean", divisor="std", ) ] @@ -525,6 +526,63 @@ def log_batch_jupyter(batch): log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") writer.close() +#%% tags=["solution"] +# ####################### +# ##### SOLUTION ######## +# ####################### +source_channel = ["Phase3D"] +target_channel = ["Nucl", "Mem"] + +augmentations = [ + RandWeightedCropd( + keys=source_channel + target_channel, + spatial_size=(1, 384, 384), + num_samples=2, + w_key=target_channel[0], + ), + RandAffined( + keys=source_channel + target_channel, + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], + ), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), +] + +normalizations = [ + NormalizeSampled( + keys=source_channel + target_channel, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) +] + +data_module.augmentations = augmentations +data_module.setup("fit") + +# get the new data loader with augmentation turned on +augmented_train_dataloader = data_module.train_dataloader() + +# Draw batches and write to tensorboard +writer = SummaryWriter(log_dir=f"{log_dir}/view_batch") +augmented_batch = next(iter(augmented_train_dataloader)) +log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") +writer.close() + + + # %% [markdown] # Visualize directly on Jupyter ☄️ @@ -636,13 +694,14 @@ def log_batch_jupyter(batch): """
-### Task 1.5 +

Task 1.5

Start training by running the following cell. Check the new logs on the tensorboard.
""" # %% # Check if GPU is available +# You can check by typing `nvidia-smi` GPU_ID = 0 n_samples = len(phase2fluor_2D_data.train_dataset) @@ -669,7 +728,7 @@ def log_batch_jupyter(batch): """
-## Checkpoint 1 +

Checkpoint 1

Now the training has started, we can come back after a while and evaluate the performance! @@ -696,10 +755,10 @@ def log_batch_jupyter(batch): """
- Task 2.1 Define metrics
+

Task 2.1 Define metrics

For each of the above metrics, write a brief definition of what they are and what they mean -for this image translation task. +for this image translation task. Use your favorite search engine and/or resources.
""" @@ -709,16 +768,23 @@ def log_batch_jupyter(batch): # ####################### # ##### Todo ############ # ####################### +# # ``` # # - Pearson Correlation: # # - Structural similarity: -# %% Compute metrics directly and plot here. +# %% [markdown] +""" +Let's compute metrics directly and plot below. +""" +# %% +# Setup the test data module. test_data_path = top_dir / "06_image_translation/test/a549_hoechst_cellmask_test.zarr" source_channel = ["Phase3D"] target_channel = ["Nucl", "Mem"] + test_data = HCSDataModule( test_data_path, source_channel=source_channel, @@ -827,7 +893,7 @@ def min_max_scale(input): """
-Task 2.2 Train fluorescence to phase contrast translation model
+

Task 2.2 Train fluorescence to phase contrast translation model

Instantiate a data module, model, and trainer for fluorescence to phase contrast translation. Copy over the code from previous cells and update the parameters. Give the variables and paths a different name/suffix (fluor2phase) to avoid overwriting objects used to train phase2fluor models.
@@ -905,7 +971,7 @@ def min_max_scale(input): NormalizeSampled( keys=source_channel + target_channel, level="fov_statistics", - subtrahend="median", + subtrahend="mean", divisor="std", ) ] @@ -992,7 +1058,7 @@ def min_max_scale(input): """
-Task 2.3
+

Task 2.3

While your model is training, let's think about the following questions: - What is the information content of each channel in the dataset? @@ -1107,20 +1173,27 @@ def min_max_scale(input): """ -# %% tags=[] +# %% [markdown] tags=[] """ -## (Extra)Tune the models and explore other architectures from [VisCy](https://github.com/mehta-lab/VisCy/tree/main/examples/demos) --------------------------------------------------- -Learning goals: -- Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model. -- How can we scale it up from 2D to 3D training and predictions? +
+ +

Extra exercises

+Tune the models and explore other architectures from VisCy +
+

Learning goals:

+
    +
  • Understand how data, model capacity, and training parameters control the performance of the model. Your goal is to try to underfit or overfit the model.
  • +
  • How can we scale it up from 2D to 3D training and predictions?
  • +
+
+ """ # %% [markdown] tags=[] #
# -# ### Extra Part +# ### Extra Example 1: Hyperparameter tuning # # - Choose a model you want to train (phase2fluor or fluor2phase). # - Set up a configuration that you think will improve the performance of the model @@ -1208,6 +1281,33 @@ def min_max_scale(input): fast_dev_run=True, ) # Set fast_dev_run to False to train the model. trainer.fit(phase2fluor_model_low_lr, datamodule=phase2fluor_2D_data) +# %% [markdown] +""" +
+

+Extra Example 2: 3D Virtual Staining +

+Now, let's implement a 3D virtual staining model(Phase->Fluorescence)
+Note: This task might take longer to train +1 hr. Try it out in your free-time. + +
+""" + +# %% tags=["task"] +data_path = Path() # TODO: Point to a 3D dataset (HEK, Neuromast) +BATCH_SIZE = 4 +YX_PATCH_SIZE = (256, 256) + +phase2fluor_3D_config = ... + +phase2fluor_3D_data = HCSDataModule(...) + +phase2fluor_3D = VSUNet(...) + +trainer = VSTrainer(...) + +# Start the training +trainer.fit(...) # %% tags=["solution"] @@ -1226,8 +1326,8 @@ def min_max_scale(input): ``` """ - -data_path = Path() # TODO: Point to a 3D dataset (HEK, Neuromast) +# TODO: Point to a 3D dataset (HEK, Neuromast) +data_path = Path("./raw-and-reconstructed.zarr") BATCH_SIZE = 4 YX_PATCH_SIZE = (256, 256) @@ -1235,6 +1335,14 @@ def min_max_scale(input): source_channel = ["Phase3D"] target_channel = ["Nucl", "Mem"] +phase2fluor_3D_config = dict( + in_channels=1, + out_channels=2, + in_stack_depth=5, + backbone="convnextv2_tiny", + deconder_conv_blocks=2, + head_expansion_ratio=4, +) phase2fluor_3D_data = HCSDataModule( data_path, architecture="UNeXt2", @@ -1251,7 +1359,7 @@ def min_max_scale(input): phase2fluor_3D = VSUNet( architecture="UNeXt2", - model_config=phase2fluor_config.copy(), + model_config=phase2fluor_3D_config.copy(), loss_function=MixedLoss(l1_alpha=0.5, l2_alpha=0.0, ms_dssim_alpha=0.5), lr=2e-4, schedule="WarmupCosine", @@ -1286,11 +1394,8 @@ def min_max_scale(input):
Please remember to document the hyperparameters, snapshots of predictions on validation set, -and loss curves for your models and add the final performance in - -this google doc. +and loss curves for your models and add the final performance in +this google doc . We'll discuss our combined results as a group.
-""" - -# %% +""" \ No newline at end of file From 9e317dfa357a60ee2421b5cfa3b73b4bd0db7bde Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 18 Jul 2024 14:13:22 -0700 Subject: [PATCH 36/42] remove old debug script --- examples/demo_dlmbl/debug_log_graph.py | 97 -------------------------- 1 file changed, 97 deletions(-) delete mode 100644 examples/demo_dlmbl/debug_log_graph.py diff --git a/examples/demo_dlmbl/debug_log_graph.py b/examples/demo_dlmbl/debug_log_graph.py deleted file mode 100644 index ec987118..00000000 --- a/examples/demo_dlmbl/debug_log_graph.py +++ /dev/null @@ -1,97 +0,0 @@ - -# %% -# %% Imports and paths - -from pathlib import Path - -import matplotlib.pyplot as plt -import numpy as np -import pandas as pd -import torch -import torchview -import torchvision -from iohub import open_ome_zarr -from lightning.pytorch import seed_everything -from lightning.pytorch.loggers import CSVLogger - -# pytorch lightning wrapper for Tensorboard. -from tensorboard import notebook # for viewing tensorboard in notebook -from torch.utils.tensorboard import SummaryWriter # for logging to tensorboard - -# HCSDataModule makes it easy to load data during training. -from viscy.data.hcs import HCSDataModule - -# Trainer class and UNet. -from viscy.light.engine import VSUNet -from viscy.light.trainer import VSTrainer - -seed_everything(42, workers=True) - -# Paths to data and log directory -data_path = Path( - Path("~/data/04_image_translation/HEK_nuclei_membrane_pyramid.zarr/") -).expanduser() - -log_dir = Path("~/data/04_image_translation/logs/").expanduser() - -# Create log directory if needed, and launch tensorboard -log_dir.mkdir(parents=True, exist_ok=True) - -# fmt: off -%reload_ext tensorboard -%tensorboard --logdir {log_dir} --port 6007 --bind_all -# fmt: on - -# %% The entire training loop is contained in this cell. - -GPU_ID = 0 -BATCH_SIZE = 10 -YX_PATCH_SIZE = (512, 512) - - -# Dictionary that specifies key parameters of the model. -phase2fluor_config = { - "architecture": "2D", - "num_filters": [24, 48, 96, 192, 384], - "in_channels": 1, - "out_channels": 2, - "residual": True, - "dropout": 0.1, # dropout randomly turns off weights to avoid overfitting of the model to data. - "task": "reg", # reg = regression task. -} - -phase2fluor_model = VSUNet( - model_config=phase2fluor_config.copy(), - batch_size=BATCH_SIZE, - loss_function=torch.nn.functional.l1_loss, - schedule="WarmupCosine", - log_num_samples=10, # Number of samples from each batch to log to tensorboard. - example_input_yx_shape=YX_PATCH_SIZE, -) - -# Reinitialize the data module. -phase2fluor_data = HCSDataModule( - data_path, - source_channel="Phase", - target_channel=["Nuclei", "Membrane"], - z_window_size=1, - split_ratio=0.8, - batch_size=BATCH_SIZE, - num_workers=8, - architecture="2D", - yx_patch_size=YX_PATCH_SIZE, - augmentations=None, -) -phase2fluor_data.setup("fit") - - -# Train for 3 epochs to see if you can log graph. -trainer = VSTrainer(accelerator="gpu", devices=[GPU_ID], max_epochs=3, default_root_dir=log_dir) - -# trainer class takes the model and the data module as inputs. -trainer.fit(phase2fluor_model, datamodule=phase2fluor_data) - -# %% Is exmple_input_array present? -print(f'{phase2fluor_model.example_input_array.shape},{phase2fluor_model.example_input_array.dtype}') -trainer.logger.log_graph(phase2fluor_model, phase2fluor_model.example_input_array) -# %% From e18ae175a650fb910c87f7af9925d35a5b0b4801 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 18 Jul 2024 14:16:43 -0700 Subject: [PATCH 37/42] fix syntax --- examples/demo_dlmbl/solution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 0867beef..d060e6f0 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -231,7 +231,7 @@ def launch_tensorboard(log_dir): - The layout on the disk is: `row/col/field/pyramid_level/timepoint/channel/z/y/x.` """ -# %%[markdown] +# %% [markdown] """
You can inspect the tree structure by using your terminal: From 5db9c62175ca48b82425b184058e2f2fa164af92 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 18 Jul 2024 14:19:39 -0700 Subject: [PATCH 38/42] fix hyperlink --- examples/demo_dlmbl/solution.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index d060e6f0..65e6457d 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -1394,8 +1394,10 @@ def min_max_scale(input):
Please remember to document the hyperparameters, snapshots of predictions on validation set, -and loss curves for your models and add the final performance in -this google doc . +and loss curves for your models and add the final performance in + +this google doc +. We'll discuss our combined results as a group.
""" \ No newline at end of file From 16acdf2df0fec872def813f896600b85c412ce1d Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 18 Jul 2024 14:22:30 -0700 Subject: [PATCH 39/42] fix batch size story --- examples/demo_dlmbl/solution.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 65e6457d..a171ab5f 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -409,10 +409,10 @@ def log_batch_jupyter(batch): # %% # Initialize the data module. -BATCH_SIZE = 4 +BATCH_SIZE = 5 -# 6 is a perfectly reasonable batch size. After all, it is the answer to the ultimate question of life, the universe and everything. -# More seriously, batch size does not have to be a power of 2. +# 5 is a perfectly reasonable batch size +# (batch size does not have to be a power of 2) # See: https://sebastianraschka.com/blog/2022/batch-size-2.html data_module = HCSDataModule( From 15003664a9a6e0ca5cf425294a59a0a8d6939b4b Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Thu, 18 Jul 2024 14:23:50 -0700 Subject: [PATCH 40/42] remove emojis --- examples/demo_dlmbl/solution.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index a171ab5f..6618a3f4 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -495,7 +495,6 @@ def log_batch_jupyter(batch): sigma_z=(0.0, 0.0), prob=0.5, ), - # ####################### # ##### TODO ######## # ####################### @@ -526,7 +525,7 @@ def log_batch_jupyter(batch): log_batch_tensorboard(augmented_batch, 0, writer, "augmentation/some") writer.close() -#%% tags=["solution"] +# %% tags=["solution"] # ####################### # ##### SOLUTION ######## # ####################### @@ -582,9 +581,8 @@ def log_batch_jupyter(batch): writer.close() - # %% [markdown] -# Visualize directly on Jupyter ☄️ +# Visualize directly on Jupyter # %% log_batch_jupyter(augmented_batch) @@ -768,7 +766,7 @@ def log_batch_jupyter(batch): # ####################### # ##### Todo ############ # ####################### -# +# # ``` # # - Pearson Correlation: @@ -1400,4 +1398,4 @@ def min_max_scale(input): . We'll discuss our combined results as a group.
-""" \ No newline at end of file +""" From 3826087e816719ad36653995987982013bfd4b15 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 19 Jul 2024 09:15:28 -0700 Subject: [PATCH 41/42] making 3D_UNeXt2 example work --- examples/demo_dlmbl/solution.py | 57 +++++++++++++++++++++++++++++---- 1 file changed, 50 insertions(+), 7 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 6618a3f4..7c0d6e28 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -1327,19 +1327,59 @@ def min_max_scale(input): # TODO: Point to a 3D dataset (HEK, Neuromast) data_path = Path("./raw-and-reconstructed.zarr") BATCH_SIZE = 4 -YX_PATCH_SIZE = (256, 256) +YX_PATCH_SIZE = (384, 384) +GPU_ID = 0 +n_epochs = 50 ## For 3D training - VSCyto3D -source_channel = ["Phase3D"] -target_channel = ["Nucl", "Mem"] +source_channel = ["reconstructed-labelfree"] +target_channel = ["reconstructed-nucleus", "reconstructed-membrane"] + +# Setup the new augmentations +augmentations = [ + RandWeightedCropd( + keys=source_channel + target_channel, + spatial_size=(-1, 384, 384), + num_samples=2, + w_key=target_channel[0], + ), + RandAffined( + keys=source_channel + target_channel, + rotate_range=[3.14, 0.0, 0.0], + scale_range=[0.0, 0.3, 0.3], + prob=0.8, + padding_mode="zeros", + shear_range=[0.0, 0.01, 0.01], + ), + RandAdjustContrastd(keys=source_channel, prob=0.5, gamma=(0.8, 1.2)), + RandScaleIntensityd(keys=source_channel, factors=0.5, prob=0.5), + RandGaussianNoised(keys=source_channel, prob=0.5, mean=0.0, std=0.3), + RandGaussianSmoothd( + keys=source_channel, + sigma_x=(0.25, 0.75), + sigma_y=(0.25, 0.75), + sigma_z=(0.0, 0.0), + prob=0.5, + ), +] + +normalizations = [ + NormalizeSampled( + keys=source_channel + target_channel, + level="fov_statistics", + subtrahend="mean", + divisor="std", + ) +] phase2fluor_3D_config = dict( in_channels=1, out_channels=2, in_stack_depth=5, backbone="convnextv2_tiny", - deconder_conv_blocks=2, + decoder_conv_blocks=2, head_expansion_ratio=4, + stem_kernel_size=(5, 4, 4), ) phase2fluor_3D_data = HCSDataModule( data_path, @@ -1354,6 +1394,10 @@ def min_max_scale(input): augmentations=augmentations, normalizations=normalizations, ) +phase2fluor_3D_data.setup("fit") + +n_samples = len(phase2fluor_3D_data.train_dataset) +steps_per_epoch = n_samples // BATCH_SIZE # steps per epoch. phase2fluor_3D = VSUNet( architecture="UNeXt2", @@ -1371,15 +1415,14 @@ def min_max_scale(input): log_every_n_steps=steps_per_epoch, logger=TensorBoardLogger( save_dir=log_dir, - name="phase2fluor", + name="phase2fluor_3D", version="3D_UNeXt2", log_graph=True, ), - fast_dev_run=True, + fast_dev_run=True, # TODO: Set to False to run full-training ) trainer.fit(phase2fluor_3D, datamodule=phase2fluor_3D_data) - # %% [markdown] tags=[] """
From f6080c887e3031f61ffef3d49c2e337d7210b763 Mon Sep 17 00:00:00 2001 From: Eduardo Hirata-Miyasaki Date: Fri, 19 Jul 2024 09:31:22 -0700 Subject: [PATCH 42/42] pointing to config files from 0.1.0 release --- examples/demo_dlmbl/solution.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/demo_dlmbl/solution.py b/examples/demo_dlmbl/solution.py index 7c0d6e28..553fcf27 100644 --- a/examples/demo_dlmbl/solution.py +++ b/examples/demo_dlmbl/solution.py @@ -472,7 +472,7 @@ def log_batch_jupyter(batch): # Can you tell what augmentation were applied from looking at the augmented images in Tensorboard? # # HINT: -# [Compare your choice of augmentations here](https://github.com/mehta-lab/VisCy/blob/b89f778b34735553cf155904eef134c756708ff2/viscy/light/data.py#L529). +# [Compare your choice of augmentations by dowloading the pretrained models and config files](https://github.com/mehta-lab/VisCy/releases/download/v0.1.0/VisCy-0.1.0-VS-models.zip). #
# %% # Here we turn on data augmentation and rerun setup @@ -1339,7 +1339,7 @@ def min_max_scale(input): augmentations = [ RandWeightedCropd( keys=source_channel + target_channel, - spatial_size=(-1, 384, 384), + spatial_size=(-1, 512, 512), num_samples=2, w_key=target_channel[0], ),