forked from MorvanZhou/NLP-Tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcnn-lm.py
126 lines (110 loc) · 5.07 KB
/
cnn-lm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
# a modification from [Convolutional Neural Networks for Sentence Classification](https://arxiv.org/pdf/1408.5882.pdf)
import tensorflow as tf
from tensorflow import keras
import numpy as np
import utils # this refers to utils.py in my [repo](https://github.com/MorvanZhou/NLP-Tutorials/)
import tensorflow_addons as tfa
class CNNTranslation(keras.Model):
def __init__(self, enc_v_dim, dec_v_dim, emb_dim, units, max_pred_len, start_token, end_token):
super().__init__()
self.units = units
# encoder
self.enc_embeddings = keras.layers.Embedding(
input_dim=enc_v_dim, output_dim=emb_dim, # [enc_n_vocab, emb_dim]
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
)
self.conv2ds = [
keras.layers.Conv2D(16, (n, emb_dim), padding="valid", activation=keras.activations.relu)
for n in range(2, 5)]
self.max_pools = [keras.layers.MaxPool2D((n, 1)) for n in [7, 6, 5]]
self.encoder = keras.layers.Dense(units, activation=keras.activations.relu)
# decoder
self.dec_embeddings = keras.layers.Embedding(
input_dim=dec_v_dim, output_dim=emb_dim, # [dec_n_vocab, emb_dim]
embeddings_initializer=tf.initializers.RandomNormal(0., 0.1),
)
self.decoder_cell = keras.layers.LSTMCell(units=units)
decoder_dense = keras.layers.Dense(dec_v_dim)
# train decoder
self.decoder_train = tfa.seq2seq.BasicDecoder(
cell=self.decoder_cell,
sampler=tfa.seq2seq.sampler.TrainingSampler(), # sampler for train
output_layer=decoder_dense
)
# predict decoder
self.decoder_eval = tfa.seq2seq.BasicDecoder(
cell=self.decoder_cell,
sampler=tfa.seq2seq.sampler.GreedyEmbeddingSampler(), # sampler for predict
output_layer=decoder_dense
)
self.cross_entropy = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
self.opt = keras.optimizers.Adam(0.01)
self.max_pred_len = max_pred_len
self.start_token = start_token
self.end_token = end_token
def encode(self, x):
embedded = self.enc_embeddings(x) # [n, step, emb]
o = tf.expand_dims(embedded, axis=3) # [n, step=8, emb=16, 1]
co = [conv2d(o) for conv2d in self.conv2ds] # [n, 7, 1, 16], [n, 6, 1, 16], [n, 5, 1, 16]
co = [self.max_pools[i](co[i]) for i in range(len(co))] # [n, 1, 1, 16] * 3
co = [tf.squeeze(c, axis=[1, 2]) for c in co] # [n, 16] * 3
o = tf.concat(co, axis=1) # [n, 16*3]
h = self.encoder(o) # [n, units]
return [h, h]
def inference(self, x):
s = self.encode(x)
done, i, s = self.decoder_eval.initialize(
self.dec_embeddings.variables[0],
start_tokens=tf.fill([x.shape[0], ], self.start_token),
end_token=self.end_token,
initial_state=s,
)
pred_id = np.zeros((x.shape[0], self.max_pred_len), dtype=np.int32)
for l in range(self.max_pred_len):
o, s, i, done = self.decoder_eval.step(
time=l, inputs=i, state=s, training=False)
pred_id[:, l] = o.sample_id
return pred_id
def train_logits(self, x, y, seq_len):
s = self.encode(x)
dec_in = y[:, :-1] # ignore <EOS>
dec_emb_in = self.dec_embeddings(dec_in)
o, _, _ = self.decoder_train(dec_emb_in, s, sequence_length=seq_len)
logits = o.rnn_output
return logits
def step(self, x, y, seq_len):
with tf.GradientTape() as tape:
logits = self.train_logits(x, y, seq_len)
dec_out = y[:, 1:] # ignore <GO>
loss = self.cross_entropy(dec_out, logits)
grads = tape.gradient(loss, self.trainable_variables)
self.opt.apply_gradients(zip(grads, self.trainable_variables))
return loss.numpy()
def train():
# get and process data
data = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ", data.date_cn[:3], "\nEnglish time order: dd/M/yyyy ", data.date_en[:3])
print("vocabularies: ", data.vocab)
print("x index sample: \n{}\n{}".format(data.idx2str(data.x[0]), data.x[0]),
"\ny index sample: \n{}\n{}".format(data.idx2str(data.y[0]), data.y[0]))
model = CNNTranslation(
data.num_word, data.num_word, emb_dim=16, units=32,
max_pred_len=11, start_token=data.start_token, end_token=data.end_token)
# training
for t in range(1500):
bx, by, decoder_len = data.sample(32)
loss = model.step(bx, by, decoder_len)
if t % 70 == 0:
target = data.idx2str(by[0, 1:-1])
pred = model.inference(bx[0:1])
res = data.idx2str(pred[0])
src = data.idx2str(bx[0])
print(
"t: ", t,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
if __name__ == "__main__":
train()