Skip to content

Commit

Permalink
AIR BD canonicalization: Unroll scf.parallel to specialize spatial of…
Browse files Browse the repository at this point in the history
…fsets (Xilinx#617)

* Unroll scf.parallel loops, so that complex ind. var. indexing can be composed and specialized

* Fixup missing offset insertion at wrap n stride canonicalizer; fixup flawed logic in new stride value calculation

* Unit test
  • Loading branch information
erwei-xilinx authored Jun 25, 2024
1 parent a5b66a1 commit 4dd631c
Show file tree
Hide file tree
Showing 3 changed files with 182 additions and 15 deletions.
109 changes: 108 additions & 1 deletion mlir/lib/Transform/AIRDependencyScheduleOpt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1518,6 +1518,109 @@ struct LabelScfForLoopInAIRSegment : public OpRewritePattern<scf::ForOp> {
private:
};

struct UnrollScfParallel : public OpRewritePattern<scf::ParallelOp> {
using OpRewritePattern<scf::ParallelOp>::OpRewritePattern;

LogicalResult matchAndRewrite(scf::ParallelOp par,
PatternRewriter &rewriter) const override {

auto loc = rewriter.getUnknownLoc();

for (auto lb : par.getLowerBound()) {
auto constLB = getConstantIntValue(lb);
assert(constLB && "non-static scf.parallel lb, NYI");
assert(*constLB == 0 && "non-zero scf.parallel lb, NYI");
}

// Get parallel loop trip count.
SmallVector<int, 2> lbs_spatial, ubs_spatial;
air::getSizesFromSpatialLoop(par.getOperation(), lbs_spatial, ubs_spatial);
std::vector<unsigned> par_size;
unsigned par_vol = 1;
for (unsigned i = 0; i < lbs_spatial.size(); i++) {
par_size.push_back(ubs_spatial[i] - lbs_spatial[i] + 1);
par_vol *= ubs_spatial[i] - lbs_spatial[i] + 1;
}

// Collect yielded tokens.
SmallVector<Value> yieldedTokens;

// Walk all iterations. Assumption: LB starts from 0.
for (unsigned iter = 0; iter < par_vol; iter++) {
IRMapping remap;
std::vector<unsigned> position =
air::getMDVectorFromIterator(par_size, iter);
// Create arith.const ops per position
SmallVector<Value> positionVals;
for (unsigned i = 0; i < position.size(); i++) {
positionVals.push_back(
rewriter.create<arith::ConstantIndexOp>(loc, position[i]));
remap.map(par.getInductionVars()[i], positionVals[i]);
}

// Splice
for (auto &op : par.getBody()->getOperations()) {
if (op.mightHaveTrait<OpTrait::IsTerminator>()) {
if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
SmallVector<Value> tokens;
for (auto yieldOper : yieldOp->getOperands())
if (isa<air::AsyncTokenType>(yieldOper.getType()))
tokens.push_back(yieldOper);
auto newWaitAll = rewriter.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(rewriter.getContext()), tokens);
yieldedTokens.push_back(newWaitAll.getAsyncToken());
}
continue;
}
rewriter.clone(op, remap);
}
}

// Scf.parallel returned token
if (par->getNumResults()) {
auto newWaitAll = rewriter.create<air::WaitAllOp>(
loc, air::AsyncTokenType::get(rewriter.getContext()), yieldedTokens);
par->getResult(0).replaceAllUsesWith(newWaitAll.getAsyncToken());
}

rewriter.eraseOp(par);
return success();
}

private:
};

struct CanonicalizeAIRExecute : public OpRewritePattern<air::ExecuteOp> {
using OpRewritePattern<air::ExecuteOp>::OpRewritePattern;

LogicalResult matchAndRewrite(air::ExecuteOp exec,
PatternRewriter &rewriter) const override {

auto childOp = exec.getChildOp();
assert(childOp && "air.execute op has no child op");
// Canonicalize air.execute with empty region.
if (!childOp->mightHaveTrait<OpTrait::IsTerminator>())
return failure();
exec.getAsyncToken().replaceAllUsesWith(
rewriter
.create<air::WaitAllOp>(
exec->getLoc(), air::AsyncTokenType::get(rewriter.getContext()),
exec.getAsyncDependencies())
.getAsyncToken());

if (childOp->getNumOperands() != 1)
return failure();
assert(childOp->getNumOperands() == 1 &&
"air.execute_terminator doesn't have exactly one operand, NYI");
exec.getResult(1).replaceAllUsesWith(childOp->getOperand(0));

rewriter.eraseOp(exec);
return success();
}

private:
};

struct CanonicalizeAffineApplyOnLoopInductionVar
: public OpRewritePattern<affine::AffineApplyOp> {
using OpRewritePattern<affine::AffineApplyOp>::OpRewritePattern;
Expand Down Expand Up @@ -2711,9 +2814,13 @@ class AIRSpecializeChannelWrapAndStridePattern
void runOptPatterns(func::FuncOp funcOp) {
MLIRContext *ctx = funcOp.getContext();
RewritePatternSet patterns(&getContext());
patterns.insert<CanonicalizeAffineApplyOnLoopInductionVar,
patterns.insert<UnrollScfParallel, CanonicalizeAIRExecute,
CanonicalizeAffineApplyOnLoopInductionVar,
AIRSpecializeChannelWrapAndStrideInScfFor,
AIRSpecializeChannelWrapAndStrideInAffineFor>(ctx);
// Canonicalize constant operands in affine.apply.
mlir::affine::AffineApplyOp::getCanonicalizationPatterns(patterns, ctx);
air::WaitAllOp::getCanonicalizationPatterns(patterns, ctx);
(void)applyPatternsAndFoldGreedily(funcOp, std::move(patterns));

// Canonicalize wrap and stride list to remove redundant dimensions
Expand Down
19 changes: 5 additions & 14 deletions mlir/lib/Util/Util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1002,7 +1002,7 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
for_loops.push_back(parent);
}
for (auto o : for_loops) {
uint64_t ind_var_factor = 1;
uint64_t ind_var_factor = 0;
for (int i = offsets.size() - 1; i >= 0; i--) {
Value iv = nullptr;
int loop_lower_bound = 0;
Expand All @@ -1019,28 +1019,17 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
// Replace for loop induction vars in offsets with zero
offsets[i] = builder.template create<arith::ConstantIndexOp>(
loc, loop_lower_bound);
ind_var_factor = *getConstantIntValue(strides[i]);
break;
} else if (iv && offsets[i].getDefiningOp()) {
if (isa<arith::IndexCastOp>(offsets[i].getDefiningOp()) &&
offsets[i].getDefiningOp()->getOperand(0) == iv) {
offsets[i] = builder.template create<arith::ConstantIndexOp>(
loc, loop_lower_bound);
ind_var_factor = *getConstantIntValue(strides[i]);
break;
};
}
// Index offset taking into account mismatch between memref rank and
// offset list size difference.
auto memref_rank = getTensorShape(memref.getType()).size();
if (memref_rank < offsets.size()) {
if ((unsigned)i < offsets.size() - memref_rank)
ind_var_factor *= getTensorVolume(memref.getType());
else
ind_var_factor *= getTensorShape(
memref.getType())[i + memref_rank - offsets.size()];
} else {
ind_var_factor *=
getTensorShape(memref.getType())[i + memref_rank - offsets.size()];
}
}
int trip_count = -1;
if (auto afo = dyn_cast<affine::AffineForOp>(o))
Expand Down Expand Up @@ -1071,6 +1060,8 @@ LogicalResult air::foldForLoopNestAsExtendedSizesAndStrides(
}

// Insert new dimension into the wraps and strides list.
offsets.insert(offsets.begin(),
builder.template create<arith::ConstantIndexOp>(loc, 0));
wraps.insert(wraps.begin(), new_wrap);
strides.insert(strides.begin(), new_stride);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
// Specialize air.channel ops in perfectly nested for loops in air.segment with wraps and strides.

#map = affine_map<()[s0] -> (s0 * 32)>
#map1 = affine_map<()[s0, s1] -> (s0 + s1)>
module {

// CHECK-LABEL: test0
Expand Down Expand Up @@ -296,4 +297,72 @@ module {
air.channel.put @channel_23[] (%arg3[%c1, %c0, %c0, %c0, %c0, %c0] [%c1, %c1, %c4, %c8, %c4, %c8] [%c1024, %c1024, %c8, %c128, %c32, %c1]) : (memref<2x1x32x32xi32, 1 : i32>)
return
}

// Scf.parallel loop specialization; specialization of affine.apply on induction vars.
// CHECK-LABEL: test10

// CHECK: air.channel.put async [%{{.*}}] @channel_24[%c0, %c0] (%{{.*}}[%c0, %c0, %c0, %c0, %c0] [%c4, %c3, %c3, %c8, %c4] [%c288, %c6, %c1, %c36, %c1]) : (memref<1x32x6x6xi32, 1>)
// CHECK: air.channel.put async [%{{.*}}] @channel_24[%c1, %c0] (%{{.*}}[%c0, %c0, %c0, %c0, %c6] [%c4, %c3, %c3, %c8, %c4] [%c288, %c6, %c1, %c36, %c1]) : (memref<1x32x6x6xi32, 1>)
// CHECK: air.channel.put async [%{{.*}}] @channel_24[%c2{{.*}}, %c0] (%{{.*}}[%c0, %c0, %c0, %c0, %c12] [%c4, %c3, %c3, %c8, %c4] [%c288, %c6, %c1, %c36, %c1]) : (memref<1x32x6x6xi32, 1>)
// CHECK: air.channel.put async [%{{.*}}] @channel_24[%c3, %c0] (%{{.*}}[%c0, %c0, %c0, %c0, %c18] [%c4, %c3, %c3, %c8, %c4] [%c288, %c6, %c1, %c36, %c1]) : (memref<1x32x6x6xi32, 1>)
// CHECK: air.channel.get async [%{{.*}}] @channel_25[%c0, %c0] (%{{.*}}[%c0, %c0] [%c4, %c4] [%c16, %c1]) : (memref<1x4x4x4xi32, 1>)
// CHECK: air.channel.get async [%{{.*}}] @channel_25[%c1, %c0] (%{{.*}}[%c0, %c4] [%c4, %c4] [%c16, %c1]) : (memref<1x4x4x4xi32, 1>)
// CHECK: air.channel.get async [%{{.*}}] @channel_25[%c2{{.*}}, %c0] (%{{.*}}[%c0, %c8] [%c4, %c4] [%c16, %c1]) : (memref<1x4x4x4xi32, 1>)
// CHECK: air.channel.get async [%{{.*}}] @channel_25[%c3, %c0] (%{{.*}}[%c0, %c12] [%c4, %c4] [%c16, %c1]) : (memref<1x4x4x4xi32, 1>)

func.func @test10(%arg0: memref<2x32x6x6xi32>, %arg1: memref<4x32x3x3xi32>, %arg2: memref<2x4x4x4xi32>) {
%c2 = arith.constant 2 : index
%0 = air.launch async (%arg3) in (%arg4=%c2) attributes {id = 1 : i32} {
%1 = air.segment @conv_static_0 async attributes {id = 2 : i32} {
%c8 = arith.constant 8 : index
%c3 = arith.constant 3 : index
%c16 = arith.constant 16 : index
%c36 = arith.constant 36 : index
%c6 = arith.constant 6 : index
%c32 = arith.constant 32 : index
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c4 = arith.constant 4 : index
%async_token, %results = air.execute -> (memref<1x32x6x6xi32, 1>) {
%alloc = memref.alloc() : memref<1x32x6x6xi32, 1>
air.execute_terminator %alloc : memref<1x32x6x6xi32, 1>
}
%async_token_0, %results_1 = air.execute -> (memref<1x4x4x4xi32, 1>) {
%alloc = memref.alloc() : memref<1x4x4x4xi32, 1>
air.execute_terminator %alloc : memref<1x4x4x4xi32, 1>
}
%2 = scf.parallel (%arg5) = (%c0) to (%c4) step (%c1) init (%async_token) -> !air.async.token {
%5 = scf.for %arg6 = %c0 to %c32 step %c8 iter_args(%arg7 = %async_token) -> (!air.async.token) {
%6 = scf.for %arg8 = %c0 to %c3 step %c1 iter_args(%arg9 = %arg7) -> (!air.async.token) {
%7 = scf.for %arg10 = %c0 to %c3 step %c1 iter_args(%arg11 = %arg9) -> (!air.async.token) {
%async_token_2, %results_3 = air.execute [%arg11] -> (index) {
%9 = affine.apply #map1()[%arg5, %arg8]
air.execute_terminator %9 : index
}
%8 = air.channel.put async [%async_token_2] @channel_24[%arg5, %c0] (%results[%arg6, %results_3, %arg10] [%c8, %c1, %c4] [%c36, %c6, %c1]) : (memref<1x32x6x6xi32, 1>)
scf.yield %8 : !air.async.token
}
scf.yield %7 : !air.async.token
}
scf.yield %6 : !air.async.token
}
scf.reduce(%5 : !air.async.token) {
^bb0(%arg6: !air.async.token, %arg7: !air.async.token):
%6 = air.wait_all async [%arg6, %arg7]
scf.reduce.return %6 : !air.async.token
}
}
%3 = scf.parallel (%arg5) = (%c0) to (%c4) step (%c1) init (%async_token_0) -> !air.async.token {
%5 = air.channel.get async [%async_token_0] @channel_25[%arg5, %c0] (%results_1[%c0, %arg5, %c0] [%c4, %c1, %c4] [%c16, %c4, %c1]) : (memref<1x4x4x4xi32, 1>)
scf.reduce(%5 : !air.async.token) {
^bb0(%arg6: !air.async.token, %arg7: !air.async.token):
%6 = air.wait_all async [%arg6, %arg7]
scf.reduce.return %6 : !air.async.token
}
}
%4 = air.wait_all async [%2, %3]
}
}
return
}
}

0 comments on commit 4dd631c

Please sign in to comment.