Skip to content

Commit

Permalink
Merge pull request #362 from rohany/fuse-bug-2
Browse files Browse the repository at this point in the history
lower: fix a bug causing undefined variables when applying fuse
  • Loading branch information
stephenchouca authored Jan 20, 2021
2 parents 864b65d + 6e57653 commit 468ad7f
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 1 deletion.
19 changes: 18 additions & 1 deletion src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,24 @@ Stmt LowererImpl::lowerForall(Forall forall)
Expr recoveredValue = provGraph.recoverVariable(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
taco_iassert(indexVarToExprMap.count(varToRecover));
recoverySteps.push_back(VarDecl::make(indexVarToExprMap[varToRecover], recoveredValue));

// After we've recovered this index variable, some iterators are now
// accessible for use when declaring locator access variables. So, generate
// the accessors for those locator variables as part of the recovery process.
// This is necessary after a fuse transformation, for example: If we fuse
// two index variables (i, j) into f, then after we've generated the loop for
// f, all locate accessors for i and j are now available for use.
std::vector<Iterator> itersForVar;
for (auto& iters : iterators.levelIterators()) {
// Collect all level iterators that have locate and iterate over
// the recovered index variable.
if (iters.second.getIndexVar() == varToRecover && iters.second.hasLocate()) {
itersForVar.push_back(iters.second);
}
}
// Finally, declare all of the collected iterators' position access variables.
recoverySteps.push_back(this->declLocatePosVars(itersForVar));

// place underived guard
std::vector<ir::Expr> iterBounds = provGraph.deriveIterBounds(varToRecover, definedIndexVarsOrdered, underivedBounds, indexVarToExprMap, iterators);
if (forallNeedsUnderivedGuards && underivedBounds.count(varToRecover) &&
Expand Down Expand Up @@ -2275,7 +2293,6 @@ Stmt LowererImpl::declLocatePosVars(vector<Iterator> locators) {
if (locateIterator.isLeaf()) {
break;
}

locateIterator = locateIterator.getChild();
} while (accessibleIterators.contains(locateIterator));
}
Expand Down
48 changes: 48 additions & 0 deletions test/tests-scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,54 @@ TEST(scheduling, splitIndexStmt) {
ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt()));
}

TEST(scheduling, fuseDenseLoops) {
auto dim = 4;
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
IndexVar f("f"), g("g");
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
A.insert({i, j, k}, i + j + k);
B.insert({i, j, k}, i + j + k);
expected.insert({i, j, k}, 2 * (i + j + k));
}
}
}
A.pack();
B.pack();
expected.pack();

// Helper function to evaluate the target statement and verify the results.
// It takes in a function that applies some scheduling transforms to the
// input IndexStmt, and applies to the point-wise tensor addition below.
// The test is structured this way as TACO does its best to avoid re-compilation
// whenever possible. I.e. changing the stmt that a tensor is compiled with
// doesn't cause compilation to occur again.
auto testFn = [&](std::function<IndexStmt(IndexStmt)> modifier) {
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
C(i, j, k) = A(i, j, k) + B(i, j, k);
auto stmt = C.getAssignment().concretize();
C.compile(modifier(stmt));
C.evaluate();
ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl;
};

// First, a sanity check with no transformations.
testFn([](IndexStmt stmt) { return stmt; });
// Next, fuse the outer two loops. This tests the original bug in #355.
testFn([&](IndexStmt stmt) {
return stmt.fuse(i, j, f);
});
// Lastly, fuse all of the loops into a single loop. This ensures that
// locators with a chain of ancestors have all of their dependencies
// generated in a valid ordering.
testFn([&](IndexStmt stmt) {
return stmt.fuse(i, j, f).fuse(f, k, g);
});
}

TEST(scheduling, lowerDenseMatrixMul) {
Tensor<double> A("A", {4, 4}, {Dense, Dense});
Tensor<double> B("B", {4, 4}, {Dense, Dense});
Expand Down
23 changes: 23 additions & 0 deletions test/tests-transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,26 @@ TEST(DISABLED_lower, transpose3) {
&reason));
ASSERT_EQ(error::expr_transposition, reason);
}

// denseIterationTranspose tests a dense iteration that contain a transposition
// of one of the tensors.
TEST(lower, denseIterationTranspose) {
auto dim = 4;
Tensor<int> A("A", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> B("B", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> C("C", {dim, dim, dim}, {Dense, Dense, Dense});
Tensor<int> expected("expected", {dim, dim, dim}, {Dense, Dense, Dense});
for (int i = 0; i < dim; i++) {
for (int j = 0; j < dim; j++) {
for (int k = 0; k < dim; k++) {
A.insert({i, j, k}, i + j + k);
B.insert({i, j, k}, i + j + k);
expected.insert({i, j, k}, 2 * (i + j + k));
}
}
}
A.pack(); B.pack(); expected.pack();
C(i, j, k) = A(i, j, k) + B(k, j, i);
C.evaluate();
ASSERT_TRUE(equals(C, expected));
}

0 comments on commit 468ad7f

Please sign in to comment.