From dcc52a73ddf67bbcc29e96cfba4762c544eabe73 Mon Sep 17 00:00:00 2001 From: Rohan Yadav Date: Wed, 13 Jan 2021 13:00:50 -0800 Subject: [PATCH] lower: fix a bug causing undefined variables when applying fuse Fixes #355. This commit fixes a bug where the fuse transformation would not generate necessary locator variables when applied to iteration over two dense variables. --- src/lower/iterator.cpp | 2 +- src/lower/lowerer_impl.cpp | 59 ++++++++++++++++++++++++-------------- test/tests-scheduling.cpp | 50 ++++++++++++++++++++++++++++++++ 3 files changed, 89 insertions(+), 22 deletions(-) diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index 9f5b7dd4b..a064469c8 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -349,7 +349,7 @@ std::ostream& operator<<(std::ostream& os, const Iterator& iterator) { if (iterator.isDimensionIterator()) { return os << "\u0394" << iterator.getIndexVar().getName(); } - return os << iterator.getTensor(); + return os << iterator.getTensor() << " " << iterator.getIndexVar(); } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 17a4dab3b..e7c2b0dce 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -872,6 +872,8 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, { Expr coordinate = getCoordinateVar(forall.getIndexVar()); + // cout << "Lowering forall dimension? " << forall << endl; + if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth++; atomicParallelUnit = forall.getParallelUnit(); @@ -2250,35 +2252,50 @@ Stmt LowererImpl::declLocatePosVars(vector locators) { for (Iterator& locator : locators) { accessibleIterators.insert(locator); - bool doLocate = true; + // Pull out some logic for constructing the locators for a given iterator. + auto addLocator = [&](const Iterator& iter) { + ModeFunction locate = iter.locate(coordinates(iter)); + taco_iassert(isValue(locate.getResults()[1], true)); + Stmt declarePosVar = VarDecl::make(iter.getPosVar(), locate.getResults()[0]); + result.push_back(declarePosVar); + }; + + // Look through all of the parent iterators. If any of these iterators + // are not accessible, we need to construct their accessors before emitting + // locator's accessors. This is because locator may use the ancestor's + // variables in its accessors. We add these ancestors into a vector and reverse + // it so that the highest parent in the tree's accessors get declared first. + std::vector ancestors; for (Iterator ancestorIterator = locator.getParent(); !ancestorIterator.isRoot() && ancestorIterator.hasLocate(); ancestorIterator = ancestorIterator.getParent()) { if (!accessibleIterators.contains(ancestorIterator)) { - doLocate = false; + // Since we're going to emit the locators for this iterator, add it to + // accessibleIterators so that other locators with this as an ancestor + // don't do the same. + accessibleIterators.insert(ancestorIterator); + ancestors.push_back(ancestorIterator); } } + for (auto it = ancestors.rbegin(); it != ancestors.rend(); it++) addLocator(*it); - if (doLocate) { - Iterator locateIterator = locator; - if (locateIterator.hasPosIter()) { - taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar())); - continue; // these will be recovered with separate procedure - } - do { - ModeFunction locate = locateIterator.locate(coordinates(locateIterator)); - taco_iassert(isValue(locate.getResults()[1], true)); - Stmt declarePosVar = VarDecl::make(locateIterator.getPosVar(), - locate.getResults()[0]); - result.push_back(declarePosVar); - - if (locateIterator.isLeaf()) { - break; - } - - locateIterator = locateIterator.getChild(); - } while (accessibleIterators.contains(locateIterator)); + Iterator locateIterator = locator; + // Position iterators will be recovered with a separate procedure, so + // don't emit anything if locator is one. + if (locateIterator.hasPosIter()) { + taco_iassert(!provGraph.isUnderived(locateIterator.getIndexVar())); + continue; } + + // Once all parent locators have been declared, add the target and all + // children locators. + do { + addLocator(locateIterator); + if (locateIterator.isLeaf()) { + break; + } + locateIterator = locateIterator.getChild(); + } while (accessibleIterators.contains(locateIterator)); } return result.empty() ? Stmt() : Block::make(result); } diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index 0fa117be3..e66a89a08 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -72,6 +72,56 @@ TEST(scheduling, splitIndexStmt) { ASSERT_TRUE(equals(a(i) = b(i), i2Forall.getStmt())); } +TEST(scheduling, fuseDenseLoops) { + auto dim = 4; + Tensor A("A", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor B("B", {dim, dim, dim}, {Dense, Dense, Dense}); + Tensor expected("expected", {dim, dim, dim}, {Dense, Dense, Dense}); + IndexVar f("f"), g("g"); + for (int i = 0; i < dim; i++) { + for (int j = 0; j < dim; j++) { + for (int k = 0; k < dim; k++) { + A.insert({i, j, k}, i + j + k); + B.insert({i, j, k}, i + j + k); + expected.insert({i, j, k}, 2 * (i + j + k)); + } + } + } + A.pack(); + B.pack(); + expected.pack(); + + // Helper function to evaluate the target statement and verify the results. + // It takes in a function that applies some scheduling transforms to the + // input IndexStmt, and applies to the point-wise tensor addition below. + // The test is structured this way as TACO does its best to avoid re-compilation + // whenever possible. I.e. changing the stmt that a tensor is compiled with + // doesn't cause compilation to occur again. + auto testFn = [&](IndexStmt modifier (IndexStmt)) { + Tensor C("C", {dim, dim, dim}, {Dense, Dense, Dense}); + C(i, j, k) = A(i, j, k) + B(i, j, k); + auto stmt = C.getAssignment().concretize(); + C.compile(modifier(stmt)); + C.evaluate(); + ASSERT_TRUE(equals(C, expected)) << endl << C << endl << expected << endl; + }; + + // First, a sanity check with no transformations. + testFn([](IndexStmt stmt) { return stmt; }); + // Next, fuse the outer two loops. This tests the original bug in #355. + testFn([](IndexStmt stmt) { + IndexVar f("f"); + return stmt.fuse(i, j, f); + }); + // Lastly, fuse all of the loops into a single loop. This ensures that + // locators with a chain of ancestors have all of their dependencies + // generated in a valid ordering. + testFn([](IndexStmt stmt) { + IndexVar f("f"), g("g"); + return stmt.fuse(i, j, f).fuse(f, k, g); + }); +} + TEST(scheduling, lowerDenseMatrixMul) { Tensor A("A", {4, 4}, {Dense, Dense}); Tensor B("B", {4, 4}, {Dense, Dense});