-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathrun.py
169 lines (137 loc) · 5.27 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
# -*- coding: utf-8 -*-
"""
Created on Fri Aug 23 17:09:01 2019
@author: HareeshRavi
"""
import configAll
import cnsi
import baseline
import process_vist
import vggfeat_vist
import coherenceVec
import argparse
import time
import json
if __name__ == '__main__':
#argument parsing
parser = argparse.ArgumentParser()
parser.add_argument('--preprocess', type=str, default=None,
help='use this argument to preprocess VIST data')
parser.add_argument('--pretrain', action='store_true', default=False,
help='use this to pre-trainstage 1 of the network')
parser.add_argument('--train', type=str, default=None,
help ='train stage1, cnsi, nsi or baseline')
parser.add_argument('--eval', type=str, default=None,
help ='evaluate cnsi, nsi or baseline')
parser.add_argument('--show', type=str, default=None,
help ='show the story for cnsi, nsi or baseline')
args = parser.parse_args()
# get config
try:
configs = json.load(open('config.json'))
except FileNotFoundError:
configs = configAll.create_config()
'''
To preprocess
'''
if args.preprocess == 'data':
# process vist data jsons and put it according to usage
starttime = time.time()
process_vist.main(configs)
print('vist data files created in {} secs'.format(time.time() -
starttime))
elif args.preprocess == 'imagefeatures':
# extract vgg feats for all images. also remove images (and stories)
# for where images are not present
starttime = time.time()
vggfeat_vist.main(configs)
print('vggfeat extracted for all images in {} secs'.format(
time.time() - starttime))
elif args.preprocess == 'coherencevectors':
# get coherence vector for all stories
# run following command in terminal before running below code
# java -mx4g -cp "*" edu.stanford.nlp.pipeline.StanfordCoreNLPServer
starttime = time.time()
coherenceVec.main(configs)
print('coherence vector extracted for all stories in {} secs'.format(
time.time() - starttime))
elif not args.preprocess:
pass
else:
raise ValueError('preprocess types are data, imagefeatures and ' +
'coherencevectors')
if args.pretrain:
'''
Pretrain stage 1 on MSCOCO dataset
'''
cnsi.main(configs, 'pretrain')
else:
pass
if args.train == 'stage1':
'''
train stage 1 on VIST dataset
'''
cnsi.main(configs, 'trainstage1')
elif args.train == 'cnsi':
'''
To Train CNSI model stage 2
'''
cnsi.main(configs, 'trainstage2', 'cnsi')
elif args.train == 'nsi':
'''
To Train NSI model stage 2
'''
cnsi.main(configs, 'trainstage2', 'nsi')
elif args.train == 'baseline':
'''
To Train baseline model
'''
configs['model'] = 'baseline'
baseline.main(configs, 'train')
elif not args.train:
pass
else:
raise ValueError('args for train can be stage1, cnsi, ' +
'nsi or baseline only')
'''
To evaluate 'model' on VIST test set. This will save predictions in file
for further use by metrics. Will not print or show any results.
'''
if args.eval == 'cnsi':
# get predictions for stories from testsamples for cnsi model
model2test = (configs['savemodel'] + 'stage2_cnsi_' +
configs['date'] + '.h5')
cnsi.main(configs, 'test', 'cnsi', model2test)
elif args.eval == 'nsi':
# get predictions for stories from testsamples for nsi model
model2test = (configs['savemodel'] + 'stage2_nsi_' +
configs['date'] + '.h5')
cnsi.main(configs, 'test', 'nsi', model2test)
elif args.eval == 'baseline':
# get predictions for stories from testsamples for baseline model
baseline.main(configs, 'test')
elif not args.eval:
pass
else:
raise ValueError('args for eval can be cnsi, ' +
'nsi or baseline only')
'''
To evaluate 'model' on VIST test set. This will save predictions in file
for further use by metrics. Will not print or show any results.
'''
if args.show == 'cnsi':
# get predictions for stories from testsamples for cnsi model
results2show = ('results_cnsi_' + configs['date'] + '.pickle')
cnsi.main(configs, 'show', 'cnsi', results2show)
elif args.show == 'nsi':
# get predictions for stories from testsamples for nsi model
model2test = ('results_nsi_' + configs['date'] + '.pickle')
cnsi.main(configs, 'show', 'nsi', model2test)
elif args.show == 'baseline':
# get predictions for stories from testsamples for baseline model
baseline.main(configs, 'show')
elif not args.show:
pass
else:
raise ValueError('args for eval can be cnsi, ' +
'nsi or baseline only')