-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathNSL-gpt2.py
199 lines (156 loc) · 6.61 KB
/
NSL-gpt2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
import numpy as np
import torch
import time
import math
torch.set_printoptions(8)
def gelu(x):
"""
Task: Use the torch API to implement the approximate calculation formula of the `GELU`
activation function. The formula is as follows (you need to paste it into the latex
online conversion website)
Website: https://www.latexlive.com/
Formula: \frac{1}{2} x\left[1+\tanh \left(\sqrt{\frac{2}{\pi}}\left(x+0.044715 x^{3}\right)\right)\right]
Input: Tensor
Output: Tensor
"""
pass
def softmax(x):
"""
Task: Use torch API to implement `softmax` function, search the specific formula by yourself
Input: Tensor
Output: Tensor
"""
pass
def layer_norm(x, g_b, eps:float = 1e-5):
"""
Task: Use torch API to implement `layernorm` function, search `layernorm` by yourself
Input:
x: Tensor
g_b: dictionary that load from gpt2 weight. g-gamma and b-bias are the keys
Output: Tensor
"""
g, b = torch.Tensor(g_b['g']), torch.Tensor(g_b['b'])
pass
def linear(x, w_b): # [m, in], [in, out], [out] -> [m, out]
"""
Task: implement linear layer
Input:
x: Tensor
w_b: dictionary that load from gpt2 weight. w-weight and b-bias are the keys
Output: Tensor
"""
w, b = w_b['w'], w_b['b']
pass
def ffn(x, mlp): # [n_seq, n_embd] -> [n_seq, n_embd]
"""
Task: use `gelu` `linear` to implement ffn
Notes: x --linear--> --gelu--> --linear--> output
Input:
x: Tensor
mlp: dictionary that load from gpt2 weight. w_b1 and w_b2 are the params of two linear layer
Output: Tensor
"""
w_b1, w_b2 = mlp['c_fc'], mlp['c_proj']
pass
def attention(q, k, v, mask): # [n_q, d_k], [n_k, d_k], [n_k, d_v], [n_q, n_k] -> [n_q, d_v]
"""
Task: use torch API to implement attention computation according to formula(1) of the following paper
where d_k account for the last dimension of `k`
Paper: https://arxiv.org/abs/1706.03762
Input:
q: Tensor
k: Tensor
v: Tensor
mask: Tensor
mlp: dictionary that load from gpt2 weight. w_b1 and w_b2 are the params of two linear layer
Output: Tensor
"""
pass
def mha(x, attn, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
"""
Task: Complete the code of the multi-head attention
Input:
x: Tensor
attn: dictionary that load from gpt2 weight. c_attn and c_proj are the params of two linear layer
n_head: number of head
Output: Tensorying multi-head attention and linear transformation, shape [n_seq, n_embd].
"""
c_attn, c_proj = attn['c_attn'], attn['c_proj']
# qkv projection
x = linear(x, c_attn) # [n_seq, n_embd] -> [n_seq, 3*n_embd]
# Split into qkv
"""
Task: Split the q,k,v matrix from the tensor x
Notes: [n_seq, 3*n_embd] -> 3 * [n_seq, n_embd]
"""
qkv = None # need to modify
# Split into heads
qkv_heads = [qkv_part.chunk(n_head, dim=-1) for qkv_part in qkv] # 3 * [n_seq, n_embd] -> 3 * n_head * [n_seq, n_embd/n_head]
qkv_heads = list(zip(*qkv_heads)) # [3, n_head, n_seq, n_embd/n_head]
# Causal mask to hide future inputs from being attended to
"""
Task: Construct mask matrix
Notes:
| 0 -inf -inf ... -inf |
| 0 0 -inf ... -inf |
| 0 0 0 ... -inf |
|... ... ... ... ... |
| 0 0 0 ... 0 |
Mask is a tensor whose dimension is [n_seq, n_seq]
"""
causal_mask = None # need to modify
# Perform attention over each head
out_heads = [attention(q, k, v, causal_mask) for q, k, v in qkv_heads] # n_head * [n_seq, n_embd/n_head]
# Merge heads
"""
Task: merge multi-heads results
Notes: n_head * [n_seq, n_embd/n_head] --> [n_seq, n_embd]
"""
x = None # need to modify
# Out projection
x = linear(x, c_proj) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
def transformer_block(x, block, n_head): # [n_seq, n_embd] -> [n_seq, n_embd]
mlp, attn, ln_1, ln_2 = block['mlp'], block['attn'], block['ln_1'], block['ln_2']
# multi-head causal self attention
x = x + mha(layer_norm(x, ln_1), attn, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
# position-wise feed forward network
x = x + ffn(layer_norm(x, ln_2), mlp) # [n_seq, n_embd] -> [n_seq, n_embd]
return x
def gpt2(inputs, params, n_head): # [n_seq] -> [n_seq, n_vocab]
wte, wpe, blocks, ln_f = params['wte'], params['wpe'], params['blocks'], params['ln_f']
# token + positional embeddings
x = wte[inputs] + wpe[range(len(inputs))] # [n_seq] -> [n_seq, n_embd]
x = torch.Tensor(x)
# forward pass through n_layer transformer blocks
for block in blocks:
x = transformer_block(x, block, n_head=n_head) # [n_seq, n_embd] -> [n_seq, n_embd]
# projection to vocab
x = layer_norm(x, ln_f) # [n_seq, n_embd] -> [n_seq, n_embd]
return x @ wte.T # [n_seq, n_embd] -> [n_seq, n_vocab]
def generate(inputs, params, n_head, n_tokens_to_generate):
from tqdm import tqdm
for _ in tqdm(range(n_tokens_to_generate), "generating"): # auto-regressive decode loop
logits = gpt2(inputs, params, n_head=n_head) # model forward pass
next_id = np.argmax(logits[-1]) # greedy sampling
inputs.append(int(next_id)) # append prediction to input
return inputs[len(inputs) - n_tokens_to_generate :] # only return generated ids
def main(prompt: str, n_tokens_to_generate: int = 5, model_size: str = "124M", models_dir: str = "models"):
from utils import load_encoder_hparams_and_params
# load encoder, hparams, and params from the released open-ai gpt-2 files
encoder, hparams, params = load_encoder_hparams_and_params(model_size, models_dir)
# encode the input string using the BPE tokenizer
input_ids = encoder.encode(prompt)
# make sure we are not surpassing the max sequence length of our model
assert len(input_ids) + n_tokens_to_generate < hparams["n_ctx"]
# generate output ids
start = time.time()
output_ids = generate(input_ids, params, hparams["n_head"], n_tokens_to_generate)
end = time.time()
print(f"Time taken to generate {n_tokens_to_generate} tokens: {end - start:.2f}s")
# decode the ids back into a string
output_text = encoder.decode(output_ids)
return output_text
if __name__ == "__main__":
import fire
fire.Fire(main)