-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathP_train4_4x.py
executable file
·104 lines (91 loc) · 3.7 KB
/
P_train4_4x.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
#!/usr/bin/env python
"""
# > Script for training 8x generative SISR models on USR-248 data
# - Paper: https://arxiv.org/pdf/1909.09437.pdf
#
# Maintainer: Jahid (email: [email protected])
# Interactive Robotics and Vision Lab (http://irvlab.cs.umn.edu/)
# Any part of this repo can be used for academic and educational purposes only
"""
import os
import sys
import datetime
import numpy as np
# keras libs
from keras.optimizers import Adam
from keras.models import Model, model_from_json
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' # less logs
# local libs
from utils.plot_utils import save_val_samples
from utils.data_utils import dataLoaderUSR4, deprocess
from utils.loss_utils import perceptual_distance, total_gen_loss
from utils.progr_utils import update_model_p2, update_fadein
# network
from nets.gen_models import PAL
#############################################################################
## dataset and image information
channels = 3
lr_width, lr_height = 160, 120 # low res
hr_width, hr_height = 640, 480 # high res (10x)
# input and output data
lr_shape = (lr_height, lr_width, channels)
hr_shape = (hr_height, hr_width, channels)
data_loader = dataLoaderUSR4(DATA_PATH="/content/drive/My Drive/USR-248/", SCALE=4)
# training parameters
num_epochs = int(sys.argv[1])
batch_size = 2
sample_interval = 500 # per step
steps_per_epoch = (data_loader.num_train//batch_size)
num_step = num_epochs*steps_per_epoch
#progressive training phase
phase = "p2"
###################################################################################
# load old model
checkpoint_dir = "/content/drive/My Drive/USR/checkpoints/PAL4/p1"
model_h5 = checkpoint_dir+".h5"
model_json = checkpoint_dir+".json"
# sanity
assert (os.path.exists(model_h5) and os.path.exists(model_json))
# load json and create model
with open(model_json, "r") as json_file:
loaded_model_json = json_file.read()
generator = model_from_json(loaded_model_json)
# load weights into the model
generator.load_weights(model_h5)
# update the model with fadein
model = update_model_p2(generator)
print (model.summary())
# checkpoint directory
checkpoint_dir = "/content/drive/My Drive/USR/checkpoints/PAL4/" + phase
## sample directory
samples_dir = os.path.join("/content/drive/My Drive/USR/images/PAL4/", phase)
if not os.path.exists(samples_dir): os.makedirs(samples_dir)
#####################################################################
# compile model
optimizer_ = Adam(0.0002, 0.5)
model.compile(optimizer=optimizer_, loss=total_gen_loss)
print ("\nTraining: {0}".format(phase))
## training pipeline
step, epoch = 0, 0; start_time = datetime.datetime.now()
while (step <= num_step):
for i, (imgs_lr, imgs_hr) in enumerate(data_loader.load_batch(batch_size)):
#update alpha for fadein process
update_fadein(model, step, num_step)
# train the generator
loss_i = model.train_on_batch(imgs_lr, imgs_hr)
# increment step, and show the progress
step += 1; elapsed_time = datetime.datetime.now() - start_time
if (step%10==0):
print ("[Epoch %d: batch %d/%d] [loss_i: %f]"
%(epoch, i+1, steps_per_epoch, loss_i))
## validate and save generated samples at regular intervals
if (step % sample_interval==0):
imgs_lr, imgs_hr = data_loader.load_val_data(batch_size=2)
fake_hr = model.predict(imgs_lr)
gen_imgs = np.concatenate([deprocess(fake_hr), deprocess(imgs_hr)])
save_val_samples(samples_dir, gen_imgs, step)
epoch += 1
with open(checkpoint_dir+".json", "w") as json_file:
json_file.write(model.to_json())
model.save_weights(checkpoint_dir+".h5")
print("\nSaved trained {0} model in {1}\n".format(phase,checkpoint_dir))