-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain.py
141 lines (105 loc) · 6.41 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
###############################################################################################
# This python code includes the function calls to train the proposed
# DeepDistanceModel and DeepDistanceExtendedModel. This file also includes the
# function to prepare training/validation patches from a set of training images.
## 'patchdata' dictionary should contain the training data as follows:
# for training set : 'x_patches', 'inner_dst_patches', 'outer_dst_patches'
# for validation set : 'x_valid_patches', 'inner_dst_valid_patches', 'outer_dst_valid_patches'
## 'modelpath' should be the desired directory to save the model.
###############################################################################################
import numpy as np
from keras.callbacks import ModelCheckpoint, EarlyStopping
from deepDistanceModels import deepDistanceModel, deepDistanceExtendedModel
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0" #Use gpu-0 only
###############################################################################################
def trainDeepDistanceModel(patchdata, modelpath = './deepDistance.hdf5'):
model = deepDistanceModel(input_height=512, input_width=512)
model.summary()
print('Model is ready !')
checkpointer = ModelCheckpoint(modelpath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
earlystopper = EarlyStopping(patience=100, verbose=0)
hist = model.fit(x = patchdata['x_patches'],
y= [patchdata['inner_dst_patches'], patchdata['outer_dst_patches']],
validation_data = (patchdata['x_valid_patches'], [ patchdata['inner_dst_valid_patches'], patchdata['outer_dst_valid_patches'] ] ),
epochs=300, batch_size=1,
shuffle=True, validation_split=0, callbacks=[checkpointer,earlystopper])
print('Model is trained !')
def trainDeepDistanceExtendedModel(patchdata, modelpath = './deepDistanceExtended.hdf5'):
model = deepDistanceExtendedModel(input_height=512, input_width=512)
model.summary()
print('Model is ready !')
checkpointer = ModelCheckpoint(modelpath, monitor='val_loss', verbose=0, save_best_only=True, save_weights_only=False, mode='auto', period=1)
earlystopper = EarlyStopping(patience=100, verbose=0)
hist = model.fit(x = patchdata['x_patches'],
y= [patchdata['y_patches'], patchdata['inner_dst_patches'], patchdata['outer_dst_patches']],
validation_data = (patchdata['x_valid_patches'], [ patchdata['y_valid_patches'], patchdata['inner_dst_valid_patches'], patchdata['outer_dst_valid_patches'] ] ),
epochs=300, batch_size=1,
shuffle=True, validation_split=0, callbacks=[checkpointer,earlystopper])
print('Model is trained !')
###############################################################################################
def normalizeImg(img):
norm_img = np.zeros(img.shape)
for i in range(3):
norm_img[:,:,i] = (img[:,:,i] - img[:,:,i].mean()) / (img[:,:,i].std())
return norm_img
###############################################################################################
# x_data : The array of input images
# inner_dst_data : The array of the corresponding inner distance maps for the input images
# outer_dst_data : The array of the corresponding normalized outer distance maps for the
# input images
# y_data : The array of segmentation maps (it is necessary only if the extended
# version of the DeepDistance model is used)
#
# The return values of this function can be used to create a dictionary of patches used in
# the trainDeepDistanceModel function.
###############################################################################################
def crop_patches( x_data, inner_dst_data, outer_dst_data, y_data,
patch_height=512, patch_width=512):
i_increment = int(patch_height/2)
j_increment = int(patch_width/2)
#training patches
x_patches = list()
inner_dst_patches = list()
outer_dst_patches = list()
y_patches = list()
for ind in range(x_data.shape[0]):
img = x_data[ind,:,:,:]
inner_dst = inner_dst_data[ind]
outer_dst = outer_dst_data[ind]
y = y_data[ind,:,:]
i = 0
while(i + i_increment < img.shape[0]):
j = 0
while(j + j_increment < img.shape[1]):
if(i +patch_height < img.shape[0] and j +patch_width < img.shape[1]): # normal crop
height_start = i
height_end = i+patch_height
width_start = j
width_end = j+patch_width
elif(j +patch_width < img.shape[1]): # img boundary, crop last possible patch
height_end = img.shape[0]
height_start = height_end - patch_height
width_end = j+patch_width
width_start = j
elif(i +patch_height < img.shape[0]): # img boundary, crop last possible patch
height_start = i
height_end = i+patch_height
width_end = img.shape[1]
width_start = width_end - patch_width
else: # image corner
height_end = img.shape[0]
height_start = height_end - patch_height
width_end = img.shape[1]
width_start = width_end - patch_width
x_patches.append(img[height_start:height_end, width_start:width_end,:])
inner_dst_patches.append(inner_dst[height_start:height_end, width_start:width_end])
outer_dst_patches.append(outer_dst[height_start:height_end, width_start:width_end])
y_patches.append(y[height_start:height_end, width_start:width_end])
j += j_increment
i += i_increment
x_patches = np.asarray(x_patches)
inner_dst_patches = np.asarray(inner_dst_patches)
outer_dst_patches = np.asarray(outer_dst_patches)
y_patches = np.asarray(y_patches)
return [x_patches, inner_dst_patches, outer_dst_patches, y_patches]