Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: torch cnn 3D unet fixes #275

Open
wants to merge 17 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion aydin/it/cnn_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,11 @@ def _get_model_class_from_string(model_name):
"aydin.nn.models" + '.' + module_of_interest.name
)

class_name = [x for x in dir(response) if model_name + "model" in x.lower()][0]
class_name = [
x
for x in dir(response)
if model_name.replace('_', '') + "model" in x.lower()
][0]

model_class = response.__getattribute__(class_name)

Expand Down
27 changes: 23 additions & 4 deletions aydin/it/test/test_cnn.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,37 @@
import numpy

from aydin.analysis.image_metrics import calculate_print_psnr_ssim
from aydin.io.datasets import camera, add_noise, normalise
from aydin.io.datasets import (
add_noise,
normalise,
examples_single,
camera,
)
from aydin.it.cnn_torch import ImageTranslatorCNNTorch


def test_it_cnn_jinet2D_light():
train_and_evaluate_cnn(camera(), model="jinet")


def test_it_cnn_jinet3D_light():
train_and_evaluate_cnn(
examples_single.myers_tribolium.get_array()[16:48, 300:332, 300:332],
model="jinet",
)


def test_it_cnn_unet2d():
train_and_evaluate_cnn(camera(), model="unet")


def test_it_cnn_unet3d():
train_and_evaluate_cnn(
examples_single.janelia_flybrain.get_array()[:32, 1:2, :32, :32],
model="unet",
)


def train_and_evaluate_cnn(input_image, model="jinet"):
"""
Demo for self-supervised denoising using camera image with synthetic noise
Expand All @@ -28,9 +47,9 @@ def train_and_evaluate_cnn(input_image, model="jinet"):
it.train(noisy, noisy)
denoised = it.translate(noisy, tile_size=image.shape[0])

image = numpy.clip(image, 0, 1)
noisy = numpy.clip(noisy.reshape(image.shape), 0, 1)
denoised = numpy.clip(denoised, 0, 1)
image = numpy.squeeze(numpy.clip(image, 0, 1))
noisy = numpy.squeeze(numpy.clip(noisy.reshape(image.shape), 0, 1))
denoised = numpy.squeeze(numpy.clip(denoised, 0, 1))

psnr_noisy, psnr_denoised, ssim_noisy, ssim_denoised = calculate_print_psnr_ssim(
image, noisy, denoised
Expand Down
26 changes: 23 additions & 3 deletions aydin/nn/datasets/random_masked_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,33 @@ def interpolate_mask(self, tensor, mask, mask_inv):
mask = mask.to(device)
mask_inv = mask_inv.to(device)

kernel = numpy.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)])
if len(self.image.shape) == 4:
kernel = numpy.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], [0.5, 1.0, 0.5]])
elif len(self.image.shape) == 5:
kernel = numpy.array(
[
[[0.5, 0.5, 0.5], [0.5, 1.0, 0.5], [0.5, 0.5, 0.5]],
[[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], [0.5, 1.0, 0.5]],
[[0.5, 0.5, 0.5], [0.5, 1.0, 0.5], [0.5, 0.5, 0.5]],
]
)
else:
raise ValueError("Image can only have 2 or 3 spacetime dimensions...")

kernel = kernel[numpy.newaxis, numpy.newaxis, :, :]
kernel = torch.Tensor(kernel).to(device)
kernel = kernel / kernel.sum()

filtered_tensor = torch.nn.functional.conv2d(
tensor, kernel, stride=1, padding=1
conv_method = (
torch.nn.functional.conv2d
if len(self.image.shape) == 4
else torch.nn.functional.conv3d
)

filtered_tensor = conv_method(
tensor,
kernel,
padding=1,
)

return filtered_tensor * mask + tensor * mask_inv
Expand Down
2 changes: 1 addition & 1 deletion aydin/nn/models/jinet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(
nb_dense_layers: int = 3,
nb_channels: int = None,
final_relu: bool = False,
degressive_residuals: bool = False, # TODO: check what happens when this is True
degressive_residuals: bool = True, # TODO: check what happens when this is True
):
super(JINetModel, self).__init__()

Expand Down
2 changes: 0 additions & 2 deletions aydin/nn/models/linear_scaling_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def __init__(
spacetime_ndim,
nb_unet_levels: int = 4,
nb_filters: int = 8,
learning_rate=0.01,
pooling_mode: str = 'max',
):
super(LinearScalingUNetModel, self).__init__()

self.spacetime_ndim = spacetime_ndim
self.nb_unet_levels = nb_unet_levels
self.nb_filters = nb_filters
self.learning_rate = learning_rate
self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode)
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def __init__(
spacetime_ndim,
nb_unet_levels: int = 4,
nb_filters: int = 8,
learning_rate=0.01,
pooling_mode: str = 'max',
):
super(ResidualUNetModel, self).__init__()

self.spacetime_ndim = spacetime_ndim
self.nb_unet_levels = nb_unet_levels
self.nb_filters = nb_filters
self.learning_rate = learning_rate
self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode)
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')

Expand Down
2 changes: 1 addition & 1 deletion aydin/nn/models/test/test_residual_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import pytest
import torch

from aydin.nn.models.res_unet import ResidualUNetModel
from aydin.nn.models.residual_unet import ResidualUNetModel


@pytest.mark.parametrize("nb_unet_levels", [2, 3, 5, 8])
Expand Down
2 changes: 1 addition & 1 deletion aydin/nn/models/test/test_training_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import torch

from aydin.io.datasets import add_noise, camera, normalise
from aydin.nn.models.res_unet import ResidualUNetModel
from aydin.nn.models.residual_unet import ResidualUNetModel
from aydin.nn.models.unet import UNetModel
from aydin.nn.training_methods.n2s import n2s_train
from aydin.nn.training_methods.n2t import n2t_train
Expand Down
2 changes: 0 additions & 2 deletions aydin/nn/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,13 @@ def __init__(
spacetime_ndim,
nb_unet_levels: int = 3,
nb_filters: int = 8,
learning_rate=0.01,
pooling_mode: str = 'max',
):
super(UNetModel, self).__init__()

self.spacetime_ndim = spacetime_ndim
self.nb_unet_levels = nb_unet_levels
self.nb_filters = nb_filters
self.learning_rate = learning_rate
self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode)
self.upsampling = nn.Upsample(scale_factor=2, mode='nearest')

Expand Down
4 changes: 2 additions & 2 deletions aydin/nn/training_methods/n2s.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def n2s_train(
nb_epochs: int = 128,
lr: float = 0.001,
# patch_size: int = 32,
patience: int = 128,
patience: int = 4,
verbose: bool = True,
):
"""
Expand All @@ -39,7 +39,7 @@ def n2s_train(
model = model.to(device)
print(f"device {device}")

optimizer = AdamW(model.parameters(), lr=lr)
optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0)

# optimizer = ESAdam(
# chain(model.parameters()),
Expand Down