Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support sparsity, target-size and sort_by_length for hstu #62

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 70 additions & 10 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ def __init__(
num_heads,
max_seq_len,
num_buckets,
sparsity,
target_size,
sort_by_length,
requires_grad,
persistent_kernel: bool = False,
) -> None:
Expand All @@ -54,6 +57,9 @@ def __init__(
self.num_heads = num_heads
self.max_seq_len = max_seq_len
self.num_buckets = num_buckets
self.sparsity = sparsity
self.target_size = target_size
self.sort_by_length = sort_by_length
self.all_ts_weights = torch.nn.Parameter(
torch.randn(
(self.num_buckets + 1,),
Expand All @@ -73,7 +79,11 @@ def __init__(
self.persistent_kernel = persistent_kernel

def forward(
self, qkv: torch.Tensor, seq_offsets: torch.Tensor, timestamps: torch.Tensor
self,
qkv: torch.Tensor,
seq_offsets: torch.Tensor,
timestamps: torch.Tensor,
num_targets: torch.Tensor,
) -> torch.Tensor:
NUM_BUCKETS = self.num_buckets
torch._check(timestamps.size(0) + 1 == seq_offsets.size(0))
Expand All @@ -99,7 +109,7 @@ def forward(
"PW": self.all_pos_weights,
"Bias": None,
"seq2_offsets": None,
"num_targets": None,
"num_targets": num_targets,
"Scale": None,
"Out": out,
"stride_qm": q.stride(0),
Expand Down Expand Up @@ -171,25 +181,75 @@ def forward(
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len
kwargs["sort_by_length_indices"], # sort_by_length
self.sort_by_length,
)

return out


def generate_sparse_seq_len(
size: int,
max_seq_len: int,
sparsity: float,
device: torch.device,
) -> torch.Tensor:
if sparsity == 0.0:
return torch.zeros(size=(size,), device=device, dtype=torch.int)
elif sparsity == 1.0:
return torch.ones(size=(size,), device=device, dtype=torch.int) * max_seq_len
elif sparsity >= 0.5:
min_seq_len: int = int((2 * sparsity - 1.0) * max_seq_len)
return torch.randint(
low=min_seq_len,
high=max_seq_len,
size=(size,),
device=device,
dtype=torch.int,
)
else:
min_seq_len: int = 0
max_seq_len: int = int(2 * sparsity * max_seq_len)
return torch.randint(
low=min_seq_len,
high=max_seq_len,
size=(size,),
device=device,
dtype=torch.int,
)


def get_test_inputs(
batch_size, num_heads, max_seq_len, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
batch_size,
num_heads,
max_seq_len,
sparsity,
target_size,
sort_by_length,
requires_grad,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = torch.randint(
86400,
size=(batch_size, max_seq_len + 1),
).cuda()
timestamps = timestamp_deltas.cumsum(dim=1)

lengths = torch.randint(
max_seq_len + 1,
size=(batch_size,),
).cuda()
lengths = generate_sparse_seq_len(
size=batch_size,
max_seq_len=max_seq_len,
sparsity=sparsity,
device=torch.device("cuda"),
)
# assume has_delta_q is False
num_targets = None
if target_size != 0:
num_targets = torch.randint(
1,
target_size + 1,
(batch_size,),
device=lengths.device,
dtype=lengths.dtype,
)
num_targets = torch.where(num_targets > lengths, lengths, num_targets)
seq_offsets = torch.zeros(
(batch_size + 1,),
dtype=torch.int64,
Expand All @@ -208,4 +268,4 @@ def get_test_inputs(
.requires_grad_(requires_grad)
.cuda()
)
return qkv, seq_offsets, timestamps
return qkv, seq_offsets, timestamps, num_targets
42 changes: 35 additions & 7 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def parse_op_args(args: List[str]):
parser.add_argument("--heads", type=int, default=4, help="Number of heads")
parser.add_argument("--max-seq-len-log2", type=int, default=9)
parser.add_argument("--num-buckets", type=int, default=2048)
parser.add_argument("--sparsity", type=float, default=0.8)
parser.add_argument("--target-size", type=int, default=20)
parser.add_argument("--sort-by-length", type=bool, default=False)
return parser.parse_args(args)


Expand All @@ -37,42 +40,67 @@ def __init__(
self.num_heads = args.heads
self.max_seq_len = 2**args.max_seq_len_log2
self.num_buckets = args.num_buckets
self.sparsity = args.sparsity
self.target_size = args.target_size
self.sort_by_length = args.sort_by_length
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)

@register_benchmark()
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps):
def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=False,
)
return lambda: attn(qkv, seq_offsets, timestamps)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

# TODO: enable persistent kernels when the OSS backward is ready
@register_benchmark(enabled=False)
def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps):
def hstu_triton_ragged_attention_persistent(
self, qkv, seq_offsets, timestamps, num_targets
):
attn = RaggedHSTUAttn(
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets)
return (
self.batch_size,
self.num_heads,
self.max_seq_len,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.requires_grad
self.batch_size,
self.num_heads,
self.max_seq_len,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
)
yield inputs

Expand All @@ -94,7 +122,7 @@ def tflops(
f1 = 0.0
f2 = 0.0
jagged = True
qkv, seq_offsets, timestamps = example_inputs
qkv, seq_offsets, timestamps, num_targets = example_inputs
q = qkv[:, :, :128]
v = qkv[:, :, 256:384]
_, nheads, attn_dim = q.shape
Expand Down
Loading