-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsolution.py
1848 lines (1556 loc) · 75.7 KB
/
solution.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# %% tags=["solution", "task"]
# ruff: noqa: F811
# %% [markdown] tags=[]
# # Build Your Own U-Net
#
# <hr style="height:2px;">
#
# In this notebook, we will implement a U-Net architecture. Through this exercise you should gain an understanding of the U-Net architecture in particular as well as learn how to approach the implementation of an architecture in general and familiarize yourself a bit more with the inner workings of pytorch.
#
# The exercise is split into three parts:
#
# In part 1 you will implement the building blocks of the U-Net. That includes the convolutions, downsampling, upsampling and skip connections. We will go in the order of how difficult they are to implement.
#
# In part 2 you will combine the modules you've built in part 1 to implement the U-Net module.
#
# In part 3 and 4 are light on coding tasks but you will learn about two important concepts: receptive fields and translational equivariance.
#
# Finally, in part 5 you will train your first U-Net of the course! This will just be a first flavor though since you will learn much more about that in the next exercise.
#
#
# Written by Larissa Heinrich, Caroline Malin-Mayor, and Morgan Schwartz, with inspiration from William Patton.
# %% [markdown] tags=[]
# <hr style="height:2px;">
#
# ## The libraries
# %% tags=[]
# %matplotlib inline
import numpy as np
import torch
import subprocess
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import unet_tests
from local import (
NucleiDataset,
apply_and_show_random_image,
plot_receptive_field,
show_random_dataset_image,
pad_to_size,
unnormalize,
)
# %% tags=[]
# make sure gpu is available. Please call a TA if this cell fails
assert torch.cuda.is_available()
# %% [markdown] tags=[]
# ## The Dataset
# For our segmentation exercises, we will be using a nucleus segmentation dataset from [Kaggle 2018 Data Science Bowl](https://www.kaggle.com/c/data-science-bowl-2018/data). We have downloaded the dataset during setup and we provided a pytorch Dataset called `NucleiDataset` which we will use for training later. In addition to training, we will use these images to visualize the output of the individual building blocks of the U-Net we will be implementing.
# Below, we create a dataset and then visualize a random image.
# %% tags=[]
dataset = NucleiDataset("nuclei_train_data")
# %% tags=[]
show_random_dataset_image(dataset)
# %% [markdown] tags=[]
# Rerun the cell above a few times to see different images.
# %% [markdown] tags=[]
# <hr style="height:2px;">
#
# ## The Components of a U-Net
# %% [markdown] tags=[]
# The [U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/) architecture has proven to outperform the other architectures in segmenting biological and medical images. It is also commonly used for other tasks that require the output to be the same resolution as the input, such as style transfer and denoising. Below is an overview figure of the U-Net architecture from the original [paper](https://arxiv.org/pdf/1505.04597.pdf). We will go through each of the components first (hint: all of them can be found in the list of PyTorch modules [here](https://pytorch.org/docs/stable/nn.html#convolution-layers)), and then fit them all together to make our very own U-Net.
#
# <img src="static/unet.png" alt="UNet" style="width: 1500px;"/>
# %% [markdown] tags=[]
# ### Component 1: Upsampling
# %% [markdown] tags=[]
# We will start with the Upsample module that we will use in our U-Net. The right side of the U-Net contains upsampling between the levels. There are many ways to upsample: in the original U-Net, they used a transposed convolution, but this has since fallen a bit out of fashion so we will use the PyTorch Upsample Module [torch.nn.Upsample](https://pytorch.org/docs/stable/generated/torch.nn.Upsample.html#torch.nn.Upsample) instead.
# %% [markdown] tags=[]
# #### Pytorch Modules
# Modules are the building blocks of PyTorch models, and contain lots of magic that makes training models easy. If you aren't familiar with PyTorch modules, take a look at the official documentation [here](https://pytorch.org/docs/stable/notes/modules.html). For our purposes, it is crucial to note how Modules can have submodules defined in the `__init__` function, and how the `forward` function is defined and called.
# %% tags=[]
# Here we make fake input to illustrate the upsampling techniques
# Pytorch expects a batch and channel dimension before the actual data,
# So this simulates a 1D input
sample_1d_input = torch.tensor([[[1, 2, 3, 4]]], dtype=torch.float64)
# And this simulates a 2D input
sample_2d_input = torch.tensor(
[[[[1, 2], [3, 4]]]],
dtype=torch.float64,
)
sample_2d_input.shape, sample_2d_input
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 1: Try out different upsampling techniques</h4>
# <p>For our U-net, we will use the built-in PyTorch Upsample Module. Here we will practice declaring and calling an Upsample module with different parameters.</p>
# <ol>
# <li>Declare an instance of the pytorch Upsample module with <code>scale_factor</code> 2 and mode <code>"nearest"</code>.</li>
# <li>Call the instance of Upsample on the <code>sample_2d_input</code> to see what the nearest mode does.</li>
# <li>Vary the scale factor and mode to see what changes. Check the documentation for possible modes and required input dimensions.</li>
# </ol>
# </div>
# %% tags=["task"]
# TASK 1.1: initialize an upsample module
up = ... # YOUR CODE HERE
# TASK 1.2: apply your upsample module to `sample_2d_input`
# YOUR CODE HERE
# %% tags=["task"]
# TASK 1.3: vary scale factor and mode
# YOUR CODE HERE
# %% tags=["solution"]
# SOLUTION 1.1: initialize an upsample module
up = torch.nn.Upsample(scale_factor=2, mode="nearest")
# SOLUTION 1.2: apply your upsample module to `sample_2d_input`
up(sample_2d_input)
# %% tags=["solution"]
# TASK 1.3: vary scale factor and mode
up3 = torch.nn.Upsample(scale_factor=3, mode="bilinear")
up3(sample_2d_input)
# %% [markdown] tags=[]
# Here is an additional example on image data.
# %% tags=[]
apply_and_show_random_image(up, dataset)
# %% [markdown] tags=[]
# ### Component 2: Downsampling
# %% [markdown] tags=[]
# Between levels of the U-Net on the left side, there is a downsample step. Traditionally, this is done with a 2x2 max pooling operation. There are other ways to downsample, for example with average pooling, but we will stick with max pooling for this exercise.
# %% tags=[]
sample_2d_input = torch.tensor(np.arange(25, dtype=np.float64).reshape((1, 1, 5, 5)))
sample_2d_input = torch.randint(0, 10, (1, 1, 6, 6))
sample_2d_input
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 2A: Try out max pooling</h4>
# <p>Using the docs for <a href=https://pytorch.org/docs/stable/generated/torch.nn.MaxPool2d.html>torch.nn.MaxPool2d</a>,
# try initializing the module and applying it to the sample input. Try varying the parameters to understand the effect of <code>kernel_size</code> and <code>stride</code>.
# </p>
# %% tags=["task"]
# TASK 2A: Initialize max pooling and apply to sample input
# YOUR CODE HERE
# %% tags=["solution"]
# SOLUTION 2A: Initialize max pooling and apply to sample input
max_pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
max_pool(sample_2d_input)
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 2B: Implement a Downsample Module</h4>
# <p>This is very similar to the built in MaxPool2d, but additionally has to check if the downsample factor matches in the input size. Note that we provide the forward function for you - in future Modules, you will implement the forward yourself.</p>
# <ol>
# <li>Declare the submodules you want to use (in this case, <code>torch.nn.MaxPool2d</code> with the correct arguments) in the <code>__init__</code> function. In our Downsample Module, we do not want to use padding and the stride should match the input kernel size.</li>
# <li>Write a function to check if the downsample factor is valid. If the downsample factor does not evenly divide the dimensions of the input to the layer, this function should return False.</li>
# </ol>
# </div>
# %% tags=["task"]
class Downsample(torch.nn.Module):
def __init__(self, downsample_factor: int):
"""Initialize a MaxPool2d module with the input downsample fator"""
super().__init__()
self.downsample_factor = downsample_factor
# TASK 2B1: Initialize the maxpool module
self.down = ... # YOUR CODE HERE
def check_valid(self, image_size: tuple[int, int]) -> bool:
"""Check if the downsample factor evenly divides each image dimension.
Returns `True` for valid image sizes and `False` for invalid image sizes.
Note: there are multiple ways to do this!
"""
# TASK 2B2: Check that the image_size is valid to use with the downsample factor
# YOUR CODE HERE
def forward(self, x):
if not self.check_valid(tuple(x.size()[2:])):
raise RuntimeError(
"Can not downsample shape %s with factor %s"
% (x.size(), self.downsample_factor)
)
return self.down(x)
# %% tags=["solution"]
class Downsample(torch.nn.Module):
def __init__(self, downsample_factor: int):
"""Initialize a MaxPool2d module with the input downsample fator"""
super().__init__()
self.downsample_factor = downsample_factor
# SOLUTION 2B1: Initialize the maxpool module
self.down = torch.nn.MaxPool2d(downsample_factor)
def check_valid(self, image_size: tuple[int, int]) -> bool:
"""Check if the downsample factor evenly divides each image dimension.
Returns `True` for valid image sizes and `False` for invalid image sizes.
Note: there are multiple ways to do this!
"""
# SOLUTION 2B2: Check that the image_size is valid to use with the downsample factor
for dim in image_size:
if dim % self.downsample_factor != 0:
return False
return True
def forward(self, x):
if not self.check_valid(tuple(x.size()[2:])):
raise RuntimeError(
"Can not downsample shape %s with factor %s"
% (x.size(), self.downsample_factor)
)
return self.down(x)
# %% tags=[]
down = Downsample(4)
apply_and_show_random_image(down, dataset)
# %% [markdown] tags=[]
# We wrote some rudimentary tests for each of the torch modules you are writing. If you get an error from your code or an AssertionError from the test, you should probably have another look ath your implementation.
# %% tags=[]
unet_tests.TestDown(Downsample).run()
# %% [markdown] tags=[]
# ### Component 3: Convolution Block
# %% [markdown] tags=[]
# #### Convolution
# A U-Net is a convolutional neural network, which means that the main type of operation is a convolution. Convolutions with defined kernels were covered briefly in the pre-course materials.
#
# <img src="./static/2D_Convolution_Animation.gif" width="400" height="300">
# %% [markdown] tags=[]
# Shown here is a 3x3 kernel being convolved with an input array to get an output array of the same size. For each pixel of the input, the value at that same pixel of the output is computed by multiplying the kernel element-wise with the surrounding 3x3 neighborhood around the input pixel, and then summing the result.
# %% [markdown] tags=[]
# #### Padding
#
# You will notice that at the edges of the input, this animation shows greyed out values that extend past the input. This is known as padding the input. This example uses "same" padding, which means the values at the edges are repeated. The other option we will use in this exercise is "valid" padding, which essentially means no padding. In the case of valid padding, the output will be smaller than the input, as values at the edges of the output will not be computed. "Same" padding can introduce edge artifacts, but "valid" padding reduces the output size at every convolution. Note that the amount of padding (for same) and the amount of size lost (for valid) depends on the size of the kernel - a 3x3 convolution would require padding of 1, a 5x5 convolution would require a padding of 2, and so on.
#
# Additionally, there are different modes of padding that determine what strategy is used to make up the values for padding. In the animation above the mode is re-using values from the border. Even more commonly, the image is simply padded with zeros.
# %% [markdown] tags=[]
# #### ReLU Activation
# The Rectified Linear Unit (ReLU) is a common activation function, which takes the max of a value and 0, shown below. It introduces a non-linearity into the neural network - without a non-linear activation function, a neural network could not learn non-linear functions.
#
# <img src="./static/ReLU.png" width="400" height="300">
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 3: Implement a ConvBlock module</h4>
# <p>The convolution block (ConvBlock) of a standard U-Net has two 3x3 convolutions, each of which is followed by a ReLU activation. Our implementation will handle other sizes of convolutions as well. The first convolution in the block will handle changing the input number of feature maps/channels into the output, and the second convolution will have the same number of feature maps in and out.</p>
# <ol>
# <li>Declare the submodules you want to use in the <code>__init__</code> function. Because you will always be calling four submodules in sequence (<a href=https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d>torch.nn.Conv2d</a>, <a href=https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html#torch.nn.ReLU>torch.nn.ReLU</a>, Conv2d, ReLU), you can use <a href=https://pytorch.org/docs/stable/generated/torch.nn.Sequential.html>torch.nn.Sequential</a> to hold the convolutions and ReLUs.</li>
# <li>Call the modules in the forward function. If you used <code>torch.nn.Sequential</code> in step 1, you only need to call the Sequential module, but if not, you can call the Conv2d and ReLU Modules explicitly.</li>
# </ol>
# </div>
#
# If you get stuck, refer back to the <a href=https://pytorch.org/docs/stable/notes/modules.html>Module</a> documentation for hints and examples of how to define a PyTorch Module.
# %% tags=["task"]
class ConvBlock(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = "same",
):
"""A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.
Args:
in_channels (int): The number of input channels for this conv block. Depends on
the layer and side of the U-Net and the hyperparameters.
out_channels (int): The number of output channels for this conv block. Depends on
the layer and side of the U-Net and the hyperparameters.
kernel_size (int): The size of the kernel. A kernel size of N signifies an
NxN square kernel.
padding (str): The type of convolution padding to use. Either "same" or "valid".
Defaults to "same".
"""
super().__init__()
if kernel_size % 2 == 0:
msg = "Only allowing odd kernel sizes."
raise ValueError(msg)
# TASK 3.1: Initialize your modules and define layers.
# YOUR CODE HERE
for _name, layer in self.named_modules():
if isinstance(layer, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
def forward(self, x):
# TASK 3.2: Apply the modules you defined to the input x
... # YOUR CODE HERE
# %% tags=["solution"]
class ConvBlock(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
padding: str = "same",
):
"""A convolution block for a U-Net. Contains two convolutions, each followed by a ReLU.
Args:
in_channels (int): The number of input channels for this conv block. Depends on
the layer and side of the U-Net and the hyperparameters.
out_channels (int): The number of output channels for this conv block. Depends on
the layer and side of the U-Net and the hyperparameters.
kernel_size (int): The size of the kernel. A kernel size of N signifies an
NxN square kernel.
padding (str): The type of convolution padding to use. Either "same" or "valid".
Defaults to "same".
"""
super().__init__()
if kernel_size % 2 == 0:
msg = "Only allowing odd kernel sizes."
raise ValueError(msg)
# SOLUTION 3.1: Initialize your modules and define layers.
self.conv_pass = torch.nn.Sequential(
torch.nn.Conv2d(
in_channels, out_channels, kernel_size=kernel_size, padding=padding
),
torch.nn.ReLU(),
torch.nn.Conv2d(
out_channels, out_channels, kernel_size=kernel_size, padding=padding
),
torch.nn.ReLU(),
)
for _name, layer in self.named_modules():
if isinstance(layer, torch.nn.Conv2d):
torch.nn.init.kaiming_normal_(layer.weight, nonlinearity="relu")
def forward(self, x):
# SOLUTION 3.2: Apply the modules you defined to the input x
return self.conv_pass(x)
# %% [markdown] tags=[]
# #### Test and Visualize Output of ConvBlock
# Try rerunning the visualization a few times. What do you observe? Can you explain it?
# %% tags=[]
unet_tests.TestConvBlock(ConvBlock).run()
# %% tags=[]
torch.manual_seed(26)
conv = ConvBlock(1, 2, 5, "same")
apply_and_show_random_image(conv, dataset)
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# <h4>Question: Padding</h4>
# As you saw, the convolution modules in pytorch allow you to directly use the keywords `"valid"` or `"same"` for your padding mode. How would you go about calculating the amount of padding you need based on the kernel size?
#
# If you'd like, you can test your assumption by editing the `ConvBlock` to pass your own calculated value to the `padding` keyword in the conv module and rerun the test
# </div>
# %% [markdown] tags=[]
# ### Component 4: Skip Connections and Concatenation
# %% [markdown] tags=[]
# The skip connections between the left and right side of the U-Net are central to successfully obtaining high-resolution output. At each layer, the output of the left conv block is concatenated to the output of the upsample block on the right side from the last layer below. Since upsampling, especially with the "nearest" algorithm, does not actually add high resolution information, the concatenation of the left side conv block output is crucial to generate high resolution segmentations.
#
# If the convolutions in the U-Net are valid, the right side will be smaller than the left side, so the left side output must be cropped before concatenation. We provide a helper function to do this cropping.
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 4: Implement a CropAndConcat module</h4>
# <p>Below, you must implement the <code>forward</code> method, including the cropping (using the provided helper function <code>center_crop</code>) and the concatenation (using <a href=https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat>torch.cat</a>).
# </p>
# Hint: Use the <code>dim</code> keyword argument of <a href=https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat>torch.cat</a> to choose along which axis to concatenate the tensors.
# </p>
# Hint: The tensors have the layout (batch, channel, x, y)
# </div>
# %% tags=["task"]
def center_crop(x, y):
"""Center-crop x to match spatial dimensions given by y."""
x_target_size = x.size()[:2] + y.size()[2:]
offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))
slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
return x[slices]
class CropAndConcat(torch.nn.Module):
def forward(self, encoder_output, upsample_output):
# TASK 4: Implement the forward function
...
# %% tags=["solution"]
def center_crop(x, y):
"""Center-crop x to match spatial dimensions given by y."""
x_target_size = x.size()[:2] + y.size()[2:]
offset = tuple((a - b) // 2 for a, b in zip(x.size(), x_target_size))
slices = tuple(slice(o, o + s) for o, s in zip(offset, x_target_size))
return x[slices]
class CropAndConcat(torch.nn.Module):
def forward(self, encoder_output, upsample_output):
# SOLUTION 4: Implement the forward function
encoder_cropped = center_crop(encoder_output, upsample_output)
return torch.cat([encoder_cropped, upsample_output], dim=1)
# %% tags=[]
unet_tests.TestCropAndConcat(CropAndConcat).run()
# %% [markdown] tags=[]
# ### Component 5: Output Block
# %% [markdown] tags=[]
# The final block we need to write for our U-Net is the output convolution block. The exact format of output you want depends on your task, so our U-Net must be flexible enough to handle different numbers of out channels and different final activation functions.
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 5: Implement an OutputConv Module</h4>
# <ol>
# <li>Define the convolution module in the <code>__init__</code> function. You can use a convolution with kernel size 1 to get the appropriate number of output channels. The activation submodule is provided for you.</li>
# <li>Call the final convolution and activation modules in the <code>forward</code> function</li>
# </ol>
# </div>
# %% tags=["task"]
class OutputConv(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: torch.nn.Module | None = None,
):
"""
A module that uses a convolution with kernel size 1 to get the appropriate
number of output channels, and then optionally applies a final activation.
Args:
in_channels (int): The number of feature maps that will be input to the
OutputConv block.
out_channels (int): The number of channels that you want in the output
activation (str | None, optional): Accepts the name of any torch activation
function (e.g., ``ReLU`` for ``torch.nn.ReLU``) or None for no final
activation. Defaults to None.
"""
super().__init__()
# TASK 5.1: Define the convolution submodule
# YOUR CODE HERE
self.activation = activation
def forward(self, x):
# TASK 5.2: Implement the forward function
# YOUR CODE HERE
...
# %% tags=["solution"]
class OutputConv(torch.nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
activation: torch.nn.Module | None = None,
):
"""
A module that uses a convolution with kernel size 1 to get the appropriate
number of output channels, and then optionally applies a final activation.
Args:
in_channels (int): The number of feature maps that will be input to the
OutputConv block.
out_channels (int): The number of channels that you want in the output
activation (str | None, optional): Accepts the name of any torch activation
function (e.g., ``ReLU`` for ``torch.nn.ReLU``) or None for no final
activation. Defaults to None.
"""
super().__init__()
# SOLUTION 5.1: Define the convolution submodule
self.final_conv = torch.nn.Conv2d(in_channels, out_channels, 1, padding=0)
self.activation = activation
def forward(self, x):
# SOLUTION 5.2: Implement the forward function
x = self.final_conv(x)
if self.activation is not None:
x = self.activation(x)
return x
# %% tags=[]
unet_tests.TestOutputConv(OutputConv).run()
# %% tags=[]
out_conv = OutputConv(in_channels=1, out_channels=1, activation=torch.nn.ReLU())
apply_and_show_random_image(out_conv, dataset)
# %% [markdown] tags=[]
# <div class="alert alert-block alert-success">
# <h2>Checkpoint 1</h2>
#
# Congratulations! You have implemented most of a U-Net!
# We will go over this portion together and answer any questions soon, but feel free to start on the next section where we put it all together.
#
# </div>
#
# <hr style="height:2px;">
# %% [markdown] tags=[]
# ## Putting the U-Net together
#
# Now we will make a U-Net class that combines all of these components as shown in the image. This image shows a U-Net of depth 4 with specific input channels, feature maps, upsampling, and final activation. Ours will be configurable with regards to depth and other features.
#
# <img src="static/unet.png" alt="UNet" style="width: 1500px;"/>
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 6: U-Net Implementation</h4>
# <p>Now we will implement our U-Net! We have written some of it for you - follow the steps below to fill in the missing parts.</p>
# <ol>
# <li>Write the helper functions <code>compute_fmaps_encoder</code> and <code>compute_fmaps_decoder</code> that compute the number of input and output feature maps at each level of the U-Net.</li>
# <li>Declare a list of encoder (left) and decoder (right) ConvPasses depending on your depth using the helper functions you wrote above. Consider the special case at the bottom of the U-Net carefully!</li>
# <li>Declare an Upsample, Downsample, CropAndConcat, and OutputConv block.</li>
# <li>Implement the <code>forward</code> function, applying the modules you declared above in the proper order.</li>
# </ol>
# </div>
# %% tags=["task"]
class UNet(torch.nn.Module):
def __init__(
self,
depth: int,
in_channels: int,
out_channels: int = 1,
final_activation: torch.nn.Module | None = None,
num_fmaps: int = 64,
fmap_inc_factor: int = 2,
downsample_factor: int = 2,
kernel_size: int = 3,
padding: str = "same",
upsample_mode: str = "nearest",
):
"""A U-Net for 2D input that expects tensors shaped like::
``(batch, channels, height, width)``.
Args:
depth:
The number of levels in the U-Net. 2 is the smallest that really
makes sense for the U-Net architecture, as a one layer U-Net is
basically just 2 conv blocks.
in_channels:
The number of input channels in your dataset.
out_channels (optional):
How many output channels you want. Depends on your task. Defaults to 1.
final_activation (optional):
What activation to use in your final output block. Depends on your task.
Defaults to None.
num_fmaps (optional):
The number of feature maps in the first layer. Defaults to 64.
fmap_inc_factor (optional):
By how much to multiply the number of feature maps between
layers. Encoder layer ``l`` will have ``num_fmaps*fmap_inc_factor**l``
output feature maps. Defaults to 2.
downsample_factor (optional):
Factor to use for down- and up-sampling the feature maps between layers.
Defaults to 2.
kernel_size (optional):
Kernel size to use in convolutions on both sides of the UNet.
Defaults to 3.
padding (optional):
How to pad convolutions. Either 'same' or 'valid'. Defaults to "same."
upsample_mode (optional):
The upsampling mode to pass to torch.nn.Upsample. Usually "nearest"
or "bilinear." Defaults to "nearest."
"""
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.out_channels = out_channels
self.final_activation = final_activation
self.num_fmaps = num_fmaps
self.fmap_inc_factor = fmap_inc_factor
self.downsample_factor = downsample_factor
self.kernel_size = kernel_size
self.padding = padding
self.upsample_mode = upsample_mode
# left convolutional passes
self.left_convs = torch.nn.ModuleList()
# TASK 6.2A: Initialize list here
# right convolutional passes
self.right_convs = torch.nn.ModuleList()
# TASK 6.2B: Initialize list here
# TASK 6.3: Initialize other modules here
def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
"""Compute the number of input and output feature maps for
a conv block at a given level of the UNet encoder (left side).
Args:
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Output (tuple[int, int]): The number of input and output feature maps
of the encoder convolutional pass in the given level.
"""
# TASK 6.1A: Implement this function
pass
def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
"""Compute the number of input and output feature maps for a conv block
at a given level of the UNet decoder (right side). Note:
The bottom layer (depth - 1) is considered an "encoder" conv pass,
so this function is only valid up to depth - 2.
Args:
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Output (tuple[int, int]): The number of input and output feature maps
of the encoder convolutional pass in the given level.
"""
# TASK 6.1B: Implement this function
pass
def forward(self, x):
# left side
# Hint - you will need the outputs of each convolutional block in the encoder for the skip connection, so you need to hold on to those output tensors
for i in range(self.depth - 1):
# TASK 6.4A: Implement encoder here
...
# bottom
# TASK 6.4B: Implement bottom of U-Net here
# right
for i in range(0, self.depth - 1)[::-1]:
# TASK 6.4C: Implement decoder here
...
# TASK 6.4D: Apply the final convolution and return the output
return
# %% tags=["solution"]
class UNet(torch.nn.Module):
def __init__(
self,
depth: int,
in_channels: int,
out_channels: int = 1,
final_activation: torch.nn.Module | None = None,
num_fmaps: int = 64,
fmap_inc_factor: int = 2,
downsample_factor: int = 2,
kernel_size: int = 3,
padding: str = "same",
upsample_mode: str = "nearest",
):
"""A U-Net for 2D input that expects tensors shaped like::
``(batch, channels, height, width)``.
Args:
depth:
The number of levels in the U-Net. 2 is the smallest that really
makes sense for the U-Net architecture, as a one layer U-Net is
basically just 2 conv blocks.
in_channels:
The number of input channels in your dataset.
out_channels (optional):
How many output channels you want. Depends on your task. Defaults to 1.
final_activation (optional):
What activation to use in your final output block. Depends on your task.
Defaults to None.
num_fmaps (optional):
The number of feature maps in the first layer. Defaults to 64.
fmap_inc_factor (optional):
By how much to multiply the number of feature maps between
layers. Encoder layer ``l`` will have ``num_fmaps*fmap_inc_factor**l``
output feature maps. Defaults to 2.
downsample_factor (optional):
Factor to use for down- and up-sampling the feature maps between layers.
Defaults to 2.
kernel_size (optional):
Kernel size to use in convolutions on both sides of the UNet.
Defaults to 3.
padding (optional):
How to pad convolutions. Either 'same' or 'valid'. Defaults to "same."
upsample_mode (optional):
The upsampling mode to pass to torch.nn.Upsample. Usually "nearest"
or "bilinear." Defaults to "nearest."
"""
super().__init__()
self.depth = depth
self.in_channels = in_channels
self.out_channels = out_channels
self.final_activation = final_activation
self.num_fmaps = num_fmaps
self.fmap_inc_factor = fmap_inc_factor
self.downsample_factor = downsample_factor
self.kernel_size = kernel_size
self.padding = padding
self.upsample_mode = upsample_mode
# left convolutional passes
self.left_convs = torch.nn.ModuleList()
# SOLUTION 6.2A: Initialize list here
for level in range(self.depth):
fmaps_in, fmaps_out = self.compute_fmaps_encoder(level)
self.left_convs.append(
ConvBlock(fmaps_in, fmaps_out, self.kernel_size, self.padding)
)
# right convolutional passes
self.right_convs = torch.nn.ModuleList()
# SOLUTION 6.2B: Initialize list here
for level in range(self.depth - 1):
fmaps_in, fmaps_out = self.compute_fmaps_decoder(level)
self.right_convs.append(
ConvBlock(
fmaps_in,
fmaps_out,
self.kernel_size,
self.padding,
)
)
# SOLUTION 6.3: Initialize other modules here
self.downsample = Downsample(self.downsample_factor)
self.upsample = torch.nn.Upsample(
scale_factor=self.downsample_factor,
mode=self.upsample_mode,
)
self.crop_and_concat = CropAndConcat()
self.final_conv = OutputConv(
self.compute_fmaps_decoder(0)[1], self.out_channels, self.final_activation
)
def compute_fmaps_encoder(self, level: int) -> tuple[int, int]:
"""Compute the number of input and output feature maps for
a conv block at a given level of the UNet encoder (left side).
Args:
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Output (tuple[int, int]): The number of input and output feature maps
of the encoder convolutional pass in the given level.
"""
# SOLUTION 6.1A: Implement this function
if level == 0:
fmaps_in = self.in_channels
else:
fmaps_in = self.num_fmaps * self.fmap_inc_factor ** (level - 1)
fmaps_out = self.num_fmaps * self.fmap_inc_factor**level
return fmaps_in, fmaps_out
def compute_fmaps_decoder(self, level: int) -> tuple[int, int]:
"""Compute the number of input and output feature maps for a conv block
at a given level of the UNet decoder (right side). Note:
The bottom layer (depth - 1) is considered an "encoder" conv pass,
so this function is only valid up to depth - 2.
Args:
level (int): The level of the U-Net which we are computing
the feature maps for. Level 0 is the input level, level 1 is
the first downsampled layer, and level=depth - 1 is the bottom layer.
Output (tuple[int, int]): The number of input and output feature maps
of the decoder convolutional pass in the given level.
"""
# SOLUTION 6.1B: Implement this function
fmaps_out = self.num_fmaps * self.fmap_inc_factor ** (level)
concat_fmaps = self.compute_fmaps_encoder(level)[
1
] # The channels that come from the skip connection
fmaps_in = concat_fmaps + self.num_fmaps * self.fmap_inc_factor ** (level + 1)
return fmaps_in, fmaps_out
def forward(self, x):
# left side
convolution_outputs = []
layer_input = x
for i in range(self.depth - 1):
# SOLUTION 6.4A: Implement encoder here
conv_out = self.left_convs[i](layer_input)
convolution_outputs.append(conv_out)
downsampled = self.downsample(conv_out)
layer_input = downsampled
# bottom
# SOLUTION 6.4B: Implement bottom of U-Net here
conv_out = self.left_convs[-1](layer_input)
layer_input = conv_out
# right
for i in range(0, self.depth - 1)[::-1]:
# SOLUTION 6.4C: Implement decoder here
upsampled = self.upsample(layer_input)
concat = self.crop_and_concat(convolution_outputs[i], upsampled)
conv_output = self.right_convs[i](concat)
layer_input = conv_output
# SOLUTION 6.4D: Apply the final convolution and return the output
return self.final_conv(layer_input)
# %% [markdown] tags=[]
# Below we declare a very simple U-Net and then apply it to a random image. Because we have not trained the U-Net the output should look similar to the output of random convolutions. If you get errors here, go back and fix your U-Net implementation!
# %% tags=[]
unet_tests.TestUNet(UNet).run()
# %% tags=[]
simple_net = UNet(depth=2, in_channels=1)
# %% tags=[]
apply_and_show_random_image(simple_net, dataset)
# %% [markdown] tags=[]
# <div class="alert alert-block alert-success">
# <h2>Checkpoint 2</h2>
#
# Congratulations! You have implemented a UNet architecture.
#
# Next we'll learn about receptive fields which should demistify how to choose the UNet hyperparameters a little bit.
#
# %% [markdown] tags=[]
# <hr style="height:2px;">
# %% [markdown] tags=[]
# ## Receptive Field
#
# The receptive field of an output value is the set of input values that can change the output value. The size of the receptive field is an important property of network architectures for image processing. Let's consider the receptive field size of the U-Net building blocks.
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# <h4>Question: Receptive Field Size</h4>
# What are the receptive field sizes of the following operations?
#
# 1. <code>torch.nn.Conv2d(1, 5, 3)</code>
# 2. <code>torch.nn.Sequential(torch.nn.Conv2d(1, 5, 3), torch.nn.Conv2d(5,5,3))</code>
# 3. <code>torch.nn.Sequential(torch.nn.Conv2d(1, 5, 3), torch.nn.Conv2d(5,5,5))</code>
# 4. <code>Downsample(3)</code>
# 5. <code>torch.nn.Sequential(ConvBlock(1, 5, 3), Downsample(2), ConvBlock(5,5,3)</code>
# 6. <code>torch.nn.Upsample(2)</code>
# 7. <code>torch.nn.Sequential(ConvBlock(1,5,3), Upsample(2), ConvBlock(5,5,3))</code>
# 8. <code>torch.nn.Sequential(ConvBlock(1,5,3), Downsample(3), ConvBlock(5,5,3), Upsample(3), ConvBlock(5,5,3))</code>
# 9. <code>UNet(depth=2, in_channels=1, downsample_factor=3, kernel_size=3)</code>
#
#
# </div>
# %% [markdown] tags=[]
# <div class="alert alert-block alert-info">
# <h4>Task 7: Receptive Field</h4>
# <p>The <code>plot_receptive_field</code> function visualizes the receptive field of a given U-Net - the square shows how many input pixels contribute to the output at the center pixel. Try it out with different U-Nets to get a sense of how varying the depth, kernel size, and downsample factor affect the receptive field of a U-Net.</p>
# </div>
# %% tags=["task"]
new_net = ... # TASK 7: declare your U-Net here
if isinstance(new_net, UNet):
plot_receptive_field(new_net)
# %% tags=["solution"]
# SOLUTION 7: declare your U-Net here
new_net = UNet(
depth=2,
in_channels=1,
downsample_factor=2,
kernel_size=3,
)
if isinstance(new_net, UNet):
plot_receptive_field(new_net)
# %% [markdown] tags=[]
# <div class="alert alert-block alert-success">
# <h2>Checkpoint 3</h2>
#
# Looking at the receptive field of your network is one of the most important aspects to consider when you choose the hyperparameters for your network.
#
# Questions to consider:
# <ol>
# <li>Which hyperparameter of the U-Net has most effect on the receptive field?</li>
# <li>If two sets of hyperparameters result in the same receptive field size are those networks equally good choices?</li>
# <li>What's the relation between the receptive field size and the size of images you feed to your UNet?</li>
# <li>Do you see a connection between padding and receptive field size?</li>
# <li>For each hyperparameter: Can you think of scenarios in which you would consider changing this parameter? Why? </li>
# </ol>
#
# </div>
#
# <hr style="height:2px;">
# %% [markdown] tags=[]
# ## Translational equivariance
#
# Depending on the task you're trying to solve you may care about translational (shift) invariance or equivariance.
#
# Let's first define what these invariance and equivariance mean in mathematical notation.
#
# Let $T$ be a transformation and $F$ the function whose properties we're considering.
#
# $F$ is invariant under transformation $T$ if: $F(T(x)) = F(x)$. The output of the function remains the same whether the input was transformed or not.
#
# $F$ is equivariant under transformation $T$ if: $F(T(x)) = T(F(x))$. Applying the function on the transformed input is the same as applying the transformation on the output of the original input.
# %% [markdown] tags=[]
# If math isn't your thing hopefully this picture helps to convey the concept, now specifically for translations.
# %% [markdown] tags=[]
# <img src="static/equivariance.png" alt="Invariance and Equivariance" style="width: 1500px;"/>
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# <h4>Question: Translational invariance and equivariance</h4>
# For what types of deep learning tasks would you want your network to be translationally invariant and equivariant, respectively? Where does the U-Net fit in?
# </div>
# %% [markdown] tags=[]
# <div class="alert alert-warning">
#
# <h4>Question: Translational properties of U-Net building blocks</h4>
# For each of these building blocks of the U-Net: Is it translationally equivariant or invariant?
# <ol>
# <li>ConvBlock</li>
# <li>Downsample</li>
# <li>Upsample</li>
# </ol>