Skip to content

Commit

Permalink
Support sparsity, target-size and sort_by_length for hstu (#62)
Browse files Browse the repository at this point in the history
Summary:
Copied over generate_sparse_seq_len
Example output
                                x_val    hstu_triton_ragged_attention-latency
-------------------------------------  --------------------------------------
(256, 4, 16384, 2048, 0.8, 20, False)                                 146.458
(256, 4, 16384, 2048, 0.8, 20, False)                                 148.616
(256, 4, 16384, 2048, 0.8, 20, False)                                 145.135
(256, 4, 16384, 2048, 0.8, 20, False)                                 148.98
(256, 4, 16384, 2048, 0.8, 20, False)                                 147.167
(256, 4, 16384, 2048, 0.8, 20, False)                                 146.155
(256, 4, 16384, 2048, 0.8, 20, False)                                 144.787
(256, 4, 16384, 2048, 0.8, 20, False)                                 144.055
(256, 4, 16384, 2048, 0.8, 20, False)                                 144.35
(256, 4, 16384, 2048, 0.8, 20, False)                                 146.67

Pull Request resolved: #62

Reviewed By: bertmaher, xuzhao9

Differential Revision: D66276135

Pulled By: manman-ren

fbshipit-source-id: d664253915adadbbe9655302ae6c48988b7fccf9
  • Loading branch information
manman-ren authored and facebook-github-bot committed Nov 21, 2024
1 parent f74fd56 commit 45d195c
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 17 deletions.
80 changes: 70 additions & 10 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
num_heads,
max_seq_len,
num_buckets,
sparsity,
target_size,
sort_by_length,
requires_grad,
persistent_kernel: bool = False,
) -> None:
Expand All @@ -54,6 +57,9 @@ def __init__(
self.num_heads = num_heads
self.max_seq_len = max_seq_len
self.num_buckets = num_buckets
self.sparsity = sparsity
self.target_size = target_size
self.sort_by_length = sort_by_length
self.all_ts_weights = torch.nn.Parameter(
torch.randn(
(self.num_buckets + 1,),
Expand All @@ -73,7 +79,11 @@ def __init__(
self.persistent_kernel = persistent_kernel

def forward(
self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor
self,
qkv: torch.Tensor,
seq_offsets: torch.Tensor,
timestamps: torch.Tensor,
num_targets: torch.Tensor,
) -> torch.Tensor:
NUM_BUCKETS = self.num_buckets
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))
Expand All @@ -99,7 +109,7 @@ def forward(
"PW": self.all_pos_weights,
"Bias": None,
"seq2_offsets": None,
"num_targets": None,
"num_targets": num_targets,
"Scale": None,
"Out": out,
"stride_qm": q.stride(0),
Expand Down Expand Up @@ -171,25 +181,75 @@ def forward(
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len
kwargs["sort_by_length_indices"], # sort_by_length
self.sort_by_length,
)

return out


def generate_sparse_seq_len(
size: int,
max_seq_len: int,
sparsity: float,
device: torch.device,
) -> torch.Tensor:
if sparsity == 0.0:
return torch.zeros(size=(size,), device=device, dtype=torch.int)
elif sparsity == 1.0:
return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len
elif sparsity >= 0.5:
min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len)
return torch.randint(
low=min_seq_len,
high=max_seq_len,
size=(size,),
device=device,
dtype=torch.int,
)
else:
min_seq_len: int = 0
max_seq_len: int = int(2 * sparsity * max_seq_len)
return torch.randint(
low=min_seq_len,
high=max_seq_len,
size=(size,),
device=device,
dtype=torch.int,
)


def get_test_inputs(
batch_size, num_heads, max_seq_len, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size,
num_heads,
max_seq_len,
sparsity,
target_size,
sort_by_length,
requires_grad,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
).cuda()
timestamps = timestamp_deltas.cumsum(dim=1)

lengths = torch.randint(
max_seq_len + 1,
size=(batch_size,),
).cuda()
lengths = generate_sparse_seq_len(
size=batch_size,
max_seq_len=max_seq_len,
sparsity=sparsity,
device=torch.device("cuda"),
)
# assume has_delta_q is False
num_targets = None
if target_size != 0:
num_targets = torch.randint(
1,
target_size + 1,
(batch_size,),
device=lengths.device,
dtype=lengths.dtype,
)
num_targets = torch.where(num_targets > lengths, lengths, num_targets)
seq_offsets = torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
Expand All @@ -208,4 +268,4 @@ def get_test_inputs(
.requires_grad_(requires_grad)
.cuda()
)
return qkv, seq_offsets, timestamps
return qkv, seq_offsets, timestamps, num_targets
42 changes: 35 additions & 7 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def parse_op_args(args: List[str]):
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
parser.add_argument("--max-seq-len-log2", type=int, default=9)
parser.add_argument("--num-buckets", type=int, default=2048)
parser.add_argument("--seq-sparsity", type=float, default=0.8)
parser.add_argument("--target-size", type=int, default=20)
parser.add_argument("--sort-by-length", type=bool, default=False)
return parser.parse_args(args)


Expand All @@ -37,42 +40,67 @@ def __init__(
self.num_heads = args.heads
self.max_seq_len = 2**args.max_seq_len_log2
self.num_buckets = args.num_buckets
self.sparsity = args.seq_sparsity
self.target_size = args.target_size
self.sort_by_length = args.sort_by_length
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
def hstu_triton_ragged_attention_persistent(
self, qkv, seq_offsets, timestamps, num_targets
):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)
return (
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
self.batch_size,
self.num_heads,
self.max_seq_len,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
)
yield inputs

Expand All @@ -94,7 +122,7 @@ def tflops(
f1 = 0.0
f2 = 0.0
jagged = True
qkv, seq_offsets, timestamps = example_inputs
qkv, seq_offsets, timestamps, num_targets = example_inputs
q = qkv[:, :, :128]
v = qkv[:, :, 256:384]
_, nheads, attn_dim = q.shape
Expand Down

0 comments on commit 45d195c

Please sign in to comment.