-
Notifications
You must be signed in to change notification settings - Fork 174
/
Copy pathvime_semi.py
196 lines (152 loc) · 6.16 KB
/
vime_semi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain (VIME) Codebase.
Reference: Jinsung Yoon, Yao Zhang, James Jordon, Mihaela van der Schaar,
"VIME: Extending the Success of Self- and Semi-supervised Learning to Tabular Domain,"
Neural Information Processing Systems (NeurIPS), 2020.
Paper link: TBD
Last updated Date: October 11th 2020
Code author: Jinsung Yoon ([email protected])
-----------------------------
vime_semi.py
- Semi-supervised learning parts of the VIME framework
- Using both labeled and unlabeled data to train the predictor with the help of trained encoder
"""
# Necessary packages
import keras
import numpy as np
import tensorflow as tf
from tensorflow.contrib import layers as contrib_layers
from vime_utils import mask_generator, pretext_generator
def vime_semi(x_train, y_train, x_unlab, x_test, parameters,
p_m, K, beta, file_name):
"""Semi-supervied learning part in VIME.
Args:
- x_train, y_train: training dataset
- x_unlab: unlabeled dataset
- x_test: testing features
- parameters: network parameters (hidden_dim, batch_size, iterations)
- p_m: corruption probability
- K: number of augmented samples
- beta: hyperparameter to control supervised and unsupervised loss
- file_name: saved filed name for the encoder function
Returns:
- y_test_hat: prediction on x_test
"""
# Network parameters
hidden_dim = parameters['hidden_dim']
act_fn = tf.nn.relu
batch_size = parameters['batch_size']
iterations = parameters['iterations']
# Basic parameters
data_dim = len(x_train[0, :])
label_dim = len(y_train[0, :])
# Divide training and validation sets (9:1)
idx = np.random.permutation(len(x_train[:, 0]))
train_idx = idx[:int(len(idx)*0.9)]
valid_idx = idx[int(len(idx)*0.9):]
x_valid = x_train[valid_idx, :]
y_valid = y_train[valid_idx, :]
x_train = x_train[train_idx, :]
y_train = y_train[train_idx, :]
# Input placeholder
# Labeled data
x_input = tf.placeholder(tf.float32, [None, data_dim])
y_input = tf.placeholder(tf.float32, [None, label_dim])
# Augmented unlabeled data
xu_input = tf.placeholder(tf.float32, [None, None, data_dim])
## Predictor
def predictor(x_input):
"""Returns prediction.
Args:
- x_input: input feature
Returns:
- y_hat_logit: logit prediction
- y_hat: prediction
"""
with tf.variable_scope('predictor', reuse=tf.AUTO_REUSE):
# Stacks multi-layered perceptron
inter_layer = contrib_layers.fully_connected(x_input,
hidden_dim,
activation_fn=act_fn)
inter_layer = contrib_layers.fully_connected(inter_layer,
hidden_dim,
activation_fn=act_fn)
y_hat_logit = contrib_layers.fully_connected(inter_layer,
label_dim,
activation_fn=None)
y_hat = tf.nn.softmax(y_hat_logit)
return y_hat_logit, y_hat
# Build model
y_hat_logit, y_hat = predictor(x_input)
yv_hat_logit, yv_hat = predictor(xu_input)
# Defin losses
# Supervised loss
y_loss = tf.losses.softmax_cross_entropy(y_input, y_hat_logit)
# Unsupervised loss
yu_loss = tf.reduce_mean(tf.nn.moments(yv_hat_logit, axes = 0)[1])
# Define variables
p_vars = [v for v in tf.trainable_variables() \
if v.name.startswith('predictor')]
# Define solver
solver = tf.train.AdamOptimizer().minimize(y_loss + \
beta * yu_loss, var_list=p_vars)
# Load encoder from self-supervised model
encoder = keras.models.load_model(file_name)
# Encode validation and testing features
x_valid = encoder.predict(x_valid)
x_test = encoder.predict(x_test)
# Start session
sess = tf.Session()
sess.run(tf.global_variables_initializer())
# Setup early stopping procedure
class_file_name = './save_model/class_model.ckpt'
saver = tf.train.Saver(p_vars)
yv_loss_min = 1e10
yv_loss_min_idx = -1
# Training iteration loop
for it in range(iterations):
# Select a batch of labeled data
batch_idx = np.random.permutation(len(x_train[:, 0]))[:batch_size]
x_batch = x_train[batch_idx, :]
y_batch = y_train[batch_idx, :]
# Encode labeled data
x_batch = encoder.predict(x_batch)
# Select a batch of unlabeled data
batch_u_idx = np.random.permutation(len(x_unlab[:, 0]))[:batch_size]
xu_batch_ori = x_unlab[batch_u_idx, :]
# Augment unlabeled data
xu_batch = list()
for rep in range(K):
# Mask vector generation
m_batch = mask_generator(p_m, xu_batch_ori)
# Pretext generator
_, xu_batch_temp = pretext_generator(m_batch, xu_batch_ori)
# Encode corrupted samples
xu_batch_temp = encoder.predict(xu_batch_temp)
xu_batch = xu_batch + [xu_batch_temp]
# Convert list to matrix
xu_batch = np.asarray(xu_batch)
# Train the model
_, y_loss_curr = sess.run([solver, y_loss],
feed_dict={x_input: x_batch, y_input: y_batch,
xu_input: xu_batch})
# Current validation loss
yv_loss_curr = sess.run(y_loss, feed_dict={x_input: x_valid,
y_input: y_valid})
if it % 100 == 0:
print('Iteration: ' + str(it) + '/' + str(iterations) +
', Current loss: ' + str(np.round(yv_loss_curr, 4)))
# Early stopping & Best model save
if yv_loss_min > yv_loss_curr:
yv_loss_min = yv_loss_curr
yv_loss_min_idx = it
# Saves trained model
saver.save(sess, class_file_name)
if yv_loss_min_idx + 100 < it:
break
#%% Restores the saved model
imported_graph = tf.train.import_meta_graph(class_file_name + '.meta')
sess = tf.Session()
imported_graph.restore(sess, class_file_name)
# Predict on x_test
y_test_hat = sess.run(y_hat, feed_dict={x_input: x_test})
return y_test_hat