forked from NVlabs/few-shot-vid2vid
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
53 lines (46 loc) · 1.92 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
#
# This work is made available
# under the Nvidia Source Code License (1-way Commercial).
# To view a copy of this license, visit
# https://nvlabs.github.io/few-shot-vid2vid/License.txt
import os
import numpy as np
import torch
import cv2
from collections import OrderedDict
from options.test_options import TestOptions
from data.data_loader import CreateDataLoader
from models.models import create_model
import util.util as util
from util.visualizer import Visualizer
from util import html
opt = TestOptions().parse()
### setup dataset
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
### setup models
model = create_model(opt)
model.eval()
visualizer = Visualizer(opt)
# create website
web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
if opt.finetune: web_dir += '_finetune'
webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch), infer=True)
# test
for i, data in enumerate(dataset):
if i >= opt.how_many or i >= len(dataset): break
img_path = data['path']
data_list = [data['tgt_label'], data['tgt_image'], None, None, data['ref_label'], data['ref_image'], None, None]
synthesized_image, _, _, _, _, _ = model(data_list)
synthesized_image = util.tensor2im(synthesized_image)
tgt_image = util.tensor2im(data['tgt_image'])
ref_image = util.tensor2im(data['ref_image'], tile=True)
seq = data['seq'][0]
visual_list = [ref_image, tgt_image, synthesized_image]
visuals = OrderedDict([(seq, np.hstack(visual_list)),
(seq + '/synthesized', synthesized_image),
(seq + '/ref_image', ref_image if i == 0 else None),
])
print('process image... %s' % img_path)
visualizer.save_images(webpage, visuals, img_path)