-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathAAE.py
174 lines (131 loc) · 9.83 KB
/
AAE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# -*- coding: utf-8 -*-
'''
Adversarial Autoencoder. Makhzani et al. 2015
Use this code with no warranty and please respect the accompanying license.
'''
import sys
sys.path.append('../common')
from tools_config import data_dir, expr_dir
import os, sys, shutil
import matplotlib.pyplot as plt
from tools_train import get_train_params, OneHot, vis_square, count_model_params
from datetime import datetime
from tools_general import tf, np
from tools_networks import deconv, conv, dense, clipped_crossentropy, dropout
import logging
def create_encoder(Xin, is_training, latentD, reuse=False, networktype='cdaeE'):
'''Xin: batchsize * H * W * Cin
output1-2: batchsize * Cout'''
with tf.variable_scope(networktype, reuse=reuse):
Xout = conv(Xin, is_training, kernel_w=4, stride=2, Cout=64, pad=1, act='reLu', norm='batchnorm', name='conv1') # 14*14
Xout = conv(Xout, is_training, kernel_w=4, stride=2, Cout=128, pad=1, act='reLu', norm='batchnorm', name='conv2') # 7*7
Xout = dense(Xout, is_training, Cout=latentD, act=None, norm=None, name='dense_mean')
return Xout
def create_decoder(Xin, is_training, latentD, Cout=1, reuse=False, networktype='vaeD'):
with tf.variable_scope(networktype, reuse=reuse):
Xout = dense(Xin, is_training, Cout=7 * 7 * 256, act='reLu', norm='batchnorm', name='dense1')
Xout = tf.reshape(Xout, shape=[-1, 7, 7, 256]) # 7
Xout = deconv(Xout, is_training, kernel_w=4, stride=2, Cout=256, epf=2, act='reLu', norm='batchnorm', name='deconv1') # 14
Xout = deconv(Xout, is_training, kernel_w=4, stride=2, Cout=Cout, epf=2, act=None, norm=None, name='deconv2') # 28
Xout = tf.nn.sigmoid(Xout)
return Xout
def create_discriminator(Xin, is_training, reuse=False, networktype='ganD'):
with tf.variable_scope(networktype, reuse=reuse):
Xout = dense(Xin, is_training, Cout=7 * 7 * 256, act='reLu', norm='batchnorm', name='dense1')
Xout = tf.reshape(Xout, shape=[-1, 7, 7, 256]) # 7
Xout = conv(Xout, is_training, kernel_w=3, stride=1, pad=1, Cout=128, act='lrelu', norm='batchnorm', name='conv1') # 7
Xout = conv(Xout, is_training, kernel_w=3, stride=1, pad=1, Cout=256 , act='lrelu', norm='batchnorm', name='conv2') # 7
Xout = conv(Xout, is_training, kernel_w=3, stride=1, pad=None, Cout=1, act=None, norm='batchnorm', name='conv3') # 5
Xout = tf.nn.sigmoid(Xout)
return Xout
def create_aae_trainer(base_lr=1e-4, latentD=2, networktype='AAE'):
'''Train an Adversarial Autoencoder'''
is_training = tf.placeholder(tf.bool, [], 'is_training')
Zph = tf.placeholder(tf.float32, [None, latentD])
Xph = tf.placeholder(tf.float32, [None, 28, 28, 1])
Xc_op = tf.cond(is_training, lambda: tf.nn.dropout(Xph, keep_prob=0.75), lambda: tf.identity(Xph))
Z_op = create_encoder(Xc_op, is_training, latentD, reuse=False, networktype=networktype + '_Enc')
Xrec_op = create_decoder(Z_op, is_training, latentD, reuse=False, networktype=networktype + '_Dec')
Xgen_op = create_decoder(Zph, is_training, latentD, reuse=True, networktype=networktype + '_Dec')
fakeLogits = create_discriminator(Z_op, is_training, reuse=False, networktype=networktype + '_Dis')
realLogits = create_discriminator(Zph, is_training, reuse=True, networktype=networktype + '_Dis')
# reconstruction loss
rec_loss_op = tf.reduce_mean(tf.reduce_sum(tf.square(tf.subtract(Xph, Xrec_op)), reduction_indices=[1, 2, 3]))
# regularization loss
dec_loss_op = rec_loss_op
enc_rec_loss_op = clipped_crossentropy(fakeLogits, tf.ones_like(fakeLogits)) + 10 * rec_loss_op
enc_gen_loss_op = clipped_crossentropy(fakeLogits, tf.ones_like(fakeLogits)) + 0.1*rec_loss_op
dis_loss_op = clipped_crossentropy(fakeLogits, tf.zeros_like(fakeLogits)) + clipped_crossentropy(realLogits, tf.ones_like(realLogits))
enc_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Enc')
dec_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dec')
dis_varlist = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope=networktype + '_Dis')
train_dec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize(dec_loss_op, var_list=dec_varlist)
train_enc_rec_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize(enc_rec_loss_op, var_list=enc_varlist)
train_enc_gen_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize(enc_gen_loss_op, var_list=enc_varlist)
train_dis_op = tf.train.AdamOptimizer(learning_rate=1.0 * base_lr, beta1=0.5).minimize(dis_loss_op, var_list=dis_varlist)
logging.info('Total Trainable Variables Count in Encoder %2.3f M, Decoder: %2.3f M, and Discriminator: %2.3f'
% (count_model_params(enc_varlist) * 1e-6, count_model_params(dec_varlist) * 1e-6, count_model_params(dis_varlist) * 1e-6))
return train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op
if __name__ == '__main__':
exp_id = 6
networktype = 'AAE_MNIST'
batch_size = 128
base_lr = 2e-4
epochs = 400
latentD = 2
work_dir = expr_dir + '%s/%.2d/' % (networktype, exp_id)
if not os.path.exists(work_dir): os.makedirs(work_dir)
else: raise ValueError('Experiment folder already exists. You probably wnt to change the experiment ID.')
starttime = datetime.now().replace(microsecond=0)
log_name = datetime.strftime(starttime, '%Y%m%d_%H%M')
logging.basicConfig(filename=work_dir + '%s.log' % log_name, level=logging.DEBUG, format='%(asctime)s :: %(message)s', datefmt='%Y%m%d-%H%M%S')
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
logging.getLogger('').addHandler(console)
logging.info('Started Training of %s at %s' % (networktype, datetime.strftime(starttime, '%Y-%m-%d_%H:%M:%S')))
logging.info('\nTraining Hyperparamters: batch_size= %d, base_lr= %1.1e, epochs= %d, latentD= %d\n' % (batch_size, base_lr, epochs, latentD))
shutil.copy2(os.path.basename(sys.argv[0]), work_dir)
data, max_iter, test_iter, test_int, disp_int = get_train_params(data_dir, batch_size, epochs=epochs, test_in_each_epoch=1, networktype=networktype)
test_int = test_int * 3
tf.reset_default_graph()
sess = tf.InteractiveSession()
train_dec_op, train_dis_op, train_enc_gen_op, train_enc_rec_op, rec_loss_op, dis_loss_op, enc_gen_loss_op, is_training, Zph, Xph, Xrec_op, Xgen_op = \
create_aae_trainer(base_lr, latentD, networktype)
tf.global_variables_initializer().run()
var_list = [var for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) if (networktype.lower() in var.name.lower()) and ('adam' not in var.name.lower())]
saver = tf.train.Saver(var_list=var_list, max_to_keep=int(epochs * .1))
# saver.restore(sess, expr_dir + 'ganMNIST/20170707/214_model.ckpt')
best_test_loss = np.ones([3, ]) * np.inf
train_loss = np.zeros([max_iter, 3])
test_loss = np.zeros([int(np.ceil(max_iter / test_int)), 3])
for it in range(max_iter):
X, _ = data.train.next_batch(batch_size)
Z = np.random.normal(size=[batch_size, latentD], loc=0.0, scale=1.).astype(np.float32)
# 1- Train the Encoder and the Decoder for reconstructing the input
sess.run([train_enc_rec_op, train_dec_op], feed_dict={Xph:X, is_training:True})
# 2- Train the Discriminator
dis_loss, _ = sess.run([dis_loss_op, train_dis_op], feed_dict={Xph:X, Zph:Z, is_training:True})
# 3 - Train the Generator (Encoder)
enc_loss, rec_loss, _ = sess.run([enc_gen_loss_op, rec_loss_op, train_enc_gen_op, ], feed_dict={Xph:X, is_training:True})
if it % test_int == 0: # Record summaries and test-set accuracy
acc_loss = np.zeros([1, 3])
for i_test in range(test_iter):
X, _ = data.test.next_batch(batch_size)
resloss = sess.run([rec_loss_op, dis_loss_op, enc_gen_loss_op], feed_dict={Xph:X, Zph: Z, is_training:False})
acc_loss = np.add(acc_loss, resloss)
test_loss[it // test_int] = np.divide(acc_loss, test_iter)
logging.info("Epoch %4d, Iteration #%4d, Test Loss [rec| dis| enc] = [%s]" % (data.train.epochs_completed, it, ' | '.join(['%2.5f' % a for a in test_loss[it // test_int]])))
if test_loss[it // test_int, 0] < best_test_loss[0]:
best_test_loss = test_loss[it // test_int]
logging.info("### Best Test Results Yet. Test Loss [rec| dis| enc] = [%s]" % (' | '.join(['%2.5f' % a for a in test_loss[it // test_int]])))
rec_sample = sess.run(Xrec_op, feed_dict={Xph:X, is_training:False})
vis_square(rec_sample[:121], [11, 11], save_path=work_dir + 'Rec_Iter_%d.jpg' % it)
gen_sample = sess.run(Xgen_op, feed_dict={Zph:Z, is_training:False})
vis_square(gen_sample[:121], [11, 11], save_path=work_dir + 'Gen_Iter_%d.jpg' % it)
saver.save(sess, work_dir + "Model_Iter_%.3d.ckpt" % it)
train_loss[it] = [rec_loss, dis_loss, enc_loss]
# if it % disp_int == 0:
# logging.info("Epoch %4d, Iteration #%4d, Train Loss [rec| dis| enc] = [%s]" % (data.train.epochs_completed, it, ' | '.join(['%2.5f' % a for a in train_loss[it]])))
endtime = datetime.now().replace(microsecond=0)
logging.info('Finished Training of %s at %s' % (networktype, datetime.strftime(endtime, '%Y-%m-%d_%H:%M:%S')))
logging.info('Training done in %s ! Best Test Loss [rec| dis| enc] = [%s]' % (endtime - starttime, ' | '.join(['%2.5f' % a for a in best_test_loss])))