From 0fb668c22f3abf9336c2c6b0be01f0b2af2e0311 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Wed, 3 Jul 2024 10:25:32 -0700 Subject: [PATCH] Add support for rank-reduced vector.transfer_read/write (#645) --- mlir/lib/Util/Util.cpp | 16 +++++-- .../segment_loop_fusion.mlir | 43 +++++++++++++++++++ 2 files changed, 55 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Util/Util.cpp b/mlir/lib/Util/Util.cpp index 7548a6051..2278b277d 100644 --- a/mlir/lib/Util/Util.cpp +++ b/mlir/lib/Util/Util.cpp @@ -1282,9 +1282,13 @@ air::writeAccessPattern(mlir::vector::TransferReadOp readOp) { std::get<0>(pattern), std::get<1>(pattern), std::get<2>(pattern)); // Update wraps based on vector shape and vector access patterns. - for (unsigned i = 0; i < std::get<1>(pattern).size(); i++) + unsigned rankOffset = + vectorTy.getShape().size() >= std::get<1>(pattern).size() + ? 0 + : std::get<1>(pattern).size() - vectorTy.getShape().size(); + for (unsigned i = rankOffset; i < std::get<1>(pattern).size(); i++) std::get<1>(pattern)[i] = builder.create( - builder.getUnknownLoc(), vectorTy.getShape()[i]); + builder.getUnknownLoc(), vectorTy.getShape()[i - rankOffset]); updateAccessPatternByScfForNest(pattern, readOp.getIndices(), builder); return pattern; } @@ -1303,9 +1307,13 @@ air::writeAccessPattern(mlir::vector::TransferWriteOp writeOp) { std::get<0>(pattern), std::get<1>(pattern), std::get<2>(pattern)); // Update wraps based on vector shape and vector access patterns. - for (unsigned i = 0; i < std::get<1>(pattern).size(); i++) + unsigned rankOffset = + vectorTy.getShape().size() >= std::get<1>(pattern).size() + ? 0 + : std::get<1>(pattern).size() - vectorTy.getShape().size(); + for (unsigned i = rankOffset; i < std::get<1>(pattern).size(); i++) std::get<1>(pattern)[i] = builder.create( - builder.getUnknownLoc(), vectorTy.getShape()[i]); + builder.getUnknownLoc(), vectorTy.getShape()[i - rankOffset]); updateAccessPatternByScfForNest(pattern, writeOp.getIndices(), builder); return pattern; } diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir index 3d5f0467d..f203f5876 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/segment_loop_fusion.mlir @@ -581,3 +581,46 @@ func.func @func6(%arg0: memref<512x512xi32>, %arg1: memref<512x512xi32>, %arg2: } return } + +// Rank-reduced vector transferRead and transferWrite. + +// CHECK-LABEL: func.func @func7 +// CHECK-DAG: %[[CST0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[CSTBF16:.*]] = arith.constant 0.000000e+00 : bf16 +// CHECK: vector.transfer_read %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]], %[[CSTBF16]] {in_bounds = [true, true, true]} : memref<1x1x4x8xbf16, 2 : i32>, vector<1x4x8xbf16> +// CHECK: vector.transfer_write %{{.*}}, %{{.*}}[%[[CST0]], %[[CST0]], %[[CST0]], %[[CST0]]] {in_bounds = [true, true, true]} : vector<1x4x8xbf16>, memref<1x1x4x8xbf16, 2 : i32> + +func.func @func7() { + %0 = air.launch async () in () { + %1 = air.segment @segment_0 async { + %c4 = arith.constant 4 : index + %c2 = arith.constant 2 : index + %2 = air.herd @herd_0 async tile (%arg0, %arg1) in (%arg2=%c2, %arg3=%c4) { + %cst = arith.constant 0.000000e+00 : bf16 + %c0 = arith.constant 0 : index + %async_token, %results = air.execute -> (memref<1x1x4x8xbf16, 2 : i32>) { + %alloc = memref.alloc() : memref<1x1x4x8xbf16, 2 : i32> + air.execute_terminator %alloc : memref<1x1x4x8xbf16, 2 : i32> + } + %async_token_0, %results_1 = air.execute -> (memref<1x1x4x8xbf16, 2 : i32>) { + %alloc = memref.alloc() : memref<1x1x4x8xbf16, 2 : i32> + air.execute_terminator %alloc : memref<1x1x4x8xbf16, 2 : i32> + } + %async_token_2, %results_3 = air.execute [%async_token_0] -> (vector<1x4x8xbf16>) { + %3 = vector.transfer_read %results_1[%c0, %c0, %c0, %c0], %cst {in_bounds = [true, true, true]} : memref<1x1x4x8xbf16, 2 : i32>, vector<1x4x8xbf16> + air.execute_terminator %3 : vector<1x4x8xbf16> + } + %async_token_4 = air.execute [%async_token_2] { + vector.transfer_write %results_3, %results[%c0, %c0, %c0, %c0] {in_bounds = [true, true, true]} : vector<1x4x8xbf16>, memref<1x1x4x8xbf16, 2 : i32> + } + %async_token_5 = air.execute [%async_token_2] { + memref.dealloc %results_1 : memref<1x1x4x8xbf16, 2 : i32> + } + %async_token_6 = air.execute [%async_token_4] { + memref.dealloc %results : memref<1x1x4x8xbf16, 2 : i32> + } + } + } + } + return +}