From eca3c10df023ef512452560616826848b9eee721 Mon Sep 17 00:00:00 2001 From: Maxime France-Pillois Date: Tue, 20 Aug 2024 18:10:28 +0100 Subject: [PATCH] [RAISE-BP] Fix semantic differences (#1895) Fix semantic differences between the triton dialect and the triton-shared dialect: - Offsets are now stride divided before writing the `MakeTensorPtr` op. - Add boundary check for axis for which modulo/shape is not empty - Fix a few bugs coming from the fact that we use now a single target op, i.e. `MakeTensorPtrOp`, for all block pointers ops instead of the `tts::MakeTensorPtrOp`. Update the tests accordingly. --- test/Triton/raise-block-pointer.mlir | 132 ++++++++++------- .../TritonIntelGPU/Transforms/Utility.h | 4 + .../MaterializeBlockPointer.cpp | 15 +- .../lib/TritonIntelGPUTransforms/Utility.cpp | 46 ++++++ .../TritonRaiseBlockPointer.cpp | 133 ++++++++++-------- 5 files changed, 202 insertions(+), 128 deletions(-) diff --git a/test/Triton/raise-block-pointer.mlir b/test/Triton/raise-block-pointer.mlir index 5e303fff3b..119f3789a7 100644 --- a/test/Triton/raise-block-pointer.mlir +++ b/test/Triton/raise-block-pointer.mlir @@ -101,9 +101,9 @@ tt.func @test_addptr_splat_splat_2d_store(%arg0 : !tt.ptr, %arg1: i64, %arg // CHECK-LABEL: tt.func @test_addptr_splat_make_range_add( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 -// CHECK: %[[VAL_3:.*]] = arith.constant 2 : i64 +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 2 : i64 // CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]]], {{\[}}%[[VAL_3]]], {{\[}}%[[VAL_2]]] {order = array} : > // CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> // CHECK: tt.return %[[VAL_5]] : tensor<128xf32> @@ -124,9 +124,11 @@ tt.func @test_addptr_splat_make_range_add(%arg0 : !tt.ptr) -> tensor<128xf3 // CHECK: %[[VAL_3:.*]] = arith.constant 0 : i32 // CHECK: %[[VAL_4:.*]] = arith.index_cast %[[VAL_1]] : i32 to index // CHECK: %[[VAL_5:.*]] = arith.index_cast %[[VAL_4]] : index to i64 -// CHECK: %[[VAL_6:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_3]]] {order = array} : > -// CHECK: %[[VAL_7:.*]] = tt.load %[[VAL_6]] : !tt.ptr> -// CHECK: tt.return %[[VAL_7]] : tensor<128xf32> +// CHECK: %[[VAL_6:.*]] = arith.trunci %[[VAL_5]] : i64 to i32 +// CHECK: %[[VAL_7:.*]] = arith.divui %[[VAL_3]], %[[VAL_6]] : i32 +// CHECK: %[[VAL_8:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_2]]], {{\[}}%[[VAL_5]]], {{\[}}%[[VAL_7]]] {order = array} : > +// CHECK: %[[VAL_9:.*]] = tt.load %[[VAL_8]] : !tt.ptr> +// CHECK: tt.return %[[VAL_9]] : tensor<128xf32> tt.func @test_addptr_splat_make_range_mul(%arg0 : !tt.ptr, %arg1: i32) -> tensor<128xf32> { %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x!tt.ptr> %1 = tt.splat %arg1 : i32 -> tensor<128xi32> @@ -171,11 +173,12 @@ tt.func @test_expand_dims(%arg0 : !tt.ptr) -> tensor<1x128xf32> { // CHECK-LABEL: tt.func @test_const_splat_addptr_2d( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 512 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 512 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_3]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> +// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> tt.func @test_const_splat_addptr_2d(%arg0 : !tt.ptr) -> tensor<2x128xf32> { %cst = arith.constant dense<512> : tensor<2x128xi32> %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> @@ -186,11 +189,12 @@ tt.func @test_const_splat_addptr_2d(%arg0 : !tt.ptr) -> tensor<2x128xf32> { // CHECK-LABEL: tt.func @test_addptr_broadcast( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> +// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> tt.func @test_addptr_broadcast(%arg0 : !tt.ptr) -> tensor<2x128xf32> { %cst = arith.constant dense<1> : tensor<1x128xi32> %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> @@ -202,11 +206,12 @@ tt.func @test_addptr_broadcast(%arg0 : !tt.ptr) -> tensor<2x128xf32> { // CHECK-LABEL: tt.func @test_addptr_broadcast_rank( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<2x128xf32> { -// CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<2x128xf32> +// CHECK-DAG: %[[VAL_1:.*]] = arith.constant 0 : i64 +// CHECK-DAG: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK-DAG: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> +// CHECK: tt.return %[[VAL_5]] : tensor<2x128xf32> tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr) -> tensor<2x128xf32> { %cst = arith.constant dense<1> : tensor<1x128xi32> %0 = tt.splat %arg0 : !tt.ptr -> tensor<2x128x!tt.ptr> @@ -219,10 +224,11 @@ tt.func @test_addptr_broadcast_rank(%arg0 : !tt.ptr) -> tensor<2x128xf32> { // CHECK-LABEL: tt.func @test_addptr_broadcast_rank_2( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128x2x128xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> +// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32> tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr) -> tensor<128x2x128xf32> { %cst = arith.constant dense<1> : tensor<128x1x128xi32> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x2x128x!tt.ptr> @@ -235,10 +241,11 @@ tt.func @test_addptr_broadcast_rank_2(%arg0 : !tt.ptr) -> tensor<128x2x128x // CHECK-LABEL: tt.func @test_addptr_broadcast_rank_3( // CHECK-SAME: %[[VAL_0:.*]]: !tt.ptr) -> tensor<128x2x128xf32> { // CHECK: %[[VAL_1:.*]] = arith.constant 0 : i64 -// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32 -// CHECK: %[[VAL_3:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_2]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > -// CHECK: %[[VAL_4:.*]] = tt.load %[[VAL_3]] : !tt.ptr> -// CHECK: tt.return %[[VAL_4]] : tensor<128x2x128xf32> +// CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32 +// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32 +// CHECK: %[[VAL_4:.*]] = tt.make_tensor_ptr %[[VAL_0]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_1]], %[[VAL_1]], %[[VAL_1]]], {{\[}}%[[VAL_3]], %[[VAL_2]], %[[VAL_2]]] {order = array} : > +// CHECK: %[[VAL_5:.*]] = tt.load %[[VAL_4]] : !tt.ptr> +// CHECK: tt.return %[[VAL_5]] : tensor<128x2x128xf32> tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128xf32> { %cst = arith.constant dense<1> : tensor<128x1x1xi32> %0 = tt.splat %arg0 : !tt.ptr -> tensor<128x2x128x!tt.ptr> @@ -263,14 +270,25 @@ tt.func @test_addptr_broadcast_rank_3(%arg0 : !tt.ptr) -> tensor<128x2x128x // CHECK: [[VAR_7_:%.+]] = arith.muli [[PARAM_4_]], [[CST_6_i32]] : i32 // CHECK: [[VAR_8_:%.+]] = arith.index_cast [[VAR_4_]] : index to i64 // CHECK: [[VAR_9_:%.+]] = arith.muli [[VAR_8_]], [[VAR_6_]] : i64 -// CHECK: [[VAR_10:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_3_]], [[VAR_7_]]] {order = array} : > -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[VAR_11_]] : index to i64 -// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[VAR_13_]] : index to i64 -// CHECK: [[VAR_15:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_12_]], [[VAR_14_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > -// CHECK: [[VAR_16:%.+]] = tt.load [[VAR_10]] : !tt.ptr> -// CHECK: tt.store [[VAR_15]], [[VAR_16]] : !tt.ptr> +// CHECK: [[VAR_10_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 +// CHECK: [[VAR_11_:%.+]] = arith.divui [[VAR_3_]], [[VAR_10_]] : i32 +// CHECK: [[VAR_12_:%.+]] = arith.trunci [[VAR_6_]] : i64 to i32 +// CHECK: [[VAR_13_:%.+]] = arith.divui [[VAR_7_]], [[VAR_12_]] : i32 +// CHECK: [[VAR_14_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[VAR_9_]]], {{\[}}[[VAR_2_]], [[VAR_6_]]], {{\[}}[[VAR_11_]], [[VAR_13_]]] {order = array} : > +// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[VAR_15_]] : index to i64 +// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK: [[VAR_18_:%.+]] = arith.index_cast [[VAR_17_]] : index to i64 + + +// CHECK: [[VAR_19_:%.+]] = arith.trunci [[VAR_16_]] : i64 to i32 +// CHECK: [[VAR_20_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_19_]] : i32 +// CHECK: [[VAR_21_:%.+]] = arith.trunci [[VAR_18_]] : i64 to i32 +// CHECK: [[VAR_22_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_21_]] : i32 + +// CHECK: [[VAR_23:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_16_]], [[VAR_18_]]], {{\[}}[[VAR_20_]], [[VAR_22_]]] {order = array} : > +// CHECK: [[VAR_24:%.+]] = tt.load [[VAR_14_]] {boundaryCheck = array} : !tt.ptr> +// CHECK: tt.store [[VAR_23]], [[VAR_24]] : !tt.ptr> // CHECK: tt.return module { tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { @@ -322,14 +340,14 @@ tt.func public @wrap_side_by_side_masked(%arg0: !tt.ptr, %arg1: !tt.ptr, [[PARAM_1_:%.+]]: !tt.ptr, [[PARAM_2_:%.+]]: !tt.ptr, [[PARAM_3_:%.+]]: i32, [[PARAM_4_:%.+]]: i32) { -// CHECK: [[CST_3_i32:%.+]] = arith.constant 3 : i32 -// CHECK: [[CST_3_:%.+]] = arith.constant 3 : index -// CHECK: [[CST_12_:%.+]] = arith.constant 12 : index -// CHECK: [[CST_0_:%.+]] = arith.constant 0 : index -// CHECK: [[CST_1_i64:%.+]] = arith.constant 1 : i64 -// CHECK: [[CST_0_i64:%.+]] = arith.constant 0 : i64 -// CHECK: [[CST_0_i32:%.+]] = arith.constant 0 : i32 -// CHECK: [[CST_5_i64:%.+]] = arith.constant 5 : i64 +// CHECK-DAG: [[CST_3_i32:%.+]] = arith.constant 3 : i32 +// CHECK-DAG: [[CST_3_:%.+]] = arith.constant 3 : index +// CHECK-DAG: [[CST_12_:%.+]] = arith.constant 12 : index +// CHECK-DAG: [[CST_0_:%.+]] = arith.constant 0 : index +// CHECK-DAG: [[CST_1_i64:%.+]] = arith.constant 1 : i64 +// CHECK-DAG: [[CST_0_i64:%.+]] = arith.constant 0 : i64 +// CHECK-DAG: [[CST_0_i32:%.+]] = arith.constant 0 : i32 +// CHECK-DAG: [[CST_5_i64:%.+]] = arith.constant 5 : i64 // CHECK: [[VAR_1_:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[CST_1_i64]], [[CST_5_i64]]], {{\[}}[[PARAM_3_]], [[CST_0_i32]]] {order = array} : > // CHECK: [[VAR_2_:%.+]] = tt.load [[VAR_1_]] : !tt.ptr> // CHECK: [[VAR_4_:%.+]]:2 = scf.for [[VAR_arg5_:%.+]] = [[CST_0_]] to [[CST_12_]] step [[CST_3_]] iter_args([[VAR_arg6_:%.+]] = [[VAR_2_]], [[VAR_arg7_:%.+]] = [[PARAM_3_]]) -> (tensor<4x256xbf16>, i32) { @@ -418,14 +436,22 @@ module { // CHECK: [[VAR_6_:%.+]] = arith.index_cast [[PARAM_4_]] : i32 to index // CHECK: [[VAR_7_:%.+]] = arith.index_cast [[VAR_6_]] : index to i64 // CHECK: [[VAR_8_:%.+]] = arith.muli [[PARAM_4_]], [[CST_3_i32]] : i32 -// CHECK: [[VAR_9:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_3_]], [[VAR_8_]]] {order = array} : > -// CHECK: [[VAR_10_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index -// CHECK: [[VAR_11_:%.+]] = arith.index_cast [[VAR_10_]] : index to i64 -// CHECK: [[VAR_12_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index -// CHECK: [[VAR_13_:%.+]] = arith.index_cast [[VAR_12_]] : index to i64 -// CHECK: [[VAR_14:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_11_]], [[VAR_13_]]], {{\[}}[[CST_0_i32]], [[CST_0_i32]]] {order = array} : > -// CHECK: [[VAR_15:%.+]] = tt.load [[VAR_9]] : !tt.ptr> -// CHECK: tt.store [[VAR_14]], [[VAR_15]] : !tt.ptr> +// CHECK: [[VAR_9_:%.+]] = arith.trunci [[VAR_2_]] : i64 to i32 +// CHECK: [[VAR_10_:%.+]] = arith.divui [[VAR_3_]], [[VAR_9_]] : i32 +// CHECK: [[VAR_11_:%.+]] = arith.trunci [[VAR_7_]] : i64 to i32 +// CHECK: [[VAR_12_:%.+]] = arith.divui [[VAR_8_]], [[VAR_11_]] : i32 +// CHECK: [[VAR_13:%.+]] = tt.make_tensor_ptr [[PARAM_0_]], {{\[}}[[VAR_5_]], [[CST_0_i64]]], {{\[}}[[VAR_2_]], [[VAR_7_]]], {{\[}}[[VAR_10_]], [[VAR_12_]]] {order = array} : > +// CHECK: [[VAR_14_:%.+]] = arith.index_cast [[PARAM_5_]] : i32 to index +// CHECK: [[VAR_15_:%.+]] = arith.index_cast [[VAR_14_]] : index to i64 +// CHECK: [[VAR_16_:%.+]] = arith.index_cast [[PARAM_6_]] : i32 to index +// CHECK: [[VAR_17_:%.+]] = arith.index_cast [[VAR_16_]] : index to i64 +// CHECK: [[VAR_18_:%.+]] = arith.trunci [[VAR_15_]] : i64 to i32 +// CHECK: [[VAR_19_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_18_]] : i32 +// CHECK: [[VAR_20_:%.+]] = arith.trunci [[VAR_17_]] : i64 to i32 +// CHECK: [[VAR_21_:%.+]] = arith.divui [[CST_0_i32]], [[VAR_20_]] : i32 +// CHECK: [[VAR_22:%.+]] = tt.make_tensor_ptr [[PARAM_1_]], {{\[}}[[CST_0_i64]], [[CST_0_i64]]], {{\[}}[[VAR_15_]], [[VAR_17_]]], {{\[}}[[VAR_19_]], [[VAR_21_]]] {order = array} : > +// CHECK: [[VAR_23:%.+]] = tt.load [[VAR_13]] {boundaryCheck = array} : !tt.ptr> +// CHECK: tt.store [[VAR_22]], [[VAR_23]] : !tt.ptr> // CHECK: tt.return module { tt.func public @wrap_stacked_masked_loop(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: i32, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32) { diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h index 7bf40dd4de..0bd4e1a0ca 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h @@ -52,6 +52,10 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc, ConversionPatternRewriter &rewriter, LLVM::LLVMFuncOp func, ValueRange args); +// Return true if the `val` value is a constant containing a value equal to +// expected. +bool isConstant(Value val, const unsigned expected); + } // namespace mlir::triton::gpu::intel #endif // TRITON_DIALECT_TRITONINTELGPU_TRANSFORMS_UTILITY_H diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp index 478e15af95..61b733ad10 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/MaterializeBlockPointer.cpp @@ -16,19 +16,6 @@ namespace mlir::triton::gpu::intel { namespace { -bool isConstant(Value v, unsigned expected) { - if (v.getDefiningOp() == nullptr) - return false; - - if (auto stride = dyn_cast(v.getDefiningOp())) { - if (auto strideInt = dyn_cast(stride.getValue())) - if (strideInt.getInt() == expected) - return true; - } - - return false; -} - struct TritonIntelGPUMaterializeBlockPointerPass : public triton::gpu::intel::impl:: TritonIntelGPUMaterializeBlockPointerBase< @@ -67,7 +54,7 @@ struct TritonIntelGPUMaterializeBlockPointerPass // HW 2D block read instruction only supports contiguous access. Value fastChangeStride = strides[fastChangeDim]; - if (!isConstant(fastChangeStride, 1)) + if (!mlir::triton::gpu::intel::isConstant(fastChangeStride, 1)) return; // Across Intel platforms, the strictest pitch restriction is to be a diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp index 983fec2659..c397312aac 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/Utility.cpp @@ -235,4 +235,50 @@ LLVM::CallOp createSPIRVBuiltinCall(Location loc, return call; } +static std::optional getIntAttr(const OpFoldResult ofr) { + if (ofr.is() && isa(ofr.get())) + return cast(ofr.get()).getInt(); + return std::nullopt; +} + +// This function folds the `op` operation and returns the constant value if it +// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. +static std::optional getFoldedConstantValue(Operation *op) { + SmallVector results; + if (failed(op->fold(results))) { + return std::nullopt; + } + + // If fold succeeded but `results` is empty, we give a second try, after the + // operands have been switched during the first call to `fold()`. + if (results.empty()) { + if (failed(op->fold(results))) { + return std::nullopt; + } + } + + if (results.size() != 1) { + return std::nullopt; + } + + auto intAttr = getIntAttr(results[0]); + if (intAttr.has_value()) { + return intAttr.value(); + } + + auto val = cast(results[0]); + auto constOp = val.getDefiningOp(); + if (!constOp) + return std::nullopt; + + return getIntAttr(constOp.getValue()); +} + +bool isConstant(Value val, const unsigned expected) { + auto defOp = val.getDefiningOp(); + if (!defOp) + return false; + return (getFoldedConstantValue(defOp) == expected); +} + } // namespace mlir::triton::gpu::intel diff --git a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp index 27eff464c7..62065a8308 100644 --- a/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp +++ b/third_party/intel/lib/TritonRaiseBlockPointer/TritonRaiseBlockPointer.cpp @@ -7,6 +7,7 @@ #include "intel/include/TritonRaiseBlockPointer/Passes.h" +#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" #include "mlir/IR/Matchers.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" @@ -25,51 +26,6 @@ namespace { constexpr unsigned offsetBitwidth = 32; constexpr unsigned shapeAndStridesBitwidth = 64; -std::optional getIntAttr(const OpFoldResult ofr) { - if (ofr.is() && isa(ofr.get())) - return cast(ofr.get()).getInt(); - return std::nullopt; -} - -// This function folds the `op` operation and returns the constant value if it -// has successfully folded to a constant. Otherwise, it returns `std::nullopt`. -std::optional getFoldedConstantValue(Operation *op) { - SmallVector results; - if (failed(op->fold(results))) { - return std::nullopt; - } - - // If fold succeeded but `results` is empty, we give a second try, after the - // operands have been switched during the first call to `fold()`. - if (results.empty()) { - if (failed(op->fold(results))) { - return std::nullopt; - } - } - - if (results.size() != 1) { - return std::nullopt; - } - - auto intAttr = getIntAttr(results[0]); - if (intAttr.has_value()) { - return intAttr.value(); - } - - auto val = cast(results[0]); - auto constOp = val.getDefiningOp(); - if (!constOp) - return std::nullopt; - - return getIntAttr(constOp.getValue()); -} - -// return true if the `val` value is a constant containing a value equal to zero -bool hasConstZero(Value val) { - auto intVal = getFoldedConstantValue(val.getDefiningOp()); - return (intVal.has_value() && (intVal.value() == 0)); -} - // Data structure used to decode pointer arithmetics. Offsets, sizes, and // strides are in unit of elements in a linearly laid-out memory, which is the // same as pointer arithmetic operations in Triton language. Scalar is a @@ -115,7 +71,7 @@ struct PtrState { // When PtrState describes a non-block pointer, shape field indicates how // address wraps around. As a result, a constant 0 indicates no wrap around // (i.e. modulo) for the dimension. - return !hasConstZero(shape[dim]); + return !mlir::triton::gpu::intel::isConstant(shape[dim], 0); } // @return true if addresses wrap around in any of the pointer dimension. @@ -144,6 +100,14 @@ struct PtrState { source = lhsState.source ? lhsState.source : rhsState.source; + if (lhsState.scalar && rhsState.scalar) { // both lhs and rhs are scalars + auto addOp = + builder.create(loc, lhsState.scalar, rhsState.scalar); + scalar = addOp.getResult(); + } else if (lhsState.getRank() == 0) { + scalar = lhsState.scalar ? lhsState.scalar : rhsState.scalar; + } + ArithBuilder abuilder(builder, loc); for (uint64_t i = 0; i < lhsState.getRank(); ++i) { Value newOffset = abuilder.add(lhsState.offsets[i], rhsState.offsets[i]); @@ -192,7 +156,7 @@ struct PtrState { for (uint64_t i = 0; i < lhs->getRank(); i++) { if (!lhs->dimHasModulo(i)) { shape.push_back(lhs->shape[i]); - } else if (hasConstZero(rhs->offsets[i])) { + } else if (mlir::triton::gpu::intel::isConstant(rhs->offsets[i], 0)) { shape.push_back(lhs->shape[i]); } else { op->emitRemark("TritonRaiseBlockPointer: do not support adding to " @@ -264,11 +228,22 @@ struct PtrState { SmallVector newOffsets; SmallVector newStrides; SmallVector newShape; + ArithBuilder abuilder(builder, loc); for (const auto &[offset, stride, dim] : llvm::zip(offsets, strides, shape)) { - newOffsets.push_back(getValueOrCreateCastToIndexLike( - builder, loc, builder.getI32Type(), offset)); + if (mlir::triton::gpu::intel::isConstant(stride, 0)) { + newOffsets.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getI32Type(), offset)); + } else { + auto divOffset = builder.create( + loc, builder.getI32Type(), + getValueOrCreateCastToIndexLike(builder, loc, builder.getI32Type(), + offset), + getValueOrCreateCastToIndexLike(builder, loc, builder.getI32Type(), + stride)); + newOffsets.push_back(divOffset); + } newStrides.push_back(getValueOrCreateCastToIndexLike( builder, loc, builder.getI64Type(), stride)); newShape.push_back(getValueOrCreateCastToIndexLike( @@ -524,8 +499,10 @@ struct TritonRaiseBlockPointer if (state.getRank() != 0) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&newOp.getRegion().front()); - auto maketptrOp = state.createTTMakeTensorPtrOp(builder, op.getLoc()); - ptrMap.map(key, maketptrOp.getResult()); + triton::MakeTensorPtrOp makePtrOp = + state.createTTMakeTensorPtrOp(builder, op.getLoc()); + ptrMap.map(key, makePtrOp.getResult()); + knownPtrs[makePtrOp.getResult()] = std::move(state); } } @@ -693,8 +670,10 @@ struct TritonRaiseBlockPointer Value result = op.getResult(); Value mapped = result; if (isa(result.getType())) { - Value maketptrOp = state.createTTMakeTensorPtrOp(builder, loc); - mapped = maketptrOp; + triton::MakeTensorPtrOp makePtrOp = + state.createTTMakeTensorPtrOp(builder, loc); + knownPtrs[makePtrOp.getResult()] = std::move(state); + mapped = makePtrOp.getResult(); } ptrMap.map(result, mapped); @@ -711,6 +690,13 @@ struct TritonRaiseBlockPointer OpBuilder &builder, bool addedByPass = false) { assert(state.isEmpty() && "state is a return argument"); + + if (auto iter = knownPtrs.find(makeTPtrOp.getResult()); + iter != knownPtrs.end()) { + state = iter->second; + return success(); + } + state.source = makeTPtrOp.getBase(); auto resType = cast(makeTPtrOp.getResult().getType()); @@ -719,9 +705,18 @@ struct TritonRaiseBlockPointer for (int64_t i = 0; i < pointeeType.getRank(); i++) { state.sizes.push_back(shape[i]); + + auto strideCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getStrides()[i]); + auto offsetCst = builder.create( + loc, builder.getIndexType(), makeTPtrOp.getOffsets()[i]); + auto scaledOffset = builder.create( + loc, offsetCst.getResult(), strideCst.getResult()); + state.offsets.push_back(getValueOrCreateCastToIndexLike( + builder, loc, builder.getIntegerType(offsetBitwidth), + scaledOffset.getResult())); } state.strides = makeTPtrOp.getStrides(); - state.offsets = makeTPtrOp.getOffsets(); state.shape = makeTPtrOp.getShape(); state.order = SmallVector(makeTPtrOp.getOrder()); @@ -858,11 +853,21 @@ struct TritonRaiseBlockPointer return success(); } + SmallVector boundary; + if (auto iter = knownPtrs.find(ptr); iter != knownPtrs.end()) { + auto state = iter->second; + for (int axis = 0; axis < state.shape.size(); ++axis) { + if (!mlir::triton::gpu::intel::isConstant(state.shape[axis], 0)) + boundary.push_back(axis); + } + } + ArrayRef newBoundaryCheck(boundary); + OpBuilder builder(op); if constexpr (isLoad) { auto loadOp = builder.create( - op.getLoc(), ptr, op.getBoundaryCheck(), op.getPadding(), - op.getCache(), op.getEvict(), op.getIsVolatile()); + op.getLoc(), ptr, newBoundaryCheck, op.getPadding(), op.getCache(), + op.getEvict(), op.getIsVolatile()); LLVM_DEBUG(llvm::dbgs() << "creating tt.load: " << loadOp << "\n";); @@ -1094,13 +1099,19 @@ LogicalResult TritonRaiseBlockPointer::visitAddPointerOperand( auto resultType = cast(op.getResult().getType()); Value offset = convertScalarToDtype(builder, loc, state.scalar, offsetType, /*isUnsignedCast=*/true); + state.offsets.push_back(offset); + state.offsets.insert( + state.offsets.end(), resultType.getShape().size() - 1, + builder.create(loc, 0, offsetBitwidth)); + state.strides.insert( + state.strides.end(), resultType.getShape().size(), + builder.create(loc, 0, shapeAndStridesBitwidth)); + state.shape.insert( + state.shape.end(), resultType.getShape().size(), + builder.create(loc, 0, shapeAndStridesBitwidth)); + for (int32_t dim : resultType.getShape()) { - state.offsets.push_back(offset); state.sizes.push_back(dim); - state.strides.push_back( - builder.create(loc, 0, shapeAndStridesBitwidth)); - state.shape.push_back( - builder.create(loc, 0, shapeAndStridesBitwidth)); } return success();