Skip to content

Commit

Permalink
address review comments + lint
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Jan 8, 2025
1 parent 335aaf3 commit 1146674
Showing 1 changed file with 24 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def matmul_kernel_with_block_pointers_batched(
stride_az: tl.constexpr, stride_am: tl.constexpr, stride_ak: tl.constexpr, #
stride_bz: tl.constexpr, stride_bk: tl.constexpr, stride_bn: tl.constexpr, #
stride_cz: tl.constexpr, stride_cm: tl.constexpr, stride_cn: tl.constexpr, #
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, ACCUMULATOR_DTYPE: tl.constexpr,
stride_dz: tl.constexpr, stride_dm: tl.constexpr, stride_dn: tl.constexpr, #
ACCUMULATOR_DTYPE: tl.constexpr,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr):
bid = tl.program_id(axis=0)
Expand Down Expand Up @@ -227,28 +228,28 @@ def matmul(a, b, d, c):
x_names=['B', 'M', 'K', 'N', 'dtype'],
# different possible values for `x_name`
x_vals=[[1, 1024 * i, 1024 * i, 1024 * i, dtype] for i in [1, 2, 4, 8] for dtype in dtypes()] + #
[[*shape, dtype]
for shape in [[1, 1, 5120, 13824], #
[1, 4, 4096, 12288], #
[1, 512, 8192, 8192], #
[1, 512, 8192, 32768], #
[1, 512, 32768, 8192], #
[1, 1024, 16384, 8192], #
[1, 1024, 28672, 8192], #
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works
[1, 4096, 16384, 8192], #
[1, 8192, 16384, 1024], #
[1, 8192, 16384, 4096], #
[1, 16384, 1024, 8192], #
[1, 16384, 4096, 8192], #
[1, 16384, 8192, 1024], #
[1, 16384, 8192, 4096], #
[4, 32768, 128, 4096], #
[4, 32768, 4096, 128], #
[32, 4096, 4096, 128], #
[4096, 8, 128, 16384], #
[4096, 8, 16384, 128]]
for dtype in dtypes()],
[[*shape, dtype] for shape in [ #
[1, 1, 5120, 13824], #
[1, 4, 4096, 12288], #
[1, 512, 8192, 8192], #
[1, 512, 8192, 32768], #
[1, 512, 32768, 8192], #
[1, 1024, 16384, 8192], #
[1, 1024, 28672, 8192], #
[1, 3072, 4096, 3072], # FIXME: Remove this case when gemm_streamk_benchmark works
[1, 4096, 16384, 8192], #
[1, 8192, 16384, 1024], #
[1, 8192, 16384, 4096], #
[1, 16384, 1024, 8192], #
[1, 16384, 4096, 8192], #
[1, 16384, 8192, 1024], #
[1, 16384, 8192, 4096], #
[4, 32768, 128, 4096], #
[4, 32768, 4096, 128], #
[32, 4096, 4096, 128], #
[4096, 8, 128, 16384], #
[4096, 8, 16384, 128] #
] for dtype in dtypes()],
line_arg='provider',
# argument name whose value corresponds to a different line in the plot
# possible values for `line_arg``
Expand Down

0 comments on commit 1146674

Please sign in to comment.