-
Notifications
You must be signed in to change notification settings - Fork 185
/
Copy pathFM.py
324 lines (292 loc) · 16.3 KB
/
FM.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
'''
Tensorflow implementation of Factorization Machines (FM) as described in:
Xiangnan He, Tat-Seng Chua. Neural Factorization Machines for Sparse Predictive Analytics. In Proc. of SIGIR 2017.
Note that the original paper of FM is: Steffen Rendle. Factorization Machines. In Proc. of ICDM 2010.
@author:
Xiangnan He ([email protected])
Lizi Liao ([email protected])
@references:
'''
import math
import os
import numpy as np
import tensorflow as tf
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.metrics import mean_squared_error
from sklearn.metrics import accuracy_score
from sklearn.metrics import log_loss
from time import time
import argparse
import LoadData as DATA
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
#################### Arguments ####################
def parse_args():
parser = argparse.ArgumentParser(description="Run FM.")
parser.add_argument('--path', nargs='?', default='data/',
help='Input data path.')
parser.add_argument('--dataset', nargs='?', default='frappe',
help='Choose a dataset.')
parser.add_argument('--epoch', type=int, default=100,
help='Number of epochs.')
parser.add_argument('--pretrain', type=int, default=-1,
help='flag for pretrain. 1: initialize from pretrain; 0: randomly initialize; -1: save the model to pretrain file')
parser.add_argument('--batch_size', type=int, default=128,
help='Batch size.')
parser.add_argument('--hidden_factor', type=int, default=64,
help='Number of hidden factors.')
parser.add_argument('--lamda', type=float, default=0,
help='Regularizer for bilinear part.')
parser.add_argument('--keep_prob', type=float, default=0.5,
help='Keep probility (1-dropout_ratio) for the Bi-Interaction layer. 1: no dropout')
parser.add_argument('--lr', type=float, default=0.05,
help='Learning rate.')
parser.add_argument('--loss_type', nargs='?', default='square_loss',
help='Specify a loss type (square_loss or log_loss).')
parser.add_argument('--optimizer', nargs='?', default='AdagradOptimizer',
help='Specify an optimizer type (AdamOptimizer, AdagradOptimizer, GradientDescentOptimizer, MomentumOptimizer).')
parser.add_argument('--verbose', type=int, default=1,
help='Show the results per X epochs (0, 1 ... any positive integer)')
parser.add_argument('--batch_norm', type=int, default=0,
help='Whether to perform batch normaization (0 or 1)')
return parser.parse_args()
class FM(BaseEstimator, TransformerMixin):
def __init__(self, features_M, pretrain_flag, save_file, hidden_factor, loss_type, epoch, batch_size, learning_rate, lamda_bilinear, keep,
optimizer_type, batch_norm, verbose, random_seed=2016):
# bind params to class
self.batch_size = batch_size
self.learning_rate = learning_rate
self.hidden_factor = hidden_factor
self.save_file = save_file
self.pretrain_flag = pretrain_flag
self.loss_type = loss_type
self.features_M = features_M
self.lamda_bilinear = lamda_bilinear
self.keep = keep
self.epoch = epoch
self.random_seed = random_seed
self.optimizer_type = optimizer_type
self.batch_norm = batch_norm
self.verbose = verbose
# performance of each epoch
self.train_rmse, self.valid_rmse, self.test_rmse = [], [], []
# init all variables in a tensorflow graph
self._init_graph()
def _init_graph(self):
'''
Init a tensorflow Graph containing: input data, variables, model, loss, optimizer
'''
self.graph = tf.Graph()
with self.graph.as_default(): # , tf.device('/cpu:0'):
# Set graph level random seed
tf.set_random_seed(self.random_seed)
# Input data.
self.train_features = tf.placeholder(tf.int32, shape=[None, None]) # None * features_M
self.train_labels = tf.placeholder(tf.float32, shape=[None, 1]) # None * 1
self.dropout_keep = tf.placeholder(tf.float32)
self.train_phase = tf.placeholder(tf.bool)
# Variables.
self.weights = self._initialize_weights()
# Model.
# _________ sum_square part _____________
# get the summed up embeddings of features.
nonzero_embeddings = tf.nn.embedding_lookup(self.weights['feature_embeddings'], self.train_features)
self.summed_features_emb = tf.reduce_sum(nonzero_embeddings, 1) # None * K
# get the element-multiplication
self.summed_features_emb_square = tf.square(self.summed_features_emb) # None * K
# _________ square_sum part _____________
self.squared_features_emb = tf.square(nonzero_embeddings)
self.squared_sum_features_emb = tf.reduce_sum(self.squared_features_emb, 1) # None * K
# ________ FM __________
self.FM = 0.5 * tf.sub(self.summed_features_emb_square, self.squared_sum_features_emb) # None * K
if self.batch_norm:
self.FM = self.batch_norm_layer(self.FM, train_phase=self.train_phase, scope_bn='bn_fm')
self.FM = tf.nn.dropout(self.FM, self.dropout_keep) # dropout at the FM layer
# _________out _________
Bilinear = tf.reduce_sum(self.FM, 1, keep_dims=True) # None * 1
self.Feature_bias = tf.reduce_sum(tf.nn.embedding_lookup(self.weights['feature_bias'], self.train_features) , 1) # None * 1
Bias = self.weights['bias'] * tf.ones_like(self.train_labels) # None * 1
self.out = tf.add_n([Bilinear, self.Feature_bias, Bias]) # None * 1
# Compute the loss.
if self.loss_type == 'square_loss':
if self.lamda_bilinear > 0:
self.loss = tf.nn.l2_loss(tf.sub(self.train_labels, self.out)) + tf.contrib.layers.l2_regularizer(self.lamda_bilinear)(self.weights['feature_embeddings']) # regulizer
else:
self.loss = tf.nn.l2_loss(tf.sub(self.train_labels, self.out))
elif self.loss_type == 'log_loss':
self.out = tf.sigmoid(self.out)
if self.lambda_bilinear > 0:
self.loss = tf.contrib.losses.log_loss(self.out, self.train_labels, weight=1.0, epsilon=1e-07, scope=None) + tf.contrib.layers.l2_regularizer(self.lamda_bilinear)(self.weights['feature_embeddings']) # regulizer
else:
self.loss = tf.contrib.losses.log_loss(self.out, self.train_labels, weight=1.0, epsilon=1e-07, scope=None)
# Optimizer.
if self.optimizer_type == 'AdamOptimizer':
self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8).minimize(self.loss)
elif self.optimizer_type == 'AdagradOptimizer':
self.optimizer = tf.train.AdagradOptimizer(learning_rate=self.learning_rate, initial_accumulator_value=1e-8).minimize(self.loss)
elif self.optimizer_type == 'GradientDescentOptimizer':
self.optimizer = tf.train.GradientDescentOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
elif self.optimizer_type == 'MomentumOptimizer':
self.optimizer = tf.train.MomentumOptimizer(learning_rate=self.learning_rate, momentum=0.95).minimize(self.loss)
# init
self.saver = tf.train.Saver()
init = tf.global_variables_initializer()
self.sess = tf.Session()
self.sess.run(init)
# number of params
total_parameters = 0
for variable in self.weights.values():
shape = variable.get_shape() # shape is an array of tf.Dimension
variable_parameters = 1
for dim in shape:
variable_parameters *= dim.value
total_parameters += variable_parameters
if self.verbose > 0:
print "#params: %d" %total_parameters
def _initialize_weights(self):
all_weights = dict()
if self.pretrain_flag > 0:
weight_saver = tf.train.import_meta_graph(self.save_file + '.meta')
pretrain_graph = tf.get_default_graph()
feature_embeddings = pretrain_graph.get_tensor_by_name('feature_embeddings:0')
feature_bias = pretrain_graph.get_tensor_by_name('feature_bias:0')
bias = pretrain_graph.get_tensor_by_name('bias:0')
with tf.Session() as sess:
weight_saver.restore(sess, self.save_file)
fe, fb, b = sess.run([feature_embeddings, feature_bias, bias])
all_weights['feature_embeddings'] = tf.Variable(fe, dtype=tf.float32)
all_weights['feature_bias'] = tf.Variable(fb, dtype=tf.float32)
all_weights['bias'] = tf.Variable(b, dtype=tf.float32)
else:
all_weights['feature_embeddings'] = tf.Variable(
tf.random_normal([self.features_M, self.hidden_factor], 0.0, 0.01),
name='feature_embeddings') # features_M * K
all_weights['feature_bias'] = tf.Variable(
tf.random_uniform([self.features_M, 1], 0.0, 0.0), name='feature_bias') # features_M * 1
all_weights['bias'] = tf.Variable(tf.constant(0.0), name='bias') # 1 * 1
return all_weights
def batch_norm_layer(self, x, train_phase, scope_bn):
# Note: the decay parameter is tunable
bn_train = batch_norm(x, decay=0.9, center=True, scale=True, updates_collections=None,
is_training=True, reuse=None, trainable=True, scope=scope_bn)
bn_inference = batch_norm(x, decay=0.9, center=True, scale=True, updates_collections=None,
is_training=False, reuse=True, trainable=True, scope=scope_bn)
z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
return z
def partial_fit(self, data): # fit a batch
feed_dict = {self.train_features: data['X'], self.train_labels: data['Y'], self.dropout_keep: self.keep, self.train_phase: True}
loss, opt = self.sess.run((self.loss, self.optimizer), feed_dict=feed_dict)
return loss
def get_random_block_from_data(self, data, batch_size): # generate a random block of training data
start_index = np.random.randint(0, len(data['Y']) - batch_size)
X , Y = [], []
# forward get sample
i = start_index
while len(X) < batch_size and i < len(data['X']):
if len(data['X'][i]) == len(data['X'][start_index]):
Y.append([data['Y'][i]])
X.append(data['X'][i])
i = i + 1
else:
break
# backward get sample
i = start_index
while len(X) < batch_size and i >= 0:
if len(data['X'][i]) == len(data['X'][start_index]):
Y.append([data['Y'][i]])
X.append(data['X'][i])
i = i - 1
else:
break
return {'X': X, 'Y': Y}
def shuffle_in_unison_scary(self, a, b): # shuffle two lists simutaneously
rng_state = np.random.get_state()
np.random.shuffle(a)
np.random.set_state(rng_state)
np.random.shuffle(b)
def train(self, Train_data, Validation_data, Test_data): # fit a dataset
# Check Init performance
if self.verbose > 0:
t2 = time()
init_train = self.evaluate(Train_data)
init_valid = self.evaluate(Validation_data)
init_test = self.evaluate(Test_data)
print("Init: \t train=%.4f, validation=%.4f, test=%.4f [%.1f s]" %(init_train, init_valid, init_test, time()-t2))
for epoch in xrange(self.epoch):
t1 = time()
self.shuffle_in_unison_scary(Train_data['X'], Train_data['Y'])
total_batch = int(len(Train_data['Y']) / self.batch_size)
for i in xrange(total_batch):
# generate a batch
batch_xs = self.get_random_block_from_data(Train_data, self.batch_size)
# Fit training
self.partial_fit(batch_xs)
t2 = time()
# output validation
train_result = self.evaluate(Train_data)
valid_result = self.evaluate(Validation_data)
test_result = self.evaluate(Test_data)
self.train_rmse.append(train_result)
self.valid_rmse.append(valid_result)
self.test_rmse.append(test_result)
if self.verbose > 0 and epoch%self.verbose == 0:
print("Epoch %d [%.1f s]\ttrain=%.4f, validation=%.4f, test=%.4f [%.1f s]"
%(epoch+1, t2-t1, train_result, valid_result, test_result, time()-t2))
if self.eva_termination(self.valid_rmse):
break
if self.pretrain_flag < 0:
print "Save model to file as pretrain."
self.saver.save(self.sess, self.save_file)
def eva_termination(self, valid):
if self.loss_type == 'square_loss':
if len(valid) > 5:
if valid[-1] > valid[-2] and valid[-2] > valid[-3] and valid[-3] > valid[-4] and valid[-4] > valid[-5]:
return True
else:
if len(valid) > 5:
if valid[-1] < valid[-2] and valid[-2] < valid[-3] and valid[-3] < valid[-4] and valid[-4] < valid[-5]:
return True
return False
def evaluate(self, data): # evaluate the results for an input set
num_example = len(data['Y'])
feed_dict = {self.train_features: data['X'], self.train_labels: [[y] for y in data['Y']], self.dropout_keep: 1.0, self.train_phase: False}
predictions = self.sess.run((self.out), feed_dict=feed_dict)
y_pred = np.reshape(predictions, (num_example,))
y_true = np.reshape(data['Y'], (num_example,))
if self.loss_type == 'square_loss':
predictions_bounded = np.maximum(y_pred, np.ones(num_example) * min(y_true)) # bound the lower values
predictions_bounded = np.minimum(predictions_bounded, np.ones(num_example) * max(y_true)) # bound the higher values
RMSE = math.sqrt(mean_squared_error(y_true, predictions_bounded))
return RMSE
elif self.loss_type == 'log_loss':
logloss = log_loss(y_true, y_pred) # I haven't checked the log_loss
return logloss
''' # for testing the classification accuracy
predictions_binary = []
for item in y_pred:
if item > 0.5:
predictions_binary.append(1.0)
else:
predictions_binary.append(0.0)
Accuracy = accuracy_score(y_true, predictions_binary)
return Accuracy '''
if __name__ == '__main__':
# Data loading
args = parse_args()
data = DATA.LoadData(args.path, args.dataset, args.loss_type)
if args.verbose > 0:
print("FM: dataset=%s, factors=%d, loss_type=%s, #epoch=%d, batch=%d, lr=%.4f, lambda=%.1e, keep=%.2f, optimizer=%s, batch_norm=%d"
%(args.dataset, args.hidden_factor, args.loss_type, args.epoch, args.batch_size, args.lr, args.lamda, args.keep_prob, args.optimizer, args.batch_norm))
save_file = '../pretrain/%s_%d/%s_%d' %(args.dataset, args.hidden_factor, args.dataset, args.hidden_factor)
# Training
t1 = time()
model = FM(data.features_M, args.pretrain, save_file, args.hidden_factor, args.loss_type, args.epoch, args.batch_size, args.lr, args.lamda, args.keep_prob, args.optimizer, args.batch_norm, args.verbose)
model.train(data.Train_data, data.Validation_data, data.Test_data)
# Find the best validation result across iterations
best_valid_score = 0
if args.loss_type == 'square_loss':
best_valid_score = min(model.valid_rmse)
elif args.loss_type == 'log_loss':
best_valid_score = max(model.valid_rmse)
best_epoch = model.valid_rmse.index(best_valid_score)
print ("Best Iter(validation)= %d\t train = %.4f, valid = %.4f, test = %.4f [%.1f s]"
%(best_epoch+1, model.train_rmse[best_epoch], model.valid_rmse[best_epoch], model.test_rmse[best_epoch], time()-t1))