diff --git a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp index d8c1ce622..9ef8c8bc6 100644 --- a/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp +++ b/mlir/lib/Dialect/AIR/IR/AIRDialect.cpp @@ -1181,17 +1181,21 @@ static LogicalResult FoldExecute(ExecuteOp op, PatternRewriter &rewriter) { return success(); } - // replace returns of constants with the constant + // replace returns of (1) constants with the constant, and (2) values not + // defined within the execute with its original value int idx = 0; for (auto v : et->getOperands()) { idx++; if (op.getResult(idx).use_empty()) continue; - auto o = v.getDefiningOp(); - if (!o) - continue; - if (isa(o)) { - op.getResult(idx).replaceAllUsesWith(rewriter.clone(*o)->getResult(0)); + if (!op.getRegion().isAncestor(v.getParentRegion())) { + rewriter.replaceAllUsesWith(op.getResult(idx), v); + return success(); + } + if (auto constOp = + dyn_cast_if_present(v.getDefiningOp())) { + rewriter.replaceAllUsesWith(op.getResult(idx), + rewriter.clone(*constOp)->getResult(0)); return success(); } } diff --git a/mlir/test/Dialect/AIR/air_canonicalize.mlir b/mlir/test/Dialect/AIR/air_canonicalize.mlir index c6c27e50b..1a822a1e4 100644 --- a/mlir/test/Dialect/AIR/air_canonicalize.mlir +++ b/mlir/test/Dialect/AIR/air_canonicalize.mlir @@ -327,6 +327,31 @@ func.func @execute_4() -> (memref<1xi32>, !air.async.token) { return %results, %t : memref<1xi32>, !air.async.token } +// CHECK-LABEL: execute_5 +// CHECK: scf.for {{.*}} { +// CHECK: air.execute {{.*}} { +// CHECK-NEXT: memref.store +// CHECK-NEXT: } +// CHECK: scf.yield +// CHECK-NEXT: } +func.func @execute_5(%alloc : memref<4xi32>) -> (!air.async.token) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c4 = arith.constant 4 : index + %cst = arith.constant 2 : i32 + %t = air.wait_all async + %0 = scf.for %arg0 = %c0 to %c4 step %c1 iter_args(%arg1 = %t) -> (!air.async.token) { + %async_token_0, %results_1 = air.execute [%arg1] -> (index) { + air.execute_terminator %arg0 : index + } + %async_token_1 = air.execute [%async_token_0] { + memref.store %cst, %alloc[%results_1] : memref<4xi32> + } + scf.yield %async_token_1 : !air.async.token + } + return %0 :!air.async.token +} + // CHECK: func.func @chan_0 // CHECK: %[[TOKEN0:.*]] = air.channel.get async @channel_0 // CHECK: %[[TOKEN1:.*]] = air.channel.get async @channel_1