-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathassert_flash.py
96 lines (71 loc) · 2.65 KB
/
assert_flash.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import click
import torch
from ring_attention_pytorch import (
default_attention,
ring_flash_attn
)
# variables
@click.command()
@click.option('--causal', is_flag = True)
@click.option('--seq-len', default = 62)
@click.option('--dim-head', default = 16)
@click.option('--heads', default = 2)
@click.option('--rand-key-pad-mask', is_flag = True)
@click.option('--softclamp-qk-sim', is_flag = True)
@click.option('--bucket_size', default = 4)
@click.option('--cuda-kernel', is_flag = True)
def test(
causal: bool,
seq_len: int,
dim_head: int,
heads: int,
rand_key_pad_mask: bool,
bucket_size: int,
softclamp_qk_sim: bool,
cuda_kernel: bool
):
# base qkv
q = torch.randn(2, seq_len, heads, dim_head)
k = torch.randn(2, seq_len, heads, dim_head)
v = torch.randn(2, seq_len, heads, dim_head)
# key padding mask
mask = None
if rand_key_pad_mask:
assert not causal
mask = torch.randint(0, 2, (2, seq_len)).bool()
# flash and regular qkv's
fq = q.clone().requires_grad_()
fk = k.clone().requires_grad_()
fv = v.clone().requires_grad_()
rq = q.clone().requires_grad_()
rk = k.clone().requires_grad_()
rv = v.clone().requires_grad_()
if cuda_kernel:
assert torch.cuda.is_available()
fcq = q.clone().cuda().requires_grad_()
fck = k.clone().cuda().requires_grad_()
fcv = v.clone().cuda().requires_grad_()
# forward
o = default_attention(rq, rk, rv, causal = causal, mask = mask, softclamp_qk_sim = softclamp_qk_sim)
fo = ring_flash_attn(fq, fk, fv, bucket_size = bucket_size, causal = causal, mask = mask, softclamp_qk_sim = softclamp_qk_sim)
assert torch.allclose(o, fo, atol = 1e-6)
if cuda_kernel:
from ring_attention_pytorch.ring_flash_attention_cuda import ring_flash_attn_cuda
if mask is not None:
mask = mask.cuda()
fco = ring_flash_attn_cuda(fcq, fck, fcv, mask, causal, softclamp_qk_sim = softclamp_qk_sim)
fco.sum().backward()
assert torch.allclose(o, fco.cpu(), atol = 1e-2)
# backwards
o.sum().backward()
fo.sum().backward()
assert torch.allclose(rq.grad, fq.grad, atol = 1e-6)
assert torch.allclose(rk.grad, fk.grad, atol = 1e-6)
assert torch.allclose(rv.grad, fv.grad, atol = 1e-6)
if cuda_kernel:
assert torch.allclose(rv.grad, fcv.grad.cpu(), atol = 1e-2)
assert torch.allclose(rq.grad, fcq.grad.cpu(), atol = 1e-2)
assert torch.allclose(rk.grad, fck.grad.cpu(), atol = 1e-2)
print('✅ outputs and gradients are same between regular attention and naive flash attention')
if __name__ == '__main__':
test()