Skip to content

Commit

Permalink
add sort-by-length
Browse files Browse the repository at this point in the history
Summary:

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
manman-ren committed Nov 21, 2024
1 parent 42e8154 commit 6114f86
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 4 additions & 2 deletions tritonbench/operators/ragged_attention/hstu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
num_buckets,
sparsity,
target_size,
sort_by_length,
requires_grad,
persistent_kernel: bool = False,
) -> None:
Expand All @@ -58,6 +59,7 @@ def __init__(
self.num_buckets = num_buckets
self.sparsity = sparsity
self.target_size = target_size
self.sort_by_length = sort_by_length
self.all_ts_weights = torch.nn.Parameter(
torch.randn(
(self.num_buckets + 1,),
Expand Down Expand Up @@ -175,7 +177,7 @@ def forward(
kwargs["ATTN_BIAS_TYPE"], # relative_bias_type
kwargs["MAX_ATTN_LEN"], # max_attn_len
kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len
kwargs["sort_by_length_indices"], # sort_by_length
self.sort_by_length,
)

return out
Expand Down Expand Up @@ -213,7 +215,7 @@ def generate_sparse_seq_len(


def get_test_inputs(
batch_size, num_heads, max_seq_len, sparsity, target_size, requires_grad
batch_size, num_heads, max_seq_len, sparsity, target_size, sort_by_length, requires_grad
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
timestamp_deltas: torch.Tensor = torch.randint(
86400,
Expand Down
8 changes: 6 additions & 2 deletions tritonbench/operators/ragged_attention/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def parse_op_args(args: List[str]):
parser.add_argument("--num-buckets", type=int, default=2048)
parser.add_argument("--sparsity", type=float, default=0.8)
parser.add_argument("--target-size", type=int, default=20)
parser.add_argument("--sort-by-length", type=bool, default=False)
return parser.parse_args(args)


Expand All @@ -41,6 +42,7 @@ def __init__(
self.num_buckets = args.num_buckets
self.sparsity = args.sparsity
self.target_size = args.target_size
self.sort_by_length = args.sort_by_length
# set a default number of inputs
self._num_inputs = 10 if self._num_inputs is None else self._num_inputs
self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD)
Expand All @@ -54,6 +56,7 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=False,
)
Expand All @@ -69,18 +72,19 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps,
self.num_buckets,
self.sparsity,
self.target_size,
self.sort_by_length,
self.requires_grad,
persistent_kernel=True,
)
return lambda: attn(qkv, seq_offsets, timestamps, num_targets)

def get_x_val(self, example_inputs):
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size)
return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size, self.sort_by_length)

def get_input_iter(self):
for _input_id in range(self._num_inputs):
inputs = get_test_inputs(
self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.requires_grad
self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.sort_by_length, self.requires_grad
)
yield inputs

Expand Down

0 comments on commit 6114f86

Please sign in to comment.