-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
64 lines (54 loc) · 2.34 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import math
import numpy as np
import timeit
import torch
# ==== Utilities to generate data ====
def zipf_sentence_lengths(alpha: float, batch_size: int) -> torch.Tensor:
# generate fake corpus by unigram Zipf distribution
# from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
sentence_lengths = np.empty(batch_size, dtype=int)
for ibatch in range(batch_size):
sentence_lengths[ibatch] = 1
word = np.random.zipf(alpha)
while word != 3 and word != 386 and word != 858:
sentence_lengths[ibatch] += 1
word = np.random.zipf(alpha)
return torch.tensor(sentence_lengths)
def gen_batch(N, E_q, E_k, E_v, device, dtype=torch.float32, query_seq_len_1=False):
# generate semi-realistic data using Zipf distribution for sentence lengths
sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)
# Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
# dimension and works with torch.compile. The batch items each have shape (B, S*, D)
# where B = batch size, S* = ragged sequence length, and D = embedding dimension.
if query_seq_len_1:
query = torch.nested.nested_tensor([
torch.randn(1, E_q, device=device, dtype=dtype)
for l in sentence_lengths
], layout=torch.jagged)
else:
query = torch.nested.nested_tensor([
torch.randn(l.item(), E_q, device=device, dtype=dtype)
for l in sentence_lengths
], layout=torch.jagged)
key = torch.nested.nested_tensor([
torch.randn(s.item(), E_k, device=device, dtype=dtype)
for s in sentence_lengths
], layout=torch.jagged)
value = torch.nested.nested_tensor([
torch.randn(s.item(), E_v, device=device, dtype=dtype)
for s in sentence_lengths
], layout=torch.jagged)
return query, key, value, sentence_lengths
# FIXME: can remove this one
def jagged_to_padded(jt, padding_val):
# TODO: do jagged -> padded directly when this is supported
return torch.nested.to_padded_tensor(
torch.nested.nested_tensor(list(jt.unbind())),
padding_val)
def benchmark(func, *args, **kwargs):
torch.cuda.synchronize()
begin = timeit.default_timer()
output = func(*args, **kwargs)
torch.cuda.synchronize()
end = timeit.default_timer()
return output, (end - begin)