Skip to content

Commit

Permalink
lower: properly fix #355
Browse files Browse the repository at this point in the history
This commit properly fixes #355 by ensuring that duplicate locators are
not generated by different codepaths. This bug is masked by the
ir::simplify call which removes the extra locators in most situations.
  • Loading branch information
rohany committed Jan 2, 2022
1 parent 0ede002 commit 126370e
Showing 1 changed file with 27 additions and 13 deletions.
40 changes: 27 additions & 13 deletions src/lower/lowerer_impl_imperative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
inParallelLoopDepth++;
}

// Record that we might have some fresh locators that need to be recovered.
std::vector<Iterator> freshLocateIterators;

// Recover any available parents that were not recoverable previously
vector<Stmt> recoverySteps;
for (const IndexVar& varToRecover : provGraph.newlyRecoverableParents(forall.getIndexVar(), definedIndexVars)) {
Expand Down Expand Up @@ -634,17 +637,16 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
// the accessors for those locator variables as part of the recovery process.
// This is necessary after a fuse transformation, for example: If we fuse
// two index variables (i, j) into f, then after we've generated the loop for
// f, all locate accessors for i and j are now available for use.
// f, all locate accessors for i and j are now available for use. So, remember
// that we have some new locate iterators that should be recovered.
std::vector<Iterator> itersForVar;
for (auto& iters : iterators.levelIterators()) {
// Collect all level iterators that have locate and iterate over
// the recovered index variable.
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
itersForVar.push_back(iters.second);
freshLocateIterators.push_back(iters.second);
}
}
// Finally, declare all of the collected iterators' position access variables.
recoverySteps.push_back(this->declLocatePosVars(itersForVar));

// place underived guard
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
Expand Down Expand Up @@ -799,7 +801,15 @@ Stmt LowererImplImperative::lowerForall(Forall forall)
}
// Emit dimension coordinate iteration loop
else if (iterator.isDimensionIterator()) {
loops = lowerForallDimension(forall, point.locators(),
// A proper fix to #355. Adding information that those locate iterators are now ready is the
// correct way to recover them, rather than blindly duplicating the emitted locators.
auto locatorsCopy = std::vector<Iterator>(point.locators());
for (auto it : freshLocateIterators) {
if (!util::contains(locatorsCopy, it)) {
locatorsCopy.push_back(it);
}
}
loops = lowerForallDimension(forall, locatorsCopy,
inserters, appenders, reducedAccesses, recoveryStmt);
}
// Emit position iteration loop
Expand Down Expand Up @@ -1772,14 +1782,19 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
const set<Access>& reducedAccesses) {
Stmt initVals = resizeAndInitValues(appenders, reducedAccesses);

// Inserter positions
Stmt declInserterPosVars = declLocatePosVars(inserters);

// Locate positions
Stmt declLocatorPosVars = declLocatePosVars(locators);
// There can be overlaps between the inserters and locators, which results in
// duplicate emitting of variable declarations. We'll fix that here.
std::vector<Iterator> itersWithLocators;
for (auto it : inserters) {
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
}
for (auto it : locators) {
if (!util::contains(itersWithLocators, it)) { itersWithLocators.push_back(it); }
}
auto declPosVars = declLocatePosVars(itersWithLocators);

if (captureNextLocatePos) {
capturedLocatePos = Block::make(declInserterPosVars, declLocatorPosVars);
capturedLocatePos = declPosVars;
captureNextLocatePos = false;
}

Expand All @@ -1792,8 +1807,7 @@ Stmt LowererImplImperative::lowerForallBody(Expr coordinate, IndexStmt stmt,
// TODO: Emit code to insert coordinates

return Block::make(initVals,
declInserterPosVars,
declLocatorPosVars,
declPosVars,
body,
appendCoords);
}
Expand Down

0 comments on commit 126370e

Please sign in to comment.