-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrain.py
133 lines (99 loc) · 4.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/python
import tensorflow as tf
from config import Config
from model import CaptionGenerator
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string('input_file_pattern', '/home/hillyess/coco_tfrecord/train-?????-of-00256',
'Image feature extracted using faster rcnn and corresponding captions')
tf.flags.DEFINE_string('train_dir', '../output/model',
'Model checkpoints and summary save here')
tf.flags.DEFINE_string('model_file', None,
'If sepcified, load a pretrained model from this file')
tf.flags.DEFINE_string('faster_rcnn_file', None,
'The file containing a pretrained Faster R-CNN model')
tf.flags.DEFINE_string("optimizer", "SGD",
"Adam, RMSProp, Momentum or SGD")
tf.flags.DEFINE_float("initial_learning_rate", "0.001",
"")
tf.flags.DEFINE_float("learning_rate_decay_factor", "0.1",
"")
tf.flags.DEFINE_integer("num_steps_per_decay", "10000",
"")
tf.flags.DEFINE_float("momentum", "0.9",
"")
tf.flags.DEFINE_string("attention", "bias",
"fc1, fc2, bias, bias2, bias_fc1, bias_fc2, rnn")
tf.flags.DEFINE_integer("number_of_steps", 20000,
"Number of training steps.")
# tf.flags.DEFINE_boolean('train_cnn', False,
# 'Turn on to train both CNN and RNN. \
# Otherwise, only RNN is trained')
tf.logging.set_verbosity(tf.logging.INFO)
def main(argv):
config = Config()
config.input_file_pattern = FLAGS.input_file_pattern
config.optimizer = FLAGS.optimizer
config.initial_learning_rate = FLAGS.initial_learning_rate
config.learning_rate_decay_factor = FLAGS.learning_rate_decay_factor
config.num_steps_per_decay = FLAGS.num_steps_per_decay
config.momentum = FLAGS.momentum
config.attention_mechanism = FLAGS.attention
config.save_dir = FLAGS.train_dir
# Create training directory.
train_dir = config.save_dir
if not tf.gfile.IsDirectory(train_dir):
tf.logging.info("Creating training directory: %s", train_dir)
tf.gfile.MakeDirs(train_dir)
# Build the TensorFlow graph.
g = tf.Graph()
with g.as_default():
# Build the model.
model = CaptionGenerator(config, mode="train")
model.build()
if FLAGS.faster_rcnn_file is not None:
init_rcnn_op, input_rcnn = model.load_faster_rcnn_feature_extractor(FLAGS.faster_rcnn_file)
if FLAGS.model_file is not None:
init_rnn_op, input_rnn = model.load_model_except_faster_rcnn(FLAGS.model_file)
def init_fn(sess):
if FLAGS.faster_rcnn_file is not None:
sess.run(init_rcnn_op, input_rcnn)
if FLAGS.model_file is not None:
sess.run(init_rnn_op, input_rnn)
# Set up the Saver for saving and restoring model checkpoints.
saver = tf.train.Saver(max_to_keep=config.max_checkpoints_to_keep)
sess_config = tf.ConfigProto()
sess_config.gpu_options.allow_growth = True
# Run training.
tf.contrib.slim.learning.train(
model.opt_op,
train_dir,
log_every_n_steps=config.log_every_n_steps,
graph=g,
global_step=model.global_step,
number_of_steps=FLAGS.number_of_steps,
summary_op=model.summary,
save_summaries_secs=600,
save_interval_secs=60000,
init_fn=init_fn,
saver=saver,
session_config=sess_config)
# model = CaptionGenerator(config, mode="train")
# model.build()
# with tf.Session() as sess:
# # sess.run(tf.global_variables_initializer())
# sess.run(tf.local_variables_initializer())
# # if model.init_fn:
# # model.init_fn(sess)
# # Start populating the filename queue.
# coord = tf.train.Coordinator()
# threads = tf.train.start_queue_runners(coord=coord)
# for i in range(1):
# # Retrieve a single instance:
# img,cap,mask = sess.run([model.image,model.caption, model.mask])
# print(img,type(img),img.shape)
# print(cap,type(cap),cap.shape)
# print(mask,type(mask),mask.shape)
# coord.request_stop()
# coord.join(threads)
if __name__ == '__main__':
tf.app.run()