-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathoriginal_seq2seq_abstractor.py
341 lines (277 loc) · 14 KB
/
original_seq2seq_abstractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
"""
This module includes PyTorch implementations of Abstractor architectures from the paper
"Abstractors and relational cross-attention: An inductive bias for explicit relational reasoning in Transformers"
Awni Altabaa, Taylor Webb, Jonathan Cohen, John Lafferty. ICLR (2024)
This is used to run some ablations and comparisons
"""
import torch
from torch import nn
import sys; sys.path += ['..', '../..']
from transformer_blocks import EncoderBlock, DecoderBlock
from symbol_retrieval import SymbolicAttention, RelationalSymbolicAttention, PositionalSymbolRetriever
from positional_encoding import SinusoidalPositionalEncoding, LearnedPositionalEmbeddings
from original_abstractor_module import AbstractorModule
class Seq2SeqAbstractorArcha(nn.Module):
"""Abstractor Seq2Seq Model"""
def __init__(self,
input_spec, output_spec, d_model, out_dim,
n_layers_dec, abstractor_kwargs, decoder_kwargs,
in_block_size, out_block_size):
super().__init__()
self.input_spec = input_spec
self.output_spec = output_spec
self.d_model = d_model
self.out_dim = out_dim
self.n_layers_dec = n_layers_dec
self.decoder_kwargs = decoder_kwargs
self.in_block_size = in_block_size
self.out_block_size = out_block_size
# TODO: make positional embedder configurable (learned or fixed sinusoidal, etc)
if input_spec['type'] == 'token':
source_embedder = torch.nn.Embedding(input_spec['vocab_size'], d_model)
elif input_spec['type'] == 'vector':
source_embedder = torch.nn.Linear(input_spec['dim'], d_model)
else:
raise ValueError(f"input_spec['type'] must be 'token' or 'vector', not {input_spec['type']}")
if output_spec['type'] == 'token':
target_embedder = torch.nn.Embedding(output_spec['vocab_size'], d_model)
elif output_spec['type'] == 'vector':
target_embedder = torch.nn.Linear(output_spec['dim'], d_model)
else:
raise ValueError(f"output_spec['type'] must be 'token' or 'vector', not {output_spec['type']}")
layer_dict = dict(
source_embedder = source_embedder,
target_embedder = target_embedder,
source_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=in_block_size),
target_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=out_block_size),
# dropout = nn.Dropout(dropout_rate),
abstractor = AbstractorModule(**abstractor_kwargs),
decoder_blocks = nn.ModuleList([DecoderBlock(d_model, **decoder_kwargs) for _ in range(n_layers_dec)]),
final_out = nn.Linear(d_model, out_dim)
)
self.layers = nn.ModuleDict(layer_dict)
# weight-tie target embedder and output layer
# self.layers.target_embedder.weight = self.layers.final_out.weight
def forward(self, x, y, targets=None):
x = self.layers.source_embedder(x)
y = self.layers.target_embedder(y)
x = self.layers.source_pos_embedder(x)
y = self.layers.target_pos_embedder(y)
x = self.layers.abstractor(x)
for dec_block in self.layers.decoder_blocks:
y = dec_block(y, x)
if targets is not None:
# compute loss if given targets
logits = self.layers.final_out(y)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.layers.final_out(y[:, [-1], :])
loss = None
return logits, loss
def get_num_params(self):
"""
Return the number of parameters in the model.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
class Seq2SeqAbstractorArchb(nn.Module):
"""Abstractor Seq2Seq Model"""
def __init__(self,
input_spec, output_spec, d_model, out_dim,
n_layers_enc, n_layers_dec, encoder_kwargs, abstractor_kwargs, decoder_kwargs,
in_block_size, out_block_size):
super().__init__()
self.input_spec = input_spec
self.output_spec = output_spec
self.d_model = d_model
self.out_dim = out_dim
self.n_layers_enc = n_layers_enc
# self.n_layers_abs = n_layers_abs
self.n_layers_dec = n_layers_dec
self.encoder_kwargs = encoder_kwargs
self.decoder_kwargs = decoder_kwargs
self.in_block_size = in_block_size
self.out_block_size = out_block_size
# TODO: make positional embedder configurable (learned or fixed sinusoidal, etc)
if input_spec['type'] == 'token':
source_embedder = torch.nn.Embedding(input_spec['vocab_size'], d_model)
elif input_spec['type'] == 'vector':
source_embedder = torch.nn.Linear(input_spec['dim'], d_model)
else:
raise ValueError(f"input_spec['type'] must be 'token' or 'vector', not {input_spec['type']}")
if output_spec['type'] == 'token':
target_embedder = torch.nn.Embedding(output_spec['vocab_size'], d_model)
elif output_spec['type'] == 'vector':
target_embedder = torch.nn.Linear(output_spec['dim'], d_model)
else:
raise ValueError(f"output_spec['type'] must be 'token' or 'vector', not {output_spec['type']}")
layer_dict = dict(
source_embedder = source_embedder,
target_embedder = target_embedder,
source_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=in_block_size),
target_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=out_block_size),
# dropout = nn.Dropout(dropout_rate),
encoder_blocks = nn.ModuleList([EncoderBlock(d_model, **encoder_kwargs) for _ in range(n_layers_enc)]),
abstractor = AbstractorModule(**abstractor_kwargs),
decoder_blocks = nn.ModuleList([DecoderBlock(d_model, **decoder_kwargs) for _ in range(n_layers_dec)]),
final_out = nn.Linear(d_model, out_dim)
)
self.layers = nn.ModuleDict(layer_dict)
# weight-tie target embedder and output layer
# self.layers.target_embedder.weight = self.layers.final_out.weight
def forward(self, x, y, targets=None):
x = self.layers.source_embedder(x)
y = self.layers.target_embedder(y)
x = self.layers.source_pos_embedder(x)
y = self.layers.target_pos_embedder(y)
for enc_block in self.layers.encoder_blocks:
x = enc_block(x)
x = self.layers.abstractor(x)
for dec_block in self.layers.decoder_blocks:
y = dec_block(y, x)
if targets is not None:
# compute loss if given targets
logits = self.layers.final_out(y)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.layers.final_out(y[:, [-1], :])
loss = None
return logits, loss
def get_num_params(self):
"""
Return the number of parameters in the model.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
def estimate_mfu(self, fwdbwd_per_iter, dt):
""" estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """
# NOTE: Model Flops Utilization (MFU) is a measure of how much of the peak FLOPS of the GPU is being utilized.
# PaLM paper has computed this for standard Transformers
# haven't done this yet for encoder-decoder architectures, so this is a placeholder
return -1
class Seq2SeqAbstractorArchd(nn.Module):
"""Abstractor Seq2Seq Model"""
def __init__(self,
input_spec, output_spec, d_model, out_dim,
n_layers_enc, n_layers_dec, encoder_kwargs, abstractor_kwargs, decoder_kwargs,
in_block_size, out_block_size):
super().__init__()
self.input_spec = input_spec
self.output_spec = output_spec
self.d_model = d_model
self.out_dim = out_dim
self.n_layers_enc = n_layers_enc
# self.n_layers_abs = n_layers_abs
self.n_layers_dec = n_layers_dec
self.encoder_kwargs = encoder_kwargs
self.decoder_kwargs = decoder_kwargs
self.in_block_size = in_block_size
self.out_block_size = out_block_size
# TODO: make positional embedder configurable (learned or fixed sinusoidal, etc)
if input_spec['type'] == 'token':
source_embedder = torch.nn.Embedding(input_spec['vocab_size'], d_model)
elif input_spec['type'] == 'vector':
source_embedder = torch.nn.Linear(input_spec['dim'], d_model)
else:
raise ValueError(f"input_spec['type'] must be 'token' or 'vector', not {input_spec['type']}")
if output_spec['type'] == 'token':
target_embedder = torch.nn.Embedding(output_spec['vocab_size'], d_model)
elif output_spec['type'] == 'vector':
target_embedder = torch.nn.Linear(output_spec['dim'], d_model)
else:
raise ValueError(f"output_spec['type'] must be 'token' or 'vector', not {output_spec['type']}")
layer_dict = dict(
source_embedder = source_embedder,
target_embedder = target_embedder,
source_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=in_block_size),
target_pos_embedder = SinusoidalPositionalEncoding(d_model, dropout=0., max_len=out_block_size),
# dropout = nn.Dropout(dropout_rate),
encoder_blocks = nn.ModuleList([EncoderBlock(d_model, **encoder_kwargs) for _ in range(n_layers_enc)]),
abstractor = AbstractorModule(**abstractor_kwargs),
multi_attn_decoder = MultiAttentionDecoder(n_contexts=2, d_model=d_model, n_layers=n_layers_dec, **decoder_kwargs),
final_out = nn.Linear(d_model, out_dim)
)
self.layers = nn.ModuleDict(layer_dict)
# weight-tie target embedder and output layer
# self.layers.target_embedder.weight = self.layers.final_out.weight
def forward(self, x, y, targets=None):
x = self.layers.source_embedder(x)
y = self.layers.target_embedder(y)
x = self.layers.source_pos_embedder(x)
y = self.layers.target_pos_embedder(y)
for enc_block in self.layers.encoder_blocks:
x = enc_block(x)
E = x
A = self.layers.abstractor(x)
y = self.layers.multi_attn_decoder(y, [E, A])
if targets is not None:
# compute loss if given targets
logits = self.layers.final_out(y)
loss = torch.nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
else:
logits = self.layers.final_out(y[:, [-1], :])
loss = None
return logits, loss
def get_num_params(self):
"""
Return the number of parameters in the model.
"""
n_params = sum(p.numel() for p in self.parameters())
return n_params
class MultiAttentionDecoder(nn.Module):
def __init__(self, d_model, n_layers, n_contexts, **kwargs):
"""Create a MultiAttentionDecoder layer.
The multi-attention decoder is a variant of the decoder which cross-attends to several context sequences.
For each layer and for each context sequence, the decoder performs causal self-attention,
then cross-attention to the context sequence, then processes the result with a feed-forward network.
Parameters
----------
d_model : int
model dimension
n_heads : int
number of attention heads
n_layers : int
number of decoder layers (there exists one for each context sequence)
dff : int, optional
The dimensionality of the feed-forward network. If not provided, it defaults to None.
dropout_rate : float, optional
The dropout rate to apply within the decoder layers. It helps prevent overfitting. Defaults to 0.
"""
super(MultiAttentionDecoder, self).__init__()
self.n_layers = n_layers
self.d_model = d_model
self.n_contexts = n_contexts
self.decoder_blocks = nn.ModuleList([
nn.ModuleList([
DecoderBlock(d_model=self.d_model, **kwargs)
for _ in range(self.n_contexts)
])
for _ in range(self.n_layers)
])
def forward(self, x, contexts):
for i in range(self.n_layers):
for j, context in enumerate(contexts):
x = self.decoder_blocks[i][j](x, context)
return x
def configure_optimizers(model, weight_decay, learning_rate, betas, device_type):
# start with all of the candidate parameters
param_dict = {pn: p for pn, p in model.named_parameters()}
# filter out those that do not require grad
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
optim_groups = [
{'params': decay_params, 'weight_decay': weight_decay},
{'params': nodecay_params, 'weight_decay': 0.0}
]
num_decay_params = sum(p.numel() for p in decay_params)
num_nodecay_params = sum(p.numel() for p in nodecay_params)
print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
# Create AdamW optimizer and use the fused version if it is available
use_fused = (device_type == 'cuda')
optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, fused=use_fused)
print(f"using fused AdamW: {use_fused}")
return optimizer