diff --git a/include/taco/index_notation/transformations.h b/include/taco/index_notation/transformations.h index a708ee965..e0bf0797a 100644 --- a/include/taco/index_notation/transformations.h +++ b/include/taco/index_notation/transformations.h @@ -140,5 +140,11 @@ IndexStmt parallelizeOuterLoop(IndexStmt stmt); */ IndexStmt reorderLoopsTopologically(IndexStmt stmt); +/** + * Insert where statements with temporaries into the following statements kinds: + * 1. The result is a is scattered into but does not support random insert. + */ +IndexStmt insertTemporaries(IndexStmt stmt); + } #endif diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 80249fb64..b09258410 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -4,6 +4,7 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/error/error_messages.h" +#include "taco/util/collections.h" #include #include @@ -14,23 +15,28 @@ using namespace std; namespace taco { + // class Transformation Transformation::Transformation(Reorder reorder) : transformation(new Reorder(reorder)) { } + Transformation::Transformation(Precompute precompute) : transformation(new Precompute(precompute)) { } + Transformation::Transformation(Parallelize parallelize) : transformation(new Parallelize(parallelize)) { } + IndexStmt Transformation::apply(IndexStmt stmt, string* reason) const { return transformation->apply(stmt, reason); } + std::ostream& operator<<(std::ostream& os, const Transformation& t) { t.transformation->print(os); return os; @@ -43,19 +49,23 @@ struct Reorder::Content { IndexVar j; }; + Reorder::Reorder(IndexVar i, IndexVar j) : content(new Content) { content->i = i; content->j = j; } + IndexVar Reorder::geti() const { return content->i; } + IndexVar Reorder::getj() const { return content->j; } + IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { INIT_REASON(reason); @@ -111,10 +121,12 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const { return ReorderRewriter(*this, reason).reorder(stmt); } + void Reorder::print(std::ostream& os) const { os << "reorder(" << geti() << ", " << getj() << ")"; } + std::ostream& operator<<(std::ostream& os, const Reorder& reorder) { reorder.print(os); return os; @@ -129,9 +141,11 @@ struct Precompute::Content { TensorVar workspace; }; + Precompute::Precompute() : content(nullptr) { } + Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw, TensorVar workspace) : content(new Content) { content->expr = expr; @@ -140,22 +154,27 @@ Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw, content->workspace = workspace; } + IndexExpr Precompute::getExpr() const { return content->expr; } + IndexVar Precompute::geti() const { return content->i; } + IndexVar Precompute::getiw() const { return content->iw; } + TensorVar Precompute::getWorkspace() const { return content->workspace; } + static bool containsExpr(Assignment assignment, IndexExpr expr) { struct ContainsVisitor : public IndexNotationVisitor { using IndexNotationVisitor::visit; @@ -192,6 +211,7 @@ static bool containsExpr(Assignment assignment, IndexExpr expr) { return visitor.contains; } + static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) { Assignment assignment; match(stmt, @@ -205,6 +225,7 @@ static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) { return assignment; } + IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { INIT_REASON(reason); @@ -247,15 +268,18 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { return rewriter.rewrite(stmt); } + void Precompute::print(std::ostream& os) const { os << "precompute(" << getExpr() << ", " << geti() << ", " << getiw() << ", " << getWorkspace() << ")"; } + bool Precompute::defined() const { return content != nullptr; } + std::ostream& operator<<(std::ostream& os, const Precompute& precompute) { precompute.print(os); return os; @@ -267,17 +291,21 @@ struct Parallelize::Content { IndexVar i; }; + Parallelize::Parallelize() : content(nullptr) { } + Parallelize::Parallelize(IndexVar i) : content(new Content) { content->i = i; } + IndexVar Parallelize::geti() const { return content->i; } + IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { INIT_REASON(reason); @@ -340,10 +368,12 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const { return rewritten; } + void Parallelize::print(std::ostream& os) const { os << "parallelize(" << geti() << ")"; } + std::ostream& operator<<(std::ostream& os, const Parallelize& parallelize) { parallelize.print(os); return os; @@ -391,8 +421,10 @@ static vector> varOrderFromTensorLevels(set> depsFromVarOrders(map>> varOrders) { +static map> +depsFromVarOrders(map>> varOrders) { map> deps; for (pair>> varOrderPair : varOrders) { vector> varOrder = varOrderPair.second; @@ -411,9 +443,10 @@ static map> depsFromVarOrders(map topologicallySort(map> tensorDeps, - vector originalOrder){ + +static vector +topologicallySort(map> tensorDeps, + vector originalOrder){ vector sortedVars; unsigned long countVars = originalOrder.size(); while (sortedVars.size() < countVars) { @@ -448,6 +481,7 @@ vector topologicallySort(map> tensorDeps, return sortedVars; } + IndexStmt reorderLoopsTopologically(IndexStmt stmt) { // Collect tensorLevelVars which stores the pairs of IndexVar and tensor // level that each tensor is accessed at @@ -554,4 +588,100 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt) { return rewriter.rewrite(stmt); } +static bool compare(std::vector vars1, std::vector vars2) { + return util::all(util::zip(vars1, vars2), + [](const pair& v) { + return v.first == v.second; + }); +} + +// TODO Temporary function to insert workspaces into SpMM kernels +static IndexStmt optimizeSpMM(IndexStmt stmt) { + if (!isa(stmt)) { + return stmt; + } + Forall foralli = to(stmt); + IndexVar i = foralli.getIndexVar(); + + if (!isa(foralli.getStmt())) { + return stmt; + } + Forall forallk = to(foralli.getStmt()); + IndexVar k = forallk.getIndexVar(); + + if (!isa(forallk.getStmt())) { + return stmt; + } + Forall forallj = to(forallk.getStmt()); + IndexVar j = forallj.getIndexVar(); + + if (!isa(forallj.getStmt())) { + return stmt; + } + Assignment assignment = to(forallj.getStmt()); + + if (!isa(assignment.getRhs())) { + return stmt; + } + Mul mul = to(assignment.getRhs()); + + taco_iassert(isa(assignment.getLhs())); + if (!isa(mul.getA())) { + return stmt; + } + if (!isa(mul.getB())) { + return stmt; + } + + Access Aaccess = to(assignment.getLhs()); + Access Baccess = to(mul.getA()); + Access Caccess = to(mul.getB()); + + if (!compare(Aaccess.getIndexVars(), {i,j}) || + !compare(Baccess.getIndexVars(), {i,k}) || + !compare(Caccess.getIndexVars(), {k,j})) { + return stmt; + } + + TensorVar A = Aaccess.getTensorVar(); + TensorVar B = Baccess.getTensorVar(); + TensorVar C = Caccess.getTensorVar(); + + if (A.getFormat() != CSR || + B.getFormat() != CSR || + C.getFormat() != CSR) { + return stmt; + } + + // It's an SpMM statement so return an optimized SpMM statement + TensorVar w("w", Type(Float64, {Dimension()}), dense); + return forall(i, + where(forall(j, + A(i,j) = w(j)), + forall(k, + forall(j, + w(j) += B(i,k) * C(k,j))))); +} + +IndexStmt insertTemporaries(IndexStmt stmt) +{ + IndexStmt spmm = optimizeSpMM(stmt); + if (spmm != stmt) { + return spmm; + } + + // TODO Implement general workspacing when scattering into sparse result modes + + // Result dimensions that are indexed by free variables dominated by a + // reduction variable are scattered into. If any of these are compressed + // then we introduce a dense workspace to scatter into instead. The where + // statement must push the reduction loop into the producer side, leaving + // only the free variable loops on the consumer side. + + //vector reductionVars = getReductionVars(stmt); + //... + + return stmt; +} + } diff --git a/test/tests-transformation.cpp b/test/tests-transformation.cpp index 0d333921e..0ac455ce5 100644 --- a/test/tests-transformation.cpp +++ b/test/tests-transformation.cpp @@ -24,10 +24,13 @@ static TensorVar b("b", vectype, Sparse); static TensorVar c("c", vectype, Sparse); static TensorVar w("w", vectype, denseNew); -static TensorVar A("A", mattype, {Sparse, Sparse}); -static TensorVar B("B", mattype, {Sparse, Sparse}); -static TensorVar C("C", mattype, {Sparse, Sparse}); -static TensorVar D("D", mattype, {denseNew, denseNew}); +static TensorVar A("A", mattype, {Dense, Sparse}); +static TensorVar B("B", mattype, {Dense, Sparse}); +static TensorVar C("C", mattype, {Dense, Sparse}); +static TensorVar D("D", mattype, {Sparse, Sparse}); +static TensorVar E("E", mattype, {Sparse, Sparse}); +static TensorVar F("F", mattype, {Sparse, Sparse}); +static TensorVar G("D", mattype, {denseNew, denseNew}); static TensorVar W("W", mattype, {denseNew, denseNew}); static TensorVar S("S", tentype, Sparse); @@ -197,21 +200,20 @@ INSTANTIATE_TEST_CASE_P(precompute, apply, ) ); -INSTANTIATE_TEST_CASE_P(parallelize, precondition, - Values( - PreconditionTest(Parallelize(i), - forall(i, a(i) = b(i)) - ), - PreconditionTest(Parallelize(i), - forall(i, w(i) = a(i) + b(i)) - ), - PreconditionTest(Parallelize(i), - forall(i, forall(j, w(i) = A(i, j) * B(i, j))) - )/*, TODO: add precondition when lowering supports reductions - PreconditionTest(Parallelize(j), - forall(i, forall(j, w(j) = W(i, j))) - )*/ - ) +INSTANTIATE_TEST_CASE_P(parallelize, precondition, Values( + PreconditionTest(Parallelize(i), + forall(i, a(i) = b(i))), + + PreconditionTest(Parallelize(i), + forall(i, w(i) = a(i) + b(i)) ), + + PreconditionTest(Parallelize(i), + forall(i, forall(j, w(i) = D(i, j) * E(i, j))))) + + /*, TODO: add precondition when lowering supports reductions + PreconditionTest(Parallelize(j), + forall(i, forall(j, w(j) = W(i, j))) + )*/ ); INSTANTIATE_TEST_CASE_P(parallelize, apply, @@ -252,17 +254,17 @@ INSTANTIATE_TEST_CASE_P(misc, reorderLoopsTopologically, Values( NotationTest(forall(j, forall(i, W(i,j) = A(i,j))), forall(i, forall(j, W(i,j) = A(i,j)))), - NotationTest(forall(j, forall(i, W(i,j) = D(i,j))), - forall(j, forall(i, W(i,j) = D(i,j)))), + NotationTest(forall(j, forall(i, W(i,j) = G(i,j))), + forall(j, forall(i, W(i,j) = G(i,j)))), - NotationTest(forall(i, forall(j, W(j,i) = D(i,j))), - forall(i, forall(j, W(j,i) = D(i,j)))), + NotationTest(forall(i, forall(j, W(j,i) = G(i,j))), + forall(i, forall(j, W(j,i) = G(i,j)))), - NotationTest(forall(j, forall(i, A(i,j) = D(i,j))), - forall(i, forall(j, A(i,j) = D(i,j)))), + NotationTest(forall(j, forall(i, A(i,j) = G(i,j))), + forall(i, forall(j, A(i,j) = G(i,j)))), - NotationTest(forall(j, forall(i, W(i,j) = D(i,j) + A(i, j))), - forall(i, forall(j, W(i,j) = D(i,j) + A(i, j)))), + NotationTest(forall(j, forall(i, W(i,j) = G(i,j) + A(i, j))), + forall(i, forall(j, W(i,j) = G(i,j) + A(i, j)))), NotationTest(forall(i, forall(j, forall(k, X(i,j,k) = V(i,j,k)))), forall(i, forall(j, forall(k, X(i,j,k) = V(i,j,k))))), @@ -283,6 +285,27 @@ INSTANTIATE_TEST_CASE_P(misc, reorderLoopsTopologically, Values( A(i,j) += B(i,k) * C(k,j))))) )); + +struct insertTemporaries : public TestWithParam {}; + +TEST_P(insertTemporaries, test) { + IndexStmt actual = taco::insertTemporaries(GetParam().actual); + ASSERT_NOTATION_EQ(GetParam().expected, actual); +} + +INSTANTIATE_TEST_CASE_P(spmm, insertTemporaries, Values( + NotationTest(forall(i, + forall(k, + forall(j, + A(i,j) += B(i,k) * C(k,j)))), + forall(i, + where(forall(j, + A(i,j) = w(j)), + forall(k, + forall(j, + w(j) += B(i,k) * C(k,j)))))) +)); + /* TEST(schedule, workspace_spmspm) { TensorBase A("A", Float(64), {3,3}, Format({Dense,Sparse}));