Skip to content

Commit

Permalink
[RAISE-BP] Fix semantic differences (#1895)
Browse files Browse the repository at this point in the history
Fix semantic differences between the triton dialect and the
triton-shared dialect:
 - Offsets are now stride divided before writing the `MakeTensorPtr` op.
 - Add boundary check for axis for which modulo/shape is not empty
- Fix a few bugs coming from the fact that we use now a single target
op, i.e. `MakeTensorPtrOp`, for all block pointers ops instead of the
`tts::MakeTensorPtrOp`.
Update the tests accordingly.
  • Loading branch information
mfrancepillois authored Aug 20, 2024
1 parent 25fab7e commit eca3c10
Show file tree
Hide file tree
Showing 5 changed files with 202 additions and 128 deletions.
132 changes: 79 additions & 53 deletions test/Triton/raise-block-pointer.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ tt.func @test_addptr_splat_splat_2d_store(%arg0 : !tt.ptr<f32>, %arg1: i64, %arg

// CHECK-LABEL: tt.func @test_addptr_splat_make_range_add(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i64
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_2]]] {order = array<i32>} : <tensor<128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<128xf32>
Expand All @@ -124,9 +124,11 @@ tt.func @test_addptr_splat_make_range_add(%arg0 : !tt.ptr<f32>) -> tensor<128xf3
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_1]] : i32 to index
// CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64
// CHECK: %[[VAL_6:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_3]]] {order = array<i32>} : <tensor<128xf32>>
// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : !tt.ptr<tensor<128xf32>>
// CHECK: tt.return %[[VAL_7]] : tensor<128xf32>
// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_5]] : i64 to i32
// CHECK: %[[VAL_7:.*]] = arith.divui %[[VAL_3]], %[[VAL_6]] : i32
// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_7]]] {order = array<i32>} : <tensor<128xf32>>
// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr<tensor<128xf32>>
// CHECK: tt.return %[[VAL_9]] : tensor<128xf32>
tt.func @test_addptr_splat_make_range_mul(%arg0 : !tt.ptr<f32>, %arg1: i32) -> tensor<128xf32> {
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x!tt.ptr<f32>>
%1 = tt.splat %arg1 : i32 -> tensor<128xi32>
Expand Down Expand Up @@ -171,11 +173,12 @@ tt.func @test_expand_dims(%arg0 : !tt.ptr<f32>) -> tensor<1x128xf32> {

// CHECK-LABEL: tt.func @test_const_splat_addptr_2d(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 512 : i32
// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 512 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_3]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32>
tt.func @test_const_splat_addptr_2d(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
%cst = arith.constant dense<512> : tensor<2x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x128x!tt.ptr<f32>>
Expand All @@ -186,11 +189,12 @@ tt.func @test_const_splat_addptr_2d(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {

// CHECK-LABEL: tt.func @test_addptr_broadcast(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32>
tt.func @test_addptr_broadcast(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
%cst = arith.constant dense<1> : tensor<1x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x128x!tt.ptr<f32>>
Expand All @@ -202,11 +206,12 @@ tt.func @test_addptr_broadcast(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {

// CHECK-LABEL: tt.func @test_addptr_broadcast_rank(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32>
// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array<i32>} : <tensor<2x128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<2x128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32>
tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
%cst = arith.constant dense<1> : tensor<1x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<2x128x!tt.ptr<f32>>
Expand All @@ -219,10 +224,11 @@ tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr<f32>) -> tensor<2x128xf32> {
// CHECK-LABEL: tt.func @test_addptr_broadcast_rank_2(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<128x2x128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<128x2x128xf32>>
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32>
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<128x2x128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32>
tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128xf32> {
%cst = arith.constant dense<1> : tensor<128x1x128xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x2x128x!tt.ptr<f32>>
Expand All @@ -235,10 +241,11 @@ tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
// CHECK-LABEL: tt.func @test_addptr_broadcast_rank_3(
// CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr<f32>) -> tensor<128x2x128xf32> {
// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<128x2x128xf32>>
// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32>
// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array<i32>} : <tensor<128x2x128xf32>>
// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr<tensor<128x2x128xf32>>
// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32>
tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128xf32> {
%cst = arith.constant dense<1> : tensor<128x1x1xi32>
%0 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<128x2x128x!tt.ptr<f32>>
Expand All @@ -263,14 +270,25 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr<f32>) -> tensor<128x2x128x
// CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32
// CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64
// CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64
// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64
// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32
// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_3_]], [[VAR_10_]] : i32
// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_6_]] : i64 to i32
// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_7_]], [[VAR_12_]] : i32
// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64
// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64


// CHECK: [[VAR_19_:%.+]] = arith.trunci [[VAR_16_]] : i64 to i32
// CHECK: [[VAR_20_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_19_]] : i32
// CHECK: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32
// CHECK: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32

// CHECK: [[VAR_23:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_16_]], [[VAR_18_]]], {{\[}}[[VAR_20_]], [[VAR_22_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_24:%.+]] = tt.load [[VAR_14_]] {boundaryCheck = array<i32: 1>} : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_23]], [[VAR_24]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
Expand Down Expand Up @@ -322,14 +340,14 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32


// CHECK: tt.func @test_addptr_for_accumulation([[PARAM_0_:%.+]]: !tt.ptr<bf16>, [[PARAM_1_:%.+]]: !tt.ptr<bf16>, [[PARAM_2_:%.+]]: !tt.ptr<bf16>, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) {
// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK: [[CST_5_i64:%.+]] = arith.constant 5 : i64
// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32
// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index
// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index
// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index
// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64
// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64
// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32
// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64
// CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x256xbf16>>
// CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr<tensor<4x256xbf16>>
// CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> (tensor<4x256xbf16>, i32) {
Expand Down Expand Up @@ -418,14 +436,22 @@ module {
// CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index
// CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64
// CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32
// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64
// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64
// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: [[VAR_9_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32
// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_3_]], [[VAR_9_]] : i32
// CHECK: [[VAR_11_:%.+]] = arith.trunci [[VAR_7_]] : i64 to i32
// CHECK: [[VAR_12_:%.+]] = arith.divui [[VAR_8_]], [[VAR_11_]] : i32
// CHECK: [[VAR_13:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_10_]], [[VAR_12_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index
// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : index to i64
// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index
// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[VAR_16_]] : index to i64
// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_15_]] : i64 to i32
// CHECK: [[VAR_19_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_18_]] : i32
// CHECK: [[VAR_20_:%.+]] = arith.trunci [[VAR_17_]] : i64 to i32
// CHECK: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32
// CHECK: [[VAR_22:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_15_]], [[VAR_17_]]], {{\[}}[[VAR_19_]], [[VAR_21_]]] {order = array<i32>} : <tensor<4x4xf32>>
// CHECK: [[VAR_23:%.+]] = tt.load [[VAR_13]] {boundaryCheck = array<i32: 0>} : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.store [[VAR_22]], [[VAR_23]] : !tt.ptr<tensor<4x4xf32>>
// CHECK: tt.return
module {
tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr<f32>, %arg1: !tt.ptr<f32>, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc,
ConversionPatternRewriter &rewriter,
LLVM::LLVMFuncOp func, ValueRange args);

// Return true if the `val` value is a constant containing a value equal to
// expected.
bool isConstant(Value val, const unsigned expected);

} // namespace mlir::triton::gpu::intel

#endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,6 @@ namespace mlir::triton::gpu::intel {

namespace {

bool isConstant(Value v, unsigned expected) {
if (v.getDefiningOp() == nullptr)
return false;

if (auto stride = dyn_cast<arith::ConstantOp>(v.getDefiningOp())) {
if (auto strideInt = dyn_cast<IntegerAttr>(stride.getValue()))
if (strideInt.getInt() == expected)
return true;
}

return false;
}

struct TritonIntelGPUMaterializeBlockPointerPass
: public triton::gpu::intel::impl::
TritonIntelGPUMaterializeBlockPointerBase<
Expand Down Expand Up @@ -67,7 +54,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass

// HW 2D block read instruction only supports contiguous access.
Value fastChangeStride = strides[fastChangeDim];
if (!isConstant(fastChangeStride, 1))
if (!mlir::triton::gpu::intel::isConstant(fastChangeStride, 1))
return;

// Across Intel platforms, the strictest pitch restriction is to be a
Expand Down
46 changes: 46 additions & 0 deletions third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -235,4 +235,50 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc,
return call;
}

static std::optional<int64_t> getIntAttr(const OpFoldResult ofr) {
if (ofr.is<Attribute>() && isa<IntegerAttr>(ofr.get<Attribute>()))
return cast<IntegerAttr>(ofr.get<Attribute>()).getInt();
return std::nullopt;
}

// This function folds the `op` operation and returns the constant value if it
// has successfully folded to a constant. Otherwise, it returns `std::nullopt`.
static std::optional<int64_t> getFoldedConstantValue(Operation *op) {
SmallVector<OpFoldResult> results;
if (failed(op->fold(results))) {
return std::nullopt;
}

// If fold succeeded but `results` is empty, we give a second try, after the
// operands have been switched during the first call to `fold()`.
if (results.empty()) {
if (failed(op->fold(results))) {
return std::nullopt;
}
}

if (results.size() != 1) {
return std::nullopt;
}

auto intAttr = getIntAttr(results[0]);
if (intAttr.has_value()) {
return intAttr.value();
}

auto val = cast<Value>(results[0]);
auto constOp = val.getDefiningOp<arith::ConstantOp>();
if (!constOp)
return std::nullopt;

return getIntAttr(constOp.getValue());
}

bool isConstant(Value val, const unsigned expected) {
auto defOp = val.getDefiningOp();
if (!defOp)
return false;
return (getFoldedConstantValue(defOp) == expected);
}

} // namespace mlir::triton::gpu::intel
Loading

0 comments on commit eca3c10

Please sign in to comment.