Skip to content

Commit

Permalink
Fixup AIRRt lowerings to support convolution examples (Xilinx#620)
Browse files Browse the repository at this point in the history
* Refactor lowering to airrt.dma to fixup bug when wrap and memref shape do not match

* Only using the last dimension for static offsets; switch to composing offsets via strides instead of memref shapes

* Remove unused variable

* Convolution board test

* Formatting
  • Loading branch information
erwei-xilinx authored Jun 26, 2024
1 parent bcbfed5 commit 4559217
Show file tree
Hide file tree
Showing 8 changed files with 653 additions and 68 deletions.
91 changes: 58 additions & 33 deletions mlir/lib/Conversion/AIRLoweringPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -453,16 +453,15 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
auto i64Ty = builder.getI64Type();
auto zero =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 0));
auto one =
builder.create<arith::ConstantOp>(loc, i64Ty, IntegerAttr::get(i64Ty, 1));
auto zero_idx = builder.create<arith::ConstantIndexOp>(loc, 0);
auto one_idx = builder.create<arith::ConstantIndexOp>(loc, 1);

auto idTy = IntegerType::get(ctx, 32);
// Get op id of the internal put/get op
if (auto id_attr = theOtherOp->getAttrOfType<IntegerAttr>("id")) {
opers.push_back(builder.create<arith::ConstantOp>(loc, idTy, id_attr));
} else {
opers.push_back(builder.create<arith::ConstantOp>(
loc, idTy, IntegerAttr::get(idTy, 0)));
opers.push_back(zero);
}

scf::ParallelOp launch = thisOp->getParentOfType<scf::ParallelOp>();
Expand All @@ -489,45 +488,71 @@ AIRChannelInterfaceToAIRRtConversionImpl(OpBuilder builder,
opers.push_back(builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), launch.getInductionVars()[1]));
else if (launch.getNumLoops() == 1)
opers.push_back(builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0)));
opers.push_back(zero);
else
opers.push_back(builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, 0)));
opers.push_back(zero);
}

opers.push_back(thisOp.getMemref());

SmallVector<Value, 4> offsets(4, zero);
SmallVector<Value, 4> lengths(4, one);
SmallVector<Value, 3> strides(3, zero);
SmallVector<Value> offsets = thisOp.getOffsets();
SmallVector<Value> wraps = thisOp.getSizes();
SmallVector<Value> strides = thisOp.getStrides();

int idx = 4 - thisOp.getOffsets().size();
for (auto o : thisOp.getOffsets()) {
offsets[idx++] =
builder.create<arith::IndexCastOp>(loc, IntegerType::get(ctx, 64), o);
auto memrefType = thisOp.getMemref().getType();

// If empty offsets/sizes/strides, then populate the lists with default
// values.
if (offsets.empty() && wraps.empty() && strides.empty()) {
offsets.push_back(zero_idx);
auto memref_volume = air::getTensorVolume(memrefType);
wraps.push_back(builder.create<arith::ConstantIndexOp>(loc, memref_volume));
strides.push_back(one_idx);
}
// Stride field implicit last element one
auto lastStrideConst = getConstantIntValue(strides.back());
assert(lastStrideConst && "the last stride is not static");
// If the last dimension's stride value is not 1, then for AIE2 we use the
// second dimension of shim dma bd to implement the last dimension.
if (*lastStrideConst != 1) {
offsets.push_back(zero_idx);
wraps.push_back(one_idx);
strides.push_back(one_idx);
}

strides.pop_back();
while (offsets.size() < 4) {
offsets.insert(offsets.begin(), zero_idx);
}
while (wraps.size() < 4) {
wraps.insert(wraps.begin(), one_idx);
}
while (strides.size() < 3) {
strides.insert(strides.begin(), zero_idx);
}

for (unsigned i = 0; i < offsets.size(); i++)
offsets[i] = builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), offsets[i]);

// In aiex.npu ops, stride value 0 means 1; only the highest dimension stride
// value 0 really means repeat.
for (unsigned i = 0; i < strides.size(); i++) {
auto constStride = getConstantIntValue(strides[i]);
assert(constStride && "stride is not static");
if (i > 0 && *constStride == 1)
strides[i] = zero;
else
strides[i] = builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), strides[i]);
}

idx = 4 - thisOp.getStrides().size();
auto op_strides = thisOp.getStrides();
if (op_strides.size())
for (auto o : op_strides.drop_back())
strides[idx++] =
builder.create<arith::IndexCastOp>(loc, IntegerType::get(ctx, 64), o);
idx =
4 - std::max(thisOp.getSizes().size(), (size_t)thisMemrefType.getRank());
// If sizes field is empty, then infer sizes from memref shape
if (thisOp.getSizes().empty())
for (auto d : air::getTensorShape(thisMemrefType))
lengths[idx++] = builder.create<arith::ConstantOp>(
loc, i64Ty, IntegerAttr::get(i64Ty, d));
else
for (auto o : thisOp.getSizes())
lengths[idx++] =
builder.create<arith::IndexCastOp>(loc, IntegerType::get(ctx, 64), o);
for (unsigned i = 0; i < wraps.size(); i++)
wraps[i] = builder.create<arith::IndexCastOp>(
loc, IntegerType::get(ctx, 64), wraps[i]);

opers.append(offsets);
opers.append(lengths);
opers.append(wraps);
opers.append(strides);

SmallVector<Type, 1> tys;
Expand Down
36 changes: 23 additions & 13 deletions mlir/lib/Conversion/AIRRtToNpuPass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,22 +137,32 @@ struct DmaToNpuPattern : public OpConversionPattern<DmaMemcpyNdOp> {
.getResult();
};
SmallVector<Value> offsets;
SmallVector<int64_t> staticOffsets;
if (auto const_int = getConstantIntValue(adaptor.getOffset3()))
staticOffsets.push_back(*const_int);
else
SmallVector<int64_t>
staticOffsets; // Note: for static offsets we compose one single offset
// at the last dimension.
int64_t overallStaticOffset = 0;
if (auto const_int = getConstantIntValue(adaptor.getOffset3())) {
overallStaticOffset +=
*getConstantIntValue(adaptor.getStride3()) * (*const_int);
staticOffsets.push_back(0);
} else
offsets.push_back(adaptor.getOffset3());
if (auto const_int = getConstantIntValue(adaptor.getOffset2()))
staticOffsets.push_back(*const_int);
else
if (auto const_int = getConstantIntValue(adaptor.getOffset2())) {
overallStaticOffset +=
*getConstantIntValue(adaptor.getStride2()) * (*const_int);
staticOffsets.push_back(0);
} else
offsets.push_back(adaptor.getOffset2());
if (auto const_int = getConstantIntValue(adaptor.getOffset1()))
staticOffsets.push_back(*const_int);
else
if (auto const_int = getConstantIntValue(adaptor.getOffset1())) {
overallStaticOffset +=
*getConstantIntValue(adaptor.getStride1()) * (*const_int);
staticOffsets.push_back(0);
} else
offsets.push_back(adaptor.getOffset1());
if (auto const_int = getConstantIntValue(adaptor.getOffset0()))
staticOffsets.push_back(*const_int / div);
else
if (auto const_int = getConstantIntValue(adaptor.getOffset0())) {
overallStaticOffset += *const_int;
staticOffsets.push_back(overallStaticOffset / div);
} else
offsets.push_back(divOp(adaptor.getOffset0()));
SmallVector<Value> sizes;
SmallVector<int64_t> staticSizes;
Expand Down
26 changes: 14 additions & 12 deletions mlir/test/Conversion/AIRLowering/air_channel_get_put.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -24,25 +24,26 @@ module {
%c32 = arith.constant 32 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%0 = air.channel.put async @channel_0[%c0, %c0] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 1 : i32} : (memref<32x16xi32>)
%1 = air.channel.get async @channel_1[%c0, %c0] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 2 : i32} : (memref<32x16xi32>)
%0 = air.channel.put async @channel_0[%c0, %c0] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 1 : i32} : (memref<32x16xi32>)
%1 = air.channel.get async @channel_1[%c0, %c0] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 2 : i32} : (memref<32x16xi32>)
air.segment @segment_0 {
%c1_0 = arith.constant 1 : index
air.herd @herd_0 tile (%arg10, %arg11) in (%arg12=%c1_0, %arg13=%c1_0) {
%c0_4 = arith.constant 0 : index
%c1_4 = arith.constant 1 : index
%c32_5 = arith.constant 32 : index
%c16_6 = arith.constant 16 : index
%c8_7 = arith.constant 8 : index
%alloc = memref.alloc() {sym_name = "scratch"} : memref<16x8xi32, 2>
%alloc_8 = memref.alloc() {sym_name = "scratch_copy"} : memref<16x8xi32, 2>
air.channel.get @channel_0[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c0_4]) {id = 3 : i32} : (memref<16x8xi32, 2>)
air.channel.get @channel_0[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c1_4]) {id = 3 : i32} : (memref<16x8xi32, 2>)
affine.for %arg18 = 0 to 8 {
affine.for %arg19 = 0 to 16 {
%2 = affine.load %alloc[%arg19, %arg18] : memref<16x8xi32, 2>
affine.store %2, %alloc_8[%arg19, %arg18] : memref<16x8xi32, 2>
}
}
air.channel.put @channel_1[%arg10, %arg11] (%alloc_8[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c0_4]) {id = 4 : i32} : (memref<16x8xi32, 2>)
air.channel.put @channel_1[%arg10, %arg11] (%alloc_8[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c1_4]) {id = 4 : i32} : (memref<16x8xi32, 2>)
memref.dealloc %alloc_8 : memref<16x8xi32, 2>
memref.dealloc %alloc : memref<16x8xi32, 2>
}
Expand Down Expand Up @@ -76,15 +77,15 @@ module {
%c0 = arith.constant 0 : index
%0 = air.wait_all async
%1 = scf.parallel (%a2, %a3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init (%0) -> !air.async.token {
%3 = air.channel.put async @channel_2[%a2, %a3] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 1 : i32} : (memref<32x16xi32>)
%3 = air.channel.put async @channel_2[%a2, %a3] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 1 : i32} : (memref<32x16xi32>)
scf.reduce(%3 : !air.async.token) {
^bb0(%a4: !air.async.token, %a5: !air.async.token):
%4 = air.wait_all async [%a4, %a5]
scf.reduce.return %4 : !air.async.token
}
}
%2 = scf.parallel (%a2, %a3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init (%0) -> !air.async.token {
%3 = air.channel.get async @channel_3[%a2, %a3] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 2 : i32} : (memref<32x16xi32>)
%3 = air.channel.get async @channel_3[%a2, %a3] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 2 : i32} : (memref<32x16xi32>)
scf.reduce(%3 : !air.async.token) {
^bb0(%a4: !air.async.token, %a5: !air.async.token):
%4 = air.wait_all async [%a4, %a5]
Expand All @@ -96,19 +97,20 @@ module {
%c2_3 = arith.constant 2 : index
air.herd @herd_0 tile (%arg10, %arg11) in (%arg12=%c2_2, %arg13=%c2_3) args(%arg14=%arg6, %arg15=%arg7, %arg16=%arg8, %arg17=%arg9) : index, index, index, index {
%c0_4 = arith.constant 0 : index
%c1_4 = arith.constant 1 : index
%c32_5 = arith.constant 32 : index
%c16_6 = arith.constant 16 : index
%c8_7 = arith.constant 8 : index
%alloc = memref.alloc() {sym_name = "scratch"} : memref<16x8xi32, 2>
%alloc_8 = memref.alloc() {sym_name = "scratch_copy"} : memref<16x8xi32, 2>
air.channel.get @channel_2[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c0_4]) {id = 3 : i32} : (memref<16x8xi32, 2>)
air.channel.get @channel_2[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c1_4]) {id = 3 : i32} : (memref<16x8xi32, 2>)
affine.for %arg18 = 0 to 8 {
affine.for %arg19 = 0 to 16 {
%3 = affine.load %alloc[%arg19, %arg18] : memref<16x8xi32, 2>
affine.store %3, %alloc_8[%arg19, %arg18] : memref<16x8xi32, 2>
}
}
air.channel.put @channel_3[%arg10, %arg11] (%alloc_8[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c0_4]) {id = 4 : i32} : (memref<16x8xi32, 2>)
air.channel.put @channel_3[%arg10, %arg11] (%alloc_8[%c0_4, %c0_4] [%c8_7, %c16_6] [%c32_5, %c1_4]) {id = 4 : i32} : (memref<16x8xi32, 2>)
memref.dealloc %alloc_8 : memref<16x8xi32, 2>
memref.dealloc %alloc : memref<16x8xi32, 2>
}
Expand Down Expand Up @@ -144,7 +146,7 @@ module {
%c0 = arith.constant 0 : index
%0 = air.wait_all async
%1 = scf.parallel (%a2, %a3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init (%0) -> !air.async.token {
%3 = air.channel.put async @channel_4[%a2, %a3] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 1 : i32} : (memref<32x16xi32>)
%3 = air.channel.put async @channel_4[%a2, %a3] (%arg0[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 1 : i32} : (memref<32x16xi32>)
scf.reduce(%3 : !air.async.token) {
^bb0(%a4: !air.async.token, %a5: !air.async.token):
%4 = air.wait_all async [%a4, %a5]
Expand All @@ -153,7 +155,7 @@ module {
}
%2 = scf.parallel (%a2, %a3) = (%c0, %c0) to (%c2, %c2) step (%c1, %c1) init (%0) -> !air.async.token {
%3 = scf.for %a4 = %c0 to %c2 step %c1 iter_args(%a5 = %0) -> (!air.async.token) {
%4 = air.channel.get async [%a5] @channel_5[%a2, %a3] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c0]) {id = 2 : i32} : (memref<32x16xi32>)
%4 = air.channel.get async [%a5] @channel_5[%a2, %a3] (%arg1[%c8, %c0] [%c8, %c16] [%c32, %c1]) {id = 2 : i32} : (memref<32x16xi32>)
scf.yield %4 : !air.async.token
}
scf.reduce(%3 : !air.async.token) {
Expand All @@ -174,15 +176,15 @@ module {
%c8_9 = arith.constant 8 : index
%alloc = memref.alloc() {sym_name = "scratch"} : memref<16x8xi32, 2>
%alloc_10 = memref.alloc() {sym_name = "scratch_copy"} : memref<16x8xi32, 2>
air.channel.get @channel_4[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_9, %c16_8] [%c32_7, %c0_4]) {id = 3 : i32} : (memref<16x8xi32, 2>)
air.channel.get @channel_4[%arg10, %arg11] (%alloc[%c0_4, %c0_4] [%c8_9, %c16_8] [%c32_7, %c1_6]) {id = 3 : i32} : (memref<16x8xi32, 2>)
affine.for %arg18 = 0 to 8 {
affine.for %arg19 = 0 to 16 {
%3 = affine.load %alloc[%arg19, %arg18] : memref<16x8xi32, 2>
affine.store %3, %alloc_10[%arg19, %arg18] : memref<16x8xi32, 2>
}
}
scf.for %arg18 = %c0_4 to %c2_5 step %c1_6 {
air.channel.put @channel_5[%arg10, %arg11] (%alloc_10[%c0_4, %c0_4] [%c8_9, %c16_8] [%c32_7, %c0_4]) {id = 4 : i32} : (memref<16x8xi32, 2>)
air.channel.put @channel_5[%arg10, %arg11] (%alloc_10[%c0_4, %c0_4] [%c8_9, %c16_8] [%c32_7, %c1_6]) {id = 4 : i32} : (memref<16x8xi32, 2>)
}
memref.dealloc %alloc_10 : memref<16x8xi32, 2>
memref.dealloc %alloc : memref<16x8xi32, 2>
Expand Down
53 changes: 53 additions & 0 deletions mlir/test/Conversion/AIRLowering/air_to_npu.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -230,3 +230,56 @@ module {
return
}
}

// -----

// Convolution.

// CHECK-DAG: %[[CST_64:.*]] = arith.constant 64 : i64
// CHECK-DAG: %[[CST_1:.*]] = arith.constant 1 : i64
// CHECK-DAG: %[[CST_1152:.*]] = arith.constant 1152 : i64
// CHECK-DAG: %[[CST_18:.*]] = arith.constant 18 : i32
// CHECK-DAG: %[[CST_5:.*]] = arith.constant 5 : i32
// CHECK-DAG: %[[CST_4:.*]] = arith.constant 4 : i32
// CHECK-DAG: %[[CST_0:.*]] = arith.constant 0 : i64
// CHECK: affine.for %[[VAL_0:.*]] = 0 to 2 {
// CHECK: %[[VAL_1:.*]] = arith.index_cast %[[VAL_0]] : index to i64
// CHECK: airrt.dma_memcpy_nd(%[[CST_4]], %0, %[[CST_0]], %arg0[%[[CST_0]], %[[CST_0]], %0, %[[CST_0]]], [%[[CST_1]], %[[CST_1]], %[[CST_1]], %[[CST_1152]]], [%[[CST_0]], %[[CST_0]], %[[CST_1152]]]) {metadata = @airMemcpyId4} : (i32, i64, i64, memref<2x6x6x32xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
// CHECK: airrt.dma_memcpy_nd(%[[CST_5]], %0, %[[CST_0]], %arg1[%[[CST_0]], %[[CST_0]], %[[CST_0]], %[[CST_0]]], [%[[CST_1]], %[[CST_1]], %[[CST_1]], %[[CST_1152]]], [%[[CST_0]], %[[CST_0]], %[[CST_0]]]) {metadata = @airMemcpyId5} : (i32, i64, i64, memref<3x3x32x4xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event
// CHECK: airrt.dma_memcpy_nd(%[[CST_18]], %0, %[[CST_0]], %arg2[%[[CST_0]], %[[CST_0]], %0, %[[CST_0]]], [%[[CST_1]], %[[CST_1]], %[[CST_1]], %[[CST_64]]], [%[[CST_0]], %[[CST_0]], %[[CST_64]]]) {metadata = @airMemcpyId18} : (i32, i64, i64, memref<2x4x4x4xi32>, [i64, i64, i64, i64], [i64, i64, i64, i64], [i64, i64, i64]) : !airrt.event

module {
air.channel @channel_5 [1, 1]
air.channel @channel_2 [1, 1]
air.channel @channel_1 [1, 1]
func.func @func3(%arg0: memref<2x6x6x32xi32>, %arg1: memref<3x3x32x4xi32>, %arg2: memref<2x4x4x4xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3) in (%arg4=%c2) args(%arg5=%arg0, %arg6=%arg2, %arg7=%arg1) : memref<2x6x6x32xi32>, memref<2x4x4x4xi32>, memref<3x3x32x4xi32> attributes {id = 1 : i32} {
%c64 = arith.constant 64 : index
%c1152 = arith.constant 1152 : index
%c1 = arith.constant 1 : index
%c0 = arith.constant 0 : index
%1 = air.channel.put async @channel_1[] (%arg5[%arg3, %c0] [%c1, %c1152] [%c1152, %c1]) {id = 1 : i32, metadata = @airMemcpyId4} : (memref<2x6x6x32xi32>)
%2 = air.channel.put async @channel_2[] (%arg7[] [] []) {id = 2 : i32, metadata = @airMemcpyId5} : (memref<3x3x32x4xi32>)
%3 = air.channel.get async @channel_5[] (%arg6[%arg3, %c0] [%c1, %c64] [%c64, %c1]) {id = 3 : i32, metadata = @airMemcpyId18} : (memref<2x4x4x4xi32>)
%4 = air.segment @conv async attributes {id = 2 : i32, x_loc = 0 : i64, x_size = 1 : i64, y_loc = 2 : i64, y_size = 4 : i64} {
%async_token, %results = air.execute -> (memref<1x6x6x32xi32, 1>) {
%alloc = memref.alloc() : memref<1x6x6x32xi32, 1>
air.execute_terminator %alloc : memref<1x6x6x32xi32, 1>
}
%5 = air.channel.get async [%async_token] @channel_1[] (%results[] [] []) {id = 4 : i32} : (memref<1x6x6x32xi32, 1>)
%async_token_0, %results_1 = air.execute -> (memref<3x3x32x4xi32, 1>) {
%alloc = memref.alloc() : memref<3x3x32x4xi32, 1>
air.execute_terminator %alloc : memref<3x3x32x4xi32, 1>
}
%6 = air.channel.get async [%async_token_0] @channel_2[] (%results_1[] [] []) {id = 5 : i32} : (memref<3x3x32x4xi32, 1>)
%async_token_2, %results_3 = air.execute -> (memref<1x4x4x4xi32, 1>) {
%alloc = memref.alloc() : memref<1x4x4x4xi32, 1>
air.execute_terminator %alloc : memref<1x4x4x4xi32, 1>
}
%7 = air.channel.put async [%async_token_2] @channel_5[] (%results_3[] [] []) {id = 18 : i32} : (memref<1x4x4x4xi32, 1>)
}
}
return
}
}
Loading

0 comments on commit 4559217

Please sign in to comment.