From c58cfe7f26a4bef67c97876271fcb8fd5dcc0783 Mon Sep 17 00:00:00 2001 From: "Lu, Chengjun" Date: Tue, 10 Sep 2024 15:28:22 +0000 Subject: [PATCH 1/3] Disable RewriteTensorPointer pass. --- third_party/intel/backend/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 9c1d25b2c..7cba3e3c1 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -216,7 +216,7 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) + # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) passes.ttgpuir.add_coalesce(pm) From 48494041da3be88468e7b584fcf63ca0f3d2294f Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 26 Sep 2024 14:29:21 +0000 Subject: [PATCH 2/3] Remove RewriteTensorPointer pass Signed-off-by: Whitney Tsang --- .../rewrite-tensor-pointer.mlir | 276 ------- third_party/intel/backend/compiler.py | 1 - .../TritonIntelGPU/Transforms/Passes.td | 13 - .../TritonIntelGPUTransforms/CMakeLists.txt | 1 - .../RewriteTensorPointer.cpp | 751 ------------------ third_party/intel/triton_xpu.cc | 2 - 6 files changed, 1044 deletions(-) delete mode 100644 test/TritonIntelGPU/rewrite-tensor-pointer.mlir delete mode 100644 third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp diff --git a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir b/test/TritonIntelGPU/rewrite-tensor-pointer.mlir deleted file mode 100644 index fd3551090..000000000 --- a/test/TritonIntelGPU/rewrite-tensor-pointer.mlir +++ /dev/null @@ -1,276 +0,0 @@ -// RUN: triton-opt %s -split-input-file -tritonintelgpu-rewrite-tensor-pointer | FileCheck %s - -// COM: Case 0: -// COM: Check that operations using block pointers satisfying the following conditions are not rewritten: -// COM: - the block pointer has the "dot" layout attribute (with dpas parent layout) or has a dpas layout (for store op) -// COM: - the block pointers is advanced in row major order: strides[1] == 1 -// COM: - the block pointer pitch is divisible by QW: strides[0] % (64 / elemTypeBitWidth) == 0 -// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> -#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { - tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, %arg8: i32 {tt.divisibility = 16 : i32}) { - // CHECK: @matmul_kernel_with_block_pointers - %c4_i32 = arith.constant 4 : i32 - %c256_i32 = arith.constant 256 : i32 - %c1_i64 = arith.constant 1 : i64 - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c255_i32 = arith.constant 255 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c255_i32 : i32 - %2 = arith.divsi %1, %c256_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.muli %4, %c4_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c4_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.minsi %8, %c4_i32 : i32 - %10 = arith.remsi %0, %9 : i32 - %11 = arith.addi %7, %10 : i32 - %12 = arith.remsi %0, %5 : i32 - %13 = arith.divsi %12, %9 : i32 - %14 = arith.muli %11, %c256_i32 : i32 - %15 = arith.extsi %arg3 : i32 to i64 - %16 = arith.extsi %arg5 : i32 to i64 - %17 = arith.extsi %arg6 : i32 to i64 - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : > - %19 = arith.muli %13, %c256_i32 : i32 - %20 = arith.extsi %arg4 : i32 to i64 - %21 = arith.extsi %arg7 : i32 to i64 - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > - %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> - // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> - // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> - // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> - %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas> - %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : > - %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : > - scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr> - } - %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> - %26 = arith.extsi %arg8 : i32 to i64 - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > - %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > - // CHECK: tt.store {{.*}}, {{.*}} {boundaryCheck = array} : !tt.ptr> - tt.store %27, %24 {boundaryCheck = array} : !tt.ptr> - tt.return - } -} - -// ----- - -// COM: Case 1: -// COM: Check that operations using block pointers satisfying the following conditions are not rewritten: -// COM: - the block pointer has the "dot" layout attribute (with dpas parent layout) -// COM: - the block pointers is advanced in row major order: strides[order[0]] == 1 -// COM: - the block pointer pitch is divisible by QW: strides[order[1]] % (64 / elemTypeBitWidth) == 0 -// COM: Check that store operations using block pointers with non Dpas layout is rewritten -// CHECK: #[[BLOCKED:.+]] = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> -// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> -#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> -#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { - tt.func public @matmul_kernel_with_block_pointers(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, - %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, - %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32 {tt.divisibility = 16 : i32}, %arg7: i32 {tt.divisibility = 16 : i32}, - %arg8: i32 {tt.divisibility = 16 : i32}) { - // CHECK: @matmul_kernel_with_block_pointers - %c4_i32 = arith.constant 4 : i32 - %c256_i32 = arith.constant 256 : i32 - %c1_i64 = arith.constant 1 : i64 - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c255_i32 = arith.constant 255 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c255_i32 : i32 - %2 = arith.divsi %1, %c256_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.muli %4, %c4_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c4_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.minsi %8, %c4_i32 : i32 - %10 = arith.remsi %0, %9 : i32 - %11 = arith.addi %7, %10 : i32 - %12 = arith.remsi %0, %5 : i32 - %13 = arith.divsi %12, %9 : i32 - %14 = arith.muli %11, %c256_i32 : i32 - %15 = arith.extsi %arg3 : i32 to i64 - %16 = arith.extsi %arg5 : i32 to i64 - %17 = arith.extsi %arg6 : i32 to i64 - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : > - %19 = arith.muli %13, %c256_i32 : i32 - %20 = arith.extsi %arg4 : i32 to i64 - %21 = arith.extsi %arg7 : i32 to i64 - // CHECK: tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : >> - %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > - %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - // CHECK: tt.load {{.*}} {boundaryCheck = array} : !tt.ptr>> - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> - // CHECK: tt.dot {{.*}}, {{.*}}, {{.*}}, inputPrecision = tf32 : tensor<256x32xf16, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> * tensor<32x256xf16, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> -> tensor<256x256xf32, #[[DPAS]]> - // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> - // CHECK: tt.advance {{.*}}, {{\[}}{{.*}}, {{.*}}] : >> - %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas> - %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : > - %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : > - scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr> - } - %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> - %25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked> - %26 = arith.extsi %arg8 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > - // CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr, #[[BLOCKED]]> - tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> - tt.return - } -} - -// ----- - -// COM: Case 2: -// COM: Check that operations using block pointers without divisibility attribute are rewritten to use a legacy pointer. -// CHECK: #[[DPAS:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#blocked = #triton_gpu.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [4, 16], order = [1, 0]}> -#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [16, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> -#dot0 = #triton_gpu.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> -#dot1 = #triton_gpu.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> -module attributes {"triton_gpu.num-warps" = 64 : i32, "triton_gpu.threads-per-warp" = 16 : i32, "triton_intel_gpu.support_sg_2d_block"} { - tt.func public @matmul_kernel_with_block_pointers_indivisible(%arg0: !tt.ptr {tt.divisibility = 16 : i32}, %arg1: !tt.ptr {tt.divisibility = 16 : i32}, %arg2: !tt.ptr {tt.divisibility = 16 : i32}, %arg3: i32 {tt.divisibility = 16 : i32}, %arg4: i32 {tt.divisibility = 16 : i32}, %arg5: i32 {tt.divisibility = 16 : i32}, %arg6: i32, %arg7: i32, %arg8: i32 {tt.divisibility = 16 : i32}) { - // CHECK: @matmul_kernel_with_block_pointers_indivisible - %c4_i32 = arith.constant 4 : i32 - %c256_i32 = arith.constant 256 : i32 - %c1_i64 = arith.constant 1 : i64 - %c0_i32 = arith.constant 0 : i32 - %c32_i32 = arith.constant 32 : i32 - %c255_i32 = arith.constant 255 : i32 - %cst = arith.constant dense<0.000000e+00> : tensor<256x256xf32, #dpas> - %0 = tt.get_program_id x : i32 - %1 = arith.addi %arg3, %c255_i32 : i32 - %2 = arith.divsi %1, %c256_i32 : i32 - %3 = arith.addi %arg4, %c255_i32 : i32 - %4 = arith.divsi %3, %c256_i32 : i32 - %5 = arith.muli %4, %c4_i32 : i32 - %6 = arith.divsi %0, %5 : i32 - %7 = arith.muli %6, %c4_i32 : i32 - %8 = arith.subi %2, %7 : i32 - %9 = arith.minsi %8, %c4_i32 : i32 - %10 = arith.remsi %0, %9 : i32 - %11 = arith.addi %7, %10 : i32 - %12 = arith.remsi %0, %5 : i32 - %13 = arith.divsi %12, %9 : i32 - %14 = arith.muli %11, %c256_i32 : i32 - %15 = arith.extsi %arg3 : i32 to i64 - %16 = arith.extsi %arg5 : i32 to i64 - %17 = arith.extsi %arg6 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %18 = tt.make_tensor_ptr %arg0, [%15, %16], [%17, %c1_i64], [%14, %c0_i32] {order = array} : > - %19 = arith.muli %13, %c256_i32 : i32 - %20 = arith.extsi %arg4 : i32 to i64 - %21 = arith.extsi %arg7 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %22 = tt.make_tensor_ptr %arg1, [%16, %20], [%21, %c1_i64], [%c0_i32, %19] {order = array} : > - %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { - // CHECK: tt.load {{.*}}, {{.*}} : tensor<256x32x!tt.ptr, #triton_gpu.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 2}>> - // CHECK: tt.load {{.*}}, {{.*}} : tensor<32x256x!tt.ptr, #triton_gpu.dot_op<{opIdx = 1, parent = #[[DPAS]], kWidth = 2}>> - %28 = tt.load %arg11 {boundaryCheck = array} : !tt.ptr> - %29 = tt.load %arg12 {boundaryCheck = array} : !tt.ptr> - %30 = tt.dot %28, %29, %arg10, inputPrecision = tf32 : tensor<256x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<256x256xf32, #dpas> - // CHECK-NOT: tt.advance - %31 = tt.advance %arg11, [%c0_i32, %c32_i32] : > - // CHECK-NOT: tt.advance - %32 = tt.advance %arg12, [%c32_i32, %c0_i32] : > - scf.yield %30, %31, %32 : tensor<256x256xf32, #dpas>, !tt.ptr>, !tt.ptr> - } - %24 = arith.truncf %23#0 : tensor<256x256xf32, #dpas> to tensor<256x256xf16, #dpas> - %25 = triton_gpu.convert_layout %24 : tensor<256x256xf16, #dpas> -> tensor<256x256xf16, #blocked> - %26 = arith.extsi %arg8 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %27 = tt.make_tensor_ptr %arg2, [%15, %20], [%26, %c1_i64], [%14, %19] {order = array} : > - // CHECK: tt.store {{.*}}, {{.*}}, {{.*}} : tensor<256x256x!tt.ptr, #[[BLOCKED]]> - tt.store %27, %25 {boundaryCheck = array} : !tt.ptr> - tt.return - } -} - -// ----- - -// COM: Case 3: -// COM: Check that operations using block pointers without a layout attribute are rewritten to use a legacy pointer. -module attributes {"triton_intel_gpu.support_sg_2d_block"} { - tt.func public @matmul_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i32, %arg4: i32, %arg5: i32, %arg6: i32, %arg7: i32) { - %c31_i32 = arith.constant 31 : i32 - %c127_i32 = arith.constant 127 : i32 - %c1 = arith.constant 1 : index - %c0 = arith.constant 0 : index - %cst = arith.constant dense<0.000000e+00> : tensor<128x32xf32> - %c0_i32 = arith.constant 0 : i32 - %c1_i64 = arith.constant 1 : i64 - %c32_i32 = arith.constant 32 : i32 - %c128_i32 = arith.constant 128 : i32 - %c8_i32 = arith.constant 8 : i32 - %0 = tt.get_program_id x : i32 - %1 = tt.get_program_id y : i32 - %2 = arith.addi %arg3, %c127_i32 : i32 - %3 = arith.divsi %2, %c128_i32 : i32 - %4 = arith.addi %arg4, %c31_i32 : i32 - %5 = arith.divsi %4, %c32_i32 : i32 - %6 = arith.muli %5, %c8_i32 : i32 - %7 = arith.divsi %0, %6 : i32 - %8 = arith.muli %7, %c8_i32 : i32 - %9 = arith.subi %3, %8 : i32 - %10 = arith.cmpi slt, %9, %c8_i32 : i32 - %11 = arith.select %10, %9, %c8_i32 : i32 - %12 = arith.remsi %0, %11 : i32 - %13 = arith.addi %8, %12 : i32 - %14 = arith.remsi %0, %6 : i32 - %15 = arith.divsi %14, %11 : i32 - %16 = arith.muli %13, %c128_i32 : i32 - %17 = arith.muli %1, %c32_i32 : i32 - %18 = arith.extsi %arg3 : i32 to i64 - %19 = arith.extsi %arg5 : i32 to i64 - %20 = arith.extsi %arg6 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %21 = tt.make_tensor_ptr %arg0, [%18, %19], [%20, %c1_i64], [%16, %17] {order = array} : !tt.ptr> - %22 = arith.muli %15, %c32_i32 : i32 - %23 = arith.extsi %arg4 : i32 to i64 - %24 = arith.extsi %arg7 : i32 to i64 - // CHECK-NOT: tt.make_tensor_ptr - %25 = tt.make_tensor_ptr %arg1, [%19, %23], [%24, %c1_i64], [%17, %22] {order = array} : !tt.ptr> - %26 = arith.addi %arg5, %c31_i32 : i32 - %27 = arith.divsi %26, %c32_i32 : i32 - %28 = arith.index_cast %27 : i32 to index - // CHECK: scf.for - %29:3 = scf.for %arg9 = %c0 to %28 step %c1 iter_args(%arg10 = %cst, %arg11 = %21, %arg12 = %25) -> (tensor<128x32xf32>, !tt.ptr>, !tt.ptr>) { - // CHECK: tt.load %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : tensor<128x32x!tt.ptr> - %55 = tt.load %arg11 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> - // CHECK: tt.load %{{[0-9]+}}, %{{[0-9]+}}, %{{[0-9]+}} : tensor<32x32x!tt.ptr> - %56 = tt.load %arg12 {boundaryCheck = array, padding = 2 : i32} : !tt.ptr> - %57 = tt.dot %55, %56, %arg10 : tensor<128x32xf16> * tensor<32x32xf16> -> tensor<128x32xf32> - // CHECK-NOT: tt.advance - %58 = tt.advance %arg11, [%c0_i32, %c32_i32] : !tt.ptr> - // CHECK-NOT: tt.advance - %59 = tt.advance %arg12, [%c32_i32, %c0_i32] : !tt.ptr> - // CHECK: scf.yield - scf.yield %57, %58, %59 : tensor<128x32xf32>, !tt.ptr>, !tt.ptr> - } - tt.return - } -} diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 7cba3e3c1..b3cbe25ad 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -216,7 +216,6 @@ def make_ttgir(mod, metadata, opt, properties): intel.passes.ttgpuir.add_accelerate_matmul(pm) intel.passes.ttgpuir.add_remove_layout_conversions(pm) intel.passes.ttgpuir.add_materialize_block_pointer(pm) - # intel.passes.ttgpuir.add_rewrite_tensor_pointer(pm) intel.passes.ttgpuir.add_pipeline(pm, opt.num_stages, False) passes.ttgpuir.add_coalesce(pm) diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td index 42e386fe2..a7985536a 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/Transforms/Passes.td @@ -158,19 +158,6 @@ def TritonIntelGPURemoveLayoutConversions : Pass<"tritonintelgpu-remove-layout-c } -def TritonIntelGPURewriteTensorPointer : Pass<"tritonintelgpu-rewrite-tensor-pointer", "mlir::ModuleOp"> { - let summary = "Rewrite load/store operations using tensor pointers that cannot be lowered to 2D Block Load/Store intrinsics"; - let description = [{ - This pass determines whether a load/store operation can be lowered to 2D - Block Load/Store intrinsic. If it cannot, it replaces the load/store - operation with a legacy pointer and removes the Triton operations that - create and advance the block pointer (that is `tt.make_tensor_tr` and - `tt.advance`). - }]; - - let dependentDialects = ["mlir::triton::TritonDialect"]; -} - def TritonIntelGPUPrefetchBlock : Pass<"tritonintelgpu-prefetch-block", "mlir::ModuleOp"> { let summary = "Prefetch a tensor block around loop"; diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt index 8c2e290ad..24a76a41e 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUTransforms/CMakeLists.txt @@ -8,7 +8,6 @@ add_triton_library(TritonIntelGPUTransforms PrefetchBlock.cpp ReduceDataDuplication.cpp RemoveLayoutConversions.cpp - RewriteTensorPointer.cpp ScheduleLoad.cpp Utility.cpp diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp deleted file mode 100644 index ecfa0f465..000000000 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RewriteTensorPointer.cpp +++ /dev/null @@ -1,751 +0,0 @@ -#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "triton/Analysis/Utility.h" -#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h" -#include "triton/Dialect/Triton/IR/Dialect.h" - -#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h" -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Utility.h" - -#include "llvm/ADT/TypeSwitch.h" -#include "llvm/Support/Debug.h" -#include - -using namespace mlir; -namespace tt = mlir::triton; -namespace ttg = mlir::triton::gpu; -namespace ttgi = mlir::triton::gpu::intel; - -namespace mlir::triton::gpu::intel { -#define GEN_PASS_DEF_TRITONINTELGPUREWRITETENSORPOINTER -#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc" -} // namespace mlir::triton::gpu::intel - -#define DEBUG_TYPE "tritonintelgpu-rewrite-tensor-pointer" - -namespace { - -/// Check if the tensor pointer should be removed. The tensor pointer should be -/// removed if: -/// - the tensor pointer does not have DotEncoding with DpasEncoding parent -/// and does not have DpasEncoding -/// - the tensor pointer pitch is not divisible by Qword bitwidth -/// - the tensor pointer is not contiguous on memory -bool shouldRemove(tt::MakeTensorPtrOp &op, bool isUsedByStoreOp) { - if (!op->getParentOfType()->hasAttr( - ttgi::TritonIntelGPUDialect::getSupportSG2DBlockAttrName())) - return true; - - auto ptrType = cast(op.getType()); - auto tensorType = cast(ptrType.getPointeeType()); - - if (!ttgi::hasDotDpasEncoding(tensorType) && - !(isUsedByStoreOp && ttgi::hasDpasEncoding(tensorType))) - return true; - - TypedValue base = op.getBase(); - Operation::operand_range shape = op.getShape(); - Operation::operand_range strides = op.getStrides(); - Operation::operand_range offsets = op.getOffsets(); - ArrayRef order = op.getOrder(); - ArrayRef tensorShape = tensorType.getShape(); - - // TODO: support column-major tensor - // HW 2D block read instruction has restriction on pitch divisibility - if (strides.size() == 2) { - auto pitch = strides[0]; - // Across Intel platforms, the strictest pitch restriction is to be a - // multiple of OWord(128 bits). - if (!ttgi::isDivisible(pitch, 128 / tensorType.getElementTypeBitWidth())) - return true; - } - - // HW 2D block read instruction only supports contiguous accessing. - auto fastChangeStride = strides[1]; - if (auto stride = fastChangeStride.getDefiningOp()) { - if (auto strideInt = dyn_cast(stride.getValue())) - return strideInt.getInt() != 1; - } - - return true; -} - -/// The `RewritedInfo` struct is used to store information about a rewritten -/// tensor pointer. It holds the base pointer, shape, strides, offsets, and -/// encoding of the tensor. This information is used later in the code to handle -/// the rewritten tensor pointer. -struct RewritedInfo { - RewritedInfo() = default; - - RewritedInfo(Value base, const SmallVector &shape, - const SmallVector &strides, - const SmallVector &offsets, - const ArrayRef &tensorShape, Attribute layout) - : base(base), shape(shape), strides(strides), offsets(offsets), - tensorShape(tensorShape), layout(layout) { - assert(shape.size() == strides.size() && shape.size() == offsets.size() && - shape.size() == tensorShape.size() && - "Expecting tensor shape, offsets and strides have the same size"); - } - - unsigned int length() const { return shape.size(); } - - Value getOffset(unsigned i) const { return offsets[i]; } - - SmallVector getOffsets() const { return offsets; } - - void setOffset(unsigned i, Value newOffset) { - offsets[i] = newOffset; - cachedOffsetWithRange.clear(); - } - - void setOffsets(const SmallVector &newOffsets) { - offsets = newOffsets; - cachedOffsetWithRange.clear(); - } - - void setEncoding(Attribute newLayout) { layout = newLayout; } - - // Creates a tensor with the values [0, tensorShape[axis]) + offsets[axis] - // broadcasted to N dimensions along axis (i.e. so that - // result[.., i, ...] = offsets[axis] + i). - Value getExpandedOffsetWithRange(OpBuilder &builder, Location loc, - unsigned i) { - if (cachedOffsetWithRange.count(i)) - return cachedOffsetWithRange.at(i); - - // Ultimately this will look like: - // - // % base = create_range ... : tensor - // %a0 = expand_dims %base : tensor - // %a1 = broadcast %a0 : tensor - // %b0 = expand_dims %a1 : tensor - // %b1 = broadcast %b1 : tensor - // ... - // - // The final result has layout this->layout. When we subtract a dim, - // that's equivalent to taking a sliced layout, so e.g. the layout of - // %a0/%a1 is a slice of %b0/%b1's layout. - size_t rank = tensorShape.size(); - MLIRContext *ctx = loc.getContext(); - - // This code is creating a vector of layout attributes for a tensor. If a - // layout is provided, it sets the layout of each axis based on the layout - // of the previous axis, starting from the last axis and moving towards the - // first. If the current axis is the one to remove, it skips it and moves to - // the previous axis. - SmallVector layouts(rank); - if (layout) { - layouts[rank - 1] = layout; - size_t axisToRemove = rank - 1; - for (int64_t k = rank - 2; k >= 0; --k) { - if (axisToRemove == i) - --axisToRemove; - layouts[k] = - ttg::SliceEncodingAttr::get(ctx, axisToRemove, layouts[k + 1]); - --axisToRemove; - } - } - - // Add range - auto indexI32RowType = RankedTensorType::get( - {tensorShape[i]}, builder.getI32Type(), layouts[0]); - auto indexRowType = RankedTensorType::get({tensorShape[i]}, - builder.getI64Type(), layouts[0]); - Value splatOffset = - builder.create(loc, indexRowType, offsets[i]); - Value range = builder.create(loc, indexI32RowType, 0, - tensorShape[i]); - Value i64Range = builder.create(loc, indexRowType, range); - - // Expand dimensions - Value expandedResult = - builder.create(loc, splatOffset, i64Range); - for (int j = 0; j < tensorShape.size(); ++j) { - if (j == i) - continue; - expandedResult = builder.create(loc, expandedResult, j); - } - - return cachedOffsetWithRange[i] = expandedResult; - } - - Value generatePtr(OpBuilder &builder, const Location &loc) { - assert(tensorShape.size() == offsets.size() && - tensorShape.size() == strides.size() && - "Expecting tensor shape, offsets and strides have the same size"); - auto indexTensorType = - RankedTensorType::get(tensorShape, builder.getI64Type(), layout); - auto ptrType = cast(base.getType()); - auto ptrTensorType = RankedTensorType::get(tensorShape, ptrType, layout); - - // Generate offsets per dimension - Value ptr = builder.create(loc, ptrTensorType, base); - for (unsigned i = 0; i < tensorShape.size(); ++i) { - auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); - - // We must splat strides into the expanded shape not a row for retaining - // the divisibility information given by strides - Value splatStride = builder.create( - loc, offsetWithRange.getType(), strides[i]); - Value offsetWithStride = - builder.create(loc, offsetWithRange, splatStride); - Value broadcasted = builder.create(loc, indexTensorType, - offsetWithStride); - - // Add to the pointer - ptr = builder.create(loc, ptrTensorType, ptr, broadcasted); - } - - return ptr; - } - - Value generateMask(OpBuilder &builder, const Location &loc, - const std::optional> &boundaryCheck) { - if (!boundaryCheck.has_value()) - return {}; - - // Generate mask per dimension - auto maskTensorType = - RankedTensorType::get(tensorShape, builder.getI1Type(), layout); - Value mask; - for (auto i : boundaryCheck.value()) { - auto offsetWithRange = getExpandedOffsetWithRange(builder, loc, i); - - // Compare with lower bound - Value lowerBound = - builder.create(loc, 0, builder.getI64Type()); - Value splatLowerBound = builder.create( - loc, offsetWithRange.getType(), lowerBound); - Value cmpLower = builder.create( - loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound); - - // Compare with upper bound - Value splatUpperBound = - builder.create(loc, offsetWithRange.getType(), shape[i]); - Value cmpUpper = builder.create( - loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound); - - // And and broadcast - Value andResult = builder.create(loc, cmpLower, cmpUpper); - Value broadcasted = - builder.create(loc, maskTensorType, andResult); - - // And up all results - if (!mask) { - mask = broadcasted; - } else { - mask = builder.create(loc, mask, broadcasted); - } - } - - return mask; - } - - Value generateOther(OpBuilder &builder, const Location &loc, - const std::optional &padding) const { - if (!padding.has_value()) - return Value(); - - // Create element attribute - auto elementType = cast(base.getType()).getPointeeType(); - auto otherTensorType = - RankedTensorType::get(tensorShape, elementType, layout); - - // Set zero padding value - TypedAttr attr = - elementType.isIntOrIndex() - ? cast(builder.getIntegerAttr(elementType, 0)) - : cast(builder.getFloatAttr(elementType, 0)); - - // Float NaN padding case - if (padding.value() == tt::PaddingOption::PAD_NAN) { - assert(!elementType.isIntOrIndex() && - "Expect element type to be non-integer type"); - auto apNaN = llvm::APFloat::getNaN( - cast(attr).getValue().getSemantics()); - attr = builder.getFloatAttr(elementType, apNaN); - } - - // Create tensor - Value constant = builder.create(loc, attr); - return builder.create(loc, otherTensorType, constant); - } - -private: - Value base; - SmallVector shape; - SmallVector strides; - SmallVector offsets; - ArrayRef tensorShape; - Attribute layout; - - // A cache to avoid generating the same offset with range - DenseMap cachedOffsetWithRange; -}; -} // namespace - -// TODO: this pass relies on assumptions of how block pointers are created and -// on pattern matches that walks the SSA links to find the base/strides. This is -// very fragile and to solve we should expose convert Ptr of tensor to a -// structure contains all values and not only offsets. -class TritonIntelGPURewriteTensorPointerPass - : public triton::gpu::intel::impl::TritonIntelGPURewriteTensorPointerBase< - TritonIntelGPURewriteTensorPointerPass> { -private: - DenseMap rewritedInfo; - DenseSet valueToRemove; - -public: - using triton::gpu::intel::impl::TritonIntelGPURewriteTensorPointerBase< - TritonIntelGPURewriteTensorPointerPass>:: - TritonIntelGPURewriteTensorPointerBase; - - static bool needRewrite(Operation *op, const DenseSet &valueToRemove) { - return llvm::any_of(op->getOperands(), [&valueToRemove](Value operand) { - return tt::isTensorPointerType(operand.getType()) && - valueToRemove.count(operand); - }); - } - - static SmallVector - generateNewOperands(const SmallVector &oldOperands, unsigned index, - const SmallVector &newValues) { - assert(index < oldOperands.size() && "Index out of range"); - SmallVector newOperands; - for (int i = 0; i < index; ++i) - newOperands.push_back(oldOperands[i]); - for (auto value : newValues) - newOperands.push_back(value); - for (auto i = index + 1; i < oldOperands.size(); ++i) - newOperands.push_back(oldOperands[i]); - return newOperands; - } - - Operation *rewriteOp(OpBuilder &builder, tt::MakeTensorPtrOp op, - std::stack &eraser) { - if (!valueToRemove.count(op.getResult())) - return nullptr; - - // Save info for later use - auto ptrType = cast(op.getType()); - auto tensorType = cast(ptrType.getPointeeType()); - - // Cast I32 offsets into I64 - SmallVector i64Offsets; - for (auto offset : op.getOffsets()) { - auto i64Offset = builder.create( - op.getLoc(), builder.getI64Type(), offset); - i64Offsets.push_back(i64Offset); - } - - // Save information - rewritedInfo[op.getResult()] = - RewritedInfo(op.getBase(), op.getShape(), op.getStrides(), i64Offsets, - tensorType.getShape(), tensorType.getEncoding()); - - // Erase the original operation - eraser.push(op); - return nullptr; - } - - Operation *rewriteOp(OpBuilder &builder, tt::AdvanceOp op, - std::stack &eraser) { - if (!valueToRemove.count(op.getResult())) - return nullptr; - - // Get info from previous results - assert(rewritedInfo.count(op.getPtr()) && - "Expecting AdvanceOp ptr in rewritedInfo"); - auto info = rewritedInfo[op.getPtr()]; - - // Calculate new offsets - assert(info.length() == op.getOffsets().size() && - "Expecting AdvanceOp ptr shape and offsets have the same size"); - SmallVector newOffsets; - for (int i = 0; i < info.length(); ++i) { - Value i64Offset = builder.create( - op.getLoc(), builder.getI64Type(), op.getOffsets()[i]); - Value newOffset = builder.create( - op.getLoc(), info.getOffset(i), i64Offset); - newOffsets.push_back(newOffset); - } - - // Save info for later use - info.setOffsets(newOffsets); - rewritedInfo[op.getResult()] = std::move(info); - - // Erase the original operation - eraser.push(op); - return nullptr; - } - - Operation *rewriteOp(OpBuilder &builder, tt::LoadOp op, - std::stack &eraser) { - if (!valueToRemove.count(op->getOperand(0))) - return nullptr; - - // Get info from previous results - auto ptr = op->getOperand(0); - assert(rewritedInfo.count(ptr) && "Expecting LoadOp ptr in rewritedInfo"); - auto info = rewritedInfo[ptr]; - - assert(!op.getMask() && !op.getOther() && - "LoadOp with tensor pointer should not have mask and other"); - std::optional> boundaryCheck = op.getBoundaryCheck(); - if (auto valueType = dyn_cast(op.getResult().getType())) - info.setEncoding(valueType.getEncoding()); - - // Generate new `ptr`, `mask` and `other` - auto newPtr = info.generatePtr(builder, op->getLoc()); - auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); - Value newOther = info.generateOther(builder, op->getLoc(), op.getPadding()); - - // Create a new operation - auto newResult = builder.create( - op.getLoc(), newPtr, newMask, newOther, op.getCache(), op.getEvict(), - op.getIsVolatile()); - op->getResult(0).replaceAllUsesWith(newResult); - - // Erase the original operation - eraser.push(op); - return nullptr; - } - - Operation *rewriteOp(OpBuilder &builder, tt::StoreOp op, - std::stack &eraser) { - if (!valueToRemove.count(op->getOperand(0))) - return nullptr; - - // Get info from previous results - auto ptr = op->getOperand(0); - assert(rewritedInfo.count(ptr) && "Expecting StoreOp ptr in rewritedInfo"); - auto info = rewritedInfo[ptr]; - - assert(!op.getMask() && "StoreOp with tensor pointer should not have mask"); - std::optional> boundaryCheck = op.getBoundaryCheck(); - if (auto valueType = dyn_cast(op.getValue().getType())) - info.setEncoding(valueType.getEncoding()); - - // Generate new `ptr`, `mask` and `other` - auto newPtr = info.generatePtr(builder, op->getLoc()); - auto newMask = info.generateMask(builder, op->getLoc(), boundaryCheck); - - // Create a new operation - builder.create(op.getLoc(), newPtr, op.getValue(), newMask, - op.getCache(), op.getEvict()); - - // Erase the original operation - eraser.push(op); - return nullptr; - } - - Operation *rewriteOp(OpBuilder &builder, scf::IfOp op, - std::stack &eraser) { - auto thenYieldOp = op.thenYield(); - assert(op.getNumResults() == thenYieldOp.getNumOperands() && - "Expecting IfOp results and its thenYieldOp operands have the same " - "number"); - SmallVector results = thenYieldOp.getOperands(); - - // get new result types - SmallVector newRetTypes; - bool needRewrite = false; - for (unsigned i = 0; i < results.size(); ++i) { - if (!tt::isTensorPointerType(results[i].getType()) || - !valueToRemove.count(results[i])) { - newRetTypes.push_back(results[i].getType()); - continue; - } - needRewrite = true; - auto makeTensorPtrOp = getMakeTensorPtrOp(results[i]); - assert(rewritedInfo.count(makeTensorPtrOp.getResult()) && - "Expecting MakeTensorPtrOp of IfOp result in rewritedInfo"); - const auto &info = rewritedInfo[makeTensorPtrOp.getResult()]; - for (unsigned j = 0; j < info.length(); ++j) { - newRetTypes.push_back(builder.getI64Type()); - } - } - if (!needRewrite) - return op; - // create and clone new IfOp - bool hasElse = !op.getElseRegion().empty(); - scf::IfOp newOp = builder.create(op.getLoc(), newRetTypes, - op.getCondition(), hasElse); - IRMapping mapping; - for (unsigned i = 0; i < op->getNumOperands(); ++i) { - mapping.map(op->getOperand(i), newOp->getOperand(i)); - } - auto rematerialize = [&](Block *block) { - for (Operation &opInIf : block->getOperations()) { - auto newOp = builder.clone(opInIf, mapping); - } - }; - builder.setInsertionPointToStart(newOp.thenBlock()); - rematerialize(op.thenBlock()); - if (hasElse) { - builder.setInsertionPointToStart(newOp.elseBlock()); - rematerialize(op.elseBlock()); - } - - // supported nested ops - for (auto &[k, v] : mapping.getValueMap()) - if (valueToRemove.find(k) != valueToRemove.end()) - valueToRemove.insert(v); - - // update rewritedInfo - unsigned oldResIdx = 0, newResIdx = 0; - while (oldResIdx < results.size()) { - if (!tt::isTensorPointerType(results[oldResIdx].getType()) || - !valueToRemove.count(results[oldResIdx])) { - oldResIdx++; - newResIdx++; - } else { - auto makeTensorPtrOp = getMakeTensorPtrOp(results[oldResIdx]); - assert(rewritedInfo.count(makeTensorPtrOp.getResult()) && - "Expecting MakeTensorPtrOp of IfOp result in rewritedInfo"); - auto info = rewritedInfo[makeTensorPtrOp.getResult()]; - for (unsigned j = 0; j < info.length(); ++j) { - info.setOffset(j, newOp->getResult(newResIdx++)); - } - rewritedInfo[op.getResult(oldResIdx)] = std::move(info); - oldResIdx++; - } - } - - eraser.push(op); - return newOp; - } - - Operation *rewriteOp(OpBuilder &builder, scf::ForOp op, - std::stack &eraser) { - // Generate new iteration operands and set rewrited information - SmallVector oldIterOperands = llvm::to_vector(op.getInitArgs()); - SmallVector newIterOperands = llvm::to_vector(op.getInitArgs()); - for (unsigned i = 0, oldI = 0, size = op.getInitArgs().size(); i < size; - ++i, ++oldI) { - if (!tt::isTensorPointerType(newIterOperands[i].getType())) - continue; - if (!valueToRemove.count(newIterOperands[i])) - continue; - - // Expand the tensor pointer into offsets - assert(rewritedInfo.count(newIterOperands[i]) && - "Expecting ForOp operands in rewritedInfo"); - const RewritedInfo &info = rewritedInfo[newIterOperands[i]]; - newIterOperands = - generateNewOperands(newIterOperands, i, info.getOffsets()); - i += info.length() - 1; - size += info.length() - 1; - } - - // Rebuild the loop type - auto newForOp = builder.create(op.getLoc(), op.getLowerBound(), - op.getUpperBound(), op.getStep(), - newIterOperands); - - // Create value mapping. Note that for tensor pointers, we use identity - // mapping. It may refer to a value in the old loop, but we will rewrite it - // later - IRMapping mapping; - for (unsigned i = 0, oldI = 0, sz = op.getInitArgs().size(); oldI < sz; - ++i, ++oldI) { - auto oldRegionIterArg = op.getRegionIterArg(oldI); - if (tt::isTensorPointerType(oldRegionIterArg.getType()) && - valueToRemove.count(oldIterOperands[oldI])) { - // Pass rewrited info inside - assert(rewritedInfo.count(oldIterOperands[oldI]) && - "Expecting ForOp operands in rewritedInfo"); - auto info = rewritedInfo[oldIterOperands[oldI]]; - mapping.map(oldRegionIterArg, oldRegionIterArg); - for (unsigned j = 0; j < info.length(); ++j) - info.setOffset(j, newForOp.getRegionIterArg(i + j)); - rewritedInfo[oldRegionIterArg] = info; - i += info.length() - 1; - } else { - mapping.map(oldRegionIterArg, newForOp.getRegionIterArg(i)); - } - } - mapping.map(op.getInductionVar(), newForOp.getInductionVar()); - - // Clone body - builder.setInsertionPointToStart(newForOp.getBody()); - for (auto &opInFor : *op.getBody()) { - auto *newOp = builder.clone(opInFor, mapping); - for (unsigned i = 0; i < opInFor.getNumResults(); ++i) { - if (valueToRemove.count(opInFor.getResult(i))) - valueToRemove.insert(newOp->getResult(i)); - mapping.map(op->getResult(i), newOp->getResult(i)); - } - } - - // supported nested scf.for ops - for (auto &[k, v] : mapping.getValueMap()) - if (valueToRemove.find(k) != valueToRemove.end()) - valueToRemove.insert(v); - - // Replace later usages - assert(op.getNumResults() == op.getInitArgs().size() && - "Expecting ForOp results and operands have the same number"); - for (unsigned i = 0, oldI = 0; oldI < op.getNumResults(); ++i, ++oldI) { - auto oldResult = op.getResult(oldI); - if (tt::isTensorPointerType(oldResult.getType()) && - valueToRemove.count(oldIterOperands[oldI])) { - // Pack new offsets into rewrited info - assert(rewritedInfo.count(oldIterOperands[oldI]) && - "Expecting ForOp operands in rewritedInfo"); - auto info = rewritedInfo[oldIterOperands[oldI]]; - for (unsigned j = 0; j < info.length(); ++j) - info.setOffset(j, newForOp.getResult(i + j)); - i += info.length() - 1; - rewritedInfo[oldResult] = std::move(info); - } else { - oldResult.replaceAllUsesWith(newForOp.getResult(i)); - } - } - - // Erase later - eraser.push(op); - return newForOp; - } - - Operation *rewriteOp(OpBuilder &builder, scf::YieldOp op, - std::stack &eraser) { - // Replace tensor pointers with offsets - SmallVector newOperands = op->getOperands(); - for (unsigned i = 0, size = op.getNumOperands(); i < size; ++i) { - if (!tt::isTensorPointerType(newOperands[i].getType())) - continue; - if (!valueToRemove.count(newOperands[i])) - continue; - - assert(rewritedInfo.count(newOperands[i]) && - "Expecting YieldOp operands in rewritedInfo"); - const RewritedInfo &info = rewritedInfo[newOperands[i]]; - newOperands = generateNewOperands(newOperands, i, info.getOffsets()); - i += info.length() - 1; - size += info.length() - 1; - } - op->setOperands(newOperands); - - // No need to erase - return nullptr; - } - - Operation *rewriteOp(Operation *op, std::stack &eraser) { - OpBuilder builder(op); - - // Rewrite `make_tensor_ptr`, `advance`, etc... - // Rewriting functions return the next operation to visit, or `nullptr` if - // there isn't one. - return TypeSwitch(op) - .Case( - [&](auto op) { return rewriteOp(builder, op, eraser); }) - .Case([&](auto op) { - return needRewrite(op, valueToRemove) ? rewriteOp(builder, op, eraser) - : op; - }) - .Default([&](Operation *op) { - StringRef opNamespace = op->getDialect()->getNamespace(); - if ((opNamespace == scf::SCFDialect::getDialectNamespace() || - opNamespace == cf::ControlFlowDialect::getDialectNamespace()) && - needRewrite(op, valueToRemove)) - llvm_unreachable( - "Currently we only support tensor pointer usages " - "inside a `scf::ForOp` or `scf::IfOp`, others such as " - "`scf::WhileOp`, `cf::BranchOp` or `cf::CondBranchOp` " - "are not supported yet"); - - return op; - }); - } - - void visitOperation(Operation *op, std::stack &eraser) { - for (auto ®ion : op->getRegions()) { - for (auto &block : region) { - // We need an extra copy because erasing operations may break the - // iterator behavior - SmallVector blockCopy; - for (auto &nestedOp : block) - blockCopy.push_back(&nestedOp); - - // Rewrite and recursively visit - for (auto &nestedOp : blockCopy) { - if (auto newOp = rewriteOp(nestedOp, eraser)) - visitOperation(newOp, eraser); - } - } - } - } - - void runOnOperation() override { - ModuleOp mod = getOperation(); - - auto usedByStoreOp = [](Value val) { - return llvm::any_of(val.getUsers(), [](Operation *user) { - return llvm::isa(user); - }); - }; - - auto markTensorPointerForRemoval = [this](Value val, - bool isUsedByStoreOp = false) { - if (tt::isTensorPointerType(val.getType())) { - tt::MakeTensorPtrOp makeTensorPtrOp = getMakeTensorPtrOp(val); - if (shouldRemove(makeTensorPtrOp, isUsedByStoreOp)) - valueToRemove.insert(val); - } - }; - - mod.walk([&](Operation *op) { - if (llvm::isa(op)) { - Value result = op->getResult(0); - markTensorPointerForRemoval(result, usedByStoreOp(result)); - } else if (llvm::isa(op)) { - markTensorPointerForRemoval(op->getOperand(0), - llvm::isa(op)); - } else if (auto forOp = dyn_cast(op)) { - for (auto arg : forOp.getInitArgs()) - markTensorPointerForRemoval(arg); - } else if (auto yieldOp = dyn_cast(op)) { - for (auto operand : yieldOp.getOperands()) - markTensorPointerForRemoval(operand); - } - }); - - LLVM_DEBUG({ - if (valueToRemove.empty()) - llvm::dbgs() << "No tensor pointer to remove\n"; - else { - llvm::dbgs() << "Values to remove: \n"; - for (auto val : valueToRemove) - llvm::dbgs() << val << "\n"; - } - }); - - // NOTES(Chenggang): we don't use `ConversionPatternRewriter`, because - // MLIR does not support one-multiple value mapping. For example, if we - // use `ConversionPatternRewriter`, we can not make a type converter, - // which converts `ptr` into multiple types `ptr<>, int64, int64, - // ...` (containing the base/offsets/strides...). What we can do is to - // convert `ptr` into a single type `Tuple, int64, int64, - // ...>`. But in this way, we also have to define `PackTuple` and - // `UnpackTuple` operations and make a canonicalization pass to optimize, - // which is much So here we recursively build the IR, to be specific, we - // have to rewrite `tt.make_tensor_ptr`, `tt.advance`, `tt.load`, - // `tt.store`, `scf.for` (tensor pointer usages may be in a loop fashion) - std::stack eraser; - visitOperation(getOperation(), eraser); - - // The operation could not be erased during visit, because they may have - // later usages, so we erase after visit - rewritedInfo.clear(); - valueToRemove.clear(); - while (!eraser.empty()) { - auto op = eraser.top(); - eraser.pop(); - op->erase(); - } - } -}; diff --git a/third_party/intel/triton_xpu.cc b/third_party/intel/triton_xpu.cc index eb7be1c08..05dcf6bca 100644 --- a/third_party/intel/triton_xpu.cc +++ b/third_party/intel/triton_xpu.cc @@ -80,8 +80,6 @@ void init_triton_intel_passes_ttgpuir(py::module &&m) { gpu::intel::createTritonIntelGPUPipeline, int, bool); ADD_PASS_WRAPPER_0("add_remove_layout_conversions", gpu::intel::createTritonIntelGPURemoveLayoutConversions); - ADD_PASS_WRAPPER_0("add_rewrite_tensor_pointer", - gpu::intel::createTritonIntelGPURewriteTensorPointer); ADD_PASS_WRAPPER_OPT_2("add_prefetch_block", gpu::intel::createTritonIntelGPUPrefetchBlock, int, bool); From 0203dbc06e498ccbb3e6ee90ccbe258c26e9efc9 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Thu, 26 Sep 2024 16:01:13 +0000 Subject: [PATCH 3/3] Fix test Signed-off-by: Whitney Tsang --- python/test/unit/language/test_core.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index 7aa7d1a8f..28f5a476b 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -4284,8 +4284,12 @@ def kernel(in_base_ptr, out_base_ptr, IN_SHAPE0: tl.constexpr, IN_SHAPE1: tl.con actual = torch.zeros(expected.shape, dtype=torch.int32, device=device) k = kernel[(1, )](input, actual, shape[0], shape[1]) - assert k.asm['ttgir'].count( - 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" + if is_xpu(): + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 0, "Expected no convert_layout op in the TTGIR after optimization" + else: + assert k.asm['ttgir'].count( + 'triton_gpu.convert_layout') == 1, "Expected exactly one convert_layout op in the TTGIR after optimization" np.testing.assert_equal(to_numpy(expected), to_numpy(actual))