diff --git a/aydin/it/cnn_torch.py b/aydin/it/cnn_torch.py index 8d5d1e89..986768ef 100644 --- a/aydin/it/cnn_torch.py +++ b/aydin/it/cnn_torch.py @@ -120,7 +120,11 @@ def _get_model_class_from_string(model_name): "aydin.nn.models" + '.' + module_of_interest.name ) - class_name = [x for x in dir(response) if model_name + "model" in x.lower()][0] + class_name = [ + x + for x in dir(response) + if model_name.replace('_', '') + "model" in x.lower() + ][0] model_class = response.__getattribute__(class_name) diff --git a/aydin/it/test/test_cnn.py b/aydin/it/test/test_cnn.py index 738c3ade..baec946b 100644 --- a/aydin/it/test/test_cnn.py +++ b/aydin/it/test/test_cnn.py @@ -1,7 +1,12 @@ import numpy from aydin.analysis.image_metrics import calculate_print_psnr_ssim -from aydin.io.datasets import camera, add_noise, normalise +from aydin.io.datasets import ( + add_noise, + normalise, + examples_single, + camera, +) from aydin.it.cnn_torch import ImageTranslatorCNNTorch @@ -9,10 +14,24 @@ def test_it_cnn_jinet2D_light(): train_and_evaluate_cnn(camera(), model="jinet") +def test_it_cnn_jinet3D_light(): + train_and_evaluate_cnn( + examples_single.myers_tribolium.get_array()[16:48, 300:332, 300:332], + model="jinet", + ) + + def test_it_cnn_unet2d(): train_and_evaluate_cnn(camera(), model="unet") +def test_it_cnn_unet3d(): + train_and_evaluate_cnn( + examples_single.janelia_flybrain.get_array()[:32, 1:2, :32, :32], + model="unet", + ) + + def train_and_evaluate_cnn(input_image, model="jinet"): """ Demo for self-supervised denoising using camera image with synthetic noise @@ -28,9 +47,9 @@ def train_and_evaluate_cnn(input_image, model="jinet"): it.train(noisy, noisy) denoised = it.translate(noisy, tile_size=image.shape[0]) - image = numpy.clip(image, 0, 1) - noisy = numpy.clip(noisy.reshape(image.shape), 0, 1) - denoised = numpy.clip(denoised, 0, 1) + image = numpy.squeeze(numpy.clip(image, 0, 1)) + noisy = numpy.squeeze(numpy.clip(noisy.reshape(image.shape), 0, 1)) + denoised = numpy.squeeze(numpy.clip(denoised, 0, 1)) psnr_noisy, psnr_denoised, ssim_noisy, ssim_denoised = calculate_print_psnr_ssim( image, noisy, denoised diff --git a/aydin/nn/datasets/random_masked_dataset.py b/aydin/nn/datasets/random_masked_dataset.py index d6788911..3fa62c32 100644 --- a/aydin/nn/datasets/random_masked_dataset.py +++ b/aydin/nn/datasets/random_masked_dataset.py @@ -52,13 +52,33 @@ def interpolate_mask(self, tensor, mask, mask_inv): mask = mask.to(device) mask_inv = mask_inv.to(device) - kernel = numpy.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], (0.5, 1.0, 0.5)]) + if len(self.image.shape) == 4: + kernel = numpy.array([[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], [0.5, 1.0, 0.5]]) + elif len(self.image.shape) == 5: + kernel = numpy.array( + [ + [[0.5, 0.5, 0.5], [0.5, 1.0, 0.5], [0.5, 0.5, 0.5]], + [[0.5, 1.0, 0.5], [1.0, 0.0, 1.0], [0.5, 1.0, 0.5]], + [[0.5, 0.5, 0.5], [0.5, 1.0, 0.5], [0.5, 0.5, 0.5]], + ] + ) + else: + raise ValueError("Image can only have 2 or 3 spacetime dimensions...") + kernel = kernel[numpy.newaxis, numpy.newaxis, :, :] kernel = torch.Tensor(kernel).to(device) kernel = kernel / kernel.sum() - filtered_tensor = torch.nn.functional.conv2d( - tensor, kernel, stride=1, padding=1 + conv_method = ( + torch.nn.functional.conv2d + if len(self.image.shape) == 4 + else torch.nn.functional.conv3d + ) + + filtered_tensor = conv_method( + tensor, + kernel, + padding=1, ) return filtered_tensor * mask + tensor * mask_inv diff --git a/aydin/nn/models/jinet.py b/aydin/nn/models/jinet.py index 375840d3..ddebe107 100644 --- a/aydin/nn/models/jinet.py +++ b/aydin/nn/models/jinet.py @@ -14,7 +14,7 @@ def __init__( nb_dense_layers: int = 3, nb_channels: int = None, final_relu: bool = False, - degressive_residuals: bool = False, # TODO: check what happens when this is True + degressive_residuals: bool = True, # TODO: check what happens when this is True ): super(JINetModel, self).__init__() diff --git a/aydin/nn/models/linear_scaling_unet.py b/aydin/nn/models/linear_scaling_unet.py index bb2dc9d5..d477de27 100644 --- a/aydin/nn/models/linear_scaling_unet.py +++ b/aydin/nn/models/linear_scaling_unet.py @@ -11,7 +11,6 @@ def __init__( spacetime_ndim, nb_unet_levels: int = 4, nb_filters: int = 8, - learning_rate=0.01, pooling_mode: str = 'max', ): super(LinearScalingUNetModel, self).__init__() @@ -19,7 +18,6 @@ def __init__( self.spacetime_ndim = spacetime_ndim self.nb_unet_levels = nb_unet_levels self.nb_filters = nb_filters - self.learning_rate = learning_rate self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode) self.upsampling = nn.Upsample(scale_factor=2, mode='nearest') diff --git a/aydin/nn/models/res_unet.py b/aydin/nn/models/residual_unet.py similarity index 98% rename from aydin/nn/models/res_unet.py rename to aydin/nn/models/residual_unet.py index f82b3c23..74822edf 100644 --- a/aydin/nn/models/res_unet.py +++ b/aydin/nn/models/residual_unet.py @@ -11,7 +11,6 @@ def __init__( spacetime_ndim, nb_unet_levels: int = 4, nb_filters: int = 8, - learning_rate=0.01, pooling_mode: str = 'max', ): super(ResidualUNetModel, self).__init__() @@ -19,7 +18,6 @@ def __init__( self.spacetime_ndim = spacetime_ndim self.nb_unet_levels = nb_unet_levels self.nb_filters = nb_filters - self.learning_rate = learning_rate self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode) self.upsampling = nn.Upsample(scale_factor=2, mode='nearest') diff --git a/aydin/nn/models/test/test_residual_unet.py b/aydin/nn/models/test/test_residual_unet.py index fa36e5d8..36c13a4e 100644 --- a/aydin/nn/models/test/test_residual_unet.py +++ b/aydin/nn/models/test/test_residual_unet.py @@ -2,7 +2,7 @@ import pytest import torch -from aydin.nn.models.res_unet import ResidualUNetModel +from aydin.nn.models.residual_unet import ResidualUNetModel @pytest.mark.parametrize("nb_unet_levels", [2, 3, 5, 8]) diff --git a/aydin/nn/models/test/test_training_models.py b/aydin/nn/models/test/test_training_models.py index 965edf1b..06839dce 100644 --- a/aydin/nn/models/test/test_training_models.py +++ b/aydin/nn/models/test/test_training_models.py @@ -3,7 +3,7 @@ import torch from aydin.io.datasets import add_noise, camera, normalise -from aydin.nn.models.res_unet import ResidualUNetModel +from aydin.nn.models.residual_unet import ResidualUNetModel from aydin.nn.models.unet import UNetModel from aydin.nn.training_methods.n2s import n2s_train from aydin.nn.training_methods.n2t import n2t_train diff --git a/aydin/nn/models/unet.py b/aydin/nn/models/unet.py index c4f280da..ca621f27 100644 --- a/aydin/nn/models/unet.py +++ b/aydin/nn/models/unet.py @@ -11,7 +11,6 @@ def __init__( spacetime_ndim, nb_unet_levels: int = 3, nb_filters: int = 8, - learning_rate=0.01, pooling_mode: str = 'max', ): super(UNetModel, self).__init__() @@ -19,7 +18,6 @@ def __init__( self.spacetime_ndim = spacetime_ndim self.nb_unet_levels = nb_unet_levels self.nb_filters = nb_filters - self.learning_rate = learning_rate self.pooling_down = PoolingDown(spacetime_ndim, pooling_mode) self.upsampling = nn.Upsample(scale_factor=2, mode='nearest') diff --git a/aydin/nn/training_methods/n2s.py b/aydin/nn/training_methods/n2s.py index b6ad3aef..b2d7f68a 100644 --- a/aydin/nn/training_methods/n2s.py +++ b/aydin/nn/training_methods/n2s.py @@ -16,7 +16,7 @@ def n2s_train( nb_epochs: int = 128, lr: float = 0.001, # patch_size: int = 32, - patience: int = 128, + patience: int = 4, verbose: bool = True, ): """ @@ -39,7 +39,7 @@ def n2s_train( model = model.to(device) print(f"device {device}") - optimizer = AdamW(model.parameters(), lr=lr) + optimizer = AdamW(model.parameters(), lr=lr, weight_decay=0) # optimizer = ESAdam( # chain(model.parameters()),