From 6114f8634dba7a2363b307bab1c7795d05220420 Mon Sep 17 00:00:00 2001 From: Manman Ren Date: Wed, 20 Nov 2024 17:30:14 -0800 Subject: [PATCH] add sort-by-length Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- tritonbench/operators/ragged_attention/hstu.py | 6 ++++-- tritonbench/operators/ragged_attention/operator.py | 8 ++++++-- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/tritonbench/operators/ragged_attention/hstu.py b/tritonbench/operators/ragged_attention/hstu.py index 132233ee..4e92a28f 100644 --- a/tritonbench/operators/ragged_attention/hstu.py +++ b/tritonbench/operators/ragged_attention/hstu.py @@ -48,6 +48,7 @@ def __init__( num_buckets, sparsity, target_size, + sort_by_length, requires_grad, persistent_kernel: bool = False, ) -> None: @@ -58,6 +59,7 @@ def __init__( self.num_buckets = num_buckets self.sparsity = sparsity self.target_size = target_size + self.sort_by_length = sort_by_length self.all_ts_weights = torch.nn.Parameter( torch.randn( (self.num_buckets + 1,), @@ -175,7 +177,7 @@ def forward( kwargs["ATTN_BIAS_TYPE"], # relative_bias_type kwargs["MAX_ATTN_LEN"], # max_attn_len kwargs["CONTEXTUAL_SEQ_LEN"], # contextual_seq_len - kwargs["sort_by_length_indices"], # sort_by_length + self.sort_by_length, ) return out @@ -213,7 +215,7 @@ def generate_sparse_seq_len( def get_test_inputs( - batch_size, num_heads, max_seq_len, sparsity, target_size, requires_grad + batch_size, num_heads, max_seq_len, sparsity, target_size, sort_by_length, requires_grad ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: timestamp_deltas: torch.Tensor = torch.randint( 86400, diff --git a/tritonbench/operators/ragged_attention/operator.py b/tritonbench/operators/ragged_attention/operator.py index c2042d1a..c157a3af 100644 --- a/tritonbench/operators/ragged_attention/operator.py +++ b/tritonbench/operators/ragged_attention/operator.py @@ -24,6 +24,7 @@ def parse_op_args(args: List[str]): parser.add_argument("--num-buckets", type=int, default=2048) parser.add_argument("--sparsity", type=float, default=0.8) parser.add_argument("--target-size", type=int, default=20) + parser.add_argument("--sort-by-length", type=bool, default=False) return parser.parse_args(args) @@ -41,6 +42,7 @@ def __init__( self.num_buckets = args.num_buckets self.sparsity = args.sparsity self.target_size = args.target_size + self.sort_by_length = args.sort_by_length # set a default number of inputs self._num_inputs = 10 if self._num_inputs is None else self._num_inputs self.requires_grad = not (self.mode == Mode.FWD_NO_GRAD) @@ -54,6 +56,7 @@ def hstu_triton_ragged_attention(self, qkv, seq_offsets, timestamps, num_targets self.num_buckets, self.sparsity, self.target_size, + self.sort_by_length, self.requires_grad, persistent_kernel=False, ) @@ -69,18 +72,19 @@ def hstu_triton_ragged_attention_persistent(self, qkv, seq_offsets, timestamps, self.num_buckets, self.sparsity, self.target_size, + self.sort_by_length, self.requires_grad, persistent_kernel=True, ) return lambda: attn(qkv, seq_offsets, timestamps, num_targets) def get_x_val(self, example_inputs): - return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size) + return (self.batch_size, self.num_heads, self.max_seq_len, self.num_buckets, self.sparsity, self.target_size, self.sort_by_length) def get_input_iter(self): for _input_id in range(self._num_inputs): inputs = get_test_inputs( - self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.requires_grad + self.batch_size, self.num_heads, self.max_seq_len, self.sparsity, self.target_size, self.sort_by_length, self.requires_grad ) yield inputs