Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[triton-raise-block-ptr]: Increase test coverage #3198

Merged
merged 12 commits into from
Jan 21, 2025
Merged
59 changes: 59 additions & 0 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_2d_example.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s

module {
tt.func @kernel(
%arg0 : !tt.ptr<bf16>,
%arg1 : !tt.ptr<bf16>,
%arg2 : !tt.ptr<bf16>,
%arg3 : i32
)
{
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
// offset = 0, size = 4, stride = 1
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
// offset = [0,0], size = [4,1], stride = [1,0]
%2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [1,0]
%arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32>
%offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32>
// offset = [%arg3,0], size = [4,256], stride = [1,0]
%3 = tt.make_range {end = 256 : i32, start = 0 : i32}: tensor<256xi32>
// offset = 0, size = 256, stride = 1
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
// offset = [0,0], size = [1,256], stride = [0,1]
%5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,1]
%6 = arith.constant 5 : i32
%splat6 = tt.splat %6 : i32 -> tensor<4x256xi32>
%scale5 = arith.muli %5, %splat6 : tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,5]
%7 = arith.addi %offset3, %scale5: tensor<4x256xi32>
// offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%10 = tt.load %9 : tensor<4x256x!tt.ptr<bf16>>
%11 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%13 = tt.load %12 : tensor<4x256x!tt.ptr<bf16>>
%14 = arith.addf %10, %13 : tensor<4x256xbf16>
%15 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
tt.store %16, %14 : tensor<4x256x!tt.ptr<bf16>>
tt.return
}
}

// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32) {
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{.*}} : <tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_5_:%.+]] = tt.load [[VAR_4_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_6_:%.+]] = arith.addf [[VAR_2_]], [[VAR_5_]] : tensor<4x256xbf16>
// CHECK: [[VAR_8_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{.*}} : <tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_8_]], [[VAR_6_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.return
// CHECK: }
66 changes: 66 additions & 0 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_add_value.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s

module {
tt.func @kernel(
%arg0 : !tt.ptr<bf16>,
%arg1 : !tt.ptr<bf16>,
%arg2 : i32,
%arg3 : i32
)
{
%0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32>
// offset = 0, size = 4, stride = 1
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
// offset = [0,0], size = [4,1], stride = [1,0]
%2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [1,0]
%arg2splat = tt.splat %arg2 : i32 -> tensor<4x256xi32>
%offset2 = arith.addi %2, %arg2splat : tensor<4x256xi32>
// offset = [%arg2,0], size = [4,256], stride = [1,0]
%arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32>
%offset3 = arith.addi %offset2, %arg3splat : tensor<4x256xi32>
// offset = [%arg2+%arg3,0], size = [4,256], stride = [1,0]
%c10 = arith.constant 10 : i32
%c10splat = tt.splat %c10 : i32 -> tensor<4x256xi32>
%offset4 = arith.addi %offset3, %c10splat : tensor<4x256xi32>
// offset = [%arg2+%arg3+10,0], size = [4,256], stride = [1,0]
%3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32>
// offset = 0, size = 256, stride = 1
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
// offset = [0,0], size = [1,256], stride = [0,1]
%5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,1]
%c6 = arith.constant 6 : i32
%splat6 = tt.splat %c6 : i32 -> tensor<4x256xi32>
%scale5 = arith.muli %5, %splat6 : tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,6]
%7 = arith.addi %offset4, %scale5: tensor<4x256xi32>
// offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6]
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr<bf16>>,tensor<4x256xi32>
// source = %arg0, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6]
%10 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%11 = tt.addptr %10, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source = %arg1, offset = [%arg2+%arg3+10, 0], size = [4, 256], stride = [1, 6]
%12 = tt.load %9 : tensor<4x256x!tt.ptr<bf16>>
tt.store %11, %12 : tensor<4x256x!tt.ptr<bf16>>
tt.return
}
}

// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: i32, [[PARAM_3_:%.+]]: i32) {
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_6_:%.+]] = arith.constant 6 : i64
// CHECK-DAG: [[CST_10_:%.+]] = arith.constant 10 : i32
// CHECK: [[VAR_2_:%.+]] = arith.addi [[PARAM_2_]], [[PARAM_3_]] : i32
// CHECK: [[VAR_3_:%.+]] = arith.addi [[VAR_2_]], [[CST_10_]] : i32
// CHECK-DAG: [[VAR_4_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_6_]]], {{\[}}[[VAR_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_7_:%.+]] = arith.addi [[PARAM_2_]], [[PARAM_3_]] : i32
// CHECK: [[VAR_8_:%.+]] = arith.addi [[VAR_7_]], [[CST_10_]] : i32
// CHECK-DAG: [[VAR_9_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_6_]]], {{\[}}[[VAR_8_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_10_:%.+]] = tt.load [[VAR_4_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_9_]], [[VAR_10_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.return
// CHECK: }
76 changes: 76 additions & 0 deletions test/Triton/Intel/RaiseToBlockPointers/addptr_cmpge.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// RUN: triton-opt %s -triton-raise-block-pointer --split-input-file -canonicalize | FileCheck %s
// XFAIL: *

// These tests check that loads/stores that exhibit a cmp ge against 0 work
// correctly with the pointer analysis pass

// Example of the triton kernel that generates the loads/stores with cmp ge 0.
//
// def kernel(in_ptr0, out_ptr0, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
// yoffset = tl.program_id(1) * YBLOCK
// xoffset = tl.program_id(0) * XBLOCK
// tmp0 = tl.load(tl.make_block_ptr(in_ptr0, shape=[16640, 10],
// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK],
// order=[1, 0], offsets=[xoffset, yoffset]),
// boundary_check=[0, 1])
// tl.store(tl.make_block_ptr(out_ptr0, shape=[16640, 10],
// strides=[1, 16640], block_shape=[XBLOCK, YBLOCK],
// order=[1, 0], offsets=[xoffset, yoffset]),
// tl.broadcast_to(tmp0, [XBLOCK, YBLOCK]).to(tl.float16),
// boundary_check=[0, 1])

tt.func public @test_masked_load(%arg0: !tt.ptr<f16>) -> tensor<16x16xf16> {
%cst = arith.constant dense<0> : tensor<1x16xi64>
%c16_i32 = arith.constant 16 : i32
%0 = tt.get_program_id y : i32
%1 = arith.muli %0, %c16_i32 : i32
%2 = arith.extsi %1 : i32 to i64
%3 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>>
%4 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%5 = arith.extsi %4 : tensor<16xi32> to tensor<16xi64>
%6 = tt.expand_dims %5 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
%7 = tt.broadcast %6 : tensor<16x1xi64> -> tensor<16x16xi64>
%8 = tt.addptr %3, %7 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64>
%9 = tt.splat %2 : i64 -> tensor<16xi64>
%10 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%11 = arith.extsi %10 : tensor<16xi32> to tensor<16xi64>
%12 = arith.addi %9, %11 : tensor<16xi64>
%13 = tt.expand_dims %12 {axis = 0 : i32} : tensor<16xi64> -> tensor<1x16xi64>
%14 = arith.cmpi sge, %13, %cst : tensor<1x16xi64>
%15 = tt.broadcast %14 : tensor<1x16xi1> -> tensor<16x16xi1>
%16 = tt.load %8 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
// TODO: Replace above with below once support for masked loads is complete.
// %16 = tt.load %8, %15 evictionPolicy = evict_last : tensor<16x16x!tt.ptr<f16>>
tt.return %16 : tensor<16x16xf16>
}

// CHECK: tt.func public @test_masked_load([[arg0:%.+]]: !tt.ptr<f16>) -> tensor<16x16xf16> {
// CHECK: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{.*}} {order = array<i32>} : <tensor<16x16xf16>>
// CHECK: [[VAR_1:%.+]] = tt.load [[VAR_0]] evictionPolicy = evict_last : !tt.ptr<tensor<16x16xf16>>
// CHECK: tt.return [[VAR_1]] : tensor<16x16xf16>
// CHECK: }

// -----

tt.func public @test_masked_store(%arg0: !tt.ptr<f16>) {
%cst = arith.constant dense<0> : tensor<16x1xi64>
%cst_0 = arith.constant dense<1.500000e+01> : tensor<16x16xf16>
%0 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<16x16x!tt.ptr<f16>>
%1 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32>
%2 = arith.extsi %1 : tensor<16xi32> to tensor<16xi64>
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi64> -> tensor<16x1xi64>
%4 = tt.broadcast %3 : tensor<16x1xi64> -> tensor<16x16xi64>
%5 = tt.addptr %0, %4 : tensor<16x16x!tt.ptr<f16>>, tensor<16x16xi64>
%6 = arith.cmpi sge, %3, %cst : tensor<16x1xi64>
%7 = tt.broadcast %6 : tensor<16x1xi1> -> tensor<16x16xi1>
// TODO: Replace above with below once support for masked stores is complete.
// tt.store %5, %cst_0, %7 : tensor<16x16x!tt.ptr<f16>>
tt.store %5, %cst_0 : tensor<16x16x!tt.ptr<f16>>
tt.return
}

// CHECK: tt.func public @test_masked_store([[arg0:%.+]]: !tt.ptr<f16>) {
// CHECK-DAG: [[VAR_cst:%.+]] = arith.constant dense<1.500000e+01> : tensor<16x16xf16>
// CHECK-DAG: [[VAR_0:%.+]] = tt.make_tensor_ptr [[arg0]], {{.*}} {order = array<i32>} : <tensor<16x16xf16>>
// CHECK: tt.store [[VAR_0]], [[VAR_cst]] : !tt.ptr<tensor<16x16xf16>>
// CHECK: }
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// RUN: triton-opt %s -triton-raise-block-pointer -canonicalize | FileCheck %s

module {
tt.func @kernel(
%arg0 : !tt.ptr<bf16>,
%arg1 : !tt.ptr<bf16>,
%arg2 : !tt.ptr<bf16>,
%arg3 : i32,
%arg4 : i32
)
{
%0 = tt.make_range {end = 4 : i32, start = 0 : i32}:tensor<4xi32>
// offset = 0, size = 4, stride = 1
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
// offset = [0,0], size = [4,1], stride = [1,0]
%2 = tt.broadcast %1 : tensor<4x1xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [1,0]
%arg3splat = tt.splat %arg3 : i32 -> tensor<4x256xi32>
%offset3 = arith.addi %2, %arg3splat : tensor<4x256xi32>
// offset = [%arg3,0], size = [4,256], stride = [1,0]
%3 = tt.make_range {end = 256 : i32, start = 0 : i32}:tensor<256xi32>
// offset = 0, size = 256, stride = 1
%4 = tt.expand_dims %3 {axis = 0 : i32} : tensor<256xi32> -> tensor<1x256xi32>
// offset = [0,0], size = [1,256], stride = [0,1]
%5 = tt.broadcast %4 : tensor<1x256xi32> -> tensor<4x256xi32>
// offset = [0,0], size = [4,256], stride = [0,1]
%c5 = arith.constant 5 : i32
%splat6 = tt.splat %c5 : i32 -> tensor<4x256xi32>
// scalar = 5
%scale5 = arith.muli %5, %splat6 : tensor<4x256xi32> // Why we never called the conversion function for the inputs here?
// offset = [0,0], size = [4,256], stride = [0,5]
%7 = arith.addi %offset3, %scale5: tensor<4x256xi32> // Why we never called the conversion function for the inputs here?
// offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%8 = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>> // Why is the input unknown
%9 = tt.addptr %8, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg0, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%19 = tt.load %9 {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr<bf16>> // this will be replaced with a memref.copy
%11 = tt.splat %arg1 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%12 = tt.addptr %11, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg1, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
%c0 = arith.constant 0 : index
%c12 = arith.constant 12 : index
%c3 = arith.constant 3 : index
%i_c3 = arith.constant 3 : i32
%sum_out, %_ptr = scf.for %i = %c0 to %c12 step %c3 iter_args(%sum_iter = %19, %ptr_iter = %12) -> (tensor<4x256xbf16>, tensor<4x256x!tt.ptr<bf16>>) {
%20 = tt.load %ptr_iter {cache = 1 : i32, evict = 1 : i32, isVolatile = false}: tensor<4x256x!tt.ptr<bf16>>
%sum = arith.addf %sum_iter, %20 : tensor<4x256xbf16>
// pointer updates
%17 = tt.splat %i_c3 : i32 -> tensor<4x256xi32>
// offset: [3, 0], size = [4, 256], stride [0, 0]
%ptr = tt.addptr %ptr_iter, %17 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg1, offset = [%arg3+%i, 0], size = [4, 256], stride = [1, 5]
scf.yield %sum, %ptr : tensor<4x256xbf16>, tensor<4x256x!tt.ptr<bf16>>
}
%15 = tt.splat %arg2 : !tt.ptr<bf16> -> tensor<4x256x!tt.ptr<bf16>>
%16 = tt.addptr %15, %7 : tensor<4x256x!tt.ptr<bf16>>, tensor<4x256xi32>
// source: %arg2, offset = [%arg3, 0], size = [4, 256], stride = [1, 5]
tt.store %16, %sum_out : tensor<4x256x!tt.ptr<bf16>>
tt.return
}
}

// CHECK: tt.func @kernel([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : i32
// CHECK-DAG: [[CST_5_:%.+]] = arith.constant 5 : i64
// CHECK-DAG: [[CST_1_:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_3_:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = {{.*}} iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[VAR_3_]]) -> (tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>) {
// CHECK: [[VAR_8_:%.+]] = tt.load [[VAR_arg7_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK-DAG: [[VAR_9_:%.+]] = arith.addf [[VAR_arg6_]], [[VAR_8_]] : tensor<4x256xbf16>
// CHECK-DAG: [[VAR_10_:%.+]] = tt.advance [[VAR_arg7_]], {{\[}}[[CST_0_i32]], [[CST_3_]]] : <tensor<4x256xbf16>>
// CHECK: scf.yield [[VAR_9_]], [[VAR_10_]] : tensor<4x256xbf16>, !tt.ptr<tensor<4x256xbf16>>
// CHECK: }
// COM: to sizes: [4, 256], strides: [1, [[CST_5_]]{{.}}, offsets: {{.}}[[VAR_5_]], 0], shape: [0, 0], order: [] : <bf16> to tensor<4x256x!tt.ptr<bf16>>
// CHECK: [[VAR_6_:%.+]] = tt.make_tensor_ptr [[PARAM_2_]], {{\[}}[[CST_0_]], [[CST_0_]]], {{\[}}[[CST_1_]], [[CST_5_]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {{.*}} : <tensor<4x256xbf16>>
// CHECK: tt.store [[VAR_6_]], [[VAR_4_]]#0 : !tt.ptr<tensor<4x256xbf16>>
// CHECK: tt.return
// CHECK: }
Loading