-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
94 lines (72 loc) · 3.2 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
#!/usr/bin/python
import tensorflow as tf
from config import Config
from model import CaptionGenerator
from dataset import prepare_train_data, prepare_eval_data, prepare_test_data
FLAGS = tf.app.flags.FLAGS
tf.flags.DEFINE_string('phase', 'train',
'The phase can be train, eval or test')
tf.flags.DEFINE_boolean('load', False,
'Turn on to load a pretrained model from either \
the latest checkpoint or a specified file')
tf.flags.DEFINE_string('model_file', None,
'If sepcified, load a pretrained model from this file')
tf.flags.DEFINE_boolean('load_cnn', False,
'Turn on to load a pretrained CNN model')
tf.flags.DEFINE_string('cnn_model_file', './vgg16_no_fc.npy',
'The file containing a pretrained CNN model')
tf.flags.DEFINE_boolean('train_cnn', False,
'Turn on to train both CNN and RNN. \
Otherwise, only RNN is trained')
tf.flags.DEFINE_integer('beam_size', 3,
'The size of beam search for caption generation')
tf.logging.set_verbosity(tf.logging.INFO)
def main(argv):
config = Config()
config.phase = FLAGS.phase
config.train_cnn = FLAGS.train_cnn
config.beam_size = FLAGS.beam_size
model = CaptionGenerator(config)
# model.train()
with tf.Session() as sess:
# sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
# if model.init_fn:
# model.init_fn(sess)
# Start populating the filename queue.
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for i in range(1):
# Retrieve a single instance:
example = sess.run(model.images)
print(example,type(example),example.shape)
coord.request_stop()
coord.join(threads)
# with tf.Session() as sess:
# if FLAGS.phase == 'train':
# # training phase
# # data = prepare_train_data(config)
# model = CaptionGenerator(config)
# sess.run(tf.global_variables_initializer())
# # if FLAGS.load:
# # model.load(sess, FLAGS.model_file)
# # if FLAGS.load_cnn:
# # model.load_cnn(sess, FLAGS.cnn_model_file)
# # tf.get_default_graph().finalize()
# # model.train(sess)
# elif FLAGS.phase == 'eval':
# # evaluation phase
# coco, data, vocabulary = prepare_eval_data(config)
# model = CaptionGenerator(config)
# model.load(sess, FLAGS.model_file)
# tf.get_default_graph().finalize()
# model.eval(sess, coco, data, vocabulary)
# else:
# # testing phase
# data, vocabulary = prepare_test_data(config)
# model = CaptionGenerator(config)
# model.load(sess, FLAGS.model_file)
# tf.get_default_graph().finalize()
# model.test(sess, data, vocabulary)
if __name__ == '__main__':
tf.app.run()