-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathres_unet.py
194 lines (151 loc) · 7.17 KB
/
res_unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
import tensorflow as tf
def res_block_initial(x, num_filters, kernel_size, strides, name):
"""Residual Unet block layer for first layer
In the residual unet the first residual block does not contain an
initial batch normalization and activation so we create this separate
block for it.
Args:
x: tensor, image or image activation
num_filters: list, contains the number of filters for each subblock
kernel_size: int, size of the convolutional kernel
strides: list, contains the stride for each subblock convolution
name: name of the layer
Returns:
x1: tensor, output from residual connection of x and x1
"""
if len(num_filters) == 1:
num_filters = [num_filters[0], num_filters[0]]
x1 = tf.keras.layers.Conv2D(filters=num_filters[0],
kernel_size=kernel_size,
strides=strides[0],
padding='same',
name=name+'_1')(x)
x1 = tf.keras.layers.BatchNormalization()(x1)
x1 = tf.keras.layers.Activation('relu')(x1)
x1 = tf.keras.layers.Conv2D(filters=num_filters[1],
kernel_size=kernel_size,
strides=strides[1],
padding='same',
name=name+'_2')(x1)
x = tf.keras.layers.Conv2D(filters=num_filters[-1],
kernel_size=1,
strides=1,
padding='same',
name=name+'_shortcut')(x)
x = tf.keras.layers.BatchNormalization()(x)
x1 = tf.keras.layers.Add()([x, x1])
return x1
def res_block(x, num_filters, kernel_size, strides, name):
"""Residual Unet block layer
Consists of batch norm and relu, folowed by conv, batch norm and relu and
final convolution. The input is then put through
Args:
x: tensor, image or image activation
num_filters: list, contains the number of filters for each subblock
kernel_size: int, size of the convolutional kernel
strides: list, contains the stride for each subblock convolution
name: name of the layer
Returns:
x1: tensor, output from residual connection of x and x1
"""
if len(num_filters) == 1:
num_filters = [num_filters[0], num_filters[0]]
x1 = tf.keras.layers.BatchNormalization()(x)
x1 = tf.keras.layers.Activation('relu')(x1)
x1 = tf.keras.layers.Conv2D(filters=num_filters[0],
kernel_size=kernel_size,
strides=strides[0],
padding='same',
name=name+'_1')(x1)
x1 = tf.keras.layers.BatchNormalization()(x1)
x1 = tf.keras.layers.Activation('relu')(x1)
x1 = tf.keras.layers.Conv2D(filters=num_filters[1],
kernel_size=kernel_size,
strides=strides[1],
padding='same',
name=name+'_2')(x1)
x = tf.keras.layers.Conv2D(filters=num_filters[-1],
kernel_size=1,
strides=strides[0],
padding='same',
name=name+'_shortcut')(x)
x = tf.keras.layers.BatchNormalization()(x)
x1 = tf.keras.layers.Add()([x, x1])
return x1
def upsample(x, target_size):
""""Upsampling function, upsamples the feature map
Deep Residual Unet paper does not describe the upsampling function
in detail. Original Unet uses a transpose convolution that downsamples
the number of feature maps. In order to restrict the number of
parameters here we use a bilinear resampling layer. This results in
the concatentation layer concatenting feature maps with n and n/2
features as opposed to n/2 and n/2 in the original unet.
Args:
x: tensor, feature map
target_size: size to resize feature map to
Returns:
x_resized: tensor, upsampled feature map
"""
x_resized = tf.keras.layers.Lambda(lambda x: tf.image.resize(x, target_size))(x)
return x_resized
def encoder(x, num_filters, kernel_size):
"""Unet encoder
Args:
x: tensor, output from previous layer
num_filters: list, number of filters for each decoder layer
kernel_size: int, size of the convolutional kernel
Returns:
encoder_output: list, output from all encoder layers
"""
x = res_block_initial(x, [num_filters[0]], kernel_size, strides=[1,1], name='layer1')
encoder_output = [x]
for i in range(1, len(num_filters)):
layer = 'encoder_layer' + str(i)
x = res_block(x, [num_filters[i]], kernel_size, strides=[2,1], name=layer)
encoder_output.append(x)
return encoder_output
def decoder(x, encoder_output, num_filters, kernel_size):
"""Unet decoder
Args:
x: tensor, output from previous layer
encoder_output: list, output from all previous encoder layers
num_filters: list, number of filters for each decoder layer
kernel_size: int, size of the convolutional kernel
Returns:
x: tensor, output from last layer of decoder
"""
for i in range(1, len(num_filters) + 1):
layer = 'decoder_layer' + str(i)
target_size = encoder_output[-i].shape[1:3]
x = upsample(x, target_size)
print(x.shape, encoder_output[-i].shape)
x = tf.keras.layers.Concatenate(axis=-1)([x, encoder_output[-i]])
x = res_block(x, [num_filters[-i]], kernel_size, strides=[1,1], name=layer)
return x
def res_unet(input_size, num_filters, kernel_size, num_channels, num_classes):
"""Residual Unet
Function that generates a residual unet
Args:
input_size: int, dimension of the input image
num_layers: int, number of layers in the encoder half, excludes bridge
num_filters: list, number of filters for each encoder layer
kernel_size: size of the kernel, applied to all convolutions
num_channels: int, number of channels for the input image
num_classes: int, number of output classes for the output
Returns:
model: tensorflow keras model for residual unet architecture
"""
x = tf.keras.Input(shape=[input_size, input_size, num_channels])
encoder_output = encoder(x, num_filters, kernel_size)
# bridge layer, number of filters is double that of the last encoder layer
bridge = res_block(encoder_output[-1], [num_filters[-1]*2], kernel_size,
strides=[2,1], name='bridge')
print(encoder_output[-1].shape)
decoder_output = decoder(bridge, encoder_output, num_filters, kernel_size)
output = tf.keras.layers.Conv2D(num_classes,
kernel_size,
strides=1,
padding='same',
name='output')(decoder_output)
model = tf.keras.Model(x, output)
return model