-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain_localizer.py
189 lines (144 loc) · 6.91 KB
/
train_localizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""
trains the localizer network
"""
import numpy
import cv2
import time
import tensorflow as tf
import csv
import os
from utils import utils
from utils import input_data
from models import localizer as localizer
flags = tf.app.flags
FLAGS = flags.FLAGS
flags.DEFINE_integer('input_size', 24, 'width and height of the input images')
flags.DEFINE_integer('label_size', 12, 'width and height of the input images')
flags.DEFINE_integer('batch_size', 256, 'training batch size')
flags.DEFINE_integer('max_steps', 3000, 'number of steps to run trainer')
flags.DEFINE_float('learning_rate', 0.00001, 'Initial learning rate.')
flags.DEFINE_float('dropout', 0.75, 'Keep probability for training dropout.')
flags.DEFINE_string('weight_import_path', '../output/checkpoints/classifier', 'path to classifier checkpoint')
flags.DEFINE_string('checkpoint_path','../output/checkpoints/localizer', 'path to checkpoint')
flags.DEFINE_string('log_dir','../output/log/localizer', 'path to log directory')
flags.DEFINE_string('output_file','../output/results/localizer/train.csv', 'path to log directory')
sess = 0
def import_data():
"""
Returns training and evaluation data sets
"""
train_set = input_data.Data((FLAGS.input_size, FLAGS.input_size), (FLAGS.label_size, FLAGS.label_size))
eval_set = input_data.Data((FLAGS.input_size, FLAGS.input_size), (FLAGS.label_size, FLAGS.label_size))
train_set.add_examples("../data/train/8300_positives.png", 8300, 100, None)
train_set.add_labels("../data/train/8300_labels.png", 8300, 100)
train_set.add_examples("../data/train/20038_negatives.png", 20038, 100, numpy.zeros([FLAGS.label_size * FLAGS.label_size]))
eval_set.add_examples("../data/train/eval_1510_positives.png", 1510, 100, None)
eval_set.add_labels("../data/train/eval_1510_labels.png", 1510, 100)
eval_set.add_examples("../data/train/eval_3710_negatives.png", 3710, 100, numpy.zeros([FLAGS.label_size * FLAGS.label_size]))
train_set.finalize()
eval_set.finalize()
utils.print_to_file(FLAGS.output_file, 'training: ' + str(train_set.count))
utils.print_to_file(FLAGS.output_file, 'evaluation: ' + str(eval_set.count))
return train_set, eval_set
# ============================================================= #
def evaluation(step, data_set, eval_op, x, y_, keep_prob):
"""
evaluates current training progress.
Args:
step: current training step
data_set: evaluation data set
eval_op: evaluation operation
x: input data placeholder
y_: desired output placeholder
keep_prob: keep probability placeholder (dropout)
Returns:
mean error per example
"""
error = 0
num_examples = 0
for batch_xs, batch_ys, count in data_set.batches(FLAGS.batch_size):
feed = {x:batch_xs, y_:batch_ys, keep_prob:1.0}
predictions = sess.run(eval_op, feed_dict = feed)
error += numpy.sum(predictions)
num_examples += count
error_mean = float(error) / float(num_examples)
return error_mean
# ============================================================= #
def train_model(model, train_set, eval_set, x, y_, keep_prob):
"""
trains the model
Args:
model: model to train
train_set: training dataset
eval_set: evaluation dataset
x: input data placeholder
y_: desired output placeholder
keep_prob: keep probability placeholder
"""
global sess
# global steps
global_step = tf.Variable(0, trainable=False, name='global_step')
# evaluation ops
with tf.name_scope('test'):
eval_op = tf.nn.l2_loss(tf.sub(model, y_))
# training ops
with tf.name_scope('train'):
loss = localizer.loss(model, y_)
train_step = localizer.train(loss, global_step)
# summary ops
merged_summary = tf.merge_all_summaries()
writer = tf.train.SummaryWriter(FLAGS.log_dir, sess.graph)
# init vars
tf.initialize_all_variables().run(session=sess)
# ---------- import weights from classifier ---------------#
if FLAGS.weight_import_path != None and tf.train.latest_checkpoint(FLAGS.weight_import_path) != None:
localizer.weights_saver().restore(sess, tf.train.latest_checkpoint(FLAGS.weight_import_path))
utils.print_to_file(FLAGS.output_file,'imported weights from classifier')
# ---------- restore model ---------------#
saver = tf.train.Saver()
if tf.train.latest_checkpoint(FLAGS.checkpoint_path) != None:
saver.restore(sess, tf.train.latest_checkpoint(FLAGS.checkpoint_path))
utils.print_to_file(FLAGS.output_file,'step, test_error, train_error')
# ------------- train --------------------#
for i in xrange(FLAGS.max_steps + 1):
# train mini batches
batch_xs, batch_ys = train_set.next_batch(FLAGS.batch_size)
feed = {x:batch_xs, y_:batch_ys, keep_prob:FLAGS.dropout}
_, loss_value = sess.run([train_step, loss], feed_dict = feed)
assert not numpy.isnan(loss_value), 'Model diverged with loss = NaN'
# increment global step count
step = tf.train.global_step(sess, global_step)
# write summary
if step % 100 == 0:
summary_str = sess.run(merged_summary, feed_dict = feed)
writer.add_summary(summary_str, step)
writer.flush()
# evaluation
if step % 100 == 0:
test_error = evaluation(step, eval_set, eval_op, x, y_, keep_prob)
train_error = evaluation(step, train_set, eval_op, x, y_, keep_prob)
utils.print_to_file(FLAGS.output_file,str(step) + ',' + str(test_error) + ',' + str(train_error))
# save model
if step % 1000 == 0 or i == FLAGS.max_steps:
saver.save(sess, FLAGS.checkpoint_path + '/model.ckpt', global_step = step)
# ============================================================= #
def main(_):
global sess
sess = tf.Session()
# ---------- import data ----------------#
train_set, eval_set = import_data()
# ---------- create model ----------------#
# model input placeholder
x = tf.placeholder("float", shape=[None, FLAGS.input_size * FLAGS.input_size])
# desired output placeholder
y_ = tf.placeholder("float", shape=[None, FLAGS.label_size * FLAGS.label_size])
# keep probability placeholder
keep_prob = tf.placeholder("float")
# model
model = localizer.create(x, keep_prob)
utils.print_to_file(FLAGS.output_file,'batch size, learning rate, drop out, image size')
utils.print_to_file(FLAGS.output_file, str(FLAGS.batch_size) + ',' + str(FLAGS.learning_rate) + ',' + str(FLAGS.dropout) + ',' + str(FLAGS.input_size))
# ---------- train model -----------------#
train_model(model, train_set, eval_set, x, y_, keep_prob)
if __name__ == '__main__':
tf.app.run()