forked from fannn1217/persona.chatbot
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathudc_test.py
42 lines (35 loc) · 1.35 KB
/
udc_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import time
import itertools
import sys
import tensorflow as tf
import udc_model
import udc_hparams
import udc_metrics
import udc_inputs
from models.dual_encoder import dual_encoder_model
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
tf.flags.DEFINE_string("test_file", "./data/persona/test.tfrecords", "Path of test data in TFRecords format")
tf.flags.DEFINE_string("model_dir", "./runs/1542774662", "Directory to load model checkpoints from")
tf.flags.DEFINE_integer("loglevel", 20, "Tensorflow log level")
tf.flags.DEFINE_integer("test_batch_size", 8, "Batch size for testing")
FLAGS = tf.flags.FLAGS
if not FLAGS.model_dir:
print("You must specify a model directory")
sys.exit(1)
tf.logging.set_verbosity(FLAGS.loglevel)
if __name__ == "__main__":
hparams = udc_hparams.create_hparams()
model_fn = udc_model.create_model_fn(hparams, model_impl=dual_encoder_model)
estimator = tf.contrib.learn.Estimator(
model_fn=model_fn,
model_dir=FLAGS.model_dir,
config=tf.contrib.learn.RunConfig())
input_fn_test = udc_inputs.create_input_fn(
mode=tf.contrib.learn.ModeKeys.EVAL,
input_files=[FLAGS.test_file],
batch_size=FLAGS.test_batch_size,
num_epochs=1)
eval_metrics = udc_metrics.create_evaluation_metrics()
estimator.evaluate(input_fn=input_fn_test, steps=None, metrics=eval_metrics)