Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update HSTU and use the OSS wrapper for non-persisent kernels #53

Closed
wants to merge 15 commits into from
5 changes: 5 additions & 0 deletions install.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def install_xformers():
parser.add_argument(
"--fa3", action="store_true", help="Install optional flash_attention 3 kernels"
)
parser.add_argument("--hstu", action="store_true", help="Install HSTU.")
parser.add_argument("--jax", action="store_true", help="Install jax nightly")
parser.add_argument("--tk", action="store_true", help="Install ThunderKittens")
parser.add_argument("--liger", action="store_true", help="Install Liger-kernel")
Expand Down Expand Up @@ -153,6 +154,10 @@ def install_xformers():
if args.xformers or args.all:
logger.info("[tritonbench] installing xformers...")
install_xformers()
if args.hstu or args.all:
logger.info("[tritonbench] installing hstu...")
from tools.hstu.install import install_hstu
install_hstu()
logger.info("[tritonbench] installation complete!")
# run tests to check installation
if args.test:
Expand Down
13 changes: 13 additions & 0 deletions tools/hstu/hstu.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
index b4e318b..d6bc894 100644
--- a/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
+++ b/generative_recommenders/ops/triton/triton_ragged_hstu_attention.py
@@ -36,7 +36,7 @@ try:
VersionedSpec,
)
except ImportError:
- from hammer.oss.generative_recommenders.ops.triton.utils import (
+ from generative_recommenders.ops.triton.utils import (
_switch_to_contiguous_if_needed,
autotune_max_seq_len,
NamedSpecType,
30 changes: 30 additions & 0 deletions tools/hstu/install.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import os
import sys
import subprocess
from pathlib import Path

PATCH_DIR = str(Path(__file__).parent.parent.parent.joinpath("submodules", "generative-recommenders").absolute())
PATCH_FILE = os.path.join(os.path.dirname(os.path.abspath(__file__)), "hstu.patch")


def install_hstu():
try:
subprocess.check_output(
[
"patch",
"-p1",
"--forward",
"-i",
PATCH_FILE,
"-r",
"/tmp/rej",
],
cwd=PATCH_DIR,
)
except subprocess.SubprocessError as e:
output_str = str(e.output)
if "previously applied" in output_str:
return
else:
print(str(output_str))
sys.exit(1)
96 changes: 50 additions & 46 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@
)
except ModuleNotFoundError:
# OSS Import
import importlib
with add_path(str(SUBMODULE_PATH.joinpath("generative-recommenders"))):
from generative_recommenders.ops.triton import triton_ragged_hstu_attention

with add_path(str(SUBMODULE_PATH)):
triton_ragged_hstu_attention = importlib.import_module(
"generative-recommenders.ops.triton.triton_ragged_hstu_attention"
)
_ragged_hstu_attn_fwd = triton_ragged_hstu_attention._ragged_hstu_attn_fwd
_ragged_hstu_attn_fwd_persistent = (
triton_ragged_hstu_attention._ragged_hstu_attn_fwd_persistent
)
_RaggedAttentionRelativeBiasFunction = (
triton_ragged_hstu_attention._RaggedAttentionRelativeBiasFunction
)

@torch.fx.wrap
def prev_power_of_2(x: int) -> int:
Expand All @@ -47,6 +46,7 @@ def __init__(
num_heads,
max_seq_len,
num_buckets,
requires_grad,
persistent_kernel: bool = False,
) -> None:
super().__init__()
Expand All @@ -58,13 +58,17 @@ def __init__(
torch.randn(
(self.num_buckets + 1,),
dtype=torch.bfloat16,
).cuda()
)
.requires_grad_(requires_grad)
.cuda()
)
self.all_pos_weights = torch.nn.Parameter(
torch.randn(
(2 * self.max_seq_len - 1,),
dtype=torch.bfloat16,
).cuda()
)
.requires_grad_(requires_grad)
.cuda()
)
self.persistent_kernel = persistent_kernel

Expand Down Expand Up @@ -141,57 +145,57 @@ def forward(
"HAS_SORT_BY_LENGTH_INDICES": False,
"sort_by_length_indices": None,
}
if not IS_FBCODE:
del kwargs["MAX_ATTN_LEN"]
del kwargs["HAS_CONTEXTUAL_SEQ_LEN"]
del kwargs["contextual_seq_len"]
del kwargs["HAS_SORT_BY_LENGTH_INDICES"]
del kwargs["sort_by_length_indices"]
kwargs["HAS_MAX_ATTN_LEN"] = False
kwargs["max_attn_len"] = 0

if self.persistent_kernel:
grid = (1216,)
_ragged_hstu_attn_fwd_persistent[grid](**kwargs)
else:
grid = lambda meta: ( # noqa E731
triton.cdiv(N, meta["BLOCK_M"]),
Z * H,
out = _RaggedAttentionRelativeBiasFunction.apply(
self.max_seq_len, # N
kwargs["alpha"],
q,
k,
v,
kwargs["seq_offsets"],
kwargs["INVALID_MASK_TYPE"],
timestamps,
self.all_ts_weights, # ts_weights
self.all_pos_weights, # pos_weights
kwargs["CAUSAL"], # causal,
kwargs["num_buckets"], # num_buckets
"sqrt", # time_bucket_fn
kwargs["time_bucket_incr"], # time_bucket_incr
kwargs["time_bucket_div"], # time_bucket_div
kwargs["time_delta"], # time_delta
kwargs["max_pos_ind"], # max_pos_ind
kwargs["num_targets"],
None, # attn_scale
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["contextual_seq_len"], # contextual_seq_len
kwargs["sort_by_length_indices"], # sort_by_length
)
_ragged_hstu_attn_fwd[grid](**kwargs)

return out


def get_test_inputs(
batch_size, num_heads, max_seq_len
batch_size, num_heads, max_seq_len, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = (
torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
)
.requires_grad_(False)
.cuda()
)
timestamp_deltas: torch.Tensor = torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
).cuda()
timestamps = timestamp_deltas.cumsum(dim=1)

lengths = (
torch.randint(
max_seq_len + 1,
size=(batch_size,),
)
.requires_grad_(False)
.cuda()
)
seq_offsets = (
torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
)
.requires_grad_(False)
.cuda()
)
lengths = torch.randint(
max_seq_len + 1,
size=(batch_size,),
).cuda()
seq_offsets = torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
).cuda()
seq_offsets[1:] = torch.cumsum(
lengths,
dim=0,
Expand All @@ -203,7 +207,7 @@ def get_test_inputs(
(L, num_heads, 512),
dtype=torch.bfloat16,
)
.requires_grad_(False)
.requires_grad_(requires_grad)
.cuda()
)
return qkv, seq_offsets, timestamps
66 changes: 62 additions & 4 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import argparse

from typing import List, Optional
from typing import Any, Callable, List, Optional

from tritonbench.utils.triton_op import BenchmarkOperator, register_benchmark
import torch
from tritonbench.utils.input import input_filter

from tritonbench.utils.triton_op import (
BenchmarkOperator,
BenchmarkOperatorMetrics,
Mode,
register_benchmark,
register_metric,
)

from .hstu import get_test_inputs, RaggedHSTUAttn

Expand Down Expand Up @@ -30,6 +39,7 @@ def __init__(
self.num_buckets = args.num_buckets
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
Expand All @@ -38,17 +48,20 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)

@register_benchmark()
# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)
Expand All @@ -58,5 +71,50 @@ def get_x_val(self, example_inputs):

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(self.batch_size, self.num_heads, self.max_seq_len)
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
)
yield inputs

def get_bwd_fn(self, fwd_fn: Callable[..., Any]) -> Callable[..., Any]:
o = fwd_fn()
o_tensor = input_filter(
lambda x: isinstance(x, torch.Tensor),
o,
)
do = torch.rand_like(o_tensor)
fn = lambda: o_tensor.backward(do, retain_graph=True)
return fn

@register_metric()
def tflops(
self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics
) -> float:
ratio = 2.0 # triangular masking
f1 = 0.0
f2 = 0.0
jagged = True
qkv, seq_offsets, timestamps = example_inputs
q = qkv[:, :, :128]
v = qkv[:, :, 256:384]
_, nheads, attn_dim = q.shape
_, _, hidden_dim = v.shape
max_seqlen = timestamps.size(1) - 1

for i in range(self.batch_size):
seq_len = (
int((seq_offsets[i + 1] - seq_offsets[i]).item())
if jagged
else max_seqlen
)
# (QK^T), dQ = d(QK^T)K, dK^T = Q^Td(QK^T)
f1 += 2 * self.num_heads * attn_dim * seq_len**2 // ratio
# (QK^T)V, d(QK^T) = dOV^T, dV = (QK^T)^TdO,
f2 += 2 * self.num_heads * hidden_dim * seq_len**2 // ratio
if self.mode == Mode.FWD:
tflops = f1 + f2 # computes (QK^T) and (QK^T)V
elif self.mode == Mode.BWD:
tflops = 3 * f1 + 2 * f2 # computes (QK^T), dQ, dK, dV, d(QK^T)
elif self.mode == Mode.FWD_BWD:
tflops = 4 * f1 + 3 * f2
return tflops / metrics.latency * 1e-9
Loading