Skip to content

Commit

Permalink
[Pass] Support ModOp in .outline() (#127)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhzh123 committed Aug 9, 2022
1 parent 6a856ef commit 20964ca
Showing 1 changed file with 49 additions and 7 deletions.
56 changes: 49 additions & 7 deletions lib/Transforms/LoopTransformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2680,17 +2680,19 @@ LogicalResult runOutline(ModuleOp &mod, func::FuncOp &f, OutlineOp &outlineOp) {
SmallVector<AffineExpr> newExprs;
isDifferent = false;
isParameterized = false;
int targetCst = 1;
if (srcLoadMap != targetLoadMap) {
bool isMod = false;
for (auto item :
llvm::zip(srcLoadMap.getResults(), targetLoadMap.getResults())) {
auto expr = std::get<0>(item);
auto targetExpr = std::get<1>(item);
if (targetLoadMap.getNumSymbols() > 0 &&
targetExpr.isa<AffineBinaryOpExpr>() &&
targetExpr.getKind() == AffineExprKind::Mul) {
targetExpr.getKind() != AffineExprKind::Add) {
int cst = 1;
if (expr.isa<AffineBinaryOpExpr>() &&
expr.getKind() == AffineExprKind::Mul)
expr.getKind() != AffineExprKind::Add)
cst = expr.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
Expand All @@ -2701,18 +2703,32 @@ LogicalResult runOutline(ModuleOp &mod, func::FuncOp &f, OutlineOp &outlineOp) {
allMemrefs.push_back(cstOp);
isParameterized = true;
} else if (expr.isa<AffineBinaryOpExpr>() &&
expr.getKind() == AffineExprKind::Mul) {
expr.getKind() != AffineExprKind::Add) {
auto cst = expr.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
auto cstOp = call_builder.create<arith::ConstantIndexOp>(
targetFunc.getLoc(), cst);
if (targetExpr.isa<AffineBinaryOpExpr>()) {
targetCst = targetExpr.cast<AffineBinaryOpExpr>()
.getRHS()
.cast<AffineConstantExpr>()
.getValue();
}
if (!isDifferent)
allMemrefs.push_back(cstOp);
isDifferent = true;
AffineExpr newExpr = expr.cast<AffineBinaryOpExpr>().getLHS() *
call_builder.getAffineSymbolExpr(0);
AffineExpr newExpr;
if (expr.getKind() == AffineExprKind::Mul)
newExpr = expr.cast<AffineBinaryOpExpr>().getLHS() *
call_builder.getAffineSymbolExpr(0);
else if (expr.getKind() == AffineExprKind::Mod) {
newExpr = expr.cast<AffineBinaryOpExpr>().getLHS() %
call_builder.getAffineSymbolExpr(0);
isMod = true;
} else
assert(false && "Unexpected affine expr kind");
newExprs.push_back(newExpr);
} else {
newExprs.push_back(expr);
Expand All @@ -2725,15 +2741,41 @@ LogicalResult runOutline(ModuleOp &mod, func::FuncOp &f, OutlineOp &outlineOp) {
auto map = AffineMap::get(targetLoadMap.getNumDims(), 1, newExprs,
targetFunc.getContext());
targetLoadOp->setAttr("map", AffineMapAttr::get(map));
if (isMod) {
// See the issue:
// https://github.com/cornell-zhang/hcl-dialect-prototype/issues/127
OpBuilder builder(targetLoadOp);
SmallVector<Value> indices(targetLoadOp.getIndices());
int pos = -1;
for (auto item :
llvm::enumerate(targetLoadOp.getAffineMap().getResults())) {
auto expr = item.value();
if (expr.isa<AffineBinaryOpExpr>() &&
expr.getKind() == AffineExprKind::Mod) {
pos = item.index();
break;
}
}
assert(pos != -1 && "Mod op not found");
auto modOp = builder.create<arith::RemSIOp>(
targetLoadOp.getLoc(), indices[pos],
indices[indices.size() - 1]);
indices.pop_back();
indices[pos] = modOp.getResult();
auto loadOp = builder.create<memref::LoadOp>(
targetLoadOp.getLoc(), targetLoadOp.getMemRef(), indices);
targetLoadOp.getResult().replaceAllUsesWith(loadOp.getResult());
targetLoadOp.erase();
}
}
}
// update previous CallOp
if (isDifferent && !isParameterized) {
for (auto callOp : f.getOps<func::CallOp>()) {
if (callOp.getCallee() == targetFunc.getName()) {
OpBuilder builder(callOp);
auto cst =
builder.create<arith::ConstantIndexOp>(targetFunc.getLoc(), 1);
auto cst = builder.create<arith::ConstantIndexOp>(
targetFunc.getLoc(), targetCst);
callOp->insertOperands(callOp.getNumOperands(), {cst});
}
}
Expand Down

0 comments on commit 20964ca

Please sign in to comment.