Skip to content

Commit

Permalink
AIRLowering: Multi-dimensional air.launch bugfix (Xilinx#615)
Browse files Browse the repository at this point in the history
* Support for air.launch having more than 2 dims

* Unit test
  • Loading branch information
erwei-xilinx authored Jun 24, 2024
1 parent a412426 commit b82f610
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 16 deletions.
32 changes: 16 additions & 16 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,8 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
opers.push_back(builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0)));
else
assert(false && "lowering of air.launch with more than 2 dimensions is "
"currently unsupported");
opers.push_back(builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0)));
}

opers.push_back(thisOp.getMemref());
Expand Down Expand Up @@ -1054,22 +1054,22 @@ LogicalResult ScfParToAffineForConversion(Operation *op) {
dyn_cast<arith::ConstantIndexOp>(v.getDefiningOp()).value());

OpBuilder builder(scf_par);
auto outer =
builder.create<affine::AffineForOp>(scf_par.getLoc(), 0, par_sizes[0]);
affine::AffineForOp inner = nullptr;
if (par_sizes.size() == 2) {
auto outer_builder = OpBuilder::atBlockBegin(outer.getBody());
inner = outer_builder.create<affine::AffineForOp>(scf_par.getLoc(), 0,
par_sizes[1]);
} else
inner = outer;
SmallVector<affine::AffineForOp> loops;
for (unsigned i = 0; i < par_sizes.size(); i++) {
if (i == 0)
loops.push_back(builder.create<affine::AffineForOp>(scf_par.getLoc(), 0,
par_sizes[0]));
else {
auto inner_builder = OpBuilder::atBlockBegin(loops[i - 1].getBody());
loops.push_back(inner_builder.create<affine::AffineForOp>(
scf_par.getLoc(), 0, par_sizes[i]));
}
}

builder.setInsertionPointToStart(inner.getBody());
builder.setInsertionPointToStart(loops.back().getBody());
IRMapping remap;
remap.map(scf_par.getInductionVars()[0], outer.getInductionVar());
if (par_sizes.size() == 2) {
remap.map(scf_par.getInductionVars()[1], inner.getInductionVar());
}
for (unsigned i = 0; i < par_sizes.size(); i++)
remap.map(scf_par.getInductionVars()[i], loops[i].getInductionVar());
for (auto &o : scf_par.getBody()->getOperations()) {
if (!isa<scf::ReduceOp>(o) && !isa<scf::YieldOp>(o) &&
!isa<scf::ParallelOp>(o)) {
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Conversion/AIRLowering/air_launch.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,45 @@ func.func @launch_1() {
}
return
}

// Multi-dimensional air.launch, with async. air.channel consuming the induction vars.

// CHECK-LABEL: launch_2
// CHECK: affine.for %[[VAL_0:.*]] = 0 to 2 {
// CHECK: affine.for %[[VAL_1:.*]] = 0 to 2 {
// CHECK: affine.for %[[VAL_2:.*]] = 0 to 2 {
// CHECK: affine.for %[[VAL_3:.*]] = 0 to 2 {

air.channel @channel_3 [1, 1]
air.channel @channel_2 [1, 1]
air.channel @channel_1 [1, 1]
func.func @launch_2(%arg0: memref<2x32x6x6xi32>, %arg1: memref<4x32x3x3xi32>, %arg2: memref<2x4x4x4xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3, %arg4, %arg5, %arg6) in (%arg7=%c2, %arg8=%c2, %arg9=%c2, %arg10=%c2) args(%arg11=%arg0, %arg12=%arg2, %arg13=%arg1) : memref<2x32x6x6xi32>, memref<2x4x4x4xi32>, memref<4x32x3x3xi32> attributes {id = 1 : i32} {
%c64 = arith.constant 64 : index
%c1152 = arith.constant 1152 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%1 = air.channel.put async @channel_1[] (%arg11[%arg3, %c0] [%c1, %c1152] [%c1152, %c1]) {id = 1 : i32} : (memref<2x32x6x6xi32>)
%2 = air.channel.put async @channel_2[] (%arg13[] [] []) {id = 2 : i32} : (memref<4x32x3x3xi32>)
%3 = air.channel.get async @channel_3[] (%arg12[%arg3, %c0] [%c1, %c64] [%c64, %c1]) {id = 3 : i32} : (memref<2x4x4x4xi32>)
%4 = air.segment @segment_0 async {
%async_token, %results = air.execute -> (memref<1x32x6x6xi32, 1>) {
%alloc = memref.alloc() : memref<1x32x6x6xi32, 1>
air.execute_terminator %alloc : memref<1x32x6x6xi32, 1>
}
%5 = air.channel.get async [%async_token] @channel_1[] (%results[] [] []) {id = 4 : i32} : (memref<1x32x6x6xi32, 1>)
%async_token_0, %results_1 = air.execute -> (memref<4x32x3x3xi32, 1>) {
%alloc = memref.alloc() : memref<4x32x3x3xi32, 1>
air.execute_terminator %alloc : memref<4x32x3x3xi32, 1>
}
%6 = air.channel.get async [%async_token_0] @channel_2[] (%results_1[] [] []) {id = 5 : i32} : (memref<4x32x3x3xi32, 1>)
%async_token_2, %results_3 = air.execute -> (memref<1x4x4x4xi32, 1>) {
%alloc = memref.alloc() : memref<1x4x4x4xi32, 1>
air.execute_terminator %alloc : memref<1x4x4x4xi32, 1>
}
%7 = air.channel.put async [%6] @channel_3[] (%results_3[] [] []) {id = 12 : i32} : (memref<1x4x4x4xi32, 1>)
}
}
return
}

0 comments on commit b82f610

Please sign in to comment.