From 3d1a4e19ff748897a37e0eb88b59999197dbc0f8 Mon Sep 17 00:00:00 2001 From: erwei-xilinx Date: Tue, 29 Oct 2024 15:40:20 +0800 Subject: [PATCH] AIRPingpongTransform: Tokens may not always get passed into async `scf.for` via `init_args` (#756) * Tokens used inside scf.for but declared outside should be handled the same way as init_args * Re-enable pingpong buffering for vecmat example --- .../Transform/AIRDependencyScheduleOpt.cpp | 25 ++++++--- .../annotate_front_back.mlir | 52 +++++++++++++++++++ test/xrt/26_vecmat_i8/aie.py | 8 +-- 3 files changed, 74 insertions(+), 11 deletions(-) diff --git a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp index 4e84788b4..e2604d8f6 100644 --- a/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp +++ b/mlir/lib/Transform/AIRDependencyScheduleOpt.cpp @@ -610,13 +610,21 @@ struct AnnotateFrontAndBackOpsInForPattern continue; if (!dep_list.size()) - op.setAttr("async_front", rewriter.getBoolAttr(true)); - for (auto token : iterTokens) { - for (auto dep : dep_list) { - if (token == dep) { - setBoolAttrForAsyncOp(rewriter, &op, "async_front"); - } - } + setBoolAttrForAsyncOp(rewriter, &op, "async_front"); + for (auto dep : dep_list) { + // Token is in iter_args + if (llvm::any_of(iterTokens, + [dep](Value token) { return token == dep; })) + setBoolAttrForAsyncOp(rewriter, &op, "async_front"); + } + // Token is declared outside of for loop + if (llvm::any_of(dep_list, [for_op](Value token) { + auto tokenDefOp = token.getDefiningOp(); + if (!tokenDefOp) + return false; + return !for_op->isProperAncestor(tokenDefOp); + })) { + setBoolAttrForAsyncOp(rewriter, &op, "async_front"); } } @@ -649,6 +657,9 @@ struct AnnotateFrontAndBackOpsInForPattern } for (auto op : back_candidates) { setBoolAttrForAsyncOp(rewriter, op, "async_back"); + if (op->hasAttr("async_front")) + // An op cannot be both "async_back" and "async_front". + op->removeAttr("async_front"); } return success(); diff --git a/mlir/test/Transform/AIRDependencyScheduleOpt/annotate_front_back.mlir b/mlir/test/Transform/AIRDependencyScheduleOpt/annotate_front_back.mlir index a02b5b2f7..87c01180a 100644 --- a/mlir/test/Transform/AIRDependencyScheduleOpt/annotate_front_back.mlir +++ b/mlir/test/Transform/AIRDependencyScheduleOpt/annotate_front_back.mlir @@ -62,3 +62,55 @@ func.func @test(%arg0: memref<256x1024xbf16>, %arg1: memref<1024x1024xbf16>, %ar } return } + +// Label async_front based on tokens declared outside of for loop. +// CHECK-LABEL: test1 +// CHECK: air.segment +// CHECK: air.wait_all async +// CHECK: air.wait_all async +// CHECK: scf.for +// CHECK: air.channel.get{{.*}}async_front = true +// CHECK: air.channel.get{{.*}}async_front = true +// CHECK: air.wait_all async{{.*}}async_back = true + +func.func @test1(%arg0: memref<2048xi8>, %arg1: memref<2048x1024xi8>, %arg2: memref<1024xi32>) { + %c4 = arith.constant 4 : index + %0 = air.launch async (%arg3) in (%arg4=%c4) { + %1 = air.segment @vecmat_i8_0 async { + %c4096 = arith.constant 4096 : index + %c8 = arith.constant 8 : index + %c16 = arith.constant 16 : index + %c256 = arith.constant 256 : index + %c1 = arith.constant 1 : index + %c0 = arith.constant 0 : index + %c2048 = arith.constant 2048 : index + %c128 = arith.constant 128 : index + %2 = air.wait_all async + %3 = air.wait_all async + %11 = scf.for %arg5 = %c0 to %c2048 step %c128 iter_args(%arg9 = %3) -> (!air.async.token) { + %async_token, %results = air.execute -> (memref<128xi8, 1>) { + %alloc = memref.alloc() {hoist_alloc = true} : memref<128xi8, 1> + air.execute_terminator %alloc : memref<128xi8, 1> + } + %async_token_0, %results_1 = air.execute -> (memref<128x256xi8, 1>) { + %alloc = memref.alloc() {hoist_alloc = true} : memref<128x256xi8, 1> + air.execute_terminator %alloc : memref<128x256xi8, 1> + } + %4 = air.channel.get async [%2] @channel_1[] (%results[%arg5] [%c128] [%c1]) {id = 4 : i32} : (memref<128xi8, 1>) + %5 = air.channel.get async [%3] @channel_2[] (%results_1[%arg5, %c0] [%c128, %c256] [%c256, %c1]) {id = 5 : i32} : (memref<128x256xi8, 1>) + %6 = air.channel.put async [%4] @channel_0[] (%results[%c0, %arg5] [%c8, %c16] [%c16, %c1]) {id = 6 : i32} : (memref<128xi8, 1>) + %7 = air.channel.put async [%5] @channel_3[%c0, %c0] (%results_1[%c0, %c0, %arg5, %c0] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256, %c1]) {id = 7 : i32} : (memref<128x256xi8, 1>) + %8 = air.channel.put async [%5] @channel_3[%c1, %c0] (%results_1[%c0, %c0, %arg5, %c128] [%c16, %c8, %c16, %c8] [%c8, %c4096, %c256, %c1]) {id = 7 : i32} : (memref<128x256xi8, 1>) + %async_token_2 = air.execute { + memref.dealloc %results : memref<128xi8, 1> + } + %async_token_3 = air.execute { + memref.dealloc %results_1 : memref<128x256xi8, 1> + } + %9 = air.wait_all async [%async_token, %async_token_0, %4, %5, %6, %7, %8, %async_token_2, %async_token_3] + scf.yield %9 : !air.async.token + } {isolated = true, unroll = 2 : i32} + } + } + return +} diff --git a/test/xrt/26_vecmat_i8/aie.py b/test/xrt/26_vecmat_i8/aie.py index 6805297f1..60f43131e 100644 --- a/test/xrt/26_vecmat_i8/aie.py +++ b/test/xrt/26_vecmat_i8/aie.py @@ -142,10 +142,10 @@ "func.func(air-split-l2-memref)", "air-isolate-async-dma-loop-nests", "func.func(air-loop-fusion)", - # "air-label-scf-for-to-ping-pong", - # "air-ping-pong-transform{keep-memref-dealloc=true}", - # "canonicalize", - # "cse", + "air-label-scf-for-to-ping-pong", + "air-ping-pong-transform{keep-memref-dealloc=true}", + "canonicalize", + "cse", "air-specialize-channel-wrap-and-stride", "canonicalize", "cse",