From 535d777ac4cd074438129ec45fa6362c8180ae7c Mon Sep 17 00:00:00 2001 From: Rawn Date: Mon, 13 Jan 2020 02:59:15 -0500 Subject: [PATCH 01/27] IndexVars can be used in computation. Need to add more tests and allow propagation of indexVar type to generated code --- include/taco/index_notation/index_notation.h | 52 +++++++------- .../index_notation/index_notation_nodes.h | 22 ++++++ .../index_notation/index_notation_printer.h | 1 + .../index_notation/index_notation_rewriter.h | 2 + .../index_notation/index_notation_visitor.h | 4 ++ include/taco/lower/lowerer_impl.h | 3 + src/index_notation/index_notation.cpp | 67 +++++++++++++++++-- src/index_notation/index_notation_nodes.cpp | 22 ++++++ src/index_notation/index_notation_printer.cpp | 4 ++ .../index_notation_rewriter.cpp | 18 +++++ src/index_notation/index_notation_visitor.cpp | 3 + src/index_notation/transformations.cpp | 3 +- src/lower/expr_tools.cpp | 9 +++ src/lower/iterator.cpp | 21 +++--- src/lower/lowerer_impl.cpp | 5 ++ src/lower/merge_lattice.cpp | 5 ++ test/tests-expr.cpp | 19 ++++++ test/tests-indexexpr.cpp | 9 +++ 18 files changed, 228 insertions(+), 41 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 2debd3c43..cd113b938 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -47,6 +47,7 @@ struct DivNode; struct CastNode; struct CallIntrinsicNode; struct ReductionNode; +struct IndexVarNode; struct AssignmentNode; struct YieldNode; @@ -415,6 +416,31 @@ class CallIntrinsic : public IndexExpr { typedef CallIntrinsicNode Node; }; + +/// Index variables are used to index into tensors in index expressions, and +/// they represent iteration over the tensor modes they index into. +/// Index variables can also be used in computation +class IndexVar : public IndexExpr { +public: + IndexVar(); + IndexVar(const std::string& name); + IndexVar(const std::string& name, const Datatype& type); + IndexVar(const IndexVarNode *); + + /// Returns the name of the index variable. + std::string getName() const; + + // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer + friend bool operator==(const IndexVar&, const IndexVar&); + friend bool operator<(const IndexVar&, const IndexVar&); + friend bool operator!=(const IndexVar&, const IndexVar&); + friend bool operator>=(const IndexVar&, const IndexVar&); + friend bool operator<=(const IndexVar&, const IndexVar&); + friend bool operator>(const IndexVar&, const IndexVar&); + + typedef IndexVarNode Node; +}; + /// Create calls to various intrinsics. IndexExpr mod(IndexExpr, IndexExpr); IndexExpr abs(IndexExpr); @@ -792,29 +818,6 @@ class Multi : public IndexStmt { /// Create a multi index statement. Multi multi(IndexStmt stmt1, IndexStmt stmt2); -/// Index variables are used to index into tensors in index expressions, and -/// they represent iteration over the tensor modes they index into. -class IndexVar : public util::Comparable { -public: - IndexVar(); - IndexVar(const std::string& name); - - /// Returns the name of the index variable. - std::string getName() const; - - friend bool operator==(const IndexVar&, const IndexVar&); - friend bool operator<(const IndexVar&, const IndexVar&); - - -private: - struct Content; - std::shared_ptr content; -}; - -struct IndexVar::Content { - std::string name; -}; - std::ostream& operator<<(std::ostream&, const IndexVar&); /// A suchthat statement provides a set of IndexVarRel that constrain @@ -916,7 +919,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr); bool isReductionNotation(IndexStmt, std::string* reason=nullptr); /// Check whether the statement is in the concrete index notation dialect. -/// This means every index variable has a forall node, there are no reduction +/// This means every index variable has a forall node, each index variable used +/// for computation is under a forall node for that variable, there are no reduction /// nodes, and that every reduction variable use is nested inside a compound /// assignment statement. You can optionally pass in a pointer to a string /// that the reason why it is not concrete notation is printed to. diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 95439cd6b..bde1b709a 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -5,6 +5,7 @@ #include #include "taco/type.h" +#include "taco/util/comparable.h" #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_nodes_abstract.h" #include "taco/index_notation/index_notation_visitor.h" @@ -185,6 +186,27 @@ struct ReductionNode : public IndexExprNode { IndexExpr a; }; +struct IndexVarNode : public IndexExprNode, public util::Comparable { + IndexVarNode() = delete; + IndexVarNode(const std::string& name, const Datatype& type); + + void accept(IndexExprVisitorStrict* v) const { + v->visit(this); + } + + std::string getName() const; + + friend bool operator==(const IndexVarNode& a, const IndexVarNode& b); + friend bool operator<(const IndexVarNode& a, const IndexVarNode& b); + +private: + struct Content; + std::shared_ptr content; +}; + +struct IndexVarNode::Content { + std::string name; +}; // Index Statements struct AssignmentNode : public IndexStmtNode { diff --git a/include/taco/index_notation/index_notation_printer.h b/include/taco/index_notation/index_notation_printer.h index 61b080fc0..ed2dd7abb 100644 --- a/include/taco/index_notation/index_notation_printer.h +++ b/include/taco/index_notation/index_notation_printer.h @@ -27,6 +27,7 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict { void visit(const CastNode*); void visit(const CallIntrinsicNode*); void visit(const ReductionNode*); + void visit(const IndexVarNode*); // Tensor Expressions void visit(const AssignmentNode*); diff --git a/include/taco/index_notation/index_notation_rewriter.h b/include/taco/index_notation/index_notation_rewriter.h index 3551aac5e..a4e340b0f 100644 --- a/include/taco/index_notation/index_notation_rewriter.h +++ b/include/taco/index_notation/index_notation_rewriter.h @@ -34,6 +34,7 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict { virtual void visit(const CastNode* op) = 0; virtual void visit(const CallIntrinsicNode* op) = 0; virtual void visit(const ReductionNode* op) = 0; + virtual void visit(const IndexVarNode* op) = 0; }; @@ -95,6 +96,7 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict { virtual void visit(const CastNode* op); virtual void visit(const CallIntrinsicNode* op); virtual void visit(const ReductionNode* op); + virtual void visit(const IndexVarNode* op); virtual void visit(const AssignmentNode* op); virtual void visit(const YieldNode* op); diff --git a/include/taco/index_notation/index_notation_visitor.h b/include/taco/index_notation/index_notation_visitor.h index 97a70adc2..9e3622289 100644 --- a/include/taco/index_notation/index_notation_visitor.h +++ b/include/taco/index_notation/index_notation_visitor.h @@ -24,6 +24,7 @@ struct CallIntrinsicNode; struct UnaryExprNode; struct BinaryExprNode; struct ReductionNode; +struct IndexVarNode; struct AssignmentNode; struct YieldNode; @@ -52,6 +53,7 @@ class IndexExprVisitorStrict { virtual void visit(const CastNode*) = 0; virtual void visit(const CallIntrinsicNode*) = 0; virtual void visit(const ReductionNode*) = 0; + virtual void visit(const IndexVarNode*) = 0; }; class IndexStmtVisitorStrict { @@ -100,6 +102,7 @@ class IndexNotationVisitor : public IndexNotationVisitorStrict { virtual void visit(const UnaryExprNode* node); virtual void visit(const BinaryExprNode* node); virtual void visit(const ReductionNode* node); + virtual void visit(const IndexVarNode* node); // Index Statments virtual void visit(const AssignmentNode* node); @@ -169,6 +172,7 @@ class Matcher : public IndexNotationVisitor { RULE(BinaryExprNode) RULE(UnaryExprNode) + RULE(IndexVarNode) RULE(AssignmentNode) RULE(YieldNode) diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index 589c192ed..66fa5330d 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -204,6 +204,9 @@ class LowererImpl : public util::Uncopyable { /// Lower an intrinsic function call expression. virtual ir::Expr lowerCallIntrinsic(CallIntrinsic call); + /// Lower an IndexVar expression + virtual ir::Expr lowerIndexVar(IndexVar var); + /// Lower a concrete index variable statement. ir::Stmt lower(IndexStmt stmt); diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 5c2a31217..8c419752b 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -124,6 +124,21 @@ struct Equals : public IndexNotationVisitorStrict { using IndexNotationVisitorStrict::visit; + void visit(const IndexVarNode* anode) { + if(!isa(bExpr.ptr)) { + eq = false; + return; + } + + auto bnode = to(bExpr.ptr); + if(anode != bnode) { + eq = false; + return; + } + + eq = true; + } + void visit(const AccessNode* anode) { if (!isa(bExpr.ptr)) { eq = false; @@ -1419,20 +1434,47 @@ template <> SuchThat to(IndexStmt s) { // class IndexVar IndexVar::IndexVar() : IndexVar(util::uniqueName('i')) {} -IndexVar::IndexVar(const std::string& name) : content(new Content) { - content->name = name; +IndexVar::IndexVar(const std::string& name) : IndexVar(name, Datatype::Int32) {} + +IndexVar::IndexVar(const std::string& name, const Datatype& type) : IndexVar(new IndexVarNode(name, type)) {} + +IndexVar::IndexVar(const IndexVarNode* n) : IndexExpr(n) {} + +template <> bool isa(IndexExpr e) { + return isa(e.ptr); +} + +template <> IndexVar to(IndexExpr e) { + taco_iassert(isa(e)); + return IndexVar(to(e.ptr)); } std::string IndexVar::getName() const { - return content->name; + return getNode(*this)->getName(); } bool operator==(const IndexVar& a, const IndexVar& b) { - return a.content == b.content; + return *getNode(a) == *getNode(b); } bool operator<(const IndexVar& a, const IndexVar& b) { - return a.content < b.content; + return *getNode(a) < *getNode(b); +} + +bool operator!=(const IndexVar& a , const IndexVar& b) { + return *getNode(a) != *getNode(b); +} + +bool operator>=(const IndexVar& a, const IndexVar& b) { + return *getNode(a) >= *getNode(b); +} + +bool operator<=(const IndexVar& a, const IndexVar& b) { + return *getNode(a) <= *getNode(b); +} + +bool operator>(const IndexVar& a , const IndexVar& b) { + return *getNode(a) > *getNode(b); } std::ostream& operator<<(std::ostream& os, const IndexVar& var) { @@ -1709,13 +1751,22 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) { std::function([&](const AccessNode* op) { for (auto& var : op->indexVars) { // non underived variables may appear in temporaries, but we don't check these - if (!boundVars.contains(var) && provGraph.isUnderived(var) && (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { + if (!boundVars.contains(var) && provGraph.isUnderived(var) && + (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { *reason = "all variables in concrete notation must be bound by a " "forall statement"; isConcrete = false; } } }), + std::function([&](const IndexVarNode* op) { + IndexVar var(op); + if (!boundVars.contains(var) && provGraph.isUnderived(var) && + (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { + *reason = "index variables used in compute statements must be nested under a forall"; + isConcrete = false; + } + }), std::function([&](const WhereNode* op, Matcher* ctx) { bool alreadyInProducer = inWhereProducer; inWhereProducer = true; @@ -2154,6 +2205,10 @@ struct Zero : public IndexNotationRewriterStrict { expr = op; } + void visit(const IndexVarNode* op) { + expr = op; + } + template IndexExpr visitUnaryOp(const T *op) { IndexExpr a = rewrite(op->a); diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index bc67e9218..813e8a4e8 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -36,4 +36,26 @@ ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) taco_iassert(isa(op.ptr)); } +IndexVarNode::IndexVarNode(const std::string& name, const Datatype& type) + : IndexExprNode(type), content(new Content) { + + if (!type.isInt() && !type.isUInt()) { + taco_not_supported_yet << ". IndexVars must be integral type."; + } + + content->name = name; +} + +std::string IndexVarNode::getName() const { + return content->name; +} + +bool operator==(const IndexVarNode& a, const IndexVarNode& b) { + return a.content->name == b.content->name; +} + +bool operator<(const IndexVarNode& a, const IndexVarNode& b) { + return a.content->name < b.content->name; +} + } diff --git a/src/index_notation/index_notation_printer.cpp b/src/index_notation/index_notation_printer.cpp index 58305077e..96501fec6 100644 --- a/src/index_notation/index_notation_printer.cpp +++ b/src/index_notation/index_notation_printer.cpp @@ -25,6 +25,10 @@ void IndexNotationPrinter::visit(const AccessNode* op) { } } +void IndexNotationPrinter::visit(const IndexVarNode* op) { + os << op->getName(); +} + void IndexNotationPrinter::visit(const LiteralNode* op) { switch (op->getDataType().getKind()) { case Datatype::Bool: diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 954a1c78f..0a8570a9b 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -42,6 +42,10 @@ void IndexNotationRewriter::visit(const AccessNode* op) { expr = op; } +void IndexNotationRewriter::visit(const IndexVarNode* op) { + expr = op; +} + template IndexExpr visitUnaryOp(const T *op, IndexNotationRewriter *rw) { IndexExpr a = rw->rewrite(op->a); @@ -246,6 +250,10 @@ struct ReplaceRewriter : public IndexNotationRewriter { SUBSTITUTE_EXPR; } + void visit(const IndexVarNode* op) { + SUBSTITUTE_EXPR; + } + void visit(const LiteralNode* op) { SUBSTITUTE_EXPR; } @@ -334,6 +342,16 @@ struct ReplaceIndexVars : public IndexNotationRewriter { } } + void visit(const IndexVarNode* op) { + + IndexVar var(op); + if(util::contains(substitutions, var)) { + expr = substitutions.at(var); + } else { + expr = var; + } + } + // TODO: Replace in assignments }; diff --git a/src/index_notation/index_notation_visitor.cpp b/src/index_notation/index_notation_visitor.cpp index ec77d99ba..e9b1f952d 100644 --- a/src/index_notation/index_notation_visitor.cpp +++ b/src/index_notation/index_notation_visitor.cpp @@ -35,6 +35,9 @@ IndexNotationVisitor::~IndexNotationVisitor() { void IndexNotationVisitor::visit(const AccessNode* op) { } +void IndexNotationVisitor::visit(const IndexVarNode *op) { +} + void IndexNotationVisitor::visit(const LiteralNode* op) { } diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 752416955..a7cf6f857 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -236,9 +236,10 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { TensorVar ws = precompute.getWorkspace(); IndexExpr e = precompute.getExpr(); IndexVar iw = precompute.getiw(); + const std::map index_var_substitutions = {{i,iw}}; IndexStmt consumer = forall(i, replace(s, {{e, ws(i)}})); - IndexStmt producer = forall(iw, ws(iw) = replace(e, {{i,iw}})); + IndexStmt producer = forall(iw, ws(iw) = replace(e, index_var_substitutions)); Where where(consumer, producer); stmt = where; diff --git a/src/lower/expr_tools.cpp b/src/lower/expr_tools.cpp index ded5c53dd..d34a79703 100644 --- a/src/lower/expr_tools.cpp +++ b/src/lower/expr_tools.cpp @@ -189,6 +189,15 @@ class SubExprVisitor : public IndexExprVisitorStrict { using IndexExprVisitorStrict::visit; + void visit(const IndexVarNode* op) { + IndexVar var(op); + if (util::contains(vars, var)) { + subExpr = op; + return; + } + subExpr = IndexExpr(); + } + void visit(const AccessNode* op) { // If any variable is in the set of index variables, then the expression // has not been emitted at a previous level, so we keep it. diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index 6c86cc434..d90b3a248 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -41,12 +41,12 @@ Iterator::Iterator(IndexVar indexVar) : Iterator(indexVar, true) { Iterator::Iterator(IndexVar indexVar, bool isFull) : content(new Content) { content->indexVar = indexVar; - content->coordVar = Var::make(indexVar.getName(), Int()); - content->posVar = Var::make(indexVar.getName() + "_pos", Int()); + content->coordVar = Var::make(indexVar.getName(), indexVar.getDataType()); + content->posVar = Var::make(indexVar.getName() + "_pos", indexVar.getDataType()); if (!isFull) { - content->beginVar = Var::make(indexVar.getName() + "_begin", Int()); - content->endVar = Var::make(indexVar.getName() + "_end", Int()); + content->beginVar = Var::make(indexVar.getName() + "_begin", indexVar.getDataType()); + content->endVar = Var::make(indexVar.getName() + "_end", indexVar.getDataType()); } } @@ -72,12 +72,12 @@ Iterator::Iterator(IndexVar indexVar, Expr tensor, Mode mode, Iterator parent, if (useNameForPos) { posNamePrefix = name; } - content->posVar = Var::make(name, Int()); - content->endVar = Var::make("p" + modeName + "_end", Int()); - content->beginVar = Var::make("p" + modeName + "_begin", Int()); + content->posVar = Var::make(name, indexVar.getDataType()); + content->endVar = Var::make("p" + modeName + "_end", indexVar.getDataType()); + content->beginVar = Var::make("p" + modeName + "_begin", indexVar.getDataType()); - content->coordVar = Var::make(name, Int()); - content->segendVar = Var::make(modeName + "_segend", Int()); + content->coordVar = Var::make(name, indexVar.getDataType()); + content->segendVar = Var::make(modeName + "_segend", indexVar.getDataType()); content->validVar = Var::make("v" + modeName, Bool); } @@ -400,6 +400,7 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) { ProvenanceGraph provGraph = ProvenanceGraph(stmt); set underivedAdded; + set computeVars; // Create dimension iterators match(stmt, function([&](auto n, auto m) { @@ -424,7 +425,7 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) }), function([&](auto n, auto m) { m->match(n->rhs); - m->match(n->lhs); + m->match(n->lhs); }) ); diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 4a08a02e1..7ff648477 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -59,6 +59,7 @@ class LowererImpl::Visitor : public IndexNotationVisitorStrict { void visit(const ReductionNode* node) { taco_ierror << "Reduction nodes not supported in concrete index notation"; } + void visit(const IndexVarNode* node) { expr = impl->lowerIndexVar(node); } }; LowererImpl::LowererImpl() : visitor(new Visitor(this)) { @@ -1371,6 +1372,10 @@ Expr LowererImpl::lowerAccess(Access access) { : getReducedValueVar(access); } +Expr LowererImpl::lowerIndexVar(IndexVar var) { + return indexVarToExprMap.at(var); +} + Expr LowererImpl::lowerLiteral(Literal literal) { switch (literal.getDataType().getKind()) { diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 5c657dbca..9cef143e0 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -77,6 +77,11 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { return MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); } + void visit(const IndexVarNode* varNode) { + IndexVar var(varNode); + lattice = MergeLattice({MergePoint({iterators.modeIterator(var)}, {}, {})}); + } + void visit(const AccessNode* access) { if (util::contains(latticesOfTemporaries, access->tensorVar)) { diff --git a/test/tests-expr.cpp b/test/tests-expr.cpp index 9877e7c5c..3fccdb172 100644 --- a/test/tests-expr.cpp +++ b/test/tests-expr.cpp @@ -160,3 +160,22 @@ TEST(expr, redefine) { a.evaluate(); ASSERT_EQ(a.begin()->second, 42.0); } + +TEST(expr, indexVarSimple) { + Tensor a("actual", {3,3}, Dense); + a(i, j) = i + j; + a.compile(); + a.assemble(); + a.compute(); + + Tensor expected("expected", a.getDimensions(), Dense); + + for(int i = 0; i < a.getDimensions()[0]; ++i) { + for(int j = 0; j < a.getDimensions()[1]; ++j) { + expected.insert({i, j}, i+j); + } + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} \ No newline at end of file diff --git a/test/tests-indexexpr.cpp b/test/tests-indexexpr.cpp index 72d28c5b4..cdca2b2eb 100644 --- a/test/tests-indexexpr.cpp +++ b/test/tests-indexexpr.cpp @@ -73,3 +73,12 @@ TEST(indexexpr, div) { ASSERT_TRUE(equals(div.getA(), b(i))); ASSERT_TRUE(equals(div.getB(), Literal(2))); } + +TEST(indexexpr, indexvar) { + IndexExpr expr = i; + ASSERT_TRUE(isa(expr)); + ASSERT_TRUE(isa(expr.ptr)); + IndexVar var = to(expr); + ASSERT_EQ(type(), var.getDataType()); + ASSERT_EQ("i", var.getName()); +} From 3c4560f2292ba9c8b1e42a42ec3b0fd747b93eb6 Mon Sep 17 00:00:00 2001 From: Rawn Date: Tue, 28 Jan 2020 17:17:39 -0500 Subject: [PATCH 02/27] Added tests --- src/lower/iterator.cpp | 6 +++- src/lower/lowerer_impl.cpp | 2 ++ test/tests-expr.cpp | 29 ++++++++++++++++-- test/tests-scheduling.cpp | 63 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index d90b3a248..f357e202c 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -404,7 +404,8 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) // Create dimension iterators match(stmt, function([&](auto n, auto m) { - content->modeIterators.insert({n->indexVar, Iterator(n->indexVar, !provGraph.hasCoordBounds(n->indexVar) && provGraph.isCoordVariable(n->indexVar))}); + content->modeIterators.insert({n->indexVar, Iterator(n->indexVar, !provGraph.hasCoordBounds(n->indexVar) + && provGraph.isCoordVariable(n->indexVar))}); for (const IndexVar& underived : provGraph.getUnderivedAncestors(n->indexVar)) { if (!underivedAdded.count(underived)) { content->modeIterators.insert({underived, underived}); @@ -412,6 +413,9 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) } } m->match(n->stmt); + }), + function([&](const IndexVarNode* var) { + }) ); diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 7ff648477..efbe7c39f 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1373,6 +1373,8 @@ Expr LowererImpl::lowerAccess(Access access) { } Expr LowererImpl::lowerIndexVar(IndexVar var) { + taco_iassert(util::contains(indexVarToExprMap, var)); + taco_iassert(provGraph.isRecoverable(var, definedIndexVars)); return indexVarToExprMap.at(var); } diff --git a/test/tests-expr.cpp b/test/tests-expr.cpp index 3fccdb172..537615800 100644 --- a/test/tests-expr.cpp +++ b/test/tests-expr.cpp @@ -164,9 +164,7 @@ TEST(expr, redefine) { TEST(expr, indexVarSimple) { Tensor a("actual", {3,3}, Dense); a(i, j) = i + j; - a.compile(); - a.assemble(); - a.compute(); + a.evaluate(); Tensor expected("expected", a.getDimensions(), Dense); @@ -177,5 +175,30 @@ TEST(expr, indexVarSimple) { } expected.pack(); + ASSERT_TENSOR_EQ(expected, a); +} + +TEST(expr, indexVarMix) { + Tensor a("actual", {3, 3}, dense); + Tensor b("input", {3, 3}, compressed); + + Tensor expected("expected", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + a.evaluate(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + ASSERT_TENSOR_EQ(expected, a); } \ No newline at end of file diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index 1bc0bdffa..3973817f9 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -779,3 +779,66 @@ TEST(scheduling_eval_test, spmv_fuse) { expected.compute(); ASSERT_TENSOR_EQ(expected, y); } + +TEST(scheduling_eval_test, indexVarSplit) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, compressed); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + IndexVar j0("j0"), j1("j1"); + stmt = stmt.split(j, j0, j1, 2); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} + +TEST(scheduling_eval_test, indexVarReorder) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, dense); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + stmt = stmt.reorder(i, j); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} \ No newline at end of file From 42e7f9900160d436808bce25c8337eeaafd18672 Mon Sep 17 00:00:00 2001 From: Rawn Date: Mon, 3 Feb 2020 22:58:25 -0500 Subject: [PATCH 03/27] Moved output of indexVar closer to class definition. Added test for Computing with IndexVars when using split command. Started implementing boilerplate code for iteration space algebra --- include/taco/index_notation/index_notation.h | 5 +- .../taco/index_notation/iteration_algebra.h | 192 ++++++++++++++++++ src/index_notation/iteration_algebra.cpp | 129 ++++++++++++ test/tests-scheduling-eval.cpp | 32 +++ 4 files changed, 356 insertions(+), 2 deletions(-) create mode 100644 include/taco/index_notation/iteration_algebra.h create mode 100644 src/index_notation/iteration_algebra.cpp diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index cd113b938..422d712e6 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -21,6 +21,7 @@ #include "taco/ir_tags.h" #include "taco/lower/iterator.h" #include "taco/index_notation/provenance_graph.h" +//#include "taco/index_notation/iteration_algebra.h" namespace taco { @@ -441,6 +442,8 @@ class IndexVar : public IndexExpr { typedef IndexVarNode Node; }; +std::ostream& operator<<(std::ostream&, const IndexVar&); + /// Create calls to various intrinsics. IndexExpr mod(IndexExpr, IndexExpr); IndexExpr abs(IndexExpr); @@ -818,8 +821,6 @@ class Multi : public IndexStmt { /// Create a multi index statement. Multi multi(IndexStmt stmt1, IndexStmt stmt2); -std::ostream& operator<<(std::ostream&, const IndexVar&); - /// A suchthat statement provides a set of IndexVarRel that constrain /// the iteration space for the child concrete index notation class SuchThat : public IndexStmt { diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h new file mode 100644 index 000000000..0ffb1e795 --- /dev/null +++ b/include/taco/index_notation/iteration_algebra.h @@ -0,0 +1,192 @@ +#ifndef TACO_ITERATION_ALGEBRA_H +#define TACO_ITERATION_ALGEBRA_H + +#include +#include "taco/util/comparable.h" +#include "taco/index_notation/index_notation.h" +#include "taco/util/intrusive_ptr.h" + + +namespace taco { + +class IterationAlgebraVisitorStrict; + +struct IterationAlgebraNode; +struct SegmentNode; +struct ComplementNode; +struct IntersectNode; +struct UnionNode; + + +/// The iteration algebra class describes a set expression composed of complements, intersections and unions on +/// TensorVars to describe the spaces in a Venn Diagram where computation will occur. +/// This algebra is used to generate merge lattices to co-iterate over tensors in an expression. +class IterationAlgebra : public util::IntrusivePtr { + +public: + IterationAlgebra(); + IterationAlgebra(const IterationAlgebraNode* n); + IterationAlgebra(TensorVar var); + + void accept(IterationAlgebraVisitorStrict* v) const; +}; + +/// A basic segment in an Iteration space. Given a Tensor A, this produces values only where A is defined. +class Segment: public IterationAlgebra { + Segment(); + Segment(TensorVar var); + Segment(const SegmentNode*); +}; + +/// This complements an iteration space algebra expression. Thus, it will flip the segments that are produced and +/// omitted in the input segment. +/// Example: Given a segment A which produces values where A is defined and omits values outside of A, +/// complement(A) will not compute where A is defined but compute over the background of A. +class Complement: public IterationAlgebra { + Complement(IterationAlgebra alg); + Complement(const ComplementNode* n); +}; + +/// This intersects two iteration space algebra expressions. This instructs taco to compute over areas where BOTH +/// set expressions produce values and ignore all other segments. +/// +/// Examples +/// +/// Given two tensors A and B: +/// Intersect(A, B) will produce values where both A and B are defined. An example of an operation with this property +/// is multiplication where both tensors are sparse over 0. +/// +/// Intersect(Complement(A), B) will produce values where only B is defined. This pattern can be useful for filtering +/// one tensor based on the values of another. +class Intersect: public IterationAlgebra { + Intersect(IterationAlgebra, IterationAlgebra); + Intersect(const IterationAlgebraNode*); +}; + +/// This takes the union of two iteration space algebra expressions. This instructs taco to compute over areas where +/// either set expression produces a value or both set expressions produce a value and ignore all other segments. +/// +/// Examples +/// +/// Given two tensors A and B: +/// Union(A, B) will produce values where either A or B is defined. Addition is an example of a union operator. +/// +/// Union(Complement(A), B) will produce values wherever A is not defined. In the places A is not defined, the compiler +/// will replace the value of A in the indexExpression with the fill value of the tensor A. Likewise, when B is not +/// defined, the compiler will replace the value of B in the index expression with the fill value of B. +class Union: public IterationAlgebra { + Union(IterationAlgebra, IterationAlgebra); + Union(const IterationAlgebraNode*); +}; + +/// A node in the iteration space algebra +struct IterationAlgebraNode: public util::Manageable, + private util::Uncopyable { +public: + IterationAlgebraNode() {} + + virtual ~IterationAlgebraNode() = default; + virtual void accept(IterationAlgebraVisitorStrict*) const = 0; +}; + +/// A binary node in the iteration space algebra. Used for Unions and Intersects +struct BinaryIterationAlgebraNode: public IterationAlgebraNode { + IterationAlgebra a; + IterationAlgebra b; +protected: + BinaryIterationAlgebraNode(IterationAlgebra a, IterationAlgebra b) : a(a), b(b) {} +}; + +/// A node which is wrapped by Segment. @see Segment +struct SegmentNode: public IterationAlgebraNode { +public: + SegmentNode() : IterationAlgebraNode() {} + SegmentNode(TensorVar var) : var(var) {} + void accept(IterationAlgebraVisitorStrict*) const; +private: + TensorVar var; +}; + +/// A node which is wrapped by Complement. @see Complement +struct ComplementNode: public IterationAlgebraNode { + IterationAlgebra a; +public: + ComplementNode(IterationAlgebra a) : a(a) {} + + void accept(IterationAlgebraVisitorStrict*) const; +}; + +/// A node which is wrapped by Intersect. @see Intersect +struct IntersectNode: public BinaryIterationAlgebraNode { +public: + IntersectNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} + + void accept(IterationAlgebraVisitorStrict*) const; +}; + +/// A node which is wrapped by Union. @see Union +struct UnionNode: public BinaryIterationAlgebraNode { +public: + UnionNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} + + void accept(IterationAlgebraVisitorStrict*) const; +}; + +/// Visits an iteration space algebra expression +class IterationAlgebraVisitorStrict { +public: + virtual ~IterationAlgebraVisitorStrict() {} + void visit(const IterationAlgebra& alg); + + virtual void visit(const SegmentNode*) = 0; + virtual void visit(const ComplementNode*) = 0; + virtual void visit(const IntersectNode*) = 0; + virtual void visit(const UnionNode*) = 0; +}; + +// Default Iteration Algebra visitor +class IterationAlgebraVisitor : public IterationAlgebraVisitorStrict { + virtual ~IterationAlgebraVisitor() {} + using IterationAlgebraVisitorStrict::visit; + + virtual void visit(const SegmentNode* n); + virtual void visit(const ComplementNode*); + virtual void visit(const IntersectNode*); + virtual void visit(const UnionNode*); +}; + +/// Rewrites an iteration algebra expression +class IterationAlgebraRewriterStrict : public IterationAlgebraVisitorStrict { +public: + virtual ~IterationAlgebraRewriterStrict() {} + IterationAlgebra rewrite(IterationAlgebra); + +protected: + /// Assign new algebra in visit method to replace the algebra nodes visited + IterationAlgebra alg; + + using IterationAlgebraVisitorStrict::visit; + + virtual void visit(const SegmentNode*) = 0; + virtual void visit(const ComplementNode*) = 0; + virtual void visit(const IntersectNode*) = 0; + virtual void visit(const UnionNode*) = 0; +}; + +class IterationAlgebraRewriter : public IterationAlgebraRewriterStrict { +public: + virtual ~IterationAlgebraRewriter() {} + +protected: + using IterationAlgebraRewriterStrict::visit; + + virtual void visit(const SegmentNode* n); + virtual void visit(const ComplementNode*); + virtual void visit(const IntersectNode*); + virtual void visit(const UnionNode*); +}; + + +} + +#endif // TACO_ITERATION_ALGEBRA_H diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp new file mode 100644 index 000000000..b0ac1c75a --- /dev/null +++ b/src/index_notation/iteration_algebra.cpp @@ -0,0 +1,129 @@ +#include "taco/index_notation/iteration_algebra.h" + +namespace taco { + +// Iteration Algebra Definitions + +IterationAlgebra::IterationAlgebra() : util::IntrusivePtr(nullptr) {} +IterationAlgebra::IterationAlgebra(const IterationAlgebraNode* n) : util::IntrusivePtr(n) {} +IterationAlgebra::IterationAlgebra(TensorVar var) : IterationAlgebra(new SegmentNode(var)) {} + +void IterationAlgebra::accept(IterationAlgebraVisitorStrict *v) const { + ptr->accept(v); +} + +// Definitions for Iteration Algebra + +// Segment +Segment::Segment() : IterationAlgebra(new SegmentNode) {} +Segment::Segment(TensorVar var) : IterationAlgebra(var) {} +Segment::Segment(const taco::SegmentNode *n) : IterationAlgebra(n) {} + +// Complement +Complement::Complement(const ComplementNode* n): IterationAlgebra(n) {} +Complement::Complement(IterationAlgebra alg) : Complement(new ComplementNode(alg)) {} + +// Intersect +Intersect::Intersect(IterationAlgebra a, IterationAlgebra b) : Intersect(new IntersectNode(a, b)) {} +Intersect::Intersect(const IterationAlgebraNode* n) : IterationAlgebra(n) {} + +// Union +Union::Union(IterationAlgebra a, IterationAlgebra b) : Union(new UnionNode(a, b)) {} +Union::Union(const IterationAlgebraNode* n) : IterationAlgebra(n) {} + +// Node method definitions start here: + +// Definitions for SegmentNode +void SegmentNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +// Definitions for ComplementNode +void ComplementNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +// Definitions for IntersectNode +void IntersectNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +// Definitions for UnionNode +void UnionNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +// Visitor definitions start here: + +// IterationAlgebraVisitorStrict definitions +void IterationAlgebraVisitorStrict::visit(const IterationAlgebra &alg) { + alg.accept(this); +} + +void IterationAlgebraVisitor::visit(const SegmentNode *n) { +} + +void IterationAlgebraVisitor::visit(const ComplementNode *n) { + n->a.accept(this); +} + +void IterationAlgebraVisitor::visit(const IntersectNode *n) { + n->a.accept(this); + n->b.accept(this); +} + +void IterationAlgebraVisitor::visit(const UnionNode *n) { + n->a.accept(this); + n->b.accept(this); +} + +// IterationAlgebraRewriter definitions start here: +IterationAlgebra IterationAlgebraRewriterStrict::rewrite(IterationAlgebra iter_alg) { + if(iter_alg.defined()) { + iter_alg.accept(this); + alg = iter_alg; + } + else { + iter_alg = IterationAlgebra(); + } + + alg = IterationAlgebra(); + return iter_alg; +} + +void IterationAlgebraRewriter::visit(const SegmentNode *n) { + alg = n; +} + +void IterationAlgebraRewriter::visit(const ComplementNode *n) { + IterationAlgebra a = rewrite(n->a); + if(n-> a == a) { + alg = n; + } else { + alg = new ComplementNode(a); + } +} + +void IterationAlgebraRewriter::visit(const IntersectNode *n) { + IterationAlgebra a = rewrite(n->a); + IterationAlgebra b = rewrite(n->b); + + if(n->a == a && n->b == b) { + alg = n; + } else { + alg = new IntersectNode(a, b); + } +} + +void IterationAlgebraRewriter::visit(const UnionNode *n) { + IterationAlgebra a = rewrite(n->a); + IterationAlgebra b = rewrite(n->b); + + if(n->a == a && n->b == b) { + alg = n; + } else { + alg = new UnionNode(a, b); + } +} + +} \ No newline at end of file diff --git a/test/tests-scheduling-eval.cpp b/test/tests-scheduling-eval.cpp index 67d3631f1..06e60909e 100644 --- a/test/tests-scheduling-eval.cpp +++ b/test/tests-scheduling-eval.cpp @@ -1107,6 +1107,38 @@ TEST(scheduling_eval, mttkrpGPU) { ASSERT_TENSOR_EQ(expected, A); } +TEST(scheduling_eval, indexVarSplit) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, compressed); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + IndexVar j0("j0"), j1("j1"); + stmt = stmt.split(j, j0, j1, 2); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} + TEST(generate_evaluation_files, DISABLED_cpu) { if (should_use_CUDA_codegen()) { return; From 3cd794482a47c5b63411339517702080ef5b1aa4 Mon Sep 17 00:00:00 2001 From: Rawn Date: Tue, 4 Feb 2020 16:17:24 -0500 Subject: [PATCH 04/27] Added printer for iteration algebra and refactored some code --- include/taco/index_notation/index_notation.h | 1 - .../taco/index_notation/iteration_algebra.h | 18 ++++-- .../iteration_algebra_printer.h | 33 +++++++++++ src/index_notation/iteration_algebra.cpp | 23 +++++++- .../iteration_algebra_printer.cpp | 55 +++++++++++++++++++ test/test-iteration_algebra.cpp | 15 +++++ 6 files changed, 137 insertions(+), 8 deletions(-) create mode 100644 include/taco/index_notation/iteration_algebra_printer.h create mode 100644 src/index_notation/iteration_algebra_printer.cpp create mode 100644 test/test-iteration_algebra.cpp diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 422d712e6..d53f3aaea 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -21,7 +21,6 @@ #include "taco/ir_tags.h" #include "taco/lower/iterator.h" #include "taco/index_notation/provenance_graph.h" -//#include "taco/index_notation/iteration_algebra.h" namespace taco { diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 0ffb1e795..6d60cd554 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -3,13 +3,12 @@ #include #include "taco/util/comparable.h" -#include "taco/index_notation/index_notation.h" #include "taco/util/intrusive_ptr.h" - namespace taco { class IterationAlgebraVisitorStrict; +class TensorVar; struct IterationAlgebraNode; struct SegmentNode; @@ -17,12 +16,10 @@ struct ComplementNode; struct IntersectNode; struct UnionNode; - /// The iteration algebra class describes a set expression composed of complements, intersections and unions on /// TensorVars to describe the spaces in a Venn Diagram where computation will occur. /// This algebra is used to generate merge lattices to co-iterate over tensors in an expression. class IterationAlgebra : public util::IntrusivePtr { - public: IterationAlgebra(); IterationAlgebra(const IterationAlgebraNode* n); @@ -31,8 +28,11 @@ class IterationAlgebra : public util::IntrusivePtr { void accept(IterationAlgebraVisitorStrict* v) const; }; +std::ostream& operator<<(std::ostream&, const IterationAlgebra&); + /// A basic segment in an Iteration space. Given a Tensor A, this produces values only where A is defined. class Segment: public IterationAlgebra { +public: Segment(); Segment(TensorVar var); Segment(const SegmentNode*); @@ -43,6 +43,7 @@ class Segment: public IterationAlgebra { /// Example: Given a segment A which produces values where A is defined and omits values outside of A, /// complement(A) will not compute where A is defined but compute over the background of A. class Complement: public IterationAlgebra { +public: Complement(IterationAlgebra alg); Complement(const ComplementNode* n); }; @@ -59,6 +60,7 @@ class Complement: public IterationAlgebra { /// Intersect(Complement(A), B) will produce values where only B is defined. This pattern can be useful for filtering /// one tensor based on the values of another. class Intersect: public IterationAlgebra { +public: Intersect(IterationAlgebra, IterationAlgebra); Intersect(const IterationAlgebraNode*); }; @@ -75,6 +77,7 @@ class Intersect: public IterationAlgebra { /// will replace the value of A in the indexExpression with the fill value of the tensor A. Likewise, when B is not /// defined, the compiler will replace the value of B in the index expression with the fill value of B. class Union: public IterationAlgebra { +public: Union(IterationAlgebra, IterationAlgebra); Union(const IterationAlgebraNode*); }; @@ -103,6 +106,7 @@ struct SegmentNode: public IterationAlgebraNode { SegmentNode() : IterationAlgebraNode() {} SegmentNode(TensorVar var) : var(var) {} void accept(IterationAlgebraVisitorStrict*) const; + const TensorVar tensorVar() const; private: TensorVar var; }; @@ -122,6 +126,8 @@ struct IntersectNode: public BinaryIterationAlgebraNode { IntersectNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} void accept(IterationAlgebraVisitorStrict*) const; + + const std::string algebraString() const; }; /// A node which is wrapped by Union. @see Union @@ -130,6 +136,8 @@ struct UnionNode: public BinaryIterationAlgebraNode { UnionNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} void accept(IterationAlgebraVisitorStrict*) const; + + const std::string algebraString() const; }; /// Visits an iteration space algebra expression @@ -185,8 +193,6 @@ class IterationAlgebraRewriter : public IterationAlgebraRewriterStrict { virtual void visit(const IntersectNode*); virtual void visit(const UnionNode*); }; - - } #endif // TACO_ITERATION_ALGEBRA_H diff --git a/include/taco/index_notation/iteration_algebra_printer.h b/include/taco/index_notation/iteration_algebra_printer.h new file mode 100644 index 000000000..37e11c774 --- /dev/null +++ b/include/taco/index_notation/iteration_algebra_printer.h @@ -0,0 +1,33 @@ +#ifndef TACO_ITERATION_ALGEBRA_PRINTER_H +#define TACO_ITERATION_ALGEBRA_PRINTER_H + +#include + +namespace taco { + +// Iteration Algebra Printer +class IterationAlgebraPrinter : IterationAlgebraVisitorStrict { +public: + IterationAlgebraPrinter(std::ostream& os); + void print(const IterationAlgebra& alg); + void visit(const SegmentNode* n); + void visit(const ComplementNode* n); + void visit(const IntersectNode* n); + void visit(const UnionNode* n); + +private: + std::ostream& os; + enum class Precedence { + COMPLEMENT = 3, + INTERSECT = 4, + UNION = 5, + TOP = 20 + }; + + Precedence parentPrecedence; + + template + void visitBinary(Node n, Precedence precedence); +}; +} +#endif //TACO_ITERATION_ALGEBRA_PRINTER_H diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index b0ac1c75a..28541ce14 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -1,4 +1,5 @@ #include "taco/index_notation/iteration_algebra.h" +#include "iteration_algebra_printer.cpp" namespace taco { @@ -12,6 +13,13 @@ void IterationAlgebra::accept(IterationAlgebraVisitorStrict *v) const { ptr->accept(v); } +std::ostream& operator<<(std::ostream& os, const IterationAlgebra& algebra) { + if(!algebra.defined()) return os << "{}"; + IterationAlgebraPrinter printer(os); + printer.print(algebra); + return os; +} + // Definitions for Iteration Algebra // Segment @@ -38,6 +46,10 @@ void SegmentNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); } +const TensorVar SegmentNode::tensorVar() const { + return var; +} + // Definitions for ComplementNode void ComplementNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); @@ -48,11 +60,19 @@ void IntersectNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); } +const std::string IntersectNode::algebraString() const { + return "*"; +} + // Definitions for UnionNode void UnionNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); } +const std::string UnionNode::algebraString() const { + return "U"; +} + // Visitor definitions start here: // IterationAlgebraVisitorStrict definitions @@ -60,6 +80,7 @@ void IterationAlgebraVisitorStrict::visit(const IterationAlgebra &alg) { alg.accept(this); } +// Default IterationAlgebraVisitor definitions void IterationAlgebraVisitor::visit(const SegmentNode *n) { } @@ -91,6 +112,7 @@ IterationAlgebra IterationAlgebraRewriterStrict::rewrite(IterationAlgebra iter_a return iter_alg; } +// Default IterationAlgebraRewriter definitions void IterationAlgebraRewriter::visit(const SegmentNode *n) { alg = n; } @@ -125,5 +147,4 @@ void IterationAlgebraRewriter::visit(const UnionNode *n) { alg = new UnionNode(a, b); } } - } \ No newline at end of file diff --git a/src/index_notation/iteration_algebra_printer.cpp b/src/index_notation/iteration_algebra_printer.cpp new file mode 100644 index 000000000..935506e9c --- /dev/null +++ b/src/index_notation/iteration_algebra_printer.cpp @@ -0,0 +1,55 @@ +#include "taco/index_notation/iteration_algebra_printer.h" + +namespace taco { + +// Iteration Algebra Printer +IterationAlgebraPrinter::IterationAlgebraPrinter(std::ostream& os) : os(os) {} + +void IterationAlgebraPrinter::print(const IterationAlgebra& alg) { + parentPrecedence = Precedence::TOP; + alg.accept(this); +} + +void IterationAlgebraPrinter::visit(const SegmentNode* n) { + os << n->tensorVar().getName(); +} + +void IterationAlgebraPrinter::visit(const ComplementNode* n) { + Precedence precedence = Precedence::COMPLEMENT; + bool parenthesize = precedence > parentPrecedence; + parentPrecedence = precedence; + os << "~"; + if (parenthesize) { + os << "("; + } + n->a.accept(this); + if (parenthesize) { + os << ")"; + } +} + +void IterationAlgebraPrinter::visit(const IntersectNode* n) { + visitBinary(n, Precedence::INTERSECT); +} + +void IterationAlgebraPrinter::visit(const UnionNode* n) { + visitBinary(n, Precedence::UNION); +} + +template +void IterationAlgebraPrinter::visitBinary(Node n, Precedence precedence) { + bool parenthesize = precedence > parentPrecedence; + if (parenthesize) { + os << "("; + } + parentPrecedence = precedence; + n->a.accept(this); + os << " " << n->algebraString() << " "; + parentPrecedence = precedence; + n->b.accept(this); + if (parenthesize) { + os << ")"; + } +} + +} \ No newline at end of file diff --git a/test/test-iteration_algebra.cpp b/test/test-iteration_algebra.cpp new file mode 100644 index 000000000..b21ab5d42 --- /dev/null +++ b/test/test-iteration_algebra.cpp @@ -0,0 +1,15 @@ +#include "test.h" + +#include "taco/type.h" +#include "taco/index_notation/iteration_algebra.h" + +using namespace taco; + +const TensorVar A("A", Type()), B("B", Type()), C("C", Type()); + +TEST(iteration_algebra, iter_alg_print) { + std::ostringstream ss; + ss << Intersect(Union(Complement(A), B), C); + std::string expected("(~A U B) * C"); + ASSERT_EQ(expected, ss.str()); +} \ No newline at end of file From e459616d19d8efceab3831136f914d182a16dcb5 Mon Sep 17 00:00:00 2001 From: Rawn Date: Tue, 4 Feb 2020 17:03:18 -0500 Subject: [PATCH 05/27] Added include to iteration_algebra_printer.h --- include/taco/index_notation/iteration_algebra_printer.h | 1 + 1 file changed, 1 insertion(+) diff --git a/include/taco/index_notation/iteration_algebra_printer.h b/include/taco/index_notation/iteration_algebra_printer.h index 37e11c774..705a627ec 100644 --- a/include/taco/index_notation/iteration_algebra_printer.h +++ b/include/taco/index_notation/iteration_algebra_printer.h @@ -2,6 +2,7 @@ #define TACO_ITERATION_ALGEBRA_PRINTER_H #include +#include "iteration_algebra.h" namespace taco { From a65a362f2be51145688629e24e6b5146df6e0b2c Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 21 Feb 2020 09:32:31 -0500 Subject: [PATCH 06/27] Adds basic functionality to MergeLattices and lowerer for general code generation. --- .../index_notation/index_notation_nodes.h | 1 - .../taco/index_notation/iteration_algebra.h | 38 +- .../iteration_algebra_printer.h | 4 +- include/taco/lower/merge_lattice.h | 19 +- include/taco/util/collections.h | 20 +- src/index_notation/iteration_algebra.cpp | 24 +- .../iteration_algebra_printer.cpp | 4 +- src/lower/iterator.cpp | 3 +- src/lower/lower.cpp | 2 +- src/lower/lowerer_impl.cpp | 17 +- src/lower/merge_lattice.cpp | 465 +++++++++++++++--- src/lower/mode_access.cpp | 34 +- test/test-iteration_algebra.cpp | 5 +- 13 files changed, 524 insertions(+), 112 deletions(-) diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index bde1b709a..68ecf0b17 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -52,7 +52,6 @@ struct LiteralNode : public IndexExprNode { void* val; }; - struct UnaryExprNode : public IndexExprNode { IndexExpr a; diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 6d60cd554..1399b91a0 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -8,34 +8,34 @@ namespace taco { class IterationAlgebraVisitorStrict; -class TensorVar; +class IndexExpr; struct IterationAlgebraNode; -struct SegmentNode; +struct RegionNode; struct ComplementNode; struct IntersectNode; struct UnionNode; /// The iteration algebra class describes a set expression composed of complements, intersections and unions on -/// TensorVars to describe the spaces in a Venn Diagram where computation will occur. +/// IndexExprs to describe the spaces in a Venn Diagram where computation will occur. /// This algebra is used to generate merge lattices to co-iterate over tensors in an expression. class IterationAlgebra : public util::IntrusivePtr { public: IterationAlgebra(); IterationAlgebra(const IterationAlgebraNode* n); - IterationAlgebra(TensorVar var); + IterationAlgebra(IndexExpr expr); void accept(IterationAlgebraVisitorStrict* v) const; }; std::ostream& operator<<(std::ostream&, const IterationAlgebra&); -/// A basic segment in an Iteration space. Given a Tensor A, this produces values only where A is defined. -class Segment: public IterationAlgebra { +/// A region in an Iteration space. Given a Tensor A, this produces values everywhere the tensorVar or access is defined. +class Region: public IterationAlgebra { public: - Segment(); - Segment(TensorVar var); - Segment(const SegmentNode*); + Region(); + Region(IndexExpr expr); + Region(const RegionNode*); }; /// This complements an iteration space algebra expression. Thus, it will flip the segments that are produced and @@ -100,15 +100,15 @@ struct BinaryIterationAlgebraNode: public IterationAlgebraNode { BinaryIterationAlgebraNode(IterationAlgebra a, IterationAlgebra b) : a(a), b(b) {} }; -/// A node which is wrapped by Segment. @see Segment -struct SegmentNode: public IterationAlgebraNode { +/// A node which is wrapped by Region. @see Region +struct RegionNode: public IterationAlgebraNode { public: - SegmentNode() : IterationAlgebraNode() {} - SegmentNode(TensorVar var) : var(var) {} + RegionNode() : IterationAlgebraNode() {} + RegionNode(IndexExpr expr) : expr(expr) {} void accept(IterationAlgebraVisitorStrict*) const; - const TensorVar tensorVar() const; + const IndexExpr indexExpr() const; private: - TensorVar var; + IndexExpr expr; }; /// A node which is wrapped by Complement. @see Complement @@ -146,7 +146,7 @@ class IterationAlgebraVisitorStrict { virtual ~IterationAlgebraVisitorStrict() {} void visit(const IterationAlgebra& alg); - virtual void visit(const SegmentNode*) = 0; + virtual void visit(const RegionNode*) = 0; virtual void visit(const ComplementNode*) = 0; virtual void visit(const IntersectNode*) = 0; virtual void visit(const UnionNode*) = 0; @@ -157,7 +157,7 @@ class IterationAlgebraVisitor : public IterationAlgebraVisitorStrict { virtual ~IterationAlgebraVisitor() {} using IterationAlgebraVisitorStrict::visit; - virtual void visit(const SegmentNode* n); + virtual void visit(const RegionNode* n); virtual void visit(const ComplementNode*); virtual void visit(const IntersectNode*); virtual void visit(const UnionNode*); @@ -175,7 +175,7 @@ class IterationAlgebraRewriterStrict : public IterationAlgebraVisitorStrict { using IterationAlgebraVisitorStrict::visit; - virtual void visit(const SegmentNode*) = 0; + virtual void visit(const RegionNode*) = 0; virtual void visit(const ComplementNode*) = 0; virtual void visit(const IntersectNode*) = 0; virtual void visit(const UnionNode*) = 0; @@ -188,7 +188,7 @@ class IterationAlgebraRewriter : public IterationAlgebraRewriterStrict { protected: using IterationAlgebraRewriterStrict::visit; - virtual void visit(const SegmentNode* n); + virtual void visit(const RegionNode* n); virtual void visit(const ComplementNode*); virtual void visit(const IntersectNode*); virtual void visit(const UnionNode*); diff --git a/include/taco/index_notation/iteration_algebra_printer.h b/include/taco/index_notation/iteration_algebra_printer.h index 705a627ec..6f97b6b49 100644 --- a/include/taco/index_notation/iteration_algebra_printer.h +++ b/include/taco/index_notation/iteration_algebra_printer.h @@ -2,7 +2,7 @@ #define TACO_ITERATION_ALGEBRA_PRINTER_H #include -#include "iteration_algebra.h" +#include "taco/index_notation/iteration_algebra.h" namespace taco { @@ -11,7 +11,7 @@ class IterationAlgebraPrinter : IterationAlgebraVisitorStrict { public: IterationAlgebraPrinter(std::ostream& os); void print(const IterationAlgebra& alg); - void visit(const SegmentNode* n); + void visit(const RegionNode* n); void visit(const ComplementNode* n); void visit(const IntersectNode* n); void visit(const UnionNode* n); diff --git a/include/taco/lower/merge_lattice.h b/include/taco/lower/merge_lattice.h index 9f1592d5f..f821ffc68 100644 --- a/include/taco/lower/merge_lattice.h +++ b/include/taco/lower/merge_lattice.h @@ -80,6 +80,12 @@ class MergeLattice { */ bool exact() const; + /** + * Get a list of iterators that should be omitted at this merge point. + */ + std::vector retrieveIteratorsToOmit(const MergePoint& point) const; + + private: std::vector points_; @@ -146,6 +152,16 @@ class MergePoint { */ const std::vector& results() const; + /** + * Returns the iterators that iterate over or locate into tensors + */ + const std::set tensorRegion() const; + + /** + * Returns true if this merge point may leave out the tensors it iterates + */ + bool isOmitter() const; + private: struct Content; std::shared_ptr content_; @@ -158,7 +174,8 @@ class MergePoint { */ MergePoint(const std::vector& iterators, const std::vector& locators, - const std::vector& results); + const std::vector& results, + bool omitPoint = false); }; std::ostream& operator<<(std::ostream&, const MergePoint&); diff --git a/include/taco/util/collections.h b/include/taco/util/collections.h index 6edddf380..7c687cafe 100644 --- a/include/taco/util/collections.h +++ b/include/taco/util/collections.h @@ -97,6 +97,20 @@ std::vector remove(const std::vector& vector, return result; } +template +std::vector removeDuplicates(const std::vector& vector) { + std::set seen; + std::vector result; + + for(const V& v: vector) { + if(!contains(seen, v)) { + seen.insert(v); + result.push_back(v); + } + } + return result; +} + template std::vector filter(const std::vector& vector, T test) { std::vector result; @@ -149,9 +163,9 @@ bool all(const C& collection, T test) { return true; } -template -bool any(const std::vector& vector, T test) { - for (auto& element : vector) { +template +bool any(const C& collection, T test) { + for (auto& element : collection) { if (test(element)) { return true; } diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index 28541ce14..4cf2a41a1 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -1,5 +1,5 @@ #include "taco/index_notation/iteration_algebra.h" -#include "iteration_algebra_printer.cpp" +#include "taco/index_notation/iteration_algebra_printer.h" namespace taco { @@ -7,7 +7,7 @@ namespace taco { IterationAlgebra::IterationAlgebra() : util::IntrusivePtr(nullptr) {} IterationAlgebra::IterationAlgebra(const IterationAlgebraNode* n) : util::IntrusivePtr(n) {} -IterationAlgebra::IterationAlgebra(TensorVar var) : IterationAlgebra(new SegmentNode(var)) {} +IterationAlgebra::IterationAlgebra(IndexExpr expr) : IterationAlgebra(new RegionNode(expr)) {} void IterationAlgebra::accept(IterationAlgebraVisitorStrict *v) const { ptr->accept(v); @@ -22,10 +22,10 @@ std::ostream& operator<<(std::ostream& os, const IterationAlgebra& algebra) { // Definitions for Iteration Algebra -// Segment -Segment::Segment() : IterationAlgebra(new SegmentNode) {} -Segment::Segment(TensorVar var) : IterationAlgebra(var) {} -Segment::Segment(const taco::SegmentNode *n) : IterationAlgebra(n) {} +// Region +Region::Region() : IterationAlgebra(new RegionNode) {} +Region::Region(IndexExpr expr) : IterationAlgebra(expr) {} +Region::Region(const taco::RegionNode *n) : IterationAlgebra(n) {} // Complement Complement::Complement(const ComplementNode* n): IterationAlgebra(n) {} @@ -41,13 +41,13 @@ Union::Union(const IterationAlgebraNode* n) : IterationAlgebra(n) {} // Node method definitions start here: -// Definitions for SegmentNode -void SegmentNode::accept(IterationAlgebraVisitorStrict *v) const { +// Definitions for RegionNode +void RegionNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); } -const TensorVar SegmentNode::tensorVar() const { - return var; +const IndexExpr RegionNode::indexExpr() const { + return expr; } // Definitions for ComplementNode @@ -81,7 +81,7 @@ void IterationAlgebraVisitorStrict::visit(const IterationAlgebra &alg) { } // Default IterationAlgebraVisitor definitions -void IterationAlgebraVisitor::visit(const SegmentNode *n) { +void IterationAlgebraVisitor::visit(const RegionNode *n) { } void IterationAlgebraVisitor::visit(const ComplementNode *n) { @@ -113,7 +113,7 @@ IterationAlgebra IterationAlgebraRewriterStrict::rewrite(IterationAlgebra iter_a } // Default IterationAlgebraRewriter definitions -void IterationAlgebraRewriter::visit(const SegmentNode *n) { +void IterationAlgebraRewriter::visit(const RegionNode *n) { alg = n; } diff --git a/src/index_notation/iteration_algebra_printer.cpp b/src/index_notation/iteration_algebra_printer.cpp index 935506e9c..85507243f 100644 --- a/src/index_notation/iteration_algebra_printer.cpp +++ b/src/index_notation/iteration_algebra_printer.cpp @@ -10,8 +10,8 @@ void IterationAlgebraPrinter::print(const IterationAlgebra& alg) { alg.accept(this); } -void IterationAlgebraPrinter::visit(const SegmentNode* n) { - os << n->tensorVar().getName(); +void IterationAlgebraPrinter::visit(const RegionNode* n) { + os << n->indexExpr(); } void IterationAlgebraPrinter::visit(const ComplementNode* n) { diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index f357e202c..07f0c2fb3 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -363,7 +363,6 @@ struct Iterators::Content { map modeIterators; }; - Iterators::Iterators() : content(new Content) { @@ -499,7 +498,7 @@ Iterator Iterators::levelIterator(ModeAccess modeAccess) const taco_iassert(content != nullptr); taco_iassert(util::contains(content->levelIterators, modeAccess)) << "Cannot find " << modeAccess << " in " - << util::join(content->levelIterators); + << util::join(content->levelIterators) << "\n" << modeAccess.getAccess(); return content->levelIterators.at(modeAccess); } diff --git a/src/lower/lower.cpp b/src/lower/lower.cpp index 676f9d346..4a8b241e7 100644 --- a/src/lower/lower.cpp +++ b/src/lower/lower.cpp @@ -51,7 +51,7 @@ ir::Stmt lower(IndexStmt stmt, std::string name, bool assemble, bool compute, std::string messages; - std::cout << "Suppressing verifier output" << endl; // TODO: + // std::cout << "Suppressing verifier output" << endl; // TODO: // TODO: verify(lowered, &messages); if (!messages.empty()) { std::cerr << "Verifier messages:\n" << messages << "\n"; diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index efbe7c39f..a50b1dea4 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1191,18 +1191,31 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I appenders, reducedAccesses); result.push_back(body); } - else { + else if (!lattice.points().empty()){ vector> cases; for (MergePoint point : lattice.points()) { + if(point.isOmitter()) { + continue; + } + // Construct case expression vector coordComparisons; for (Iterator iterator : point.rangers()) { - if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { + if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && + provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { coordComparisons.push_back(Eq::make(iterator.getCoordVar(), coordinate)); } } + vector omittedRangers = lattice.retrieveIteratorsToOmit(point); + for (auto iterator: omittedRangers) { + if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && + provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { + coordComparisons.push_back(Neq::make(iterator.getCoordVar(), coordinate)); + } + } + // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 9cef143e0..0928d4019 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -12,16 +12,19 @@ #include "mode_access.h" #include "taco/util/collections.h" #include "taco/util/strings.h" +#include "taco/index_notation/iteration_algebra.h" #include "taco/util/scopedmap.h" using namespace std; namespace taco { -class MergeLatticeBuilder : public IndexNotationVisitorStrict { +class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationAlgebraVisitorStrict { public: - MergeLatticeBuilder(IndexVar i, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, std::map whereTempsToResult = {}) - : i(i), iterators(iterators), provGraph(provGraph), definedIndexVars(definedIndexVars), whereTempsToResult(whereTempsToResult) {} + MergeLatticeBuilder(IndexVar i, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, + std::map whereTempsToResult = {}) + : i(i), iterators(iterators), provGraph(provGraph), definedIndexVars(definedIndexVars), + whereTempsToResult(whereTempsToResult) {} MergeLattice build(IndexStmt stmt) { stmt.accept(this); @@ -37,6 +40,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { return l; } + MergeLattice build(IterationAlgebra alg) { + alg.accept(this); + MergeLattice l = lattice; + lattice = MergeLattice({}); + return l; + } + Iterator getIterator(Access access, IndexVar accessVar) { // must have matching underived ancestor map accessUnderivedAncestorsToLoc; @@ -72,11 +82,81 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { std::set definedIndexVars; map latticesOfTemporaries; std::map whereTempsToResult; + map baseMergePoints; MergeLattice modeIterationLattice() { return MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); } + void visit(const RegionNode* node) { + lattice = build(node->indexExpr()); + } + + void visit(const ComplementNode* node) { + lattice = build(node->a); + vector points = flipPoints(lattice.points()); + + // TODO: Handle complementing with broadcasting - Can't distinguish dimension iterators inserted + // as optimizations to unordered tensors with dimension iterators inserted to broadcast. + // Could perhaps do the optimization at the end of lattice construction instead of after + // each union? + // In case 1, we want to complement the lattice but in case two we can return the empty + // lattice + + // Otherwise, all tensors are sparse + points = includeMissingProducerPoints(points); + + // Add dimension point + Iterator dimIter = iterators.modeIterator(i); + points = includeDimensionIterator(points, dimIter); + + bool needsDimPoint = true; + for(const auto& point: points) { + if(point.locators().empty() && point.iterators().size() == 1 && point.iterators()[0] == dimIter) { + needsDimPoint = false; + break; + } + } + + if(needsDimPoint) { + points.push_back(MergePoint({dimIter}, {}, {})); + } + + points = removeUnnecessaryOmitterPoints(points); + lattice = MergeLattice(points); + } + + void visit(const IntersectNode* node) { + MergeLattice a = build(node->a); + MergeLattice b = build(node->b); + + if (a.points().size() > 0 && b.points().size() > 0) { + lattice = intersectLattices(a, b); + } + // Scalar operands + else if (a.points().size() > 0) { + lattice = a; + } + else if (b.points().size() > 0) { + lattice = b; + } + } + + void visit(const UnionNode* node) { + MergeLattice a = build(node->a); + MergeLattice b = build(node->b); + if (a.points().size() > 0 && b.points().size() > 0) { + lattice = unionLattices(a, b); + } + // Scalar operands + else if (a.points().size() > 0) { + lattice = a; + } + else if (b.points().size() > 0) { + lattice = b; + } + } + void visit(const IndexVarNode* varNode) { IndexVar var(varNode); lattice = MergeLattice({MergePoint({iterators.modeIterator(var)}, {}, {})}); @@ -84,6 +164,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { void visit(const AccessNode* access) { + // TODO: Case where Access is used in computation but not iteration algebra + if (util::contains(latticesOfTemporaries, access->tensorVar)) { // If the accessed tensor variable is a temporary with an associated merge // lattice then we return that lattice. @@ -146,6 +228,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { : MergePoint(pointIterators, {}, {}); lattice = MergeLattice({point}); } + + baseMergePoints.insert({iterator, lattice.points()[0]}); } void visit(const LiteralNode* node) { @@ -159,63 +243,19 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } void visit(const AddNode* node) { - MergeLattice a = build(node->a); - MergeLattice b = build(node->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = unionLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new UnionNode(Region(node->a), Region(node->b))); } void visit(const SubNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = unionLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new UnionNode(Region(expr->a), Region(expr->b))); } void visit(const MulNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = intersectLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new IntersectNode(Region(expr->a), Region(expr->b))); } void visit(const DivNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = intersectLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new IntersectNode(Region(expr->a), Region(expr->b))); } void visit(const SqrtNode* expr) { @@ -296,7 +336,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { vector points; for (auto &point : lattice.points()) { points.push_back(MergePoint(point.iterators(), point.locators(), - vector(resultIterators.begin(), resultIterators.end()))); + vector(resultIterators.begin(), resultIterators.end()), + point.isOmitter())); } lattice = MergeLattice(points); } @@ -335,6 +376,98 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { taco_not_supported_yet; } + vector + enumerateChildrenPoints(const MergePoint& point, const map, MergePoint>& originalPoints, + set>& seen) { + set pointIters(point.iterators().begin(), point.iterators().end()); + set pointLocs(point.locators().begin(), point.locators().end()); + + set regions = point.tensorRegion(); + set currentRegion = point.tensorRegion(); + + vector result; + for(const auto& tensorIt: regions) { + currentRegion.erase(tensorIt); + + if(util::contains(seen, currentRegion)) { + currentRegion.insert(tensorIt); + continue; + } + + if(util::contains(originalPoints, currentRegion)) { + result.push_back(originalPoints.at(currentRegion)); + } + else if(!currentRegion.empty()){ + MergePoint mp({}, {}, {}); + for(const auto& it: currentRegion) { + mp = unionPoints(mp, baseMergePoints.at(it)); + } + + vector newIters; + vector newLocators = mp.locators(); + for(const auto& it: mp.iterators()) { + if(util::contains(pointLocs, it)) { + newLocators.push_back(it); + } + else { + newIters.push_back(it); + } + } + + result.push_back(MergePoint(newIters, newLocators, point.results())); + } + + seen.insert(currentRegion); + currentRegion.insert(tensorIt); + } + return result; + } + + vector + includeMissingProducerPoints(const vector& points) { + if(points.empty()) return points; + + map, MergePoint> originalPoints; + set> seen; + for(const auto& point: points) { + originalPoints.insert({point.tensorRegion(), point}); + } + + vector frontier = {points[0]}; + vector exactLattice; + + while(!frontier.empty()) { + vector nextFrontier; + for (const auto &frontierPoint: frontier) { + exactLattice.push_back(frontierPoint); + util::append(nextFrontier, enumerateChildrenPoints(frontierPoint, originalPoints, seen)); + } + + frontier = nextFrontier; + } + + return exactLattice; + } + + static vector + includeDimensionIterator(const vector& points, const Iterator& dimIter) { + vector results; + for (auto& point : points) { + vector iterators = point.iterators(); + if (!any(iterators, [](Iterator it){ return it.isDimensionIterator(); })) { + taco_iassert(point.iterators().size() > 0); + results.push_back(MergePoint(combine(iterators, {dimIter}), + point.locators(), + point.results(), + point.isOmitter())); + } + else { + results.push_back(point); + } + } + return results; + } + /** * The intersection of two lattices is the result of merging all the * combinations of merge points from the two lattices. @@ -355,6 +488,28 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } } + // + + // Correctness: ensures that points produced on BOTH the left and the + // right lattices are produced in the final intersection. + // Needed since some subPoints may omit leading to erroneous + // omit intersection points. + points = correctPointTypesAfterIntersect(left.points(), right.points(), points); + + // Correctness: Deduplicate regions that are described by multiple lattice + // points and resolves conflicts arising between omitters and + // producers + points = removeDuplicatedTensorRegions(points, true); + + // Optimization: Removed a subLattice of points if the entire subLattice is + // made of only omitters + points = removeUnnecessaryOmitterPoints(points); + + // Optimization: remove lattice points whose iterators are identical to the + // iterators of an earlier point, since we have already iterated + // over this sub-space. + points = removePointsWithIdenticalIterators(points); + return MergeLattice(points); } @@ -379,6 +534,15 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // Append the merge points of b util::append(points, right.points()); + // Correctness: This ensures that points omitted on BOTH the left and the + // right lattices are omitted in the Union. Needed since some + // subpoints may produce leading to erroneous producer regions + points = correctPointTypesAfterUnion(left.points(), right.points(), points); + + // Correctness: Deduplicate regions that are described by multiple lattice + // points and resolves conflicts arising between omitters and + // producers + points = removeDuplicatedTensorRegions(points, false); // Optimization: insert a dimension iterator if one of the iterators in the // iterate set is not ordered. @@ -393,6 +557,10 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // have iterated over the whole space. points = removePointsThatLackFullIterators(points); + // Optimization: Removes a subLattice of points if the entire subLattice is + // made of only omitters + points = removeUnnecessaryOmitterPoints(points); + // Optimization: remove lattice points whose iterators are identical to the // iterators of an earlier point, since we have already iterated // over this sub-space. @@ -432,7 +600,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { vector results = combine(left.results(), right.results()); - return MergePoint(iterators, locators, results); + return MergePoint(iterators, locators, results, left.isOmitter() || right.isOmitter()); } /** @@ -450,7 +618,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // Remove duplicate iterators. iterators = deduplicateDimensionIterators(iterators); - return MergePoint(iterators, locaters, results); + return MergePoint(iterators, locaters, results, left.isOmitter() && right.isOmitter()); } static bool locateFromLeft(MergeLattice left, MergeLattice right) @@ -484,7 +652,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } static vector - insertDimensionIteratorIfNotOrdered(vector points) + insertDimensionIteratorIfNotOrdered(const vector& points) { vector results; for (auto& point : points) { @@ -495,7 +663,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { Iterator dimension(iterators[0].getIndexVar()); results.push_back(MergePoint(combine(iterators, {dimension}), point.locators(), - point.results())); + point.results(), + point.isOmitter())); } else { results.push_back(point); @@ -505,7 +674,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } static vector - moveLocateSubsetIteratorsToLocateSet(vector points) + moveLocateSubsetIteratorsToLocateSet(const vector& points) { vector full = filter(points[0].iterators(), [](Iterator it){ return it.isFull(); }); @@ -528,13 +697,14 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { }); result.push_back(MergePoint(iterators, combine(point.locators(), locators), - point.results())); + point.results(), + point.isOmitter())); } return result; } static vector - removePointsThatLackFullIterators(vector points) + removePointsThatLackFullIterators(const vector& points) { vector result; vector fullIterators = filter(points[0].iterators(), @@ -591,6 +761,124 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } return deduplicates; } + + static vector + flipPoints(const vector& points) { + vector flippedPoints; + for(const auto& mp: points) { + MergePoint flippedPoint(mp.iterators(), mp.locators(), mp.results(), !mp.isOmitter()); + flippedPoints.push_back(flippedPoint); + } + return flippedPoints; + } + + static set> + getProducerOrOmitterRegions(const std::vector& points, bool getOmitters) { + set> result; + + for(const auto& point: points) { + if(point.isOmitter() == getOmitters) { + set region = point.tensorRegion(); + result.insert(region); + } + } + return result; + } + + static vector + correctPointTypes(const vector& left, const vector& right, + const vector& points, bool preserveOmit) { + vector result; + set> leftSet = getProducerOrOmitterRegions(left, preserveOmit); + set> rightSet = getProducerOrOmitterRegions(right, preserveOmit); + + for (auto& point : points) { + set iteratorSet = point.tensorRegion(); + + MergePoint newPoint = point; + if(util::contains(leftSet, iteratorSet) && util::contains(rightSet, iteratorSet)) { + // Both regions produce/omit, so we ensure that this is preserved + newPoint = MergePoint(point.iterators(), point.locators(), point.results(), preserveOmit); + } + result.push_back(newPoint); + } + + return result; + } + + static vector + correctPointTypesAfterIntersect(const vector& left, const vector& right, + const vector& points) { + return correctPointTypes(left, right, points, false); + } + + static vector + correctPointTypesAfterUnion(const vector& left, const vector& right, + const vector& points) { + return correctPointTypes(left, right, points, true); + } + + static vector + removeDuplicatedTensorRegions(const vector& points, bool preserveOmitters) { + + set> producerRegions = getProducerOrOmitterRegions(points, false); + set> omitterRegions = getProducerOrOmitterRegions(points, true); + + vector result; + + set> regionSets; + for (auto& point : points) { + set region = point.tensorRegion(); + + if(util::contains(regionSets, region)) { + continue; + } + + MergePoint p = point; + if(util::contains(producerRegions, region) && util::contains(omitterRegions, region)) { + // If a region is marked as both produce and omit resolve the ambiguity based on the preserve + // omitters flag. + p = MergePoint(point.iterators(), point.locators(), point.results(), preserveOmitters); + } + + result.push_back(p); + regionSets.insert(region); + } + + return result; + } + + static vector + removeUnnecessaryOmitterPoints(const vector& points) { + vector filteredPoints; + + MergeLattice l(points); + set> removed; + + for(const auto& point : points) { + + if(util::contains(removed, point.tensorRegion())) { + continue; + } + + MergeLattice subLattice = l.subLattice(point); + + if(util::all(subLattice.points(), [](const MergePoint p) {return p.isOmitter();})) { + for(const auto& p : subLattice.points()) { + removed.insert(p.tensorRegion()); + } + } + } + + for(const auto& point : points) { + if(!util::contains(removed, point.tensorRegion())) { + filteredPoints.push_back(point); + } + } + + return filteredPoints; + } + }; @@ -603,6 +891,7 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr { // Can emit merge lattice once underived ancestor can be recovered IndexVar indexVar = forall.getIndexVar(); + MergeLatticeBuilder builder(indexVar, iterators, provGraph, definedIndexVars, whereTempsToResult); vector underivedAncestors = provGraph.getUnderivedAncestors(indexVar); @@ -666,6 +955,11 @@ bool MergeLattice::exact() const { // A lattice is full if any merge point iterates over only full iterators // or if each sparse iterator is uniquely iterated by some lattice point. set uniquelyMergedIterators; + + if (util::any(points(), [](const MergePoint& m) {return m.isOmitter();})) { + return false; + } + for (auto& point : this->points()) { if (all(point.iterators(), [](Iterator it) {return it.isFull();})) { return true; @@ -686,6 +980,27 @@ bool MergeLattice::exact() const { return true; } +std::vector MergeLattice::retrieveIteratorsToOmit(const MergePoint &point) const { + + vector omittedIterators; + const size_t levelOfParent = point.iterators().size() + 1; + vector pointIterators = point.iterators(); + sort(pointIterators.begin(), pointIterators.end()); + + for(const auto& mp: points()) { + if(mp.iterators().size() == levelOfParent && mp.isOmitter()) { + // We are one level above this point + vector parentIterators = mp.iterators(); + sort(parentIterators.begin(), parentIterators.end()); + set_difference(parentIterators.begin(), parentIterators.end(), + pointIterators.begin(), pointIterators.end(), + back_inserter(omittedIterators)); + } + } + + return omittedIterators; +} + ostream& operator<<(ostream& os, const MergeLattice& ml) { return os << util::join(ml.points(), ", "); } @@ -714,19 +1029,22 @@ struct MergePoint::Content { std::vector iterators; std::vector locators; std::vector results; + bool omitPoint; }; MergePoint::MergePoint(const vector& iterators, const vector& locators, - const vector& results) : content_(new Content) { + const vector& results, + bool omitPoint) : content_(new Content) { taco_uassert(all(iterators, [](Iterator it){ return it.hasLocate() || it.isOrdered(); })) << "Merge points do not support iterators that do not have locate and " << "that are not ordered."; - content_->iterators = iterators; - content_->locators = locators; - content_->results = results; + content_->iterators = util::removeDuplicates(iterators); + content_->locators = util::removeDuplicates(locators); + content_->results = util::removeDuplicates(results); + content_->omitPoint = omitPoint; } const vector& MergePoint::iterators() const { @@ -794,6 +1112,18 @@ const std::vector& MergePoint::results() const { return content_->results; } +const std::set MergePoint::tensorRegion() const { + std::vector iterators = filter(content_->iterators, + [](Iterator it) {return !it.isDimensionIterator();}); + + append(iterators, content_->locators); + return set(iterators.begin(), iterators.end()); +} + +bool MergePoint::isOmitter() const { + return content_->omitPoint; +} + ostream& operator<<(ostream& os, const MergePoint& mlp) { os << "["; os << util::join(mlp.iterators(), ", "); @@ -805,6 +1135,14 @@ ostream& operator<<(ostream& os, const MergePoint& mlp) { os << "|"; if (mlp.results().size() > 0) os << " "; os << util::join(mlp.results(), ", "); + + os << "|"; + if(mlp.isOmitter()) { + os << " O "; + } else { + os << " P "; + } + os << "]"; return os; } @@ -825,6 +1163,7 @@ bool operator==(const MergePoint& a, const MergePoint& b) { if (!compare(a.iterators(), b.iterators())) return false; if (!compare(a.locators(), b.locators())) return false; if (!compare(a.results(), b.results())) return false; + if ((a.isOmitter() != b.isOmitter())) return false; return true; } diff --git a/src/lower/mode_access.cpp b/src/lower/mode_access.cpp index 8aea151bc..fcb0d19c4 100644 --- a/src/lower/mode_access.cpp +++ b/src/lower/mode_access.cpp @@ -13,15 +13,43 @@ size_t ModeAccess::getModePos() const { return mode; } +static bool accessEqual(const Access& a, const Access& b) { + return a == b || + (a.getTensorVar() == b.getTensorVar() && a.getIndexVars() == b.getIndexVars()); +} + bool operator==(const ModeAccess& a, const ModeAccess& b) { - return a.getAccess() == b.getAccess() && a.getModePos() == b.getModePos(); + return accessEqual(a.getAccess(), b.getAccess()) && a.getModePos() == b.getModePos(); } bool operator<(const ModeAccess& a, const ModeAccess& b) { - if (a.getAccess() == b.getAccess()) { + + // fast path for when access pointers are equal + if(a.getAccess() == b.getAccess()) { return a.getModePos() < b.getModePos(); } - return a.getAccess() < b.getAccess(); + + // First break on tensorVars + if(a.getAccess().getTensorVar() != b.getAccess().getTensorVar()) { + return a.getAccess().getTensorVar() < b.getAccess().getTensorVar(); + } + + // Then break on the indexVars used in the access + std::vector aVars = a.getAccess().getIndexVars(); + std::vector bVars = b.getAccess().getIndexVars(); + + if(aVars.size() != bVars.size()) { + return aVars.size() < bVars.size(); + } + + for(size_t i = 0; i < aVars.size(); ++i) { + if(aVars[i] != bVars[i]) { + return aVars[i] < bVars[i]; + } + } + + // Finally, break on the mode position + return a.getModePos() < b.getModePos(); } std::ostream &operator<<(std::ostream &os, const ModeAccess & modeAccess) { diff --git a/test/test-iteration_algebra.cpp b/test/test-iteration_algebra.cpp index b21ab5d42..b02bbc592 100644 --- a/test/test-iteration_algebra.cpp +++ b/test/test-iteration_algebra.cpp @@ -1,11 +1,14 @@ #include "test.h" #include "taco/type.h" + +#include "taco/index_notation/index_notation.h" #include "taco/index_notation/iteration_algebra.h" using namespace taco; -const TensorVar A("A", Type()), B("B", Type()), C("C", Type()); +const TensorVar A_var("A", Type()), B_var("B", Type()), C_var("C", Type()); +const Access A(A_var), B(B_var), C(C_var); TEST(iteration_algebra, iter_alg_print) { std::ostringstream ss; From 978852441ead50240e55239f58e195d2233f0c64 Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 28 Feb 2020 00:32:18 -0500 Subject: [PATCH 07/27] Started adding front end for defining new operators --- include/taco.h | 1 + include/taco/index_notation/index_notation.h | 24 +++++++ .../index_notation/index_notation_nodes.h | 21 ++++++ .../taco/index_notation/iteration_algebra.h | 3 +- include/taco/index_notation/properties.h | 12 ++++ include/taco/index_notation/tensor_operator.h | 64 +++++++++++++++++++ src/index_notation/index_notation.cpp | 33 +++++++++- src/index_notation/index_notation_nodes.cpp | 6 ++ src/lower/lowerer_impl.cpp | 2 + 9 files changed, 164 insertions(+), 2 deletions(-) create mode 100644 include/taco/index_notation/properties.h create mode 100644 include/taco/index_notation/tensor_operator.h diff --git a/include/taco.h b/include/taco.h index 016960ffc..9293b3414 100644 --- a/include/taco.h +++ b/include/taco.h @@ -3,6 +3,7 @@ #include "taco/tensor.h" #include "taco/format.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation.h" #endif diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index d53f3aaea..56e0cc4c7 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -8,7 +8,9 @@ #include #include #include +#include +#include "taco/util/name_generator.h" #include "taco/format.h" #include "taco/error.h" #include "taco/util/intrusive_ptr.h" @@ -21,6 +23,7 @@ #include "taco/ir_tags.h" #include "taco/lower/iterator.h" #include "taco/index_notation/provenance_graph.h" +#include "taco/index_notation/properties.h" namespace taco { @@ -36,6 +39,8 @@ class IndexExpr; class Assignment; class Access; +class IterationAlgebra; + struct AccessNode; struct LiteralNode; struct NegNode; @@ -45,6 +50,7 @@ struct SubNode; struct MulNode; struct DivNode; struct CastNode; +struct TensorOpNode; struct CallIntrinsicNode; struct ReductionNode; struct IndexVarNode; @@ -396,6 +402,24 @@ class Cast : public IndexExpr { typedef CastNode Node; }; +/// A call to an operator +class TensorOp: public IndexExpr { +public: + TensorOp() = default; + TensorOp(const TensorOpNode*, std::string name = util::uniqueName("Op")); + + const std::vector& getArgs() const; + const IterationAlgebra& getAlgebra() const; + const Properties& getProperties() const; + const std::map, std::function&)>> getDefs() const; + + std::string getName() const; + + typedef TensorOpNode Node; + +private: + std::string name; +}; /// A call to an intrinsic. /// ``` diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 68ecf0b17..232828e7e 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -11,6 +11,8 @@ #include "taco/index_notation/index_notation_visitor.h" #include "taco/index_notation/intrinsic.h" #include "taco/util/strings.h" +#include "iteration_algebra.h" +#include "properties.h" namespace taco { @@ -171,6 +173,25 @@ struct CallIntrinsicNode : public IndexExprNode { std::vector args; }; +struct TensorOpNode : public IndexExprNode { + typedef std::function&)> opImpl; + typedef std::function&)> algebraImpl; + typedef std::function&)> regionDefinition; + + TensorOpNode(const std::vector& exprs, opImpl lowerFunc, const IterationAlgebra& iterAlg, + const Properties& properties, const std::map, regionDefinition>& regionDefinitions, + Datatype type); + + void accept(IndexExprVisitorStrict* v) const { + v->visit(this); + } + + std::vector exprs; + opImpl lowerFunc; + IterationAlgebra iterAlg; + Properties properties; + std::map, regionDefinition> regionDefinitions; +}; struct ReductionNode : public IndexExprNode { ReductionNode(IndexExpr op, IndexVar var, IndexExpr a); diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 1399b91a0..33ee1db79 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -1,7 +1,8 @@ #ifndef TACO_ITERATION_ALGEBRA_H #define TACO_ITERATION_ALGEBRA_H -#include +#include "taco/index_notation/index_notation.h" +#include "taco/util/uncopyable.h" #include "taco/util/comparable.h" #include "taco/util/intrusive_ptr.h" diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h new file mode 100644 index 000000000..ad14cb855 --- /dev/null +++ b/include/taco/index_notation/properties.h @@ -0,0 +1,12 @@ +#ifndef TACO_PROPERTIES_H +#define TACO_PROPERTIES_H + +namespace taco { + +class Properties { + +}; + +} + +#endif //TACO_PROPERTIES_H diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h new file mode 100644 index 000000000..a60c64593 --- /dev/null +++ b/include/taco/index_notation/tensor_operator.h @@ -0,0 +1,64 @@ +#ifndef TACO_OPS_H +#define TACO_OPS_H + +#include +#include +#include + +#include "taco/ir/ir.h" +#include "taco/util/collections.h" +#include "taco/index_notation/properties.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/index_notation_nodes.h" + +namespace taco { + +class Op { + +using opImpl = TensorOpNode::opImpl; +using algebraImpl = TensorOpNode::algebraImpl; +using regionDefinition = TensorOpNode::regionDefinition; + +public: + Op(opImpl lowererFunc, algebraImpl algebraFunc, std::map, regionDefinition> specialDefinitions) : + Op(lowererFunc, algebraFunc, Properties(), specialDefinitions) {} + + Op(opImpl lowererFunc, algebraImpl algebraFunc, Properties properties = Properties(), + std::map, regionDefinition> specialDefinitions = {}) : + lowererFunc(lowererFunc), algebraFunc(algebraFunc), + properties(properties), regionDefinitions(specialDefinitions) {} + + template + TensorOp operator()(IndexExprs&&... exprs) { + std::vector actualArgs{exprs...}; + IterationAlgebra nodeAlgebra = algebraFunc(actualArgs); + Datatype returnType = inferReturnType(actualArgs); + + TensorOpNode* op = new TensorOpNode(actualArgs, lowererFunc, nodeAlgebra, properties, + regionDefinitions, returnType); + + return TensorOp(op, util::uniqueName("Op")); + } + + +private: + opImpl lowererFunc; + algebraImpl algebraFunc; + Properties properties; + std::map, regionDefinition> regionDefinitions; + + Datatype inferReturnType(const std::vector& inputs) { + std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; + std::vector exprs = util::map(inputs, getExprs); + return lowererFunc(exprs).type(); + } + +}; + +} +#endif //TACO_OPS_H + +// Using vectors for interface to keep it consistent + +// Can't use variadic functions since the lower function would need to be stored meaning methods in the compiler +// would have to be templated. diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 8c419752b..688113cd4 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -21,7 +21,6 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_printer.h" #include "taco/ir/ir.h" -#include "taco/lower/lower.h" #include "taco/codegen/module.h" #include "taco/util/name_generator.h" @@ -752,6 +751,38 @@ template <> Cast to(IndexExpr e) { return Cast(to(e.ptr)); } +// class TensorOp, most construction should happen from tensor_operator.h +TensorOp::TensorOp(const TensorOpNode *n, std::string name) : IndexExpr(n), name(name) { +} + +const std::vector& TensorOp::getArgs() const { + return getNode(*this)->exprs; +} + +const IterationAlgebra& TensorOp::getAlgebra() const { + return getNode(*this)->iterAlg; +} + +const Properties& TensorOp::getProperties() const { + return getNode(*this)->properties; +} + +const std::map, TensorOpNode::regionDefinition> TensorOp::getDefs() const { + return getNode(*this)->regionDefinitions; +} + +std::string TensorOp::getName() const { + return name; +} + +template <> bool isa(IndexExpr e) { + return isa(e.ptr); +} + +template <> TensorOp to(IndexExpr e) { + taco_iassert(isa(e)); + return TensorOp(to(e.ptr)); +} // class CallIntrinsic CallIntrinsic::CallIntrinsic(const CallIntrinsicNode* n) : IndexExpr(n) { diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 813e8a4e8..854402c45 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -29,6 +29,12 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, func(func), args(args) { } +// class TensorOpNode +TensorOpNode::TensorOpNode(const std::vector& exprs, opImpl lowerFunc, + const IterationAlgebra &iterAlg, const Properties &properties, + const std::map, regionDefinition>& regionDefinitions, Datatype type) : + IndexExprNode(type), exprs(exprs), lowerFunc(lowerFunc), iterAlg(iterAlg), + properties(properties), regionDefinitions(regionDefinitions) {} // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index a50b1dea4..c7bb89ef3 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1187,6 +1187,7 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I // Just one iterator so no conditionals if (lattice.iterators().size() == 1) { + taco_iassert(!lattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, appenders, reducedAccesses); result.push_back(body); @@ -1453,6 +1454,7 @@ Expr LowererImpl::lowerSub(Sub sub) { Expr LowererImpl::lowerMul(Mul mul) { + const IndexExpr t = mul.getA(); Expr a = lower(mul.getA()); Expr b = lower(mul.getB()); return (mul.getDataType().getKind() == Datatype::Bool) From 9144c887421654828079a90c0989ae4f2f36be06 Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 28 Feb 2020 20:52:41 -0500 Subject: [PATCH 08/27] Rework user facing API --- include/taco/index_notation/index_notation.h | 5 +- .../index_notation/index_notation_nodes.h | 5 +- include/taco/index_notation/properties.h | 2 +- include/taco/index_notation/tensor_operator.h | 48 +++++++++++++++---- src/index_notation/index_notation.cpp | 5 +- src/index_notation/index_notation_nodes.cpp | 2 +- 6 files changed, 52 insertions(+), 15 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 56e0cc4c7..e1328d1e3 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -406,11 +406,12 @@ class Cast : public IndexExpr { class TensorOp: public IndexExpr { public: TensorOp() = default; - TensorOp(const TensorOpNode*, std::string name = util::uniqueName("Op")); + TensorOp(const TensorOpNode*); + TensorOp(const TensorOpNode*, std::string name); const std::vector& getArgs() const; const IterationAlgebra& getAlgebra() const; - const Properties& getProperties() const; + const std::vector& getProperties() const; const std::map, std::function&)>> getDefs() const; std::string getName() const; diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 232828e7e..e64ae8d3f 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -179,7 +179,8 @@ struct TensorOpNode : public IndexExprNode { typedef std::function&)> regionDefinition; TensorOpNode(const std::vector& exprs, opImpl lowerFunc, const IterationAlgebra& iterAlg, - const Properties& properties, const std::map, regionDefinition>& regionDefinitions, + const std::vector& properties, + const std::map, regionDefinition>& regionDefinitions, Datatype type); void accept(IndexExprVisitorStrict* v) const { @@ -189,7 +190,7 @@ struct TensorOpNode : public IndexExprNode { std::vector exprs; opImpl lowerFunc; IterationAlgebra iterAlg; - Properties properties; + std::vector properties; std::map, regionDefinition> regionDefinitions; }; diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h index ad14cb855..fde1110ba 100644 --- a/include/taco/index_notation/properties.h +++ b/include/taco/index_notation/properties.h @@ -3,7 +3,7 @@ namespace taco { -class Properties { +class Property { }; diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index a60c64593..e1b359cf1 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -20,31 +20,52 @@ using algebraImpl = TensorOpNode::algebraImpl; using regionDefinition = TensorOpNode::regionDefinition; public: - Op(opImpl lowererFunc, algebraImpl algebraFunc, std::map, regionDefinition> specialDefinitions) : - Op(lowererFunc, algebraFunc, Properties(), specialDefinitions) {} - Op(opImpl lowererFunc, algebraImpl algebraFunc, Properties properties = Properties(), - std::map, regionDefinition> specialDefinitions = {}) : - lowererFunc(lowererFunc), algebraFunc(algebraFunc), + // Full construction + Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, + std::map, regionDefinition> specialDefinitions = {}) : + name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) {} + Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, + std::map, regionDefinition> specialDefinitions = {}) : + name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), + properties(properties), regionDefinitions(specialDefinitions) {} + + // Construct without specifying algebra + Op(std::string name, opImpl lowererFunc, std::vector properties, + std::map, regionDefinition> specialDefinitions = {}) : + Op(name, lowererFunc, nullptr, properties, specialDefinitions) {} + + Op(opImpl lowererFunc, std::vector properties, + std::map, regionDefinition> specialDefinitions = {}) : + Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) {} + + // Construct without algebra or properties + Op(std::string name, opImpl lowererFunc) : Op(name, lowererFunc, nullptr) {} + + explicit Op(opImpl lowererFunc) : Op(lowererFunc, nullptr) {} + + template TensorOp operator()(IndexExprs&&... exprs) { std::vector actualArgs{exprs...}; - IterationAlgebra nodeAlgebra = algebraFunc(actualArgs); + + IterationAlgebra nodeAlgebra = algebraFunc != nullptr? algebraFunc(actualArgs): inferAlgFromProperties(actualArgs); Datatype returnType = inferReturnType(actualArgs); TensorOpNode* op = new TensorOpNode(actualArgs, lowererFunc, nodeAlgebra, properties, regionDefinitions, returnType); - return TensorOp(op, util::uniqueName("Op")); + return TensorOp(op, name); } private: + std::string name; opImpl lowererFunc; algebraImpl algebraFunc; - Properties properties; + std::vector properties; std::map, regionDefinition> regionDefinitions; Datatype inferReturnType(const std::vector& inputs) { @@ -53,6 +74,17 @@ using regionDefinition = TensorOpNode::regionDefinition; return lowererFunc(exprs).type(); } + IterationAlgebra inferAlgFromProperties(const std::vector& exprs) { + if(properties.empty()) { + return constructDefaultAlgebra(exprs); + } + return {}; + } + + IterationAlgebra constructDefaultAlgebra(const std::vector& exprs) { + return {}; + } + }; } diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 688113cd4..20efd975a 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -752,6 +752,9 @@ template <> Cast to(IndexExpr e) { } // class TensorOp, most construction should happen from tensor_operator.h +TensorOp::TensorOp(const TensorOpNode* n) : IndexExpr(n) { +} + TensorOp::TensorOp(const TensorOpNode *n, std::string name) : IndexExpr(n), name(name) { } @@ -763,7 +766,7 @@ const IterationAlgebra& TensorOp::getAlgebra() const { return getNode(*this)->iterAlg; } -const Properties& TensorOp::getProperties() const { +const std::vector& TensorOp::getProperties() const { return getNode(*this)->properties; } diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 854402c45..6e18ca6b9 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -31,7 +31,7 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, // class TensorOpNode TensorOpNode::TensorOpNode(const std::vector& exprs, opImpl lowerFunc, - const IterationAlgebra &iterAlg, const Properties &properties, + const IterationAlgebra &iterAlg, const std::vector &properties, const std::map, regionDefinition>& regionDefinitions, Datatype type) : IndexExprNode(type), exprs(exprs), lowerFunc(lowerFunc), iterAlg(iterAlg), properties(properties), regionDefinitions(regionDefinitions) {} From 4f82ccb21f0fd45e3a33a2507c42a8d2e758fdf1 Mon Sep 17 00:00:00 2001 From: Rawn Date: Wed, 4 Mar 2020 23:23:03 -0500 Subject: [PATCH 09/27] Added TensorOpNode and made tests for new iteration algebra functions --- include/taco/index_notation/index_notation.h | 4 +- .../index_notation/index_notation_nodes.h | 6 +- .../index_notation/index_notation_printer.h | 1 + .../index_notation/index_notation_rewriter.h | 2 + .../index_notation/index_notation_visitor.h | 4 + .../taco/index_notation/iteration_algebra.h | 50 +++- include/taco/index_notation/properties.h | 86 +++++++ include/taco/index_notation/tensor_operator.h | 22 +- include/taco/lower/lowerer_impl.h | 3 +- src/index_notation/index_notation.cpp | 190 ++++++++++++++- src/index_notation/index_notation_nodes.cpp | 12 +- src/index_notation/index_notation_printer.cpp | 7 + .../index_notation_rewriter.cpp | 27 +++ src/index_notation/index_notation_visitor.cpp | 6 + src/index_notation/iteration_algebra.cpp | 159 +++++++++++- .../iteration_algebra_printer.cpp | 2 +- src/lower/expr_tools.cpp | 4 + src/lower/lowerer_impl.cpp | 8 + src/lower/merge_lattice.cpp | 6 +- test/test-iteration_algebra.cpp | 227 +++++++++++++++++- 20 files changed, 784 insertions(+), 42 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index e1328d1e3..9c5c4ed2f 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -410,12 +410,12 @@ class TensorOp: public IndexExpr { TensorOp(const TensorOpNode*, std::string name); const std::vector& getArgs() const; + const std::function&)> getFunc() const; const IterationAlgebra& getAlgebra() const; const std::vector& getProperties() const; + const std::string getName() const; const std::map, std::function&)>> getDefs() const; - std::string getName() const; - typedef TensorOpNode Node; private: diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index e64ae8d3f..01ade2412 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -178,7 +178,8 @@ struct TensorOpNode : public IndexExprNode { typedef std::function&)> algebraImpl; typedef std::function&)> regionDefinition; - TensorOpNode(const std::vector& exprs, opImpl lowerFunc, const IterationAlgebra& iterAlg, + TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, + const IterationAlgebra& iterAlg, const std::vector& properties, const std::map, regionDefinition>& regionDefinitions, Datatype type); @@ -187,7 +188,8 @@ struct TensorOpNode : public IndexExprNode { v->visit(this); } - std::vector exprs; + std::string name; + std::vector args; opImpl lowerFunc; IterationAlgebra iterAlg; std::vector properties; diff --git a/include/taco/index_notation/index_notation_printer.h b/include/taco/index_notation/index_notation_printer.h index ed2dd7abb..3aceaaa20 100644 --- a/include/taco/index_notation/index_notation_printer.h +++ b/include/taco/index_notation/index_notation_printer.h @@ -25,6 +25,7 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict { void visit(const MulNode*); void visit(const DivNode*); void visit(const CastNode*); + void visit(const TensorOpNode*); void visit(const CallIntrinsicNode*); void visit(const ReductionNode*); void visit(const IndexVarNode*); diff --git a/include/taco/index_notation/index_notation_rewriter.h b/include/taco/index_notation/index_notation_rewriter.h index a4e340b0f..729349a58 100644 --- a/include/taco/index_notation/index_notation_rewriter.h +++ b/include/taco/index_notation/index_notation_rewriter.h @@ -32,6 +32,7 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict { virtual void visit(const MulNode* op) = 0; virtual void visit(const DivNode* op) = 0; virtual void visit(const CastNode* op) = 0; + virtual void visit(const TensorOpNode* op) = 0; virtual void visit(const CallIntrinsicNode* op) = 0; virtual void visit(const ReductionNode* op) = 0; virtual void visit(const IndexVarNode* op) = 0; @@ -94,6 +95,7 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict { virtual void visit(const MulNode* op); virtual void visit(const DivNode* op); virtual void visit(const CastNode* op); + virtual void visit(const TensorOpNode* op); virtual void visit(const CallIntrinsicNode* op); virtual void visit(const ReductionNode* op); virtual void visit(const IndexVarNode* op); diff --git a/include/taco/index_notation/index_notation_visitor.h b/include/taco/index_notation/index_notation_visitor.h index 9e3622289..2cbde9489 100644 --- a/include/taco/index_notation/index_notation_visitor.h +++ b/include/taco/index_notation/index_notation_visitor.h @@ -20,6 +20,7 @@ struct MulNode; struct DivNode; struct SqrtNode; struct CastNode; +struct TensorOpNode; struct CallIntrinsicNode; struct UnaryExprNode; struct BinaryExprNode; @@ -51,6 +52,7 @@ class IndexExprVisitorStrict { virtual void visit(const DivNode*) = 0; virtual void visit(const SqrtNode*) = 0; virtual void visit(const CastNode*) = 0; + virtual void visit(const TensorOpNode*) = 0; virtual void visit(const CallIntrinsicNode*) = 0; virtual void visit(const ReductionNode*) = 0; virtual void visit(const IndexVarNode*) = 0; @@ -98,6 +100,7 @@ class IndexNotationVisitor : public IndexNotationVisitorStrict { virtual void visit(const DivNode* node); virtual void visit(const SqrtNode* node); virtual void visit(const CastNode* node); + virtual void visit(const TensorOpNode* node); virtual void visit(const CallIntrinsicNode* node); virtual void visit(const UnaryExprNode* node); virtual void visit(const BinaryExprNode* node); @@ -167,6 +170,7 @@ class Matcher : public IndexNotationVisitor { RULE(MulNode) RULE(DivNode) RULE(CastNode) + RULE(TensorOpNode) RULE(CallIntrinsicNode) RULE(ReductionNode) diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 33ee1db79..2b712df5f 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -98,25 +98,25 @@ struct BinaryIterationAlgebraNode: public IterationAlgebraNode { IterationAlgebra a; IterationAlgebra b; protected: - BinaryIterationAlgebraNode(IterationAlgebra a, IterationAlgebra b) : a(a), b(b) {} + BinaryIterationAlgebraNode(IterationAlgebra a, IterationAlgebra b) : IterationAlgebraNode(), a(a), b(b) {} }; /// A node which is wrapped by Region. @see Region struct RegionNode: public IterationAlgebraNode { public: RegionNode() : IterationAlgebraNode() {} - RegionNode(IndexExpr expr) : expr(expr) {} + RegionNode(IndexExpr expr) : IterationAlgebraNode(), expr_(expr) {} void accept(IterationAlgebraVisitorStrict*) const; - const IndexExpr indexExpr() const; + const IndexExpr expr() const; private: - IndexExpr expr; + IndexExpr expr_; }; /// A node which is wrapped by Complement. @see Complement struct ComplementNode: public IterationAlgebraNode { IterationAlgebra a; public: - ComplementNode(IterationAlgebra a) : a(a) {} + ComplementNode(IterationAlgebra a) : IterationAlgebraNode(), a(a) {} void accept(IterationAlgebraVisitorStrict*) const; }; @@ -155,6 +155,7 @@ class IterationAlgebraVisitorStrict { // Default Iteration Algebra visitor class IterationAlgebraVisitor : public IterationAlgebraVisitorStrict { +public: virtual ~IterationAlgebraVisitor() {} using IterationAlgebraVisitorStrict::visit; @@ -194,6 +195,45 @@ class IterationAlgebraRewriter : public IterationAlgebraRewriterStrict { virtual void visit(const IntersectNode*); virtual void visit(const UnionNode*); }; + +/// Returns true if algebra e is of type E. +template +inline bool isa(const IterationAlgebraNode* e) { + return e != nullptr && dynamic_cast(e) != nullptr; +} + +/// Casts the algebraNode e to type E. +template +inline const E* to(const IterationAlgebraNode* e) { + taco_iassert(isa(e)) << + "Cannot convert " << typeid(e).name() << " to " << typeid(E).name(); + return static_cast(e); +} + +/// Return true if the iteration algebra is of the given subtype. The subtypes +/// are Region, Complement, Union and Intersect. +template bool isa(IterationAlgebra); + +/// Casts the iteration algebra to the given subtype. Assumes S is a subtype and +/// the subtypes are Region, Complement, Union and Intersect. +template SubType to(IterationAlgebra); + +/// Returns true if the structure of the iteration algebra is the same. +/// This means that intersections, unions, complements and regions appear +/// in the same places but the IndexExpressions these operations are applied +/// to are not necessarily the same. +bool algStructureEqual(const IterationAlgebra&, const IterationAlgebra&); + +/// Returns true if the iterations algebras passed in have the same structure +/// and the Index Expressions that they operate on are the same. +bool algEqual(const IterationAlgebra&, const IterationAlgebra&); + +/// Applies demorgan's laws to the algebra passed in and returns a new algebra +/// which describes the same space but with complements appearing only around region +/// nodes. +IterationAlgebra applyDemorgan(IterationAlgebra alg); + } + #endif // TACO_ITERATION_ALGEBRA_H diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h index fde1110ba..a148ca2b4 100644 --- a/include/taco/index_notation/properties.h +++ b/include/taco/index_notation/properties.h @@ -1,12 +1,98 @@ #ifndef TACO_PROPERTIES_H #define TACO_PROPERTIES_H +#include +#include + +#include "taco/error.h" +#include "taco/util/comparable.h" + namespace taco { +class Literal; + class Property { +public: + virtual ~Property(); + virtual bool defined() const; + virtual bool equals(const Property&) const; +}; + +class Annihilator : public Property { +public: + Annihilator(); + Annihilator(Literal); + const Literal& getAnnihilator() const; + virtual bool defined() const; + virtual bool equals(const Property&) const; + +private: + struct Content; + std::shared_ptr content; +}; + +class Identity : public Property { +public: + Identity(); + Identity(Literal); + const Literal& getIdentity() const; + virtual bool defined() const; + virtual bool equals(const Property&) const; +private: + struct Content; + std::shared_ptr content; }; + +class Associative : public Property { +public: + Associative(); + static Associative makeUndefined(); + + virtual bool defined() const; + virtual bool equals(const Property&) const; + +private: + bool isDefined; +}; + +class Commutative : public Property { +public: + Commutative(); + Commutative(std::vector); + static Commutative makeUndefined(); + + const std::vector& ordering() const; + virtual bool defined() const; + virtual bool equals(const Property&) const; + +private: + const std::vector ordering_; + bool isDefined; +}; + +/// Returns true if property p is of type P. +template +inline bool isa(const Property& p) { + return dynamic_cast(&p) != nullptr; +} + +/// Casts the Property p to type P. +template +inline const P& to(const Property& p) { + taco_iassert(isa

(p)) << "Cannot convert " << typeid(p).name() << " to " << typeid(P).name(); + return static_cast(p); +} + +template +inline const P findProperty(const std::vector& properties, P defaultProperty) { + for (const auto& p: properties) { + if(isa

(p)) return to

(p); + } + return defaultProperty; +} + } #endif //TACO_PROPERTIES_H diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index e1b359cf1..c0f18e7c3 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -51,13 +51,13 @@ using regionDefinition = TensorOpNode::regionDefinition; TensorOp operator()(IndexExprs&&... exprs) { std::vector actualArgs{exprs...}; - IterationAlgebra nodeAlgebra = algebraFunc != nullptr? algebraFunc(actualArgs): inferAlgFromProperties(actualArgs); + IterationAlgebra nodeAlgebra = algebraFunc == nullptr? inferAlgFromProperties(actualArgs): algebraFunc(actualArgs); Datatype returnType = inferReturnType(actualArgs); - TensorOpNode* op = new TensorOpNode(actualArgs, lowererFunc, nodeAlgebra, properties, + TensorOpNode* op = new TensorOpNode(name, actualArgs, lowererFunc, nodeAlgebra, properties, regionDefinitions, returnType); - return TensorOp(op, name); + return TensorOp(op); } @@ -81,8 +81,17 @@ using regionDefinition = TensorOpNode::regionDefinition; return {}; } - IterationAlgebra constructDefaultAlgebra(const std::vector& exprs) { - return {}; + // Constructs an algebra that iterates over the entire space + static IterationAlgebra constructDefaultAlgebra(const std::vector& exprs) { + if(exprs.empty()) return Region(); + + IterationAlgebra tensorsRegions(exprs[0]); + for(size_t i = 1; i < exprs.size(); ++i) { + tensorsRegions = Union(tensorsRegions, exprs[i]); + } + + IterationAlgebra background = Complement(tensorsRegions); + return Union(tensorsRegions, background); } }; @@ -90,7 +99,4 @@ using regionDefinition = TensorOpNode::regionDefinition; } #endif //TACO_OPS_H -// Using vectors for interface to keep it consistent -// Can't use variadic functions since the lower function would need to be stored meaning methods in the compiler -// would have to be templated. diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index 66fa5330d..c81e5ca37 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -207,6 +207,8 @@ class LowererImpl : public util::Uncopyable { /// Lower an IndexVar expression virtual ir::Expr lowerIndexVar(IndexVar var); + /// Lower a generic tensor operation expression + virtual ir::Expr lowerTensorOp(TensorOp op); /// Lower a concrete index variable statement. ir::Stmt lower(IndexStmt stmt); @@ -214,7 +216,6 @@ class LowererImpl : public util::Uncopyable { /// Lower a concrete index variable expression. ir::Expr lower(IndexExpr expr); - /// Check whether the lowerer should generate code to assemble result indices. bool generateAssembleCode() const; diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 20efd975a..b3c48c117 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -14,6 +14,7 @@ #include "taco/type.h" #include "taco/format.h" +#include "taco/index_notation/properties.h" #include "taco/index_notation/intrinsic.h" #include "taco/index_notation/schedule.h" #include "taco/index_notation/transformations.h" @@ -27,6 +28,7 @@ #include "taco/util/scopedmap.h" #include "taco/util/strings.h" #include "taco/util/collections.h" +#include "taco/util/functions.h" using namespace std; @@ -240,6 +242,63 @@ struct Equals : public IndexNotationVisitorStrict { eq = true; } + void visit(const TensorOpNode* anode) { + if (!isa(bExpr.ptr)) { + eq = false; + return; + } + auto bnode = to(bExpr.ptr); + + // Properties + if (anode->properties.size() != bnode->properties.size()) { + eq = false; + return; + } + + for(const auto& a_prop : anode->properties) { + bool found = false; + for(const auto& b_prop : bnode->properties) { + if(a_prop.equals(b_prop)) { + found = true; + break; + } + } + if (!found) { + eq = false; + return; + } + } + + // Lower function + // TODO: For now just check that the function pointers are the same. + if(!util::targetPtrEqual(anode->lowerFunc, bnode->lowerFunc)) { + eq = false; + return; + } + + // Check arguments + if (anode->args.size() != bnode->args.size()) { + eq = false; + return; + } + + for (size_t i = 0; i < anode->args.size(); ++i) { + if (!equals(anode->args[i], bnode->args[i])) { + eq = false; + return; + } + } + + // Algebra + if (!checkIterationAlg(anode, bnode)) { + eq = false; + return; + } + + // Special definitions + eq = checkRegionDefinitions(anode, bnode); + } + void visit(const CallIntrinsicNode* anode) { if (!isa(bExpr.ptr)) { eq = false; @@ -381,6 +440,72 @@ struct Equals : public IndexNotationVisitorStrict { } eq = true; } + + static bool checkRegionDefinitions(const TensorOpNode* anode, const TensorOpNode* bnode) { + // Check region definitions + if (anode->regionDefinitions.size() != bnode->regionDefinitions.size()) { + return false; + } + + auto& aDefs = anode->regionDefinitions; + auto& bDefs = bnode->regionDefinitions; + for (auto itA = aDefs.begin(), itB = bDefs.begin(); itA != aDefs.end(); ++itA, ++itB) { + if(itA->first != itB->first) { + return false; + } + + std::vector aArgs; + std::vector bArgs; + for(int idx : itA->first) { + taco_iassert((size_t)idx < anode->args.size()); // We already know anode->args.size == bnode->args.size + aArgs.push_back(anode->args[idx]); + bArgs.push_back(bnode->args[idx]); + } + + IndexExpr aRes = itA->second(aArgs); + IndexExpr bRes = itB->second(bArgs); + if(!equals(aRes, bRes)) { + return false; + } + } + + return true; + } + + /// Checks if the iteration algebra structure is the same and the ordering of the index expressions + /// nested under regions is the same for each op node. + static bool checkIterationAlg(const TensorOpNode* anode, const TensorOpNode* bnode) { + // Check IterationAlgebra structures + if(!algStructureEqual(anode->iterAlg, bnode->iterAlg)) { + return false; + } + + struct OrderChecker : public IterationAlgebraVisitor { + explicit OrderChecker(const TensorOpNode* op) : op(op) {} + + std::vector& check() { + op->iterAlg.accept(this); + return ordering; + } + + using IterationAlgebraVisitor::visit; + + void visit(const RegionNode* region) { + const IndexExpr& e = region->expr(); + auto it = std::find(op->args.begin(), op->args.end(), e); + taco_iassert(it != op->args.end()) << "Iteration algebra region expressions must be in arguments"; + size_t loc = it - op->args.begin(); + ordering.push_back(loc); + } + + std::vector ordering; + const TensorOpNode* op; + }; + + std::vector aOrdering = OrderChecker(anode).check(); + std::vector bOrdering = OrderChecker(bnode).check(); + return aOrdering == bOrdering; + } }; bool equals(IndexExpr a, IndexExpr b) { @@ -759,7 +884,11 @@ TensorOp::TensorOp(const TensorOpNode *n, std::string name) : IndexExpr(n), name } const std::vector& TensorOp::getArgs() const { - return getNode(*this)->exprs; + return getNode(*this)->args; +} + +const std::function &)> TensorOp::getFunc() const { + return getNode(*this)->lowerFunc; } const IterationAlgebra& TensorOp::getAlgebra() const { @@ -770,13 +899,15 @@ const std::vector& TensorOp::getProperties() const { return getNode(*this)->properties; } +const std::string TensorOp::getName() const { + return getNode(*this)->name; +} + const std::map, TensorOpNode::regionDefinition> TensorOp::getDefs() const { return getNode(*this)->regionDefinitions; } -std::string TensorOp::getName() const { - return name; -} + template <> bool isa(IndexExpr e) { return isa(e.ptr); @@ -2346,6 +2477,57 @@ struct Zero : public IndexNotationRewriterStrict { } } + void visit(const TensorOpNode* op) { + std::vector args; + bool rewritten = false; + + Annihilator annihilator = findProperty(op->properties, Annihilator()); + Literal annihilatorVal = annihilator.defined()? annihilator.getAnnihilator(): Literal(); + + // TODO: Check exhausted default against result default + for(auto& arg : op->args) { + IndexExpr rewrittenArg = rewrite(arg); + rewrittenArg = rewrittenArg.defined()? rewrittenArg: Literal::zero(arg.getDataType()); + if(equals(annihilatorVal, rewrittenArg)) { + expr = IndexExpr(); + return; + } + + args.push_back(rewrittenArg); + if (arg != rewrittenArg) { + rewritten = true; + } + } + + Identity identity = findProperty(op->properties, Identity()); + Literal identityVal = identity.defined()? identity.getIdentity(): Literal(); + + // If only one term is not the identity, replace expr with just that term + size_t nonIdentityTerms = 0; + IndexExpr nonIdentityTerm; + for(const auto& arg : args) { + if(!equals(identityVal, arg)) { + nonIdentityTerm = arg; + ++nonIdentityTerms; + } + if(nonIdentityTerms > 1) break; + } + + if(nonIdentityTerms == 1) { + expr = nonIdentityTerm; + return; + } + + if (rewritten) { + expr = new TensorOpNode(op->name, args, op->lowerFunc, op->iterAlg, op->properties, + op->regionDefinitions, op->getDataType()); + } + else { + expr = op; + } + + } + void visit(const CallIntrinsicNode* op) { std::vector args; std::vector zeroArgs; diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 6e18ca6b9..8c2ab8c45 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -30,11 +30,17 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, } // class TensorOpNode -TensorOpNode::TensorOpNode(const std::vector& exprs, opImpl lowerFunc, +TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, const IterationAlgebra &iterAlg, const std::vector &properties, const std::map, regionDefinition>& regionDefinitions, Datatype type) : - IndexExprNode(type), exprs(exprs), lowerFunc(lowerFunc), iterAlg(iterAlg), - properties(properties), regionDefinitions(regionDefinitions) {} + IndexExprNode(type), name(name), args(args), lowerFunc(lowerFunc), + iterAlg(applyDemorgan(iterAlg)), properties(properties), + regionDefinitions(regionDefinitions) { + taco_iassert(lowerFunc != nullptr); + for (const auto& pair: regionDefinitions) { + taco_iassert(args.size() >= pair.first.size()); + } + } // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) diff --git a/src/index_notation/index_notation_printer.cpp b/src/index_notation/index_notation_printer.cpp index 96501fec6..7cdd8edcc 100644 --- a/src/index_notation/index_notation_printer.cpp +++ b/src/index_notation/index_notation_printer.cpp @@ -157,6 +157,13 @@ static inline void acceptJoin(IndexNotationPrinter* printer, } } +void IndexNotationPrinter::visit(const TensorOpNode* op) { + parentPrecedence = Precedence::FUNC; + os << op->name << "("; + acceptJoin(this, os, op->args, ", "); + os << ")"; +} + void IndexNotationPrinter::visit(const CallIntrinsicNode* op) { parentPrecedence = Precedence::FUNC; os << op->func->getName(); diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 0a8570a9b..4b847fc8b 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -107,6 +107,25 @@ void IndexNotationRewriter::visit(const CastNode* op) { } } +void IndexNotationRewriter::visit(const TensorOpNode* op) { + std::vector args; + bool rewritten = false; + for(auto& arg : op->args) { + IndexExpr rewrittenArg = rewrite(arg); + args.push_back(rewrittenArg); + if (arg != rewrittenArg) { + rewritten = true; + } + } + if (rewritten) { + expr = new TensorOpNode(op->name, args, op->lowerFunc, op->iterAlg, op->properties, + op->regionDefinitions, op->getDataType()); + } + else { + expr = op; + } +} + void IndexNotationRewriter::visit(const CallIntrinsicNode* op) { std::vector args; bool rewritten = false; @@ -282,6 +301,14 @@ struct ReplaceRewriter : public IndexNotationRewriter { SUBSTITUTE_EXPR; } + void visit(const TensorOpNode* op) { + SUBSTITUTE_EXPR; + } + + void visit(const CallIntrinsicNode* op) { + SUBSTITUTE_EXPR; + } + void visit(const ReductionNode* op) { SUBSTITUTE_EXPR; } diff --git a/src/index_notation/index_notation_visitor.cpp b/src/index_notation/index_notation_visitor.cpp index e9b1f952d..e954895f7 100644 --- a/src/index_notation/index_notation_visitor.cpp +++ b/src/index_notation/index_notation_visitor.cpp @@ -69,6 +69,12 @@ void IndexNotationVisitor::visit(const CastNode* op) { op->a.accept(this); } +void IndexNotationVisitor::visit(const TensorOpNode* op) { + for (auto& arg : op->args) { + arg.accept(this); + } +} + void IndexNotationVisitor::visit(const CallIntrinsicNode* op) { for (auto& arg : op->args) { arg.accept(this); diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index 4cf2a41a1..66ce3614e 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -27,18 +27,59 @@ Region::Region() : IterationAlgebra(new RegionNode) {} Region::Region(IndexExpr expr) : IterationAlgebra(expr) {} Region::Region(const taco::RegionNode *n) : IterationAlgebra(n) {} +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Region to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Region(to(alg.ptr)); +} + // Complement -Complement::Complement(const ComplementNode* n): IterationAlgebra(n) {} -Complement::Complement(IterationAlgebra alg) : Complement(new ComplementNode(alg)) {} +Complement::Complement(const ComplementNode* n): IterationAlgebra(n) { +} + +Complement::Complement(IterationAlgebra alg) : Complement(new ComplementNode(alg)) { +} + + +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Complement to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Complement(to(alg.ptr)); +} // Intersect Intersect::Intersect(IterationAlgebra a, IterationAlgebra b) : Intersect(new IntersectNode(a, b)) {} Intersect::Intersect(const IterationAlgebraNode* n) : IterationAlgebra(n) {} +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Intersect to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Intersect(to(alg.ptr)); +} + // Union Union::Union(IterationAlgebra a, IterationAlgebra b) : Union(new UnionNode(a, b)) {} Union::Union(const IterationAlgebraNode* n) : IterationAlgebra(n) {} +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Union to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Union(to(alg.ptr)); +} + + // Node method definitions start here: // Definitions for RegionNode @@ -46,8 +87,8 @@ void RegionNode::accept(IterationAlgebraVisitorStrict *v) const { v->visit(this); } -const IndexExpr RegionNode::indexExpr() const { - return expr; +const IndexExpr RegionNode::expr() const { + return expr_; } // Definitions for ComplementNode @@ -102,7 +143,7 @@ void IterationAlgebraVisitor::visit(const UnionNode *n) { IterationAlgebra IterationAlgebraRewriterStrict::rewrite(IterationAlgebra iter_alg) { if(iter_alg.defined()) { iter_alg.accept(this); - alg = iter_alg; + iter_alg = alg; } else { iter_alg = IterationAlgebra(); @@ -147,4 +188,112 @@ void IterationAlgebraRewriter::visit(const UnionNode *n) { alg = new UnionNode(a, b); } } + +struct AlgComparer : public IterationAlgebraVisitorStrict { + + bool eq = false; + IterationAlgebra bAlg; + bool checkIndexExprs; + + explicit AlgComparer(bool checkIndexExprs) : checkIndexExprs(checkIndexExprs) { + } + + bool compare(const IterationAlgebra& a, const IterationAlgebra& b) { + bAlg = b; + a.accept(this); + return eq; + } + + void visit(const RegionNode* node) { + if(!isa(bAlg.ptr)) { + eq = false; + return; + } + + auto bnode = to(bAlg.ptr); + if (checkIndexExprs && !equals(node->expr(), bnode->expr())) { + eq = false; + return; + } + eq = true; + } + + void visit(const ComplementNode* node) { + if (!isa(bAlg.ptr)) { + eq = false; + return; + } + + auto bNode = to(bAlg.ptr); + eq = AlgComparer(checkIndexExprs).compare(node->a, bNode->a); + } + + template + bool binaryCheck(const T* anode, IterationAlgebra b) { + if (!isa(b.ptr)) { + return false; + } + auto bnode = to(b.ptr); + return AlgComparer(checkIndexExprs).compare(anode->a, bnode->a) && + AlgComparer(checkIndexExprs).compare(anode->b, bnode->b); + } + + + void visit(const IntersectNode* node) { + eq = binaryCheck(node, bAlg); + } + + void visit(const UnionNode* node) { + eq = binaryCheck(node, bAlg); + } + +}; + +bool algStructureEqual(const IterationAlgebra& a, const IterationAlgebra& b) { + return AlgComparer(false).compare(a, b); +} + +bool algEqual(const IterationAlgebra& a, const IterationAlgebra& b) { + return AlgComparer(true).compare(a, b); +} + +class DeMorganApplier : public IterationAlgebraRewriterStrict { + + void visit(const RegionNode* n) { + alg = Complement(n); + } + + void visit(const ComplementNode* n) { + alg = applyDemorgan(n->a); + } + + template + IterationAlgebra binaryVisit(Node n) { + IterationAlgebra a = applyDemorgan(Complement(n->a)); + IterationAlgebra b = applyDemorgan(Complement(n->b)); + return new ComplementedNode(a, b); + } + + void visit(const IntersectNode* n) { + alg = binaryVisit(n); + } + + void visit(const UnionNode* n) { + alg = binaryVisit(n); + } +}; + +struct DeMorganDispatcher : public IterationAlgebraRewriter { + + using IterationAlgebraRewriter::visit; + + void visit(const ComplementNode *n) { + alg = DeMorganApplier().rewrite(n->a); + } +}; + +IterationAlgebra applyDemorgan(IterationAlgebra alg) { + return DeMorganDispatcher().rewrite(alg); +} + } \ No newline at end of file diff --git a/src/index_notation/iteration_algebra_printer.cpp b/src/index_notation/iteration_algebra_printer.cpp index 85507243f..d582b1f5b 100644 --- a/src/index_notation/iteration_algebra_printer.cpp +++ b/src/index_notation/iteration_algebra_printer.cpp @@ -11,7 +11,7 @@ void IterationAlgebraPrinter::print(const IterationAlgebra& alg) { } void IterationAlgebraPrinter::visit(const RegionNode* n) { - os << n->indexExpr(); + os << n->expr(); } void IterationAlgebraPrinter::visit(const ComplementNode* n) { diff --git a/src/lower/expr_tools.cpp b/src/lower/expr_tools.cpp index d34a79703..2d86bd966 100644 --- a/src/lower/expr_tools.cpp +++ b/src/lower/expr_tools.cpp @@ -261,6 +261,10 @@ class SubExprVisitor : public IndexExprVisitorStrict { subExpr = binarySubExpr(op); } + void visit(const TensorOpNode* op) { + taco_not_supported_yet; + } + void visit(const CastNode* op) { taco_not_supported_yet; } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index c7bb89ef3..7bc88f67f 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -56,6 +56,7 @@ class LowererImpl::Visitor : public IndexNotationVisitorStrict { void visit(const SqrtNode* node) { expr = impl->lowerSqrt(node); } void visit(const CastNode* node) { expr = impl->lowerCast(node); } void visit(const CallIntrinsicNode* node) { expr = impl->lowerCallIntrinsic(node); } + void visit(const TensorOpNode* node) { expr = impl->lowerTensorOp(node); } void visit(const ReductionNode* node) { taco_ierror << "Reduction nodes not supported in concrete index notation"; } @@ -1485,6 +1486,13 @@ Expr LowererImpl::lowerCallIntrinsic(CallIntrinsic call) { return call.getFunc().lower(args); } +Expr LowererImpl::lowerTensorOp(TensorOp op) { + std::vector args; + for (auto& arg : op.getArgs()) { + args.push_back(lower(arg)); + } + return op.getFunc()(args); +} Stmt LowererImpl::lower(IndexStmt stmt) { return visitor->lower(stmt); diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 0928d4019..3445bb592 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -89,7 +89,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } void visit(const RegionNode* node) { - lattice = build(node->indexExpr()); + lattice = build(node->expr()); } void visit(const ComplementNode* node) { @@ -266,6 +266,10 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA lattice = build(expr->a); } + void visit(const TensorOpNode* expr) { + lattice = build(expr->iterAlg); + } + void visit(const CallIntrinsicNode* expr) { const auto zeroPreservingArgsSets = expr->func->zeroPreservingArgs(expr->args); diff --git a/test/test-iteration_algebra.cpp b/test/test-iteration_algebra.cpp index b02bbc592..b95513881 100644 --- a/test/test-iteration_algebra.cpp +++ b/test/test-iteration_algebra.cpp @@ -1,18 +1,225 @@ #include "test.h" - -#include "taco/type.h" - #include "taco/index_notation/index_notation.h" #include "taco/index_notation/iteration_algebra.h" using namespace taco; -const TensorVar A_var("A", Type()), B_var("B", Type()), C_var("C", Type()); -const Access A(A_var), B(B_var), C(C_var); +const IndexVar i("i"), j("j"), k("k"); + +Type vec(type(), {3}); +TensorVar v1("v1", vec), v2("v2", vec), v3("v3", vec); + +TEST(iteration_algebra, region) { + Access access = v1(i); + IterationAlgebra alg = access; + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Region region = to(alg); + const RegionNode* node = to(alg.ptr); + ASSERT_FALSE(isa(alg)); + ASSERT_EQ(access, node->expr()); +} + +TEST(iteration_algebra, Complement) { + IndexExpr expr = v1(i) + v2(i); + IterationAlgebra alg = Complement(expr); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Complement complement = to(alg); + const ComplementNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(expr, n->a)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r = to(n->a.ptr); + ASSERT_EQ(expr, r->expr()); +} + +TEST(iteration_algebra, Union) { + IndexExpr exprA = v1(i); + IndexExpr exprB = v2(i); + IterationAlgebra alg = Union(exprA, exprB); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Union u = to(alg); + const UnionNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(exprA, n->a)); + ASSERT_TRUE(algEqual(exprB, n->b)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r1 = to(n->a.ptr); + ASSERT_TRUE(isa(n->b.ptr)); + const RegionNode* r2 = to(n->b.ptr); + + ASSERT_EQ(exprA, r1->expr()); + ASSERT_EQ(exprB, r2->expr()); +} + +TEST(iteration_algebra, Intersect) { + IndexExpr exprA = v1(i); + IndexExpr exprB = v2(i); + IterationAlgebra alg = Intersect(exprA, exprB); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Intersect i = to(alg); + const IntersectNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(exprA, n->a)); + ASSERT_TRUE(algEqual(exprB, n->b)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r1 = to(n->a.ptr); + ASSERT_TRUE(isa(n->b.ptr)); + const RegionNode* r2 = to(n->b.ptr); + + ASSERT_EQ(exprA, r1->expr()); + ASSERT_EQ(exprB, r2->expr()); +} + +TEST(iteration_algebra, comparatorRegion) { + IterationAlgebra alg1(v1(i)); + IterationAlgebra alg2(v2(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algEqual(alg1, alg1)); + ASSERT_TRUE(algStructureEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorComplement) { + IterationAlgebra alg1 = Complement(v2(i)); + IterationAlgebra alg2 = Complement(v3(j)); +// ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorIntersect) { + IterationAlgebra alg1 = Intersect(v1(i), v2(i)); + IterationAlgebra alg2 = Intersect(v1(j), v3(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorUnion) { + IterationAlgebra alg1 = Union(v1(i), v2(i)); + IterationAlgebra alg2 = Union(v1(j), v3(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorMix) { + IterationAlgebra alg1 = Union(Intersect(v1(i), v2(i)), Complement(v3(i))); + IterationAlgebra alg2 = Union(Intersect(v1(j), v2(j)), Complement(v3(j))); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, deMorganRegion) { + IterationAlgebra alg(v1(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, deMorganComplement) { + IterationAlgebra alg = Complement(v1(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, DeMorganNestedComplements) { + IterationAlgebra alg = v1(i); + for(int cnt = 0; cnt < 10; ++cnt) { + if(cnt % 2 == 0) { + IterationAlgebra simplified = applyDemorgan(alg); + IterationAlgebra expectedEven = v1(i); + ASSERT_TRUE(algEqual(expectedEven, simplified)); + } + else { + IterationAlgebra simplified = applyDemorgan(alg); + IterationAlgebra expectedOdd = Complement(v1(i)); + ASSERT_TRUE(algEqual(expectedOdd, simplified)); + } + alg = Complement(alg); + } +} + +TEST(iteration_algebra, deMorganIntersect) { + IterationAlgebra alg = Intersect(v1(i), v2(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, deMorganUnion) { + IterationAlgebra alg = Union(v1(i), v2(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, UnionComplement) { + IterationAlgebra alg = Union(v1(i), Complement(v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, flipUnionToIntersect) { + IterationAlgebra alg = Complement(Union(v1(i), v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_FALSE(algEqual(alg, simplified)); + ASSERT_FALSE(algStructureEqual(alg, simplified)); + + IterationAlgebra expected = Intersect(Complement(v1(i)), Complement(v2(i))); + ASSERT_TRUE(algEqual(simplified, expected)); + ASSERT_TRUE(algStructureEqual(simplified, expected)); +} + +TEST(iteration_algebra, flipIntersectToUnion) { + IterationAlgebra alg = Complement(Intersect(v1(i), v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_FALSE(algEqual(alg, simplified)); + ASSERT_FALSE(algStructureEqual(alg, simplified)); + + IterationAlgebra expected = Union(Complement(v1(i)), Complement(v2(i))); + ASSERT_TRUE(algEqual(simplified, expected)); + ASSERT_TRUE(algStructureEqual(simplified, expected)); +} + +TEST(iteration_algebra, hiddenIntersect) { + IterationAlgebra alg = Complement(Union(Complement(v1(i)), Complement(v2(i)))); + IterationAlgebra simplified = applyDemorgan(alg); + + IterationAlgebra expected = Intersect(v1(i), v2(i)); + ASSERT_TRUE(algEqual(expected, simplified)); +} + +TEST(iteration_algebra, hiddenUnion) { + IterationAlgebra alg = Complement(Intersect(Complement(v1(i)), Complement(v2(i)))); + IterationAlgebra simplified = applyDemorgan(alg); -TEST(iteration_algebra, iter_alg_print) { - std::ostringstream ss; - ss << Intersect(Union(Complement(A), B), C); - std::string expected("(~A U B) * C"); - ASSERT_EQ(expected, ss.str()); + IterationAlgebra expected = Union(v1(i), v2(i)); + ASSERT_TRUE(algEqual(expected, simplified)); } \ No newline at end of file From 2316c38a1f157c469a15e23f5540dde49532ca02 Mon Sep 17 00:00:00 2001 From: Rawn Date: Thu, 5 Mar 2020 17:22:03 -0500 Subject: [PATCH 10/27] Reformat some code and fixed some bugs in lowering. Added one test for tensorOp nodes. --- .../index_notation/index_notation_nodes.h | 11 ++++- .../taco/index_notation/iteration_algebra.h | 5 +++ include/taco/index_notation/tensor_operator.h | 43 ++++++++----------- include/taco/util/collections.h | 10 +++++ src/index_notation/index_notation.cpp | 6 ++- src/index_notation/index_notation_nodes.cpp | 18 ++++---- .../index_notation_rewriter.cpp | 6 ++- src/index_notation/iteration_algebra.cpp | 25 +++++++++++ test/test_properties.cpp | 1 + test/tests-lower.cpp | 38 ++++++++++++++++ 10 files changed, 124 insertions(+), 39 deletions(-) create mode 100644 test/test_properties.cpp diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 01ade2412..adb540b2d 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -5,6 +5,7 @@ #include #include "taco/type.h" +#include "taco/util/collections.h" #include "taco/util/comparable.h" #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_nodes_abstract.h" @@ -181,8 +182,7 @@ struct TensorOpNode : public IndexExprNode { TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, const IterationAlgebra& iterAlg, const std::vector& properties, - const std::map, regionDefinition>& regionDefinitions, - Datatype type); + const std::map, regionDefinition>& regionDefinitions); void accept(IndexExprVisitorStrict* v) const { v->visit(this); @@ -194,6 +194,13 @@ struct TensorOpNode : public IndexExprNode { IterationAlgebra iterAlg; std::vector properties; std::map, regionDefinition> regionDefinitions; + +private: + static Datatype inferReturnType(opImpl f, const std::vector& inputs) { + std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; + std::vector exprs = util::map(inputs, getExprs); + return f(exprs).type(); + } }; struct ReductionNode : public IndexExprNode { diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 2b712df5f..2b2d3e508 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -233,6 +233,11 @@ bool algEqual(const IterationAlgebra&, const IterationAlgebra&); /// nodes. IterationAlgebra applyDemorgan(IterationAlgebra alg); +/// Rewrites the algebra to replace the IndexExprs in the algebra with new index exprs as +/// specified by the input map. If the map does not contain an indexExpr, it is kept the +/// same as the input algebra. +IterationAlgebra replaceIndexExprs(IterationAlgebra alg, const std::map&); + } diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index c0f18e7c3..4aa89d521 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -20,47 +20,49 @@ using algebraImpl = TensorOpNode::algebraImpl; using regionDefinition = TensorOpNode::regionDefinition; public: - // Full construction Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, - std::map, regionDefinition> specialDefinitions = {}) : - name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), - properties(properties), regionDefinitions(specialDefinitions) {} + std::map, regionDefinition> specialDefinitions = {}) + : name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), + properties(properties), regionDefinitions(specialDefinitions) { + } Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, - std::map, regionDefinition> specialDefinitions = {}) : - name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), - properties(properties), regionDefinitions(specialDefinitions) {} + std::map, regionDefinition> specialDefinitions = {}) + : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), + regionDefinitions(specialDefinitions) { + } // Construct without specifying algebra Op(std::string name, opImpl lowererFunc, std::vector properties, - std::map, regionDefinition> specialDefinitions = {}) : - Op(name, lowererFunc, nullptr, properties, specialDefinitions) {} + std::map, regionDefinition> specialDefinitions = {}) + : Op(name, lowererFunc, nullptr, properties, specialDefinitions) { + } Op(opImpl lowererFunc, std::vector properties, - std::map, regionDefinition> specialDefinitions = {}) : - Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) {} + std::map, regionDefinition> specialDefinitions = {}) + : Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) { + } // Construct without algebra or properties - Op(std::string name, opImpl lowererFunc) : Op(name, lowererFunc, nullptr) {} - - explicit Op(opImpl lowererFunc) : Op(lowererFunc, nullptr) {} + Op(std::string name, opImpl lowererFunc) : Op(name, lowererFunc, nullptr) { + } + explicit Op(opImpl lowererFunc) : Op(lowererFunc, nullptr) { + } template TensorOp operator()(IndexExprs&&... exprs) { std::vector actualArgs{exprs...}; IterationAlgebra nodeAlgebra = algebraFunc == nullptr? inferAlgFromProperties(actualArgs): algebraFunc(actualArgs); - Datatype returnType = inferReturnType(actualArgs); TensorOpNode* op = new TensorOpNode(name, actualArgs, lowererFunc, nodeAlgebra, properties, - regionDefinitions, returnType); + regionDefinitions); return TensorOp(op); } - private: std::string name; opImpl lowererFunc; @@ -68,12 +70,6 @@ using regionDefinition = TensorOpNode::regionDefinition; std::vector properties; std::map, regionDefinition> regionDefinitions; - Datatype inferReturnType(const std::vector& inputs) { - std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; - std::vector exprs = util::map(inputs, getExprs); - return lowererFunc(exprs).type(); - } - IterationAlgebra inferAlgFromProperties(const std::vector& exprs) { if(properties.empty()) { return constructDefaultAlgebra(exprs); @@ -93,7 +89,6 @@ using regionDefinition = TensorOpNode::regionDefinition; IterationAlgebra background = Complement(tensorsRegions); return Union(tensorsRegions, background); } - }; } diff --git a/include/taco/util/collections.h b/include/taco/util/collections.h index 7c687cafe..e0a07ba78 100644 --- a/include/taco/util/collections.h +++ b/include/taco/util/collections.h @@ -133,6 +133,16 @@ size_t count(const std::vector& vector, T test) { return count; } +template +std::map zipToMap(const std::vector& keys, const std::vector& values) { + std::map result; + size_t limit = std::min(keys.size(), values.size()); + for(size_t i = 0; i < limit; ++i) { + result.insert({keys[i], values[i]}); + } + return result; +} + /** * Split the vector into two vectors where elements in the first pass the test * and elements in the second do not. diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index b3c48c117..d7bef4b2b 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -2519,8 +2519,10 @@ struct Zero : public IndexNotationRewriterStrict { } if (rewritten) { - expr = new TensorOpNode(op->name, args, op->lowerFunc, op->iterAlg, op->properties, - op->regionDefinitions, op->getDataType()); + const std::map subs = util::zipToMap(op->args, args); + IterationAlgebra newAlg = replaceIndexExprs(op->iterAlg, subs); + expr = new TensorOpNode(op->name, args, op->lowerFunc, newAlg, op->properties, + op->regionDefinitions); } else { expr = op; diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 8c2ab8c45..5e2ef0ebe 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -32,15 +32,15 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, // class TensorOpNode TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, const IterationAlgebra &iterAlg, const std::vector &properties, - const std::map, regionDefinition>& regionDefinitions, Datatype type) : - IndexExprNode(type), name(name), args(args), lowerFunc(lowerFunc), - iterAlg(applyDemorgan(iterAlg)), properties(properties), - regionDefinitions(regionDefinitions) { - taco_iassert(lowerFunc != nullptr); - for (const auto& pair: regionDefinitions) { - taco_iassert(args.size() >= pair.first.size()); - } - } + const std::map, regionDefinition>& regionDefinitions) + : IndexExprNode(inferReturnType(lowerFunc, args)), name(name), args(args), lowerFunc(lowerFunc), + iterAlg(applyDemorgan(iterAlg)), properties(properties), regionDefinitions(regionDefinitions) { + + taco_iassert(lowerFunc != nullptr); + for (const auto& pair: regionDefinitions) { + taco_iassert(args.size() >= pair.first.size()); + } +} // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 4b847fc8b..48e81553a 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -118,8 +118,10 @@ void IndexNotationRewriter::visit(const TensorOpNode* op) { } } if (rewritten) { - expr = new TensorOpNode(op->name, args, op->lowerFunc, op->iterAlg, op->properties, - op->regionDefinitions, op->getDataType()); + const std::map subs = util::zipToMap(op->args, args); + IterationAlgebra newAlg = replaceIndexExprs(op->iterAlg, subs); + expr = new TensorOpNode(op->name, args, op->lowerFunc, newAlg, op->properties, + op->regionDefinitions); } else { expr = op; diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index 66ce3614e..f1cad202f 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -1,3 +1,4 @@ +#include "taco/util/collections.h" #include "taco/index_notation/iteration_algebra.h" #include "taco/index_notation/iteration_algebra_printer.h" @@ -296,4 +297,28 @@ IterationAlgebra applyDemorgan(IterationAlgebra alg) { return DeMorganDispatcher().rewrite(alg); } +class IndexExprReplacer : public IterationAlgebraRewriter { + +public: + IndexExprReplacer(const std::map& substitutions) : substitutions(substitutions) { + } + +private: + using IterationAlgebraRewriter::visit; + + void visit(const RegionNode* node) { + if (util::contains(substitutions, node->expr())) { + alg = new RegionNode(substitutions.at(node->expr())); + return; + } + alg = node; + } + + const std::map substitutions; +}; + +IterationAlgebra replaceIndexExprs(IterationAlgebra alg, const std::map& substitutions) { + return IndexExprReplacer(substitutions).rewrite(alg); +} + } \ No newline at end of file diff --git a/test/test_properties.cpp b/test/test_properties.cpp new file mode 100644 index 000000000..a1bb4cbc4 --- /dev/null +++ b/test/test_properties.cpp @@ -0,0 +1 @@ +#include "taco/index_notation/properties.h" \ No newline at end of file diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 89c322172..4efad9e42 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -8,6 +8,7 @@ #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/kernel.h" #include "taco/codegen/module.h" #include "taco/storage/storage.h" @@ -1556,4 +1557,41 @@ TEST_STMT(vector_not, } ) +// Test tensorOps +struct lowerOp { + ir::Expr operator()(const std::vector& v) { + return ir::Add::make(ir::Mul::make(v[0], v[1]), v[2]); + } +}; + +struct algebraGen { + IterationAlgebra operator()(const std::vector& v) { + IterationAlgebra r1 = Intersect(v[0], v[1]); + IterationAlgebra r2 = Intersect(v[0], v[2]); + IterationAlgebra r3 = Intersect(v[1], v[2]); + + IterationAlgebra omit = Complement(Intersect(Intersect(v[0], v[1]), v[2])); + return Intersect(Union(Union(r1, r2), r3), omit); + } +}; + +Op testOp("testOp", lowerOp(), algebraGen()); + +TEST_STMT(testOp1, + forall(i, + a(i) = testOp(b(i), c(i), d(i)) + ), + Values( + Formats({{a,sparse}, {b,sparse}, {c,sparse}, {d, sparse}}) + ), + { + TestCase( + {{b, {{{0}, 2.0}, {{1}, 2.0}, {{4}, 4.0}}}, + {c, {{{0}, 3.0}, {{2}, 3.0}, {{4}, 6.0}}}, + {d, {{{1}, 1.0}, {{2}, 4.0}, {{4}, 5.0}}}}, + + {{a, {{{0}, 6.0}, {{1}, 1.0}, {{2}, 4.0}}}}) + } +) + }} From c75b22795d3a0779fc70243a22d36004b2a8b56b Mon Sep 17 00:00:00 2001 From: Rawn Date: Thu, 5 Mar 2020 17:23:06 -0500 Subject: [PATCH 11/27] Added some missing files to git --- include/taco/util/functions.h | 21 +++++ src/index_notation/properties.cpp | 133 ++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 include/taco/util/functions.h create mode 100644 src/index_notation/properties.cpp diff --git a/include/taco/util/functions.h b/include/taco/util/functions.h new file mode 100644 index 000000000..1893147ae --- /dev/null +++ b/include/taco/util/functions.h @@ -0,0 +1,21 @@ +#ifndef TACO_FUNCTIONAL_H +#define TACO_FUNCTIONAL_H + +#include + +namespace taco { +namespace util { + +template +Fnptr functorAddress(std::function f) { + return *f.template target(); +} + +template +bool targetPtrEqual(std::function f, std::function g) { + return functorAddress(f) != nullptr && functorAddress(f) == functorAddress(g); +} + +} +} +#endif //TACO_FUNCTIONAL_H diff --git a/src/index_notation/properties.cpp b/src/index_notation/properties.cpp new file mode 100644 index 000000000..e7dbc7a7e --- /dev/null +++ b/src/index_notation/properties.cpp @@ -0,0 +1,133 @@ +#include "taco/index_notation/properties.h" +#include "taco/index_notation/index_notation.h" + +namespace taco { + +struct Annihilator::Content { + Literal annihilator; +}; + +struct Identity::Content { + Literal identity; +}; + +// Property class definitions +Property::~Property() {} + +bool Property::defined() const { + return false; +} + +bool Property::equals(const Property& p) const { + return defined() == p.defined(); +} + +// Annihilator class definitions +Annihilator::Annihilator() {} + +Annihilator::Annihilator(Literal annihilator) : content(new Content) { + content->annihilator = annihilator; +} + +const Literal& Annihilator::getAnnihilator() const { + taco_iassert(defined()); + return content->annihilator; +} + +bool Annihilator::defined() const { + return content.get() != nullptr; +} + +bool Annihilator::equals(const Property& p) const { + if(!isa(p)) return false; + + Annihilator a = to(p); + if (!defined() && !a.defined()) return true; + + if(defined() && a.defined()) { + return ::taco::equals(getAnnihilator(), a.getAnnihilator()); + } + return false; +} + +// Identity class definitions +Identity::Identity() {} + +Identity::Identity(Literal identity) : content(new Content) { + content->identity = identity; +} + +const Literal& Identity::getIdentity() const { + taco_iassert(defined()); + return content->identity; +} + +bool Identity::defined() const { + return content.get() != nullptr; +} + +bool Identity::equals(const Property& p) const { + if(!isa(p)) return false; + + Identity i = to(p); + if (!defined() && !i.defined()) return true; + + if(defined() && i.defined()) { + return ::taco::equals(getIdentity(), i.getIdentity()); + } + return false; +} + +// Associative class definitions +Associative::Associative() : isDefined(true) {} + +Associative Associative::makeUndefined() { + Associative a = Associative(); + a.isDefined = false; + return a; +} + +bool Associative::defined() const { + return isDefined; +} + +bool Associative::equals(const Property& p) const { + if(!isa(p)) return false; + Associative a = to(p); + return defined() == a.defined(); +} + +// Commutative class definitions +Commutative::Commutative() : isDefined(true) {} + +Commutative::Commutative(std::vector ordering) : ordering_(ordering), isDefined(true) { +} + +Commutative Commutative::makeUndefined() { + Commutative com; + com.isDefined = false; + return com; +} + +const std::vector & Commutative::ordering() const { + return ordering_; +} + +bool Commutative::defined() const { + return isDefined; +} + +bool Commutative::equals(const Property& p) const { + if(!isa(p)) return false; + + Commutative c = to(p); + if (!defined() && !c.defined()) return true; + + if(defined() && c.defined()) { + return ordering() == c.ordering(); + } + + return false; +} + +} \ No newline at end of file From 67a4e6b8356c7d44e0662c3e6a5f3c7c78a4e351 Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 6 Mar 2020 19:09:55 -0500 Subject: [PATCH 12/27] Redesign of property class. Properties are now wrapped for functions so users can create vectors of properties which are actually a vector of pointers. This is a work around to avoid object slicing when storing properties in a vector --- include/taco/index_notation/properties.h | 98 +++++------- .../taco/index_notation/property_pointers.h | 92 +++++++++++ src/index_notation/index_notation.cpp | 8 +- src/index_notation/properties.cpp | 143 ++++++++---------- src/index_notation/property_pointers.cpp | 121 +++++++++++++++ src/lower/merge_lattice.cpp | 1 + test/test_properties.cpp | 109 ++++++++++++- 7 files changed, 432 insertions(+), 140 deletions(-) create mode 100644 include/taco/index_notation/property_pointers.h create mode 100644 src/index_notation/property_pointers.cpp diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h index a148ca2b4..d8c8348f7 100644 --- a/include/taco/index_notation/properties.h +++ b/include/taco/index_notation/properties.h @@ -1,96 +1,82 @@ #ifndef TACO_PROPERTIES_H #define TACO_PROPERTIES_H -#include -#include - -#include "taco/error.h" -#include "taco/util/comparable.h" +#include "taco/index_notation/property_pointers.h" +#include "taco/util/intrusive_ptr.h" namespace taco { -class Literal; - -class Property { +/// A class containing properties about an operation +class Property : public util::IntrusivePtr { public: - virtual ~Property(); - virtual bool defined() const; - virtual bool equals(const Property&) const; + Property(); + Property(const PropertyPtr* p); + + bool equals(const Property& p) const; + std::ostream& print(std::ostream&) const; }; +std::ostream& operator<<(std::ostream&, const Property&); + +/// A class wrapping the annihilator property pointer class Annihilator : public Property { public: - Annihilator(); - Annihilator(Literal); - const Literal& getAnnihilator() const; - virtual bool defined() const; - virtual bool equals(const Property&) const; - -private: - struct Content; - std::shared_ptr content; + explicit Annihilator(Literal); + Annihilator(const PropertyPtr*); + + const Literal& annihilator() const; + + typedef AnnihilatorPtr Ptr; }; +/// A class wrapping an identity property pointer class Identity : public Property { public: - Identity(); - Identity(Literal); - const Literal& getIdentity() const; - virtual bool defined() const; - virtual bool equals(const Property&) const; - -private: - struct Content; - std::shared_ptr content; -}; + explicit Identity(Literal); + Identity(const PropertyPtr*); + + const Literal& identity() const; + typedef IdentityPtr Ptr; +}; +/// A class wrapping an associative property pointer class Associative : public Property { public: Associative(); - static Associative makeUndefined(); + Associative(const PropertyPtr*); - virtual bool defined() const; - virtual bool equals(const Property&) const; - -private: - bool isDefined; + typedef AssociativePtr Ptr; }; +/// A class wrapping a commutative property pointer class Commutative : public Property { public: Commutative(); - Commutative(std::vector); - static Commutative makeUndefined(); + explicit Commutative(const std::vector&); + Commutative(const PropertyPtr*); const std::vector& ordering() const; - virtual bool defined() const; - virtual bool equals(const Property&) const; -private: - const std::vector ordering_; - bool isDefined; + typedef CommutativePtr Ptr; }; /// Returns true if property p is of type P. -template -inline bool isa(const Property& p) { - return dynamic_cast(&p) != nullptr; -} +template bool isa(const Property& p); /// Casts the Property p to type P. -template -inline const P& to(const Property& p) { - taco_iassert(isa

(p)) << "Cannot convert " << typeid(p).name() << " to " << typeid(P).name(); - return static_cast(p); -} +template P to(const Property& p); +/// Finds and returns the property of type P if it exists in the vector. If +/// the property does not exist, returns an undefined instance of the property +/// requested. +/// The vector of properties should not contain duplicates so this is sufficient. template -inline const P findProperty(const std::vector& properties, P defaultProperty) { - for (const auto& p: properties) { - if(isa

(p)) return to

(p); +inline const P findProperty(const std::vector &properties) { + for (const auto &p: properties) { + if (isa

(p)) return to

(p); } - return defaultProperty; + return P(nullptr); } } diff --git a/include/taco/index_notation/property_pointers.h b/include/taco/index_notation/property_pointers.h new file mode 100644 index 000000000..a0ad502aa --- /dev/null +++ b/include/taco/index_notation/property_pointers.h @@ -0,0 +1,92 @@ +#ifndef TACO_PROPERTY_POINTERS_H +#define TACO_PROPERTY_POINTERS_H + +#include +#include +#include +#include +#include + +#include "taco/error.h" +#include "taco/util/comparable.h" + +namespace taco { + +class Literal; +struct PropertyPtr; + +/// A pointer to the property data. This will be wrapped in an auxillary class +/// to allow a user to create a vector of properties. Needed since properties +/// have different methods and data +struct PropertyPtr : public util::Manageable, + private util::Uncopyable { +public: + PropertyPtr(); + virtual ~PropertyPtr(); + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +/// Pointer class for annihilators +struct AnnihilatorPtr : public PropertyPtr { + AnnihilatorPtr(); + AnnihilatorPtr(Literal); + const Literal& annihilator() const; + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; + + struct Content; + std::shared_ptr content; +}; + +/// Pointer class for identities +struct IdentityPtr : public PropertyPtr { +public: + IdentityPtr(); + IdentityPtr(Literal); + const Literal& identity() const; + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; + + struct Content; + std::shared_ptr content; +}; + +/// Pointer class for associativity +struct AssociativePtr : public PropertyPtr { + AssociativePtr(); + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +/// Pointer class for commutativity +struct CommutativePtr : public PropertyPtr { + CommutativePtr(); + CommutativePtr(const std::vector&); + const std::vector ordering_; + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +template +inline bool isa(const PropertyPtr* p) { + return p != nullptr && dynamic_cast(p) != nullptr; +} + +template +inline const P* to(const PropertyPtr* p) { + taco_iassert(isa

(p)) << + "Cannot convert " << typeid(p).name() << " to " << typeid(P).name();; + return static_cast(p); +} + +template +inline const typename P::Ptr* getPtr(const P& propertyPtr) { + taco_iassert(isa(propertyPtr.ptr)); + return static_cast(propertyPtr.ptr); +} + + +} + +#endif //TACO_PROPERTY_POINTERS_H diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index d7bef4b2b..dd2c7c4ea 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -2481,8 +2481,8 @@ struct Zero : public IndexNotationRewriterStrict { std::vector args; bool rewritten = false; - Annihilator annihilator = findProperty(op->properties, Annihilator()); - Literal annihilatorVal = annihilator.defined()? annihilator.getAnnihilator(): Literal(); + Annihilator annihilator = findProperty(op->properties); + Literal annihilatorVal = annihilator.defined()? annihilator.annihilator(): Literal(); // TODO: Check exhausted default against result default for(auto& arg : op->args) { @@ -2499,8 +2499,8 @@ struct Zero : public IndexNotationRewriterStrict { } } - Identity identity = findProperty(op->properties, Identity()); - Literal identityVal = identity.defined()? identity.getIdentity(): Literal(); + Identity identity = findProperty(op->properties); + Literal identityVal = identity.defined()? identity.identity(): Literal(); // If only one term is not the identity, replace expr with just that term size_t nonIdentityTerms = 0; diff --git a/src/index_notation/properties.cpp b/src/index_notation/properties.cpp index e7dbc7a7e..a305fbaa4 100644 --- a/src/index_notation/properties.cpp +++ b/src/index_notation/properties.cpp @@ -3,131 +3,116 @@ namespace taco { -struct Annihilator::Content { - Literal annihilator; -}; +// Property class definitions +Property::Property() : util::IntrusivePtr(nullptr) { +} -struct Identity::Content { - Literal identity; -}; +Property::Property(const PropertyPtr* p) : util::IntrusivePtr(p) { +} -// Property class definitions -Property::~Property() {} +bool Property::equals(const Property &p) const { + if(!defined() && !p.defined()) { + return true; + } + + if(defined() && p.defined()) { + return ptr->equals(p.ptr); + } -bool Property::defined() const { return false; } -bool Property::equals(const Property& p) const { - return defined() == p.defined(); +std::ostream & Property::print(std::ostream& os) const { + if(!defined()) { + os << "Property(undef)"; + return os; + } + return ptr->print(os); } -// Annihilator class definitions -Annihilator::Annihilator() {} - -Annihilator::Annihilator(Literal annihilator) : content(new Content) { - content->annihilator = annihilator; +std::ostream& operator<<(std::ostream& os, const Property& p) { + return p.print(os); } -const Literal& Annihilator::getAnnihilator() const { - taco_iassert(defined()); - return content->annihilator; +// Annihilator class definitions +template<> bool isa(const Property& p) { + return isa(p.ptr); } -bool Annihilator::defined() const { - return content.get() != nullptr; +template<> Annihilator to(const Property& p) { + taco_iassert(isa(p)); + return Annihilator(to(p.ptr)); } -bool Annihilator::equals(const Property& p) const { - if(!isa(p)) return false; +Annihilator::Annihilator(Literal annihilator) : Annihilator(new AnnihilatorPtr(annihilator)) { +} - Annihilator a = to(p); - if (!defined() && !a.defined()) return true; +Annihilator::Annihilator(const PropertyPtr* p) : Property(p) { +} - if(defined() && a.defined()) { - return ::taco::equals(getAnnihilator(), a.getAnnihilator()); - } - return false; +const Literal& Annihilator::annihilator() const { + taco_iassert(defined()); + return getPtr(*this)->annihilator(); } // Identity class definitions -Identity::Identity() {} - -Identity::Identity(Literal identity) : content(new Content) { - content->identity = identity; +template<> bool isa(const Property& p) { + return isa(p.ptr); } -const Literal& Identity::getIdentity() const { - taco_iassert(defined()); - return content->identity; +template<> Identity to(const Property& p) { + taco_iassert(isa(p)); + return Identity(to(p.ptr)); } -bool Identity::defined() const { - return content.get() != nullptr; +Identity::Identity(Literal identity) : Identity(new IdentityPtr(identity)) { } -bool Identity::equals(const Property& p) const { - if(!isa(p)) return false; - - Identity i = to(p); - if (!defined() && !i.defined()) return true; +Identity::Identity(const PropertyPtr* p) : Property(p) { +} - if(defined() && i.defined()) { - return ::taco::equals(getIdentity(), i.getIdentity()); - } - return false; +const Literal& Identity::identity() const { + taco_iassert(defined()); + return getPtr(*this)->identity(); } // Associative class definitions -Associative::Associative() : isDefined(true) {} +template<> bool isa(const Property& p) { + return isa(p.ptr); +} -Associative Associative::makeUndefined() { - Associative a = Associative(); - a.isDefined = false; - return a; +template<> Associative to(const Property& p) { + taco_iassert(isa(p)); + return Associative(to(p.ptr)); } -bool Associative::defined() const { - return isDefined; +Associative::Associative() : Associative(new AssociativePtr) { } -bool Associative::equals(const Property& p) const { - if(!isa(p)) return false; - Associative a = to(p); - return defined() == a.defined(); +Associative::Associative(const PropertyPtr* p) : Property(p) { } // Commutative class definitions -Commutative::Commutative() : isDefined(true) {} - -Commutative::Commutative(std::vector ordering) : ordering_(ordering), isDefined(true) { +template<> bool isa(const Property& p) { + return isa(p.ptr); } -Commutative Commutative::makeUndefined() { - Commutative com; - com.isDefined = false; - return com; +template<> Commutative to(const Property& p) { + taco_iassert(isa(p)); + return Commutative(to(p.ptr)); } -const std::vector & Commutative::ordering() const { - return ordering_; +Commutative::Commutative() : Commutative(new CommutativePtr) { } -bool Commutative::defined() const { - return isDefined; +Commutative::Commutative(const std::vector& ordering) : Commutative(new CommutativePtr(ordering)) { } -bool Commutative::equals(const Property& p) const { - if(!isa(p)) return false; - - Commutative c = to(p); - if (!defined() && !c.defined()) return true; - - if(defined() && c.defined()) { - return ordering() == c.ordering(); - } +Commutative::Commutative(const PropertyPtr* p) : Property(p) { +} - return false; +const std::vector & Commutative::ordering() const { + return getPtr(*this)->ordering_; } } \ No newline at end of file diff --git a/src/index_notation/property_pointers.cpp b/src/index_notation/property_pointers.cpp new file mode 100644 index 000000000..1086f1d2d --- /dev/null +++ b/src/index_notation/property_pointers.cpp @@ -0,0 +1,121 @@ +#include "taco/index_notation/property_pointers.h" +#include "taco/index_notation/index_notation.h" +#include "taco/util/strings.h" + +namespace taco { + +struct AnnihilatorPtr::Content { + Literal annihilator; +}; + +struct IdentityPtr::Content { + Literal identity; +}; + +// Property pointer definitions +PropertyPtr::PropertyPtr() { +} + +PropertyPtr::~PropertyPtr() { +} + +std::ostream& PropertyPtr::print(std::ostream& os) const { + os << "Property()"; + return os; +} + +bool PropertyPtr::equals(const PropertyPtr* p) const { + return this == p; +} + +// Annihilator pointer definitions +AnnihilatorPtr::AnnihilatorPtr() : PropertyPtr(), content(nullptr) { +} + +AnnihilatorPtr::AnnihilatorPtr(Literal annihilator) : PropertyPtr(), content(new Content) { + content->annihilator = annihilator; +} + +const Literal& AnnihilatorPtr::annihilator() const { + return content->annihilator; +} + +std::ostream& AnnihilatorPtr::print(std::ostream& os) const { + os << "Annihilator("; + if (annihilator().defined()) { + os << annihilator(); + } else { + os << "undef"; + } + os << ")"; + return os; +} + +bool AnnihilatorPtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const AnnihilatorPtr* a = to(p); + return ::taco::equals(annihilator(), a->annihilator()); +} + +// Identity pointer definitions +IdentityPtr::IdentityPtr() : PropertyPtr(), content(nullptr) { +} + +IdentityPtr::IdentityPtr(Literal identity) : PropertyPtr(), content(new Content) { + content->identity = identity; +} + +const Literal& IdentityPtr::identity() const { + return content->identity; +} + +std::ostream& IdentityPtr::print(std::ostream& os) const { + os << "Identity("; + if (identity().defined()) { + os << identity(); + } else { + os << "undef"; + } + os << ")"; + return os; +} + +bool IdentityPtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const IdentityPtr* idnty = to(p); + return ::taco::equals(identity(), idnty->identity()); +} + +// Associative pointer definitions +AssociativePtr::AssociativePtr() : PropertyPtr() { +} + +std::ostream& AssociativePtr::print(std::ostream& os) const { + os << "Associative()"; + return os; +} + +bool AssociativePtr::equals(const PropertyPtr* p) const { + return isa(p); +} + +// CommutativePtr definitions +CommutativePtr::CommutativePtr() : PropertyPtr() { +} + +CommutativePtr::CommutativePtr(const std::vector& ordering) : ordering_(ordering) { +} + +std::ostream& CommutativePtr::print(std::ostream& os) const { + os << "Commutative("; + os << "{" << util::join(ordering_) << "})"; + return os; +} + +bool CommutativePtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const CommutativePtr* idnty = to(p); + return ordering_ == idnty->ordering_; +} + +} diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 3445bb592..8deeff71a 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -93,6 +93,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } void visit(const ComplementNode* node) { + taco_iassert(isa(node->a)) << "Demorgan's rule must be applied before lowering."; lattice = build(node->a); vector points = flipPoints(lattice.points()); diff --git a/test/test_properties.cpp b/test/test_properties.cpp index a1bb4cbc4..8e17af052 100644 --- a/test/test_properties.cpp +++ b/test/test_properties.cpp @@ -1 +1,108 @@ -#include "taco/index_notation/properties.h" \ No newline at end of file +#include "test.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/properties.h" + +using namespace taco; + +TEST(properties, annihilator) { + Literal z(0); + Annihilator a(nullptr); + Annihilator zero(z); + ASSERT_FALSE(a.defined()); + ASSERT_TRUE(zero.defined()); + + ASSERT_EQ(zero.annihilator(), z); + ASSERT_TRUE(equals(zero.annihilator(), z)); + + ASSERT_TRUE(zero.equals(Annihilator(Literal(0)))); + ASSERT_FALSE(zero.equals(a)); +} + +TEST(properties, identity) { + Literal z(0); + Identity a(nullptr); + Identity zero(z); + ASSERT_FALSE(a.defined()); + ASSERT_TRUE(zero.defined()); + + ASSERT_EQ(zero.identity(), z); + ASSERT_TRUE(equals(zero.identity(), Literal(0))); + + ASSERT_TRUE(zero.equals(Identity(Literal(0)))); + ASSERT_FALSE(zero.equals(a)); +} + +TEST(properties, associative) { + Associative a; + Associative undef(nullptr); + + ASSERT_TRUE(a.equals(a)); + ASSERT_FALSE(a.equals(undef)); + ASSERT_TRUE(a.defined()); + ASSERT_FALSE(undef.defined()); +} + +TEST(properties, commutative) { + Commutative com; + Commutative specific({0, 1}); + Commutative specific2({1, 2}); + Commutative undef(nullptr); + + ASSERT_TRUE(specific.defined()); + ASSERT_TRUE(com.defined()); + ASSERT_FALSE(undef.defined()); + + ASSERT_NE(specific.ordering(), specific2.ordering()); + ASSERT_EQ(specific.ordering(), std::vector({0, 1})); + ASSERT_TRUE(specific.equals(specific)); +} + +TEST(properties, property_conversion) { + Property annh = Annihilator(10); + Property identity = Identity(40); + Property assoc = Associative(); + Property com = Commutative({0,1,2}); + + ASSERT_TRUE(isa(annh)); + ASSERT_FALSE(isa(annh)); + Annihilator a = to(annh); + ASSERT_TRUE(equals(a.annihilator(), Literal(10))); + + ASSERT_TRUE(isa(identity)); + ASSERT_FALSE(isa(identity)); + Identity idnty = to(identity); + ASSERT_TRUE(equals(idnty.identity(), Literal(40))); + + ASSERT_TRUE(isa(assoc)); + ASSERT_FALSE(isa(assoc)); + Associative assc = to(assoc); + ASSERT_TRUE(assc.defined()); + + ASSERT_TRUE(isa(com)); + ASSERT_FALSE(isa(com)); + Commutative comm = to(com); + ASSERT_EQ(comm.ordering(), std::vector({0,1,2})); +} + +TEST(properties, findProperty) { + Annihilator a(10); + Identity i(10); + Associative as; + Commutative c({0, 1}); + + std::vector properties({a, i, as, c}); + ASSERT_TRUE(a.equals(findProperty(properties))); + ASSERT_TRUE(i.equals(findProperty(properties))); + ASSERT_TRUE(as.equals(findProperty(properties))); + ASSERT_TRUE(c.equals(findProperty(properties))); + + std::vector partialProperties({a, c}); + ASSERT_FALSE(i.equals(findProperty(partialProperties))); + ASSERT_FALSE(as.equals(findProperty(partialProperties))); + + ASSERT_FALSE(findProperty(partialProperties).defined()); + ASSERT_FALSE(findProperty(partialProperties).defined()); + + ASSERT_TRUE(properties[0].equals(Annihilator(10))); + ASSERT_FALSE(properties[0].equals(i)); +} \ No newline at end of file From 6371d49b5cb459b3d20a871914a2583dd99f717d Mon Sep 17 00:00:00 2001 From: Rawn Date: Sat, 21 Mar 2020 23:17:05 -0400 Subject: [PATCH 13/27] Bug fixes. Added more tests for new lattice machinery. Fixed issues with lowering generic tensorOp. Changed spec of special definition to take an IR function. --- include/taco/index_notation/index_notation.h | 3 +- .../index_notation/index_notation_nodes.h | 27 +- .../taco/index_notation/iteration_algebra.h | 3 +- include/taco/index_notation/tensor_operator.h | 32 +- src/index_notation/index_notation.cpp | 45 ++- src/index_notation/index_notation_nodes.cpp | 19 +- .../index_notation_rewriter.cpp | 5 +- src/index_notation/iteration_algebra.cpp | 4 +- src/lower/lowerer_impl.cpp | 14 +- src/lower/merge_lattice.cpp | 63 ++-- test/op_factory.h | 128 +++++++ test/tests-lower.cpp | 46 ++- test/tests-merge_lattice.cpp | 321 ++++++++++++++++++ 13 files changed, 627 insertions(+), 83 deletions(-) create mode 100644 test/op_factory.h diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index f22613b59..ba3b1738a 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -414,7 +414,8 @@ class TensorOp: public IndexExpr { const IterationAlgebra& getAlgebra() const; const std::vector& getProperties() const; const std::string getName() const; - const std::map, std::function&)>> getDefs() const; + const std::map, std::function&)>> getDefs() const; + const std::vector& getDefinedArgs() const; typedef TensorOpNode Node; diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index adb540b2d..7e51364b0 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -3,6 +3,7 @@ #include #include +#include #include "taco/type.h" #include "taco/util/collections.h" @@ -177,12 +178,17 @@ struct CallIntrinsicNode : public IndexExprNode { struct TensorOpNode : public IndexExprNode { typedef std::function&)> opImpl; typedef std::function&)> algebraImpl; - typedef std::function&)> regionDefinition; TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, const IterationAlgebra& iterAlg, const std::vector& properties, - const std::map, regionDefinition>& regionDefinitions); + const std::map, opImpl>& regionDefinitions, + const std::vector& definedRegions); + + TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, + const IterationAlgebra& iterAlg, + const std::vector& properties, + const std::map, opImpl>& regionDefinitions); void accept(IndexExprVisitorStrict* v) const { v->visit(this); @@ -190,10 +196,13 @@ struct TensorOpNode : public IndexExprNode { std::string name; std::vector args; - opImpl lowerFunc; + opImpl defaultLowerFunc; IterationAlgebra iterAlg; std::vector properties; - std::map, regionDefinition> regionDefinitions; + std::map, opImpl> regionDefinitions; + + // Needed to track which inputs have been exhausted so the lowerer can know which lower func to use + std::vector definedRegions; private: static Datatype inferReturnType(opImpl f, const std::vector& inputs) { @@ -201,6 +210,16 @@ struct TensorOpNode : public IndexExprNode { std::vector exprs = util::map(inputs, getExprs); return f(exprs).type(); } + + static std::vector definedIndices(std::vector args) { + std::vector v; + for(int i = 0; i < (int) args.size(); ++i) { + if(args[i].defined()) { + v.push_back(i); + } + } + return v; + } }; struct ReductionNode : public IndexExprNode { diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h index 2b2d3e508..a3b8fa34e 100644 --- a/include/taco/index_notation/iteration_algebra.h +++ b/include/taco/index_notation/iteration_algebra.h @@ -236,8 +236,7 @@ IterationAlgebra applyDemorgan(IterationAlgebra alg); /// Rewrites the algebra to replace the IndexExprs in the algebra with new index exprs as /// specified by the input map. If the map does not contain an indexExpr, it is kept the /// same as the input algebra. -IterationAlgebra replaceIndexExprs(IterationAlgebra alg, const std::map&); - +IterationAlgebra replaceAlgIndexExprs(IterationAlgebra alg, const std::map&); } diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index 4aa89d521..9b517cdaf 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -17,38 +17,50 @@ class Op { using opImpl = TensorOpNode::opImpl; using algebraImpl = TensorOpNode::algebraImpl; -using regionDefinition = TensorOpNode::regionDefinition; public: // Full construction - Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, - std::map, regionDefinition> specialDefinitions = {}) + Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, + std::map, opImpl> specialDefinitions = {}) : name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } - Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties = {}, - std::map, regionDefinition> specialDefinitions = {}) + Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, + std::map, opImpl> specialDefinitions = {}) : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } // Construct without specifying algebra Op(std::string name, opImpl lowererFunc, std::vector properties, - std::map, regionDefinition> specialDefinitions = {}) + std::map, opImpl> specialDefinitions = {}) : Op(name, lowererFunc, nullptr, properties, specialDefinitions) { } Op(opImpl lowererFunc, std::vector properties, - std::map, regionDefinition> specialDefinitions = {}) + std::map, opImpl> specialDefinitions = {}) : Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) { } + // Construct without properties + Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, + std::map, opImpl> specialDefinitions = {}) + : Op(name, lowererFunc, algebraFunc, {}, specialDefinitions) { + } + + Op(opImpl lowererFunc, algebraImpl algebraFunc, + std::map, opImpl> specialDefinitions = {}) + : Op(util::uniqueName("Op"), lowererFunc, algebraFunc, {}, specialDefinitions) { + } + // Construct without algebra or properties - Op(std::string name, opImpl lowererFunc) : Op(name, lowererFunc, nullptr) { + Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}) + : Op(name, lowererFunc, nullptr, specialDefinitions) { } - explicit Op(opImpl lowererFunc) : Op(lowererFunc, nullptr) { + explicit Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}) + : Op(lowererFunc, nullptr, specialDefinitions) { } template @@ -68,7 +80,7 @@ using regionDefinition = TensorOpNode::regionDefinition; opImpl lowererFunc; algebraImpl algebraFunc; std::vector properties; - std::map, regionDefinition> regionDefinitions; + std::map, opImpl> regionDefinitions; IterationAlgebra inferAlgFromProperties(const std::vector& exprs) { if(properties.empty()) { diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 7d5bcf5b7..2fedbdc12 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -269,9 +269,15 @@ struct Equals : public IndexNotationVisitorStrict { } } + // Exhausted regions + if (anode->definedRegions != bnode->definedRegions) { + eq = false; + return; + } + // Lower function // TODO: For now just check that the function pointers are the same. - if(!util::targetPtrEqual(anode->lowerFunc, bnode->lowerFunc)) { + if(!util::targetPtrEqual(anode->defaultLowerFunc, bnode->defaultLowerFunc)) { eq = false; return; } @@ -462,9 +468,8 @@ struct Equals : public IndexNotationVisitorStrict { bArgs.push_back(bnode->args[idx]); } - IndexExpr aRes = itA->second(aArgs); - IndexExpr bRes = itB->second(bArgs); - if(!equals(aRes, bRes)) { + // TODO lower and check IR + if(!util::targetPtrEqual(itA->second, itB->second)) { return false; } } @@ -887,8 +892,8 @@ const std::vector& TensorOp::getArgs() const { return getNode(*this)->args; } -const std::function &)> TensorOp::getFunc() const { - return getNode(*this)->lowerFunc; +const TensorOpNode::opImpl TensorOp::getFunc() const { + return getNode(*this)->defaultLowerFunc; } const IterationAlgebra& TensorOp::getAlgebra() const { @@ -903,10 +908,13 @@ const std::string TensorOp::getName() const { return getNode(*this)->name; } -const std::map, TensorOpNode::regionDefinition> TensorOp::getDefs() const { +const std::map, TensorOpNode::opImpl> TensorOp::getDefs() const { return getNode(*this)->regionDefinitions; } +const std::vector& TensorOp::getDefinedArgs() const { + return getNode(*this)->definedRegions; +} template <> bool isa(IndexExpr e) { @@ -2523,15 +2531,26 @@ struct Zero : public IndexNotationRewriterStrict { void visit(const TensorOpNode* op) { std::vector args; + std::vector rewrittenArgs; + std::vector definedArgs; bool rewritten = false; Annihilator annihilator = findProperty(op->properties); Literal annihilatorVal = annihilator.defined()? annihilator.annihilator(): Literal(); // TODO: Check exhausted default against result default - for(auto& arg : op->args) { + for(int argIdx = 0; argIdx < (int) op->args.size(); ++argIdx) { + IndexExpr arg = op->args[argIdx]; IndexExpr rewrittenArg = rewrite(arg); - rewrittenArg = rewrittenArg.defined()? rewrittenArg: Literal::zero(arg.getDataType()); + rewrittenArgs.push_back(rewrittenArg); + + if (rewrittenArg.defined()) { + definedArgs.push_back(argIdx); + } else { + // TODO: fill value instead of 0 + rewrittenArg = Literal::zero(arg.getDataType()); + } + if(equals(annihilatorVal, rewrittenArg)) { expr = IndexExpr(); return; @@ -2563,10 +2582,10 @@ struct Zero : public IndexNotationRewriterStrict { } if (rewritten) { - const std::map subs = util::zipToMap(op->args, args); - IterationAlgebra newAlg = replaceIndexExprs(op->iterAlg, subs); - expr = new TensorOpNode(op->name, args, op->lowerFunc, newAlg, op->properties, - op->regionDefinitions); + const std::map subs = util::zipToMap(op->args, rewrittenArgs); + IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); + expr = new TensorOpNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, + op->regionDefinitions, definedArgs); } else { expr = op; diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 5e2ef0ebe..a28e1f62d 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -30,13 +30,22 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, } // class TensorOpNode -TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, + TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl defaultLowerFunc, + const IterationAlgebra &iterAlg, const std::vector &properties, + const std::map, opImpl>& regionDefinitions) + : TensorOpNode(name, args, defaultLowerFunc, iterAlg, properties, regionDefinitions, definedIndices(args)){ + } + +// class TensorOpNode +TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl defaultLowerFunc, const IterationAlgebra &iterAlg, const std::vector &properties, - const std::map, regionDefinition>& regionDefinitions) - : IndexExprNode(inferReturnType(lowerFunc, args)), name(name), args(args), lowerFunc(lowerFunc), - iterAlg(applyDemorgan(iterAlg)), properties(properties), regionDefinitions(regionDefinitions) { + const std::map, opImpl>& regionDefinitions, + const std::vector& definedRegions) + : IndexExprNode(inferReturnType(defaultLowerFunc, args)), name(name), args(args), defaultLowerFunc(defaultLowerFunc), + iterAlg(applyDemorgan(iterAlg)), properties(properties), regionDefinitions(regionDefinitions), + definedRegions(definedRegions) { - taco_iassert(lowerFunc != nullptr); + taco_iassert(defaultLowerFunc != nullptr); for (const auto& pair: regionDefinitions) { taco_iassert(args.size() >= pair.first.size()); } diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 48e81553a..5e206c4e0 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -117,10 +117,11 @@ void IndexNotationRewriter::visit(const TensorOpNode* op) { rewritten = true; } } + if (rewritten) { const std::map subs = util::zipToMap(op->args, args); - IterationAlgebra newAlg = replaceIndexExprs(op->iterAlg, subs); - expr = new TensorOpNode(op->name, args, op->lowerFunc, newAlg, op->properties, + IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); + expr = new TensorOpNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, op->regionDefinitions); } else { diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index f1cad202f..7f56bb7cc 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -6,7 +6,7 @@ namespace taco { // Iteration Algebra Definitions -IterationAlgebra::IterationAlgebra() : util::IntrusivePtr(nullptr) {} +IterationAlgebra::IterationAlgebra() : IterationAlgebra(new RegionNode(nullptr)) {} IterationAlgebra::IterationAlgebra(const IterationAlgebraNode* n) : util::IntrusivePtr(n) {} IterationAlgebra::IterationAlgebra(IndexExpr expr) : IterationAlgebra(new RegionNode(expr)) {} @@ -317,7 +317,7 @@ class IndexExprReplacer : public IterationAlgebraRewriter { const std::map substitutions; }; -IterationAlgebra replaceIndexExprs(IterationAlgebra alg, const std::map& substitutions) { +IterationAlgebra replaceAlgIndexExprs(IterationAlgebra alg, const std::map& substitutions) { return IndexExprReplacer(substitutions).rewrite(alg); } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index efddc33d3..e422bdfc9 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1494,11 +1494,23 @@ Expr LowererImpl::lowerCallIntrinsic(CallIntrinsic call) { return call.getFunc().lower(args); } + Expr LowererImpl::lowerTensorOp(TensorOp op) { + auto definedArgs = op.getDefinedArgs(); std::vector args; - for (auto& arg : op.getArgs()) { + + if(util::contains(op.getDefs(), definedArgs)) { + auto lowerFunc = op.getDefs().at(definedArgs); + for (auto& argIdx : definedArgs) { + args.push_back(lower(op.getArgs()[argIdx])); + } + return lowerFunc(args); + } + + for(const auto& arg : op.getArgs()) { args.push_back(lower(arg)); } + return op.getFunc()(args); } diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 3b4daa078..30a23ad05 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -90,6 +90,12 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } void visit(const RegionNode* node) { + if(!node->expr().defined()) { + // Region is empty so return empty lattice + lattice = MergeLattice({}); + return; + } + lattice = build(node->expr()); } @@ -98,13 +104,6 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA lattice = build(node->a); vector points = flipPoints(lattice.points()); - // TODO: Handle complementing with broadcasting - Can't distinguish dimension iterators inserted - // as optimizations to unordered tensors with dimension iterators inserted to broadcast. - // Could perhaps do the optimization at the end of lattice construction instead of after - // each union? - // In case 1, we want to complement the lattice but in case two we can return the empty - // lattice - // Otherwise, all tensors are sparse points = includeMissingProducerPoints(points); @@ -124,7 +123,6 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA points.push_back(MergePoint({dimIter}, {}, {})); } - points = removeUnnecessaryOmitterPoints(points); lattice = MergeLattice(points); } @@ -134,13 +132,9 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA if (a.points().size() > 0 && b.points().size() > 0) { lattice = intersectLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; + } else { + // If any side of an intersection is empty, the entire intersection must be empty + lattice = MergeLattice({}); } } @@ -494,8 +488,6 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } } - // - // Correctness: ensures that points produced on BOTH the left and the // right lattices are produced in the final intersection. // Needed since some subPoints may omit leading to erroneous @@ -505,16 +497,16 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // Correctness: Deduplicate regions that are described by multiple lattice // points and resolves conflicts arising between omitters and // producers - points = removeDuplicatedTensorRegions(points, true); + points = removeDuplicatedTensorRegions(points, true); // Optimization: Removed a subLattice of points if the entire subLattice is // made of only omitters - points = removeUnnecessaryOmitterPoints(points); + // points = removeUnnecessaryOmitterPoints(points); // Optimization: remove lattice points whose iterators are identical to the // iterators of an earlier point, since we have already iterated // over this sub-space. - points = removePointsWithIdenticalIterators(points); + points = removeProducersWithIdenticalIterators(points); return MergeLattice(points); } @@ -540,11 +532,22 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // Append the merge points of b util::append(points, right.points()); + struct pointSort { + bool operator()(const MergePoint& a, const MergePoint& b) { + size_t left_size = a.iterators().size() + a.locators().size(); + size_t right_size = b.iterators().size() + b.locators().size(); + return left_size > right_size; + } + } pointSorter; + + std::sort(points.begin(), points.end(), pointSorter); + // Correctness: This ensures that points omitted on BOTH the left and the // right lattices are omitted in the Union. Needed since some // subpoints may produce leading to erroneous producer regions points = correctPointTypesAfterUnion(left.points(), right.points(), points); + // Correctness: Deduplicate regions that are described by multiple lattice // points and resolves conflicts arising between omitters and // producers @@ -565,12 +568,12 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // Optimization: Removes a subLattice of points if the entire subLattice is // made of only omitters - points = removeUnnecessaryOmitterPoints(points); + // points = removeUnnecessaryOmitterPoints(points); // Optimization: remove lattice points whose iterators are identical to the // iterators of an earlier point, since we have already iterated // over this sub-space. - points = removePointsWithIdenticalIterators(points); + points = removeProducersWithIdenticalIterators(points); return MergeLattice(points); } @@ -731,18 +734,26 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } static vector - removePointsWithIdenticalIterators(vector points) + removeProducersWithIdenticalIterators(vector points) { + // Can't remove points if lattice contains omitters since we lose merge cases during lowering. + if(util::any(points, [](const MergePoint& point){return point.isOmitter();})) { + return points; + } + vector result; - set> iteratorSets; + set> producerIteratorSets; for (auto& point : points) { set iteratorSet(point.iterators().begin(), point.iterators().end()); - if (util::contains(iteratorSets, iteratorSet)) { + if (!point.isOmitter() && util::contains(producerIteratorSets, iteratorSet)) { continue; } result.push_back(point); - iteratorSets.insert(iteratorSet); + + if (!point.isOmitter()) { + producerIteratorSets.insert(iteratorSet); + } } return result; } diff --git a/test/op_factory.h b/test/op_factory.h new file mode 100644 index 000000000..76d9398a4 --- /dev/null +++ b/test/op_factory.h @@ -0,0 +1,128 @@ +#ifndef TACO_OP_FACTORY_H +#define TACO_OP_FACTORY_H + +#include "taco/index_notation/index_notation.h" +#include "taco/ir/ir.h" + + +namespace taco { + +// Algebras +struct BC_BD_CD { + IterationAlgebra operator()(const std::vector &v) { + IterationAlgebra r1 = Intersect(v[0], v[1]); + IterationAlgebra r2 = Intersect(v[0], v[2]); + IterationAlgebra r3 = Intersect(v[1], v[2]); + + IterationAlgebra omit = Complement(Intersect(Intersect(v[0], v[1]), v[2])); + return Intersect(Union(Union(r1, r2), r3), omit); + } +}; + +struct UnionDeMorgan { + IterationAlgebra operator()(const std::vector& regions) { + if(regions.empty()) { + return IterationAlgebra(); + } + + if (regions.size() == 1) { + return regions[0]; + } + + IterationAlgebra intersections = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + intersections = Intersect(intersections, Complement(regions[i])); + } + return Complement(intersections); + } +}; + +struct ComplementUnion { + IterationAlgebra operator()(const std::vector& regions) { + taco_iassert(regions.size() >= 2); + IterationAlgebra unions = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + unions = Union(unions, regions[i]); + } + return unions; + } +}; + +struct IntersectGen { + IterationAlgebra operator()(const std::vector& regions) { + if (regions.size() < 2) { + return IterationAlgebra(); + } + + IterationAlgebra intersections = regions[0]; + for(size_t i = 1; i < regions.size(); ++i) { + intersections = Intersect(intersections, regions[i]); + } + return intersections; + } +}; + +struct ComplementIntersect { + IterationAlgebra operator()(const std::vector& regions) { + if (regions.size() < 2) { + return IterationAlgebra(); + } + + IterationAlgebra intersections = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + intersections = Intersect(intersections, regions[i]); + } + return intersections; + } +}; + + +struct IntersectGenDeMorgan { + IterationAlgebra operator()(const std::vector& regions) { + IterationAlgebra unions; + for(const auto& region : regions) { + unions = Union(unions, Complement(region)); + } + return Complement(unions); + } +}; + + +// Lowerers +struct MulAdd { + ir::Expr operator()(const std::vector &v) { + return ir::Add::make(ir::Mul::make(v[0], v[1]), v[2]); + } +}; + +struct GeneralAdd { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() >= 2) << "Add operator needs at least two operands"; + ir::Expr add = ir::Add::make(v[0], v[1]); + + for (size_t idx = 2; idx < v.size(); ++idx) { + add = ir::Add::make(add, v[idx]); + } + + return add; + } +}; + +// Special definitions +struct MulRegionDef { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() == 2) << "Add operator needs at least two operands"; + return ir::Mul::make(v[0], v[1]); + } +}; + +struct SubRegionDef { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() == 2) << "Sub def needs two operands"; + return ir::Sub::make(v[1], v[0]); + } +}; + + +} +#endif //TACO_OP_FACTORY_H diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 25994cf45..7f06b5c89 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -17,6 +17,8 @@ #include "taco/format.h" #include "taco/util/strings.h" +#include "op_factory.h" + namespace taco { namespace test { @@ -1558,24 +1560,8 @@ TEST_STMT(vector_not, ) // Test tensorOps -struct lowerOp { - ir::Expr operator()(const std::vector& v) { - return ir::Add::make(ir::Mul::make(v[0], v[1]), v[2]); - } -}; - -struct algebraGen { - IterationAlgebra operator()(const std::vector& v) { - IterationAlgebra r1 = Intersect(v[0], v[1]); - IterationAlgebra r2 = Intersect(v[0], v[2]); - IterationAlgebra r3 = Intersect(v[1], v[2]); - - IterationAlgebra omit = Complement(Intersect(Intersect(v[0], v[1]), v[2])); - return Intersect(Union(Union(r1, r2), r3), omit); - } -}; -Op testOp("testOp", lowerOp(), algebraGen()); +Op testOp("testOp", MulAdd(), BC_BD_CD()); TEST_STMT(testOp1, forall(i, @@ -1594,4 +1580,30 @@ TEST_STMT(testOp1, } ) + +Op specialOp("specialOp", GeneralAdd(), BC_BD_CD(), {{{0,1}, MulRegionDef()}, {{0,2}, SubRegionDef()}}); + + +TEST_STMT(testSpecialOp, + forall(i, + forall(j, + A(i, j) = specialOp(B(i, j), C(i, j), D(i, j)) + )), + Values( + Formats({{A, Format({dense, sparse})}, {B, Format({dense, sparse})}, {C, Format({dense,sparse})}, + {D, Format({dense,sparse})}}), + Formats({{A, Format({sparse, sparse})}, {B, Format({sparse, sparse})}, {C, Format({sparse,sparse})}, + {D, Format({sparse,sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 2.0}, {{1, 1}, 3.0}, {{1, 2}, 2.0}, {{4, 3}, 4.0}}}, + {C, {{{0, 1}, 3.0}, {{2, 1}, 3.0}, {{2, 2}, 4.0}, {{4, 3}, 6.0}}}, + {D, {{{1, 2}, 1.0}, {{2, 1}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 5.0}}}}, + + {{A, {{{0, 1}, 6.0}, {{1, 2}, -1.0}, {{2, 1}, 7.0}}}}) + } +) + + }} diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index a05e8db73..6f3f4f325 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -9,6 +9,8 @@ #include "lower/mode_access.h" #include "taco/ir/ir.h" #include "taco/lower/mode_format_impl.h" +#include "taco/index_notation/tensor_operator.h" +#include "op_factory.h" using namespace std; @@ -275,6 +277,30 @@ INSTANTIATE_TEST_CASE_P(add, merge_lattice, {it(rd)}) }) ), + Test(forall(i, rd = s1 + (s2 + s3)), + MergeLattice({MergePoint({it(s1), it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1)}, + {}, + {it(rd)}), + MergePoint({it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s3)}, + {}, + {it(rd)}) + }) + ), Test(forall(i, rd = d1 + s2), MergeLattice({MergePoint({i, it(s2)}, {it(d1)}, @@ -589,6 +615,301 @@ INSTANTIATE_TEST_CASE_P(hashmap, merge_lattice, ) ); +Op intersectAdd("intersectAdd", GeneralAdd(), IntersectGen()); +Op intersectAddDeMorgan("intersectAddDeMorgan", GeneralAdd(), IntersectGenDeMorgan()); + +INSTANTIATE_TEST_CASE_P(deMorganIntersect, merge_lattice, + Values( + Test(forall(i, rd = intersectAdd(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAdd(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(h1, h2)), + MergeLattice({MergePoint({it(h1)}, + {it(h2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, h1)), + MergeLattice({MergePoint({it(h1)}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}) + }) + ) + + ) +); + +Op complementIntersect("complementIntersect", GeneralAdd(), ComplementIntersect()); + +INSTANTIATE_TEST_CASE_P(complementIntersect, merge_lattice, + Values( + Test(forall(i, rd = complementIntersect(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true), + + MergePoint({it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(s1, d1)), + MergeLattice({MergePoint({it(s1), i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({{it(s1)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, h2)), + MergeLattice({MergePoint({it(h2)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({{it(h2)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({{it(s1)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(s1, s2, s3)), + MergeLattice({MergePoint({it(s1), it(s2), it(s3)}, + {}, + {it(rd)}, + true), + MergePoint({it(s2), it(s3)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {it(h1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, d1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, d2, d3)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2), it(d3)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d2), it(d3)}, + {it(rd)}) + }) + ) + + ) +); + + +Op complementUnion("complementUnion", GeneralAdd(), ComplementUnion()); +INSTANTIATE_TEST_CASE_P(complementUnion, merge_lattice, + Values( + Test(forall(i, rd = complementUnion(s1, s2)), + MergeLattice({MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}), + MergePoint({i}, + {it(d2)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(s1, d1)), + MergeLattice({MergePoint({it(s1), i}, + {it(d1)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(d1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(d1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(h1, h2)), + MergeLattice({MergePoint({i}, + {it(h1), it(h2)}, + {it(rd)}), + MergePoint({i}, + {it(h2)}, + {it(rd)}), + MergePoint({i}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(h1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(h1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}), + MergePoint({i}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(s1, s2, s3)), + MergeLattice({MergePoint({it(s1), i, it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({i, it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), i, it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i, it(s3)}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + + ) + + ) +); + + + IndexVar i1, i2; TEST(merge_lattice, split) { From 895bed92c3b2e116ac18d23b2ce44f7823a4a676 Mon Sep 17 00:00:00 2001 From: Rawn Date: Mon, 23 Mar 2020 19:36:44 -0400 Subject: [PATCH 14/27] Added more tests for lattice construction. Moved code that applies lattice optimizations to the end of lattice construction. Conditions to apply these optimizations still not quite right so all tests pass for now. Need to also check that no producer regions have a special definition. Lowerer also needs to be altered to handle compute regions more generally instead of just sparse. Lowerer needs to be altered to handle explicit zeros. --- include/taco/lower/merge_lattice.h | 13 +++ src/lower/merge_lattice.cpp | 110 +++++++++----------- test/op_factory.h | 57 ++++++---- test/tests-lower.cpp | 8 +- test/tests-merge_lattice.cpp | 161 +++++++++++++++++++++++++++++ 5 files changed, 264 insertions(+), 85 deletions(-) diff --git a/include/taco/lower/merge_lattice.h b/include/taco/lower/merge_lattice.h index f821ffc68..a3d0cf657 100644 --- a/include/taco/lower/merge_lattice.h +++ b/include/taco/lower/merge_lattice.h @@ -49,6 +49,19 @@ class MergeLattice { static MergeLattice make(Forall forall, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, std::map whereTempsToResult = {}); + + /** + * Removes lattice points whose iterators are identical to the iterators of an earlier point, since we have + * already iterated over this sub-space. + */ + static MergeLattice removeProducersWithIdenticalIterators(const MergeLattice&); + + /** + * remove lattice points that lack any of the full iterators of the first point, since when a full iterator exhausts + * we have iterated over the whole space. + */ + static MergeLattice removePointsThatLackFullIterators(const MergeLattice &l); + /** * Returns the sub-lattice rooted at the given merge point. */ diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 30a23ad05..4d0ed499f 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -503,11 +503,6 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // made of only omitters // points = removeUnnecessaryOmitterPoints(points); - // Optimization: remove lattice points whose iterators are identical to the - // iterators of an earlier point, since we have already iterated - // over this sub-space. - points = removeProducersWithIdenticalIterators(points); - return MergeLattice(points); } @@ -561,20 +556,9 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // are subsets of some other iterator. points = moveLocateSubsetIteratorsToLocateSet(points); - // Optimization: remove lattice points that lack any of the full iterators - // of the first point, since when a full iterator exhausts we - // have iterated over the whole space. - points = removePointsThatLackFullIterators(points); - // Optimization: Removes a subLattice of points if the entire subLattice is // made of only omitters // points = removeUnnecessaryOmitterPoints(points); - - // Optimization: remove lattice points whose iterators are identical to the - // iterators of an earlier point, since we have already iterated - // over this sub-space. - points = removeProducersWithIdenticalIterators(points); - return MergeLattice(points); } @@ -712,52 +696,6 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA return result; } - static vector - removePointsThatLackFullIterators(const vector& points) - { - vector result; - vector fullIterators = filter(points[0].iterators(), - [](Iterator it){return it.isFull();}); - for (auto& point : points) { - bool missingFullIterator = false; - for (auto& fullIterator : fullIterators) { - if (!util::contains(point.iterators(), fullIterator)) { - missingFullIterator = true; - break; - } - } - if (!missingFullIterator) { - result.push_back(point); - } - } - return result; - } - - static vector - removeProducersWithIdenticalIterators(vector points) - { - // Can't remove points if lattice contains omitters since we lose merge cases during lowering. - if(util::any(points, [](const MergePoint& point){return point.isOmitter();})) { - return points; - } - - vector result; - set> producerIteratorSets; - for (auto& point : points) { - set iteratorSet(point.iterators().begin(), - point.iterators().end()); - if (!point.isOmitter() && util::contains(producerIteratorSets, iteratorSet)) { - continue; - } - result.push_back(point); - - if (!point.isOmitter()) { - producerIteratorSets.insert(iteratorSet); - } - } - return result; - } - static vector deduplicateDimensionIterators(const vector& iterators) { @@ -919,7 +857,12 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr } MergeLattice lattice = builder.build(forall.getStmt()); - return lattice; + // Can't remove points if lattice contains omitters since we lose merge cases during lowering. + if(util::any(lattice.points(), [](const MergePoint& point){return point.isOmitter();})) { + return lattice; + } + lattice = removePointsThatLackFullIterators(lattice); + return removeProducersWithIdenticalIterators(lattice); } MergeLattice MergeLattice::subLattice(MergePoint lp) const { @@ -1018,6 +961,47 @@ std::vector MergeLattice::retrieveIteratorsToOmit(const MergePoint &po return omittedIterators; } +MergeLattice +MergeLattice::removePointsThatLackFullIterators(const MergeLattice& l) +{ + vector result; + vector fullIterators = filter(l.points()[0].iterators(), + [](Iterator it){return it.isFull();}); + for (auto& point : l.points()) { + bool missingFullIterator = false; + for (auto& fullIterator : fullIterators) { + if (!util::contains(point.iterators(), fullIterator)) { + missingFullIterator = true; + break; + } + } + if (!missingFullIterator) { + result.push_back(point); + } + } + return MergeLattice(result); +} + +MergeLattice +MergeLattice::removeProducersWithIdenticalIterators(const MergeLattice& l) +{ + vector result; + set> producerIteratorSets; + for (auto& point : l.points()) { + set iteratorSet(point.iterators().begin(), + point.iterators().end()); + if (!point.isOmitter() && util::contains(producerIteratorSets, iteratorSet)) { + continue; + } + result.push_back(point); + + if (!point.isOmitter()) { + producerIteratorSets.insert(iteratorSet); + } + } + return MergeLattice(result); +} + ostream& operator<<(ostream& os, const MergeLattice& ml) { return os << util::join(ml.points(), ", "); } diff --git a/test/op_factory.h b/test/op_factory.h index 76d9398a4..2a46cc9e5 100644 --- a/test/op_factory.h +++ b/test/op_factory.h @@ -19,24 +19,6 @@ struct BC_BD_CD { } }; -struct UnionDeMorgan { - IterationAlgebra operator()(const std::vector& regions) { - if(regions.empty()) { - return IterationAlgebra(); - } - - if (regions.size() == 1) { - return regions[0]; - } - - IterationAlgebra intersections = Complement(regions[0]); - for(size_t i = 1; i < regions.size(); ++i) { - intersections = Intersect(intersections, Complement(regions[i])); - } - return Complement(intersections); - } -}; - struct ComplementUnion { IterationAlgebra operator()(const std::vector& regions) { taco_iassert(regions.size() >= 2); @@ -76,7 +58,6 @@ struct ComplementIntersect { } }; - struct IntersectGenDeMorgan { IterationAlgebra operator()(const std::vector& regions) { IterationAlgebra unions; @@ -87,6 +68,38 @@ struct IntersectGenDeMorgan { } }; +struct xorGen { + IterationAlgebra operator()(const std::vector& regions) { + IterationAlgebra noIntersect = Complement(Intersect(regions[0], regions[1])); + return Intersect(noIntersect, Union(regions[0], regions[1])); + } +}; + +struct fullSpaceGen { + IterationAlgebra operator()(const std::vector& regions) { + return Union(Complement(regions[0]), regions[0]); + } +}; + +struct emptyGen { + IterationAlgebra operator()(const std::vector& regions) { + return Intersect(Complement(regions[0]), regions[0]); + } +}; + +struct intersectEdge { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Intersect(Complement(Intersect(r[0], r[1])), Intersect(r[0], r[1])); + } +}; + +struct unionEdge { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Union(Complement(Intersect(r[0], r[1])), Intersect(r[0], r[1])); + } +}; // Lowerers struct MulAdd { @@ -95,6 +108,12 @@ struct MulAdd { } }; +struct identityFunc { + ir::Expr operator()(const std::vector &v) { + return v[0]; + } +}; + struct GeneralAdd { ir::Expr operator()(const std::vector &v) { taco_iassert(v.size() >= 2) << "Add operator needs at least two operands"; diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 7f06b5c89..b9e16bfda 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -1590,9 +1590,11 @@ TEST_STMT(testSpecialOp, A(i, j) = specialOp(B(i, j), C(i, j), D(i, j)) )), Values( - Formats({{A, Format({dense, sparse})}, {B, Format({dense, sparse})}, {C, Format({dense,sparse})}, - {D, Format({dense,sparse})}}), - Formats({{A, Format({sparse, sparse})}, {B, Format({sparse, sparse})}, {C, Format({sparse,sparse})}, + Formats({{A, Format({dense,dense})}, {B, Format({dense,dense})}, {C, Format({dense,dense})}, + {D, Format({dense,dense})}}), +// Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, +// {D, Format({dense,sparse})}}), + Formats({{A, Format({sparse,sparse})}, {B, Format({sparse,sparse})}, {C, Format({sparse,sparse})}, {D, Format({sparse,sparse})}}) ), { diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index 6f3f4f325..e43898d10 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -908,6 +908,167 @@ INSTANTIATE_TEST_CASE_P(complementUnion, merge_lattice, ) ); +Op xorOp("xor", GeneralAdd(), xorGen()); +INSTANTIATE_TEST_CASE_P(xorLattice, merge_lattice, + Values(Test(forall(i, rd = xorOp(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {}, + {it(rd)}), + MergePoint({it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({i}, + {it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, h2)), + MergeLattice({MergePoint({i}, + {it(h1), it(h2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i}, + {it(h2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(d1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, d1)), + MergeLattice({MergePoint({i}, + {it(d1), it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ) + ) +); + +Op identity("identity", identityFunc(), fullSpaceGen()); +INSTANTIATE_TEST_CASE_P(singleCompUnion, merge_lattice, + Values(Test(forall(i, rd = identity(s1)), + MergeLattice({MergePoint({it(s1), i}, + {}, + {it(rd)}), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = identity(d1)), + MergeLattice({MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = identity(h1)), + MergeLattice({MergePoint({i}, + {it(h1)}, + {it(rd)}) + }) + ) + ) +); + +Op emptyIdentity("emptyIdentity", identityFunc(), emptyGen()); +Op intersectEdgeCase("intersectEdgeCase", GeneralAdd(), intersectEdge()); +Op unionEdgeCase("unionEdgeCase", GeneralAdd(), unionEdge()); +INSTANTIATE_TEST_CASE_P(edgeCases, merge_lattice, + Values(Test(forall(i, rd = emptyIdentity(s1)), + MergeLattice({MergePoint({it(s1)}, + {}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = emptyIdentity(d1)), + MergeLattice({MergePoint({i}, + {it(d1)}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = emptyIdentity(h1)), + MergeLattice({MergePoint({it(h1)}, + {it(h1)}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = intersectEdgeCase(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = unionEdgeCase(s1, s2)), + MergeLattice({MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i}, + {}, + {it(rd)}) + }) + ) + + ) +); + + IndexVar i1, i2; From 09b504017ad2d7ffb976df020c7807205d991ad2 Mon Sep 17 00:00:00 2001 From: Rawn Date: Thu, 9 Apr 2020 17:22:29 -0400 Subject: [PATCH 15/27] Added fill value to tensor. Differentiated between case and loop lattice in lowerer. Added explicit zero checks in bottom loop of compute --- include/taco/index_notation/index_notation.h | 22 +- include/taco/ir/ir.h | 1 + include/taco/lower/iterator.h | 1 - include/taco/lower/lowerer_impl.h | 53 +++- include/taco/lower/merge_lattice.h | 29 +- include/taco/storage/pack.h | 11 +- include/taco/storage/storage.h | 10 +- include/taco/taco_tensor_t.h | 5 +- include/taco/tensor.h | 52 +++- src/codegen/codegen.cpp | 6 + src/codegen/codegen_c.cpp | 1 + src/codegen/codegen_cuda.cpp | 1 + src/index_notation/index_notation.cpp | 44 ++- src/ir/ir.cpp | 3 + src/ir/ir_printer.cpp | 7 +- src/lower/lowerer_impl.cpp | 297 +++++++++++++++---- src/lower/merge_lattice.cpp | 195 ++++++++---- src/storage/pack.cpp | 19 +- src/storage/storage.cpp | 33 ++- src/tensor.cpp | 50 +++- test/tests-api.cpp | 2 + test/tests-lower.cpp | 73 ++++- test/tests-tensor.cpp | 25 ++ 23 files changed, 747 insertions(+), 193 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index ba3b1738a..8c62c4b2f 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -278,11 +278,14 @@ class Literal : public IndexExpr { Literal(std::complex); Literal(std::complex); - static IndexExpr zero(Datatype); + static Literal zero(Datatype); /// Returns the literal value. template T getVal() const; + /// Returns an untyped pointer to the literal value + void* getValPtr(); + typedef LiteralNode Node; }; @@ -869,10 +872,10 @@ SuchThat suchthat(IndexStmt stmt, std::vector predicate); class TensorVar : public util::Comparable { public: TensorVar(); - TensorVar(const Type& type); - TensorVar(const std::string& name, const Type& type); - TensorVar(const Type& type, const Format& format); - TensorVar(const std::string& name, const Type& type, const Format& format); + TensorVar(const Type& type, const Literal& fill = Literal()); + TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal()); + TensorVar(const Type& type, const Format& format, const Literal& fill = Literal()); + TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal()); /// Returns the name of the tensor variable. std::string getName() const; @@ -890,6 +893,12 @@ class TensorVar : public util::Comparable { /// and execute it's expression. const Schedule& getSchedule() const; + /// Gets the fill value of the tensor variable. May be left undefined. + const Literal& getFill() const; + + /// Set the fill value of the tensor variable + void setFill(const Literal& fill); + /// Set the name of the tensor variable. void setName(std::string name); @@ -1004,5 +1013,8 @@ IndexExpr zero(IndexExpr, const std::set& zeroed); /// zero and then propagating and removing zeroes. IndexStmt zero(IndexStmt, const std::set& zeroed); +/// Returns true if there are no forall nodes in the indexStmt. Used to check +/// if the last loop is being lowered. +bool hasNoForAlls(IndexStmt); } #endif diff --git a/include/taco/ir/ir.h b/include/taco/ir/ir.h index 15dbdc7aa..12b6937cc 100644 --- a/include/taco/ir/ir.h +++ b/include/taco/ir/ir.h @@ -76,6 +76,7 @@ enum class TensorProperty { ModeTypes, Indices, Values, + FillValue, ValuesSize }; diff --git a/include/taco/lower/iterator.h b/include/taco/lower/iterator.h index 1d871ffaa..93a9447a0 100644 --- a/include/taco/lower/iterator.h +++ b/include/taco/lower/iterator.h @@ -56,7 +56,6 @@ class Iterator : public util::Comparable { /// Get the child of this iterator in its iterator list. const Iterator getChild() const; - /// Returns true if the iterator iterates over the dimension. bool isDimensionIterator() const; diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index f9ac237e6..47768f4cf 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -77,6 +77,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -86,6 +87,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -96,6 +98,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -103,6 +106,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -166,6 +170,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, const std::set& reducedAccesses); @@ -311,7 +316,7 @@ class LowererImpl : public util::Uncopyable { * Generate code to zero-initialize values array in range * [begin * size, (begin + 1) * size). */ - ir::Stmt zeroInitValues(ir::Expr tensor, ir::Expr begin, ir::Expr size); + ir::Stmt initValues(ir::Expr tensor, ir::Expr initVal, ir::Expr begin, ir::Expr size); /// Declare position variables and initialize them with a locate. ir::Stmt declLocatePosVars(std::vector iterators); @@ -359,6 +364,52 @@ class LowererImpl : public util::Uncopyable { /// Expression that evaluates to true if none of the iteratators are exhausted ir::Expr checkThatNoneAreExhausted(std::vector iterators); + /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. + /// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice. + ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, + MergeLattice lattice, const std::set& reducedAccesses); + + /// Constructs cases comparing the coordVar for each iterator to the resolved coordinate. + /// Returns a vector where coordComparisons[i] corresponds to a case for iters[i] + /// If no case can be formed for a given iterator, an undefined expr is appended where a case would normally be. + template + std::vector compareToResolvedCoordinate(const std::vector& iters, ir::Expr resolvedCoordinate, + IndexVar coordinateVar) { + std::vector coordComparisons; + + for (Iterator iterator : iters) { + if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && + provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { + coordComparisons.push_back(C::make(iterator.getCoordVar(), resolvedCoordinate)); + } else { + coordComparisons.push_back(ir::Expr()); + } + } + + return coordComparisons; + } + + /// Makes the preamble of booleans used in case checks for the inner most loop of the computations + /// The iterator to condition map contains the name of the boolean indicating if each corresponding mode iterator + /// and each locator is non-zero. This function populates this map so the caller can user the boolean names to emit + /// checks for each lattice point. + std::vector constructInnerLoopCasePreamble(ir::Expr coordinate, IndexVar coordinateVar, + MergeLattice lattice, + std::map& iteratorToConditionMap); + + /// Lowers merge cases in the lattice using a map to know what expr to emit for each iterator in the lattice. + /// The map must be of iterators to exprs of boolean types + std::vector lowerCasesFromMap(std::map iteratorToCondition, + ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, + const std::set& reducedAccesses); + + /// Constructs an expression which checks if this access is "zero" + ir::Expr constructCheckForAccessZero(Access); + + /// Filters out a list of iterators and returns those the lowerer should explicitly check for zeros. + /// For now, we only check mode iterators. + std::vector getModeIterators(const std::vector&); + private: bool assemble; bool compute; diff --git a/include/taco/lower/merge_lattice.h b/include/taco/lower/merge_lattice.h index a3d0cf657..731152ff5 100644 --- a/include/taco/lower/merge_lattice.h +++ b/include/taco/lower/merge_lattice.h @@ -54,13 +54,16 @@ class MergeLattice { * Removes lattice points whose iterators are identical to the iterators of an earlier point, since we have * already iterated over this sub-space. */ - static MergeLattice removeProducersWithIdenticalIterators(const MergeLattice&); + static std::vector removePointsWithIdenticalIterators(const std::vector&); /** * remove lattice points that lack any of the full iterators of the first point, since when a full iterator exhausts * we have iterated over the whole space. */ - static MergeLattice removePointsThatLackFullIterators(const MergeLattice &l); + static std::vector removePointsThatLackFullIterators(const std::vector&); + + /// Returns true if we need to emit checks for explicit zeros in the lattice given. + static bool needExplicitZeroChecks(const MergeLattice& lattice); /** * Returns the sub-lattice rooted at the given merge point. @@ -77,6 +80,11 @@ class MergeLattice { */ const std::vector& iterators() const; + /** + * Retrieve all the locators in this lattice. + */ + const std::vector& locators() const; + /** * Returns iterators that have been exhausted prior to the merge point. */ @@ -96,11 +104,24 @@ class MergeLattice { /** * Get a list of iterators that should be omitted at this merge point. */ - std::vector retrieveIteratorsToOmit(const MergePoint& point) const; + std::vector retrieveRegionIteratorsToOmit(const MergePoint& point) const; + /** + * Returns a set of sets of tensor iterators. A merge point with a tensorRegion in this set should not + * be removed from the lattice. + * + * Needed so that special regions are kept when applying optimizations that remove merge points. + */ + std::set> getTensorRegionsToKeep() const; + + /** + * Removes points from the lattice that would duplicate iteration over the input tensors. + */ + MergeLattice getLoopLattice() const; private: std::vector points_; + std::set> regionsToKeep; public: /** @@ -108,7 +129,7 @@ class MergeLattice { * is primarily intended for testing purposes and most construction should * happen through `MergeLattice::make`. */ - MergeLattice(std::vector points); + MergeLattice(std::vector points, std::set> regionsToKeep = {}); }; std::ostream& operator<<(std::ostream&, const MergeLattice&); diff --git a/include/taco/storage/pack.h b/include/taco/storage/pack.h index 437a441ee..e39536c56 100644 --- a/include/taco/storage/pack.h +++ b/include/taco/storage/pack.h @@ -17,6 +17,8 @@ namespace taco { +class Literal; + namespace ir { class Stmt; } @@ -25,12 +27,13 @@ TensorStorage pack(Datatype datatype, const std::vector& dimensions, const Format& format, const std::vector& coordinates, - const void* values); - + const void* values, + const Literal& fill); template TensorStorage pack(std::vector dimensions, Format format, - const std::vector,V>>& components){ + const std::vector,V>>& components, + const Literal& fill){ size_t order = dimensions.size(); size_t nnz = components.size(); @@ -45,7 +48,7 @@ TensorStorage pack(std::vector dimensions, Format format, } } - return pack(type(), dimensions, format, coordinates, values.data()); + return pack(type(), dimensions, format, coordinates, values.data(), fill); } } diff --git a/include/taco/storage/storage.h b/include/taco/storage/storage.h index 81692e061..50bb92124 100644 --- a/include/taco/storage/storage.h +++ b/include/taco/storage/storage.h @@ -12,6 +12,7 @@ class Type; class Datatype; class Index; class Array; +class Literal; /// Storage for a tensor object. Tensor storage consists of a value array that /// contains the tensor values and one index per mode. The type of each @@ -22,7 +23,7 @@ class TensorStorage { /// Construct tensor storage for the given format. TensorStorage(Datatype componentType, const std::vector& dimensions, - Format format); + Format format, Literal fillVal); /// Returns the tensor storage format. const Format& getFormat() const; @@ -45,6 +46,10 @@ class TensorStorage { /// Returns the value array that contains the tensor components. const Array& getValues() const; + /// Returns the fill array containing the tensor fill value. This is always + /// of size one. + const Array& getFill() const; + /// Returns the tensor component value array. Array getValues(); @@ -60,6 +65,9 @@ class TensorStorage { /// Set the tensor component value array. void setValues(const Array& values); + /// Set the fill array. This should always be size 1 + void setFill(const Array& fill); + private: struct Content; std::shared_ptr content; diff --git a/include/taco/taco_tensor_t.h b/include/taco/taco_tensor_t.h index e39ac9e74..479bb029c 100644 --- a/include/taco/taco_tensor_t.h +++ b/include/taco/taco_tensor_t.h @@ -18,12 +18,13 @@ typedef struct taco_tensor_t { taco_mode_t* mode_types; // mode storage types uint8_t*** indices; // tensor index data (per mode) uint8_t* vals; // tensor values + uint8_t* fill_value; // tensor fill value int32_t vals_size; // values array size } taco_tensor_t; taco_tensor_t *init_taco_tensor_t(int32_t order, int32_t csize, - int32_t* dimensions, int32_t* modeOrdering, - taco_mode_t* mode_types); + int32_t* dimensions, int32_t* modeOrdering, + taco_mode_t* mode_types); void deinit_taco_tensor_t(taco_tensor_t* t); diff --git a/include/taco/tensor.h b/include/taco/tensor.h index 05bc1773b..4a69ce6cd 100644 --- a/include/taco/tensor.h +++ b/include/taco/tensor.h @@ -53,19 +53,27 @@ class TensorBase { /// Create a tensor with the given dimensions. The format defaults to sparse /// in every mode. TensorBase(Datatype ctype, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed); + ModeFormat modeType = ModeFormat::compressed, Literal fill = Literal()); /// Create a tensor with the given dimensions and format. - TensorBase(Datatype ctype, std::vector dimensions, Format format); + TensorBase(Datatype ctype, std::vector dimensions, Format format, Literal fill = Literal()); /// Create a tensor with the given data type, dimensions and format. The /// format defaults to sparse in every mode. TensorBase(std::string name, Datatype ctype, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed); - + ModeFormat modeType = ModeFormat::compressed, Literal fill = Literal()); + + /// Create a tensor with the given dimensions and fill value. The format + /// defaults to sparse in every mode. + TensorBase(Datatype ctype, std::vector dimensions, Literal fill); + + /// Create a tensor with the given data type, dimensions and fill value. The + /// format defaults to sparse in every mode. + TensorBase(std::string name, Datatype ctype, std::vector dimensions, Literal fill); + /// Create a tensor with the given data type, dimensions and format. TensorBase(std::string name, Datatype ctype, std::vector dimensions, - Format format); + Format format, Literal fill = Literal()); /// Set the name of the tensor. void setName(std::string name) const; @@ -440,6 +448,9 @@ class TensorBase { /// Get the taco_tensor_t representation of this tensor. taco_tensor_t* getTacoTensorT(); + /// Get the fill value of this tensor. + Literal getFillValue() const; + /// True iff two tensors have the same type and the same values. friend bool equals(const TensorBase&, const TensorBase&); @@ -486,7 +497,7 @@ template class Tensor : public TensorBase { public: /// Create a scalar - Tensor() : TensorBase() {} + Tensor() : TensorBase(type()) {} /// Create a scalar with the given name explicit Tensor(std::string name) : TensorBase(name, type()) {} @@ -496,22 +507,35 @@ class Tensor : public TensorBase { /// Create a tensor with the given dimensions. The format defaults to sparse /// in every mode. - Tensor(std::vector dimensions, ModeFormat modeType = ModeFormat::compressed) - : TensorBase(type(), dimensions) {} + Tensor(std::vector dimensions, ModeFormat modeType = ModeFormat::compressed, + CType fill = CType()) + : TensorBase(type(), dimensions, fill) {} /// Create a tensor with the given dimensions and format - Tensor(std::vector dimensions, Format format) - : TensorBase(type(), dimensions, format) {} + Tensor(std::vector dimensions, Format format, CType fill = CType()) + : TensorBase(type(), dimensions, format, fill) {} /// Create a tensor with the given name, dimensions and format. The format /// defaults to sparse in every mode. Tensor(std::string name, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed) - : TensorBase(name, type(), dimensions, modeType) {} + ModeFormat modeType = ModeFormat::compressed, + CType fill = CType()) + : TensorBase(name, type(), dimensions, modeType, fill) {} + + /// Create a tensor with the given dimensions and fill value. The format + /// defaults to sparse in every mode. + Tensor(std::vector dimensions, CType fill) + : TensorBase(type(), dimensions, fill) {} + + /// Create a tensor with the given data type, dimensions and fill value. The + /// format defaults to sparse in every mode. + Tensor(std::string name, std::vector dimensions, CType fill) + : TensorBase(name, type(), dimensions, fill) {} /// Create a tensor with the given name, dimensions and format - Tensor(std::string name, std::vector dimensions, Format format) - : TensorBase(name, type(), dimensions, format) {} + Tensor(std::string name, std::vector dimensions, Format format, + CType fill = CType()) + : TensorBase(name, type(), dimensions, format, fill) {} /// Create a tensor from a TensorBase instance. The Tensor and TensorBase /// objects will reference the same underlying tensor so it is a shallow copy. diff --git a/src/codegen/codegen.cpp b/src/codegen/codegen.cpp index 126fe1712..9930f4969 100644 --- a/src/codegen/codegen.cpp +++ b/src/codegen/codegen.cpp @@ -209,6 +209,10 @@ string CodeGen::unpackTensorProperty(string varname, const GetProperty* op, } else if (op->property == TensorProperty::ValuesSize) { ret << "int " << varname << " = " << tensor->name << "->vals_size;\n"; return ret.str(); + } else if (op->property == TensorProperty::FillValue) { + ret << printType(tensor->type, false) << " " << varname << " = "; + ret << "*((" <type, true) << ")(" << tensor->name << "->fill_value));\n"; + return ret.str(); } string tp; @@ -246,6 +250,8 @@ string CodeGen::packTensorProperty(string varname, Expr tnsr, } else if (property == TensorProperty::ValuesSize) { ret << tensor->name << "->vals_size = " << varname << ";\n"; return ret.str(); + } else if (property == TensorProperty::FillValue) { + return ""; } string tp; diff --git a/src/codegen/codegen_c.cpp b/src/codegen/codegen_c.cpp index 4e3a4fc21..11674db42 100644 --- a/src/codegen/codegen_c.cpp +++ b/src/codegen/codegen_c.cpp @@ -48,6 +48,7 @@ const string cHeaders = " taco_mode_t* mode_types; // mode storage types\n" " uint8_t*** indices; // tensor index data (per mode)\n" " uint8_t* vals; // tensor values\n" + " uint8_t* fill_value; // tensor fill value\n" " int32_t vals_size; // values array size\n" "} taco_tensor_t;\n" "#endif\n" diff --git a/src/codegen/codegen_cuda.cpp b/src/codegen/codegen_cuda.cpp index 5eb57c7ad..301a73a33 100644 --- a/src/codegen/codegen_cuda.cpp +++ b/src/codegen/codegen_cuda.cpp @@ -49,6 +49,7 @@ const string cHeaders = " taco_mode_t* mode_types; // mode storage types\n" " uint8_t*** indices; // tensor index data (per mode)\n" " uint8_t* vals; // tensor values\n" + " uint8_t* fill_value; // tensor fill value\n" " int32_t vals_size; // values array size\n" "} taco_tensor_t;\n" "#endif\n" diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 2fedbdc12..99d454ee7 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -656,7 +656,7 @@ Literal::Literal(std::complex val) : Literal(new LiteralNode(val)) { Literal::Literal(std::complex val) : Literal(new LiteralNode(val)) { } -IndexExpr Literal::zero(Datatype type) { +Literal Literal::zero(Datatype type) { switch (type.getKind()) { case Datatype::Bool: return Literal(false); case Datatype::UInt8: return Literal(uint8_t(0)); @@ -674,7 +674,7 @@ IndexExpr Literal::zero(Datatype type) { default: taco_ierror << "unsupported type"; }; - return IndexExpr(); + return Literal(); } template T Literal::getVal() const { @@ -697,6 +697,10 @@ template double Literal::getVal() const; template std::complex Literal::getVal() const; template std::complex Literal::getVal() const; +void* Literal::getValPtr() { + return getNode(*this)->val; +} + template <> bool isa(IndexExpr e) { return isa(e.ptr); } @@ -1700,6 +1704,7 @@ struct TensorVar::Content { Type type; Format format; Schedule schedule; + Literal fill; }; TensorVar::TensorVar() : content(nullptr) { @@ -1709,23 +1714,24 @@ static Format createDenseFormat(const Type& type) { return Format(vector(type.getOrder(), ModeFormat(Dense))); } -TensorVar::TensorVar(const Type& type) -: TensorVar(type, createDenseFormat(type)) { +TensorVar::TensorVar(const Type& type, const Literal& fill) +: TensorVar(type, createDenseFormat(type), fill) { } -TensorVar::TensorVar(const std::string& name, const Type& type) -: TensorVar(name, type, createDenseFormat(type)) { +TensorVar::TensorVar(const std::string& name, const Type& type, const Literal& fill) +: TensorVar(name, type, createDenseFormat(type), fill) { } -TensorVar::TensorVar(const Type& type, const Format& format) - : TensorVar(util::uniqueName('A'), type, format) { +TensorVar::TensorVar(const Type& type, const Format& format, const Literal& fill) + : TensorVar(util::uniqueName('A'), type, format, fill) { } -TensorVar::TensorVar(const string& name, const Type& type, const Format& format) +TensorVar::TensorVar(const string& name, const Type& type, const Format& format, const Literal& fill) : content(new Content) { content->name = name; content->type = type; content->format = format; + content->fill = fill.defined()? fill : Literal::zero(type.getDataType()); } std::string TensorVar::getName() const { @@ -1761,6 +1767,14 @@ const Schedule& TensorVar::getSchedule() const { return content->schedule; } +const Literal& TensorVar::getFill() const { + return content->fill; +} + +void TensorVar::setFill(const Literal &fill) { + content->fill = fill; +} + void TensorVar::setName(std::string name) { content->name = name; } @@ -2723,4 +2737,16 @@ IndexStmt zero(IndexStmt stmt, const std::set& zeroed) { return Zero(zeroed).rewrite(stmt); } +bool hasNoForAlls(IndexStmt stmt) { + + bool noForAlls = true; + match(stmt, + std::function([&](const ForallNode* op) { + noForAlls = false; + }) + ); + return noForAlls; +} + + } diff --git a/src/ir/ir.cpp b/src/ir/ir.cpp index dbe941fe6..5eb6c9af3 100644 --- a/src/ir/ir.cpp +++ b/src/ir/ir.cpp @@ -855,6 +855,9 @@ Expr GetProperty::make(Expr tensor, TensorProperty property, int mode) { case TensorProperty::ValuesSize: gp->name = tensorVar->name + "_vals_size"; break; + case TensorProperty::FillValue: + gp->name = tensorVar->name + "_fill_value"; + break; } return gp; diff --git a/src/ir/ir_printer.cpp b/src/ir/ir_printer.cpp index 0fca68786..6338430bd 100644 --- a/src/ir/ir_printer.cpp +++ b/src/ir/ir_printer.cpp @@ -131,7 +131,12 @@ void IRPrinter::visit(const Var* op) { } void IRPrinter::visit(const Neg* op) { - stream << "-"; + if(op->type == taco::Bool) { + stream << "!"; + } + else { + stream << "-"; + } parentPrecedence = Precedence::NEG; op->a.accept(this); } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index e422bdfc9..0bdf87a33 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -222,7 +222,6 @@ LowererImpl::lower(IndexStmt stmt, string name, bool assemble, bool compute) } } } - // Allocate and initialize append and insert mode indices Stmt initializeResults = initResultArrays(resultAccesses, inputAccesses, reducedAccesses); @@ -397,7 +396,7 @@ Stmt LowererImpl::lowerForall(Forall forall) parallelUnitSizes[forall.getParallelUnit()] = ir::Sub::make(bounds[1], bounds[0]); } - MergeLattice lattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult); + MergeLattice caseLattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult); vector resultAccesses; set reducedAccesses; std::tie(resultAccesses, reducedAccesses) = getResultAccesses(forall); @@ -410,11 +409,12 @@ Stmt LowererImpl::lowerForall(Forall forall) Stmt loops; // Emit a loop that iterates over over a single iterator (optimization) - if (lattice.iterators().size() == 1 && lattice.iterators()[0].isUnique()) { - taco_iassert(lattice.points().size() == 1); + if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique()) { + MergeLattice loopLattice = caseLattice.getLoopLattice(); +// taco_iassert(caseLattice.points().size() == 1) << "\n Lattice received was " << caseLattice; - MergePoint point = lattice.points()[0]; - Iterator iterator = lattice.iterators()[0]; + MergePoint point = loopLattice.points()[0]; + Iterator iterator = loopLattice.iterators()[0]; vector locators = point.locators(); vector appenders; @@ -441,18 +441,18 @@ Stmt LowererImpl::lowerForall(Forall forall) } } if (!isWhereProducer && hasPosDescendant && underivedAncestors.size() > 1 && provGraph.isPosVariable(iterator.getIndexVar()) && posDescendant == forall.getIndexVar()) { - loops = lowerForallFusedPosition(forall, iterator, locators, - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallFusedPosition(forall, iterator, locators, inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } // Emit dimension coordinate iteration loop else if (iterator.isDimensionIterator()) { - loops = lowerForallDimension(forall, point.locators(), - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallDimension(forall, point.locators(), inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } // Emit position iteration loop else if (iterator.hasPosIter()) { - loops = lowerForallPosition(forall, iterator, locators, - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallPosition(forall, iterator, locators, inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } // Emit coordinate iteration loop else { @@ -465,7 +465,7 @@ Stmt LowererImpl::lowerForall(Forall forall) else { std::vector underivedAncestors = provGraph.getUnderivedAncestors(forall.getIndexVar()); taco_iassert(underivedAncestors.size() == 1); // TODO: add support for fused coordinate of pos loop - loops = lowerMergeLattice(lattice, underivedAncestors[0], + loops = lowerMergeLattice(caseLattice, underivedAncestors[0], forall.getStmt(), reducedAccesses); } // taco_iassert(loops.defined()); @@ -779,6 +779,7 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -789,8 +790,8 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, atomicParallelUnit = forall.getParallelUnit(); } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, + appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -823,6 +824,7 @@ Stmt LowererImpl::lowerForallCoordinate(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { taco_not_supported_yet; @@ -833,6 +835,7 @@ Stmt LowererImpl::lowerForallPosition(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -847,8 +850,7 @@ Stmt LowererImpl::lowerForallPosition(Forall forall, Iterator iterator, markAssignsAtomicDepth++; } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -909,6 +911,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -996,7 +999,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, } Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + locators, inserters, appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1062,31 +1065,34 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, } -Stmt LowererImpl::lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar, +Stmt LowererImpl::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinateVar, IndexStmt statement, const std::set& reducedAccesses) { + // Lower merge lattice always gets called from lowerForAll. So we want loop lattice + MergeLattice loopLattice = caseLattice.getLoopLattice(); + Expr coordinate = getCoordinateVar(coordinateVar); - vector appenders = filter(lattice.results(), + vector appenders = filter(loopLattice.results(), [](Iterator it){return it.hasAppend();}); - vector mergers = lattice.points()[0].mergers(); - Stmt iteratorVarInits = codeToInitializeIteratorVars(lattice.iterators(), lattice.points()[0].rangers(), mergers, coordinate, coordinateVar); + vector mergers = loopLattice.points()[0].mergers(); + Stmt iteratorVarInits = codeToInitializeIteratorVars(loopLattice.iterators(), loopLattice.points()[0].rangers(), mergers, coordinate, coordinateVar); // if modeiteratornonmerger then will be declared in codeToInitializeIteratorVars auto modeIteratorsNonMergers = - filter(lattice.points()[0].iterators(), [mergers](Iterator it){ + filter(loopLattice.points()[0].iterators(), [mergers](Iterator it){ bool isMerger = find(mergers.begin(), mergers.end(), it) != mergers.end(); return it.isDimensionIterator() && !isMerger; }); bool resolvedCoordDeclared = !modeIteratorsNonMergers.empty(); vector mergeLoopsVec; - for (MergePoint point : lattice.points()) { + for (MergePoint point : loopLattice.points()) { // Each iteration of this loop generates a while loop for one of the merge // points in the merge lattice. - IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point,lattice)); - MergeLattice sublattice = lattice.subLattice(point); + IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice)); + MergeLattice sublattice = caseLattice.subLattice(point); Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared); mergeLoopsVec.push_back(mergeLoop); } @@ -1189,19 +1195,28 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I const std::set& reducedAccesses) { vector result; + if (hasNoForAlls(stmt) && MergeLattice::needExplicitZeroChecks(lattice)) { + // In the bottom most loop. + Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, lattice, reducedAccesses); + result.push_back(body); + return Block::make(result); + } + + // Emitting structural cases so unconditionally apply lattice optimizations. + lattice = lattice.getLoopLattice(); vector appenders; vector inserters; tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); - // Just one iterator so no conditionals if (lattice.iterators().size() == 1) { + // Just one iterator so no conditional taco_iassert(!lattice.points()[0].isOmitter()); - Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, lattice, reducedAccesses); result.push_back(body); } - else if (!lattice.points().empty()){ + else if (!lattice.points().empty()) { vector> cases; for (MergePoint point : lattice.points()) { @@ -1210,29 +1225,20 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I } // Construct case expression - vector coordComparisons; - for (Iterator iterator : point.rangers()) { - if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && - provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { - coordComparisons.push_back(Eq::make(iterator.getCoordVar(), coordinate)); - } - } + vector coordComparisons = compareToResolvedCoordinate(point.rangers(), coordinate, coordinateVar); + vector omittedRegionIterators = lattice.retrieveRegionIteratorsToOmit(point); + std::vector neqComparisons = compareToResolvedCoordinate(omittedRegionIterators, coordinate, coordinateVar); + append(coordComparisons, neqComparisons); - vector omittedRangers = lattice.retrieveIteratorsToOmit(point); - for (auto iterator: omittedRangers) { - if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && - provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { - coordComparisons.push_back(Neq::make(iterator.getCoordVar(), coordinate)); - } - } + coordComparisons = filter(coordComparisons, [](const Expr& e) { return e.defined(); }); // Construct case body IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses); if (coordComparisons.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses); result.push_back(body); break; } @@ -1244,11 +1250,177 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I return Block::make(result); } +ir::Expr LowererImpl::constructCheckForAccessZero(Access access) { + Expr tensorValue = lower(access); + IndexExpr zeroVal = Literal::zero(tensorValue.type()); //TODO ARRAY Generalize + return Neq::make(tensorValue, lower(zeroVal)); +} + +std::vector LowererImpl::getModeIterators(const std::vector& iters) { + // For now only check mode iterators. + return filter(iters, [](const Iterator& it){return it.isModeIterator();}); +} + +std::vector LowererImpl::constructInnerLoopCasePreamble(ir::Expr coordinate, IndexVar coordinateVar, + MergeLattice lattice, + map& iteratorToConditionMap) { + vector result; + + // First, get mode iterator coordinate comparisons + std::vector modeIterators = getModeIterators(lattice.iterators()); + vector coordComparisonsForModeIters = compareToResolvedCoordinate(modeIterators, coordinate, coordinateVar); + + std::vector modeItersWithIndexCases; + std::vector coordComparisons; + for(size_t i = 0; i < coordComparisonsForModeIters.size(); ++i) { + Expr expr = coordComparisonsForModeIters[i]; + if (expr.defined()) { + modeItersWithIndexCases.push_back(modeIterators[i]); + coordComparisons.push_back(expr); + } + } + + // Construct tensor iterators with modeIterators first then locate iterators to keep a mapping between vector indices + vector tensorIterators = combine(modeItersWithIndexCases, lattice.locators()); + tensorIterators = getModeIterators(tensorIterators); + + // Get value comparisons for all tensor iterators + vector itAccesses; + vector valueComparisons; + for(auto it : tensorIterators) { + Access itAccess = iterators.modeAccess(it).getAccess(); + itAccesses.push_back(itAccess); + valueComparisons.push_back(constructCheckForAccessZero(itAccess)); + } + + // Construct isNonZero cases + for(size_t i = 0; i < coordComparisons.size(); ++i) { + Expr nonZeroCase; + if(coordComparisons[i].defined()) { + nonZeroCase = conjunction({coordComparisons[i], valueComparisons[i]}); + } else { + nonZeroCase = valueComparisons[i]; + } + Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + Stmt declaration = VarDecl::make(caseName, nonZeroCase); + result.push_back(declaration); + iteratorToConditionMap[tensorIterators[i]] = caseName; + } + + for(size_t i = modeItersWithIndexCases.size(); i < valueComparisons.size(); ++i) { + Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + Stmt declaration = VarDecl::make(caseName, valueComparisons[i]); + result.push_back(declaration); + iteratorToConditionMap[tensorIterators[i]] = caseName; + } + + return result; +} + +vector LowererImpl::lowerCasesFromMap(map iteratorToCondition, + ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, + const std::set& reducedAccesses) { + + vector appenders; + vector inserters; + tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + + std::vector result; + vector> cases; + for (MergePoint point : lattice.points()) { + + if(point.isOmitter()) { + continue; + } + + // Construct case expression + vector isNonZeroComparisions; + for(auto& it : combine(point.rangers(), point.locators())) { + if(util::contains(iteratorToCondition, it)) { + taco_iassert(iteratorToCondition.at(it).type() == taco::Bool) << "Map must have boolean types"; + isNonZeroComparisions.push_back(iteratorToCondition.at(it)); + } + } + + function getNegatedComparison = [&](const Iterator& it) {return ir::Neg::make(iteratorToCondition.at(it));}; + vector omittedRegionIterators = lattice.retrieveRegionIteratorsToOmit(point); + for(auto& it : omittedRegionIterators) { + if(util::contains(iteratorToCondition, it)) { + isNonZeroComparisions.push_back(ir::Neg::make(iteratorToCondition.at(it))); + } + } + + // Construct case body + // TODO - Rawn "Exhaust" locators + IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); + Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, + inserters, appenders, MergeLattice({point}), reducedAccesses); + if (isNonZeroComparisions.empty()) { + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, MergeLattice({point}), reducedAccesses); + result.push_back(body); + break; + } + cases.push_back({taco::ir::conjunction(isNonZeroComparisions), body}); + } + + vector inputs = combine(lattice.iterators(), lattice.locators()); + inputs = getModeIterators(inputs); + + if(!lattice.exact() && util::any(inserters, [](Iterator it){return it.isFull();}) && hasNoForAlls(stmt) + && any(inputs, [](Iterator it){return it.isFull();})) { + // Currently, if the lattice is not exact, the output is full and any of the inputs are full, we initialize + // the result tensor + vector stmts; + for(auto& it : inserters) { + if(it.isFull()) { + Access access = iterators.modeAccess(it).getAccess(); + IndexStmt initStmt = Assignment(access, Literal::zero(access.getDataType())); + Stmt initialization = lowerForallBody(coordinate, initStmt, {}, inserters, + appenders, MergeLattice({}), reducedAccesses); + stmts.push_back(initialization); + } + } + Stmt backgroundInit = Block::make(stmts); + cases.push_back({Expr((bool) true), backgroundInit}); + result.push_back(Case::make(cases, true)); + } else { + result.push_back(Case::make(cases, lattice.exact())); + } + return result; +} + +/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. +Stmt LowererImpl::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, + MergeLattice lattice, const std::set& reducedAccesses) { + + vector result; + if (lattice.points().size() == 1 && lattice.iterators().size() == 1) { + // Just one iterator so no conditional + vector appenders; + vector inserters; + tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + taco_iassert(!lattice.points()[0].isOmitter()); + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, lattice, reducedAccesses); + result.push_back(body); + } else if (!lattice.points().empty()) { + map iteratorToConditionMap; + + vector preamble = constructInnerLoopCasePreamble(coordinate, coordinateVar, lattice, iteratorToConditionMap); + util::append(result, preamble); + vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses); + util::append(result, cases); + } + + return Block::make(result); +} Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, const set& reducedAccesses) { Stmt initVals = resizeAndInitValues(appenders, reducedAccesses); @@ -1264,7 +1436,24 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, } // Code of loop body statement - Stmt body = lower(stmt); + Stmt body; + if (hasNoForAlls(stmt) && caseLattice.points().size() > 1) { + std::vector stmts; + + // Need to emit checks based on case lattice + vector modeIterators = getModeIterators(combine(caseLattice.iterators(), caseLattice.locators())); + std::map caseMap; + for(auto it : modeIterators) { + Access itAccess = iterators.modeAccess(it).getAccess(); + Expr accessCase = constructCheckForAccessZero(itAccess); + caseMap.insert({it, accessCase}); + } + std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); + append(stmts, loweredCases); + body = Block::make(stmts); + } else { + body = lower(stmt); + } // Code to append coordinates Stmt appendCoords = appendCoordinate(appenders, coordinate); @@ -1666,6 +1855,7 @@ Stmt LowererImpl::initResultArrays(vector writes, taco_iassert(!iterators.empty()); Expr tensor = getTensorVar(write.getTensorVar()); + Expr fill = GetProperty::make(tensor, TensorProperty::FillValue); Expr valuesArr = GetProperty::make(tensor, TensorProperty::Values); Expr parentSize = 1; @@ -1742,7 +1932,7 @@ Stmt LowererImpl::initResultArrays(vector writes, // iteration of all the iterators is not full. We can check this by seeing if we can recover a // full iterator from our set of iterators. Expr size = generateAssembleCode() ? getCapacityVar(tensor) : parentSize; - result.push_back(zeroInitValues(tensor, 0, size)); + result.push_back(initValues(tensor, fill, 0, size)); } } return result.empty() ? Stmt() : Block::blanks(result); @@ -1833,6 +2023,7 @@ Stmt LowererImpl::initResultArrays(IndexVar var, vector writes, vector result; for (auto& write : writes) { Expr tensor = getTensorVar(write.getTensorVar()); + Expr fill = GetProperty::make(tensor, TensorProperty::FillValue); Expr values = GetProperty::make(tensor, TensorProperty::Values); vector iterators = getIteratorsFrom(var, getIterators(write)); @@ -1892,7 +2083,7 @@ Stmt LowererImpl::initResultArrays(IndexVar var, vector writes, util::contains(reducedAccesses, write)) { // Zero-initialize values array if might not assign to every element // in values array during compute - result.push_back(zeroInitValues(tensor, resultParentPos, stride)); + result.push_back(initValues(tensor, fill, resultParentPos, stride)); } } } @@ -1938,18 +2129,20 @@ Stmt LowererImpl::resizeAndInitValues(const std::vector& appenders, } -Stmt LowererImpl::zeroInitValues(Expr tensor, Expr begin, Expr size) { +Stmt LowererImpl::initValues(Expr tensor, Expr initVal, Expr begin, Expr size) { Expr lower = simplify(ir::Mul::make(begin, size)); Expr upper = simplify(ir::Mul::make(ir::Add::make(begin, 1), size)); Expr p = Var::make("p" + util::toString(tensor), Int()); Expr values = GetProperty::make(tensor, TensorProperty::Values); - Stmt zeroInit = Store::make(values, p, ir::Literal::zero(tensor.type())); + Stmt zeroInit = Store::make(values, p, initVal); LoopKind parallel = (isa(size) && to(size)->getIntValue() < (1 << 10)) ? LoopKind::Serial : LoopKind::Static_Chunked; if (should_use_CUDA_codegen() && util::contains(parallelUnitSizes, ParallelUnit::GPUBlock)) { return ir::VarDecl::make(ir::Var::make("status", Int()), - ir::Call::make("cudaMemset", {values, ir::Literal::make(0, Int()), ir::Mul::make(ir::Sub::make(upper, lower), ir::Literal::make(values.type().getNumBytes()))}, Int())); + ir::Call::make("cudaMemset", {values, ir::Literal::make(0, Int()), + ir::Mul::make(ir::Sub::make(upper, lower), + ir::Literal::make(values.type().getNumBytes()))}, Int())); } return For::make(p, lower, upper, 1, zeroInit, parallel); } diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 4d0ed499f..a238ea8a6 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -83,7 +83,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA std::set definedIndexVars; map latticesOfTemporaries; std::map whereTempsToResult; - map baseMergePoints; + map seenMergePoints; MergeLattice modeIterationLattice() { return MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); @@ -161,6 +161,10 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA void visit(const AccessNode* access) { // TODO: Case where Access is used in computation but not iteration algebra + if(seenMergePoints.find(access) != seenMergePoints.end()) { + lattice = MergeLattice({seenMergePoints.at(access)}); + return; + } if (util::contains(latticesOfTemporaries, access->tensorVar)) { // If the accessed tensor variable is a temporary with an associated merge @@ -225,7 +229,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA lattice = MergeLattice({point}); } - baseMergePoints.insert({iterator, lattice.points()[0]}); + seenMergePoints.insert({access, lattice.points()[0]}); } void visit(const LiteralNode* node) { @@ -264,6 +268,28 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA void visit(const TensorOpNode* expr) { lattice = build(expr->iterAlg); + + // Now we need to store regions that should be kept when applying optimizations. + // Can't remove regions described by special regions since the lowerer must emit checks for those in + // all cases. + const auto regionDefs = expr->regionDefinitions; + const vector inputs = expr->args; + set> regionsToKeep; + + for(auto& it : regionDefs) { + vector region = it.first; + set regionToKeep; + for(auto idx : region) { + match(inputs[idx], + function([&](const AccessNode* n) { + set tensorRegion = seenMergePoints.at(n).tensorRegion(); + regionToKeep.insert(tensorRegion.begin(), tensorRegion.end()); + }) + ); + } + regionsToKeep.insert(regionToKeep); + } + lattice = MergeLattice(lattice.points(), regionsToKeep); } void visit(const CallIntrinsicNode* expr) { @@ -339,7 +365,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA vector(resultIterators.begin(), resultIterators.end()), point.isOmitter())); } - lattice = MergeLattice(points); + lattice = MergeLattice(points, lattice.getTensorRegionsToKeep()); } } @@ -400,7 +426,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA else if(!currentRegion.empty()){ MergePoint mp({}, {}, {}); for(const auto& it: currentRegion) { - mp = unionPoints(mp, baseMergePoints.at(it)); + mp = unionPoints(mp, seenMergePoints.at(iterators.modeAccess(it).getAccess())); } vector newIters; @@ -503,7 +529,11 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // made of only omitters // points = removeUnnecessaryOmitterPoints(points); - return MergeLattice(points); + set> toKeep = left.getTensorRegionsToKeep(); + set> toKeepRight = right.getTensorRegionsToKeep(); + + toKeep.insert(toKeepRight.begin(), toKeepRight.end()); + return MergeLattice(points, toKeep); } /** @@ -559,7 +589,11 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // Optimization: Removes a subLattice of points if the entire subLattice is // made of only omitters // points = removeUnnecessaryOmitterPoints(points); - return MergeLattice(points); + set> toKeep = left.getTensorRegionsToKeep(); + set> toKeepRight = right.getTensorRegionsToKeep(); + + toKeep.insert(toKeepRight.begin(), toKeepRight.end()); + return MergeLattice(points, toKeep); } /** @@ -838,7 +872,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA // class MergeLattice -MergeLattice::MergeLattice(vector points) : points_(points) +MergeLattice::MergeLattice(vector points, set> regionsToKeep) : points_(points), + regionsToKeep(regionsToKeep) { } @@ -857,12 +892,59 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr } MergeLattice lattice = builder.build(forall.getStmt()); + // Can't remove points if lattice contains omitters since we lose merge cases during lowering. - if(util::any(lattice.points(), [](const MergePoint& point){return point.isOmitter();})) { + if(hasNoForAlls(forall.getStmt()) && needExplicitZeroChecks(lattice)) { return lattice; } - lattice = removePointsThatLackFullIterators(lattice); - return removeProducersWithIdenticalIterators(lattice); + + // Loop lattice and case lattice are identical so simplify here + return lattice.getLoopLattice(); +} + +std::vector +MergeLattice::removePointsThatLackFullIterators(const std::vector& points) +{ + vector result; + vector fullIterators = filter(points[0].iterators(), + [](Iterator it){return it.isFull();}); + for (auto& point : points) { + bool missingFullIterator = false; + for (auto& fullIterator : fullIterators) { + if (!util::contains(point.iterators(), fullIterator)) { + missingFullIterator = true; + break; + } + } + if (!missingFullIterator) { + result.push_back(point); + } + } + return result; +} + +std::vector +MergeLattice::removePointsWithIdenticalIterators(const std::vector& points) +{ + vector result; + set> producerIteratorSets; + for (auto& point : points) { + set iteratorSet(point.iterators().begin(), + point.iterators().end()); + if (util::contains(producerIteratorSets, iteratorSet)) { + continue; + } + result.push_back(point); + producerIteratorSets.insert(iteratorSet); + } + return result; +} + +bool MergeLattice::needExplicitZeroChecks(const MergeLattice &lattice) { + if(util::any(lattice.points(), [](const MergePoint& mp) {return mp.isOmitter();})) { + return true; + } + return !lattice.getTensorRegionsToKeep().empty(); } MergeLattice MergeLattice::subLattice(MergePoint lp) const { @@ -892,12 +974,20 @@ const vector& MergeLattice::iterators() const { return points()[0].iterators(); } +const vector& MergeLattice::locators() const { + // The iterators merged by a lattice are those merged by the first point + taco_iassert(points().size() > 0) << "No merge points in the merge lattice"; + return points()[0].locators(); +} + set MergeLattice::exhausted(MergePoint point) { - set notExhaustedIters(point.iterators().begin(), - point.iterators().end()); + set notExhaustedIters = point.tensorRegion(); set exhausted; - for (auto& iterator : iterators()) { + vector modeIterators = combine(iterators(), locators()); + modeIterators = filter(modeIterators, [](const Iterator& it) {return it.isModeIterator();}); + + for (auto& iterator : modeIterators) { if (!util::contains(notExhaustedIters, iterator)) { exhausted.insert(iterator); } @@ -940,66 +1030,47 @@ bool MergeLattice::exact() const { return true; } -std::vector MergeLattice::retrieveIteratorsToOmit(const MergePoint &point) const { - +std::vector MergeLattice::retrieveRegionIteratorsToOmit(const MergePoint &point) const { vector omittedIterators; - const size_t levelOfParent = point.iterators().size() + 1; - vector pointIterators = point.iterators(); - sort(pointIterators.begin(), pointIterators.end()); + set pointRegion = point.tensorRegion(); + set seen; + const size_t levelOfPoint = pointRegion.size(); + + if(point.isOmitter()) { + seen = set(pointRegion.begin(), pointRegion.end()); + omittedIterators = vector(seen.begin(), seen.end()); + } + // Look at all points above for(const auto& mp: points()) { - if(mp.iterators().size() == levelOfParent && mp.isOmitter()) { - // We are one level above this point - vector parentIterators = mp.iterators(); - sort(parentIterators.begin(), parentIterators.end()); - set_difference(parentIterators.begin(), parentIterators.end(), - pointIterators.begin(), pointIterators.end(), - back_inserter(omittedIterators)); + if((mp.tensorRegion().size() > levelOfPoint) && mp.isOmitter()) { + // Grab the omitted tensors + set parentRegion = mp.tensorRegion(); + std::vector parentItersToOmit; + set_difference(parentRegion.begin(), parentRegion.end(), + pointRegion.begin(), pointRegion.end(), + back_inserter(parentItersToOmit)); + + // Add iterators not in present point to the iterators to omit + for(const auto& it : parentItersToOmit) { + if(!util::contains(seen, it)) { + seen.insert(it); + omittedIterators.push_back(it); + } + } } } return omittedIterators; } -MergeLattice -MergeLattice::removePointsThatLackFullIterators(const MergeLattice& l) -{ - vector result; - vector fullIterators = filter(l.points()[0].iterators(), - [](Iterator it){return it.isFull();}); - for (auto& point : l.points()) { - bool missingFullIterator = false; - for (auto& fullIterator : fullIterators) { - if (!util::contains(point.iterators(), fullIterator)) { - missingFullIterator = true; - break; - } - } - if (!missingFullIterator) { - result.push_back(point); - } - } - return MergeLattice(result); +set> MergeLattice::getTensorRegionsToKeep() const { + return regionsToKeep; } -MergeLattice -MergeLattice::removeProducersWithIdenticalIterators(const MergeLattice& l) -{ - vector result; - set> producerIteratorSets; - for (auto& point : l.points()) { - set iteratorSet(point.iterators().begin(), - point.iterators().end()); - if (!point.isOmitter() && util::contains(producerIteratorSets, iteratorSet)) { - continue; - } - result.push_back(point); - - if (!point.isOmitter()) { - producerIteratorSets.insert(iteratorSet); - } - } - return MergeLattice(result); +MergeLattice MergeLattice::getLoopLattice() const { + std::vector p = removePointsThatLackFullIterators(points()); + return removePointsWithIdenticalIterators(p); } ostream& operator<<(ostream& os, const MergeLattice& ml) { diff --git a/src/storage/pack.cpp b/src/storage/pack.cpp index 95705d13f..852b230a0 100644 --- a/src/storage/pack.cpp +++ b/src/storage/pack.cpp @@ -5,6 +5,7 @@ #include "taco/format.h" #include "taco/error.h" #include "taco/ir/ir.h" +#include "taco/index_notation/index_notation.h" #include "taco/storage/storage.h" #include "taco/storage/index.h" #include "taco/storage/array.h" @@ -19,12 +20,14 @@ namespace taco { if (cbegin < cend) { \ memcpy(&values[valuesIndex], &vals[cbegin*dataType.getNumBytes()], dataType.getNumBytes()); \ } \ - else { \ + else if (fill == nullptr){ \ memset(&values[valuesIndex], 0, dataType.getNumBytes()); \ + } else { \ + memcpy(&values[valuesIndex], fill, dataType.getNumBytes()); \ } \ valuesIndex += dataType.getNumBytes(); \ } else { \ - valuesIndex = packTensor(dimensions, coords, vals, cbegin, (cend), modeTypes, i+1, \ + valuesIndex = packTensor(dimensions, coords, vals, fill, cbegin, (cend), modeTypes, i+1, \ indices, values, dataType, valuesIndex); \ } \ } @@ -56,7 +59,7 @@ static TypedIndexVector getUniqueEntries(TypedIndexVector v, /// [0,2] index arrays. static int packTensor(const vector& dimensions, const vector& coords, - char* vals, + char* vals, const void* fill, size_t begin, size_t end, const vector& modeTypes, size_t i, std::vector>* indices, @@ -119,16 +122,19 @@ TensorStorage pack(Datatype componentType, const std::vector& dimensions, const Format& format, const std::vector& coordinates, - const void * values) { + const void * values, + const Literal& fill) { + taco_iassert(dimensions.size() == (size_t)format.getOrder()); taco_iassert(coordinates.size() == (size_t)format.getOrder()); taco_iassert(sameSize(coordinates)); taco_iassert(dimensions.size() > 0) << "Scalar packing not supported"; + taco_iassert(fill.getDataType() == componentType) << "Component type must match value type"; size_t order = dimensions.size(); size_t numCoordinates = coordinates[0].size(); - TensorStorage storage(componentType, dimensions, format); + TensorStorage storage(componentType, dimensions, format, fill); // Create vectors to store pointers to indices/index sizes vector> indices; @@ -155,7 +161,8 @@ TensorStorage pack(Datatype componentType, } void* vals = malloc(maxSize * componentType.getNumBytes()); - int actual_size = packTensor(dimensions, coordinates, (char *) values, 0, + const void* fillData = storage.getFill().getData(); + int actual_size = packTensor(dimensions, coordinates, (char *) values, fillData, 0, numCoordinates, format.getModeFormats(), 0, &indices, (char *)vals, componentType, 0); vals = realloc(vals, actual_size); diff --git a/src/storage/storage.cpp b/src/storage/storage.cpp index 80af9f0a0..32fcac31c 100644 --- a/src/storage/storage.cpp +++ b/src/storage/storage.cpp @@ -10,6 +10,7 @@ #include "taco/storage/index.h" #include "taco/storage/array.h" #include "taco/util/strings.h" +#include "taco/index_notation/index_notation.h" using namespace std; @@ -26,11 +27,15 @@ struct TensorStorage::Content { Index index; Array values; - Content(Datatype componentType, vector dimensions, Format format) + // Always an array of size 1 + Array fillValue; + + Content(Datatype componentType, vector dimensions, Format format, Literal fill) : componentType(componentType), dimensions(dimensions), format(format), index(format) { int order = (int)dimensions.size(); + taco_iassert(fill.getDataType() == componentType) << "Fill value must be of same type as data array"; taco_iassert(order <= INT_MAX && componentType.getNumBits() <= INT_MAX); vector dimensionsInt32(order); vector modeOrdering(order); @@ -50,9 +55,14 @@ struct TensorStorage::Content { } } + void* fillData = malloc(componentType.getNumBytes()); + memcpy(fillData, fill.getValPtr(), componentType.getNumBytes()); + + fillValue = Array(componentType, fillData, 1, Array::Policy::Free); + tensorData = init_taco_tensor_t(order, componentType.getNumBits(), - dimensionsInt32.data(), modeOrdering.data(), - modeTypes.data()); + dimensionsInt32.data(), modeOrdering.data(), + modeTypes.data()); } ~Content() { @@ -60,9 +70,9 @@ struct TensorStorage::Content { } }; -TensorStorage::TensorStorage(Datatype componentType, - const vector& dimensions, Format format) - : content(new Content(componentType, dimensions, format)) { +TensorStorage::TensorStorage(Datatype componentType, const vector& dimensions, + Format format, Literal fillVal) + : content(new Content(componentType, dimensions, format, fillVal)) { } const Format& TensorStorage::getFormat() const { @@ -93,6 +103,11 @@ const Array& TensorStorage::getValues() const { return content->values; } +const Array& TensorStorage::getFill() const { + taco_iassert(content->fillValue.getSize() == 1); + return content->fillValue; +} + Array TensorStorage::getValues() { return content->values; } @@ -159,6 +174,7 @@ TensorStorage::operator struct taco_tensor_t*() const { } tensorData->vals = (uint8_t*)getValues().getData(); + tensorData->fill_value = (uint8_t*)getFill().getData(); return content->tensorData; } @@ -171,6 +187,11 @@ void TensorStorage::setValues(const Array& values) { content->values = values; } +void TensorStorage::setFill(const Array &fill) { + taco_iassert(fill.getSize() == 1); + content->fillValue = fill; +} + bool equals(TensorStorage a, TensorStorage b) { return false; } diff --git a/src/tensor.cpp b/src/tensor.cpp index 5efba93cc..e0e0f6398 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -59,10 +59,10 @@ struct TensorBase::Content { shared_ptr module; Content(string name, Datatype dataType, const vector& dimensions, - Format format) + Format format, Literal fill) : dataType(dataType), dimensions(dimensions), - storage(TensorStorage(dataType, dimensions, format)), - tensorVar(TensorVar(name, Type(dataType,convert(dimensions)),format)) {} + storage(TensorStorage(dataType, dimensions, format, fill)), + tensorVar(TensorVar(name, Type(dataType,convert(dimensions)),format, fill)) {} }; TensorBase::TensorBase() : TensorBase(Float()) { @@ -73,23 +73,31 @@ TensorBase::TensorBase(Datatype ctype) } TensorBase::TensorBase(std::string name, Datatype ctype) - : TensorBase(name, ctype, {}, Format()) { + : TensorBase(name, ctype, {}, Format(), Literal::zero(ctype)) { } TensorBase::TensorBase(Datatype ctype, vector dimensions, - ModeFormat modeType) + ModeFormat modeType, Literal fill) : TensorBase(util::uniqueName('A'), ctype, dimensions, - std::vector(dimensions.size(), modeType)) { + std::vector(dimensions.size(), modeType), fill) { } -TensorBase::TensorBase(Datatype ctype, vector dimensions, Format format) - : TensorBase(util::uniqueName('A'), ctype, dimensions, format) { +TensorBase::TensorBase(Datatype ctype, vector dimensions, Format format, Literal fill) + : TensorBase(util::uniqueName('A'), ctype, dimensions, format, fill) { } TensorBase::TensorBase(std::string name, Datatype ctype, - std::vector dimensions, ModeFormat modeType) + std::vector dimensions, ModeFormat modeType, Literal fill) : TensorBase(name, ctype, dimensions, - std::vector(dimensions.size(), modeType)) { + std::vector(dimensions.size(), modeType), fill) { +} + +TensorBase::TensorBase(Datatype ctype, std::vector dimensions, Literal fill) + : TensorBase(ctype, dimensions, ModeFormat::compressed, fill) { +} + +TensorBase::TensorBase(std::string name, Datatype ctype, std::vector dimensions, Literal fill) + : TensorBase(name, ctype, dimensions, ModeFormat::compressed, fill) { } static Format initFormat(Format format) { @@ -118,12 +126,19 @@ static Format initFormat(Format format) { } TensorBase::TensorBase(string name, Datatype ctype, vector dimensions, - Format format) - : content(new Content(name, ctype, dimensions, initFormat(format))) { + Format format, Literal fill) { + + // Default fill to zero since undefined. This is done since we need the ctype to initialize the + // fill and we can't use this inside the default arguments. + fill = fill.defined()? fill : Literal::zero(ctype); + content = shared_ptr(new Content(name, ctype, dimensions, initFormat(format), fill)); + taco_uassert((size_t)format.getOrder() == dimensions.size()) << "The number of format mode types (" << format.getOrder() << ") " << "must match the tensor order (" << dimensions.size() << ")."; + taco_uassert(ctype == fill.getDataType()) << "Fill value must be of the same type as the tensor."; + content->allocSize = 1 << 20; vector modeIndices(format.getOrder()); @@ -269,6 +284,7 @@ void TensorBase::pack() { bufferStorage->indices[0][0] = (uint8_t*)pos.data(); bufferStorage->indices[0][1] = (uint8_t*)bufferCoords.data(); bufferStorage->vals = (uint8_t*)this->coordinateBuffer->data(); + bufferStorage->fill_value = (uint8_t*)(getStorage().getFill().getData()); std::vector arguments = {content->storage, bufferStorage}; helperFuncs->callFuncPacked("pack", arguments.data()); @@ -339,6 +355,7 @@ void TensorBase::pack() { bufferStorage->indices[i][1] = (uint8_t*)coordinates[i].data(); } bufferStorage->vals = (uint8_t*)values; + bufferStorage->fill_value = (uint8_t*)(getStorage().getFill().getData()); // Pack nonzero components into required format std::vector arguments = {content->storage, bufferStorage}; @@ -412,6 +429,10 @@ taco_tensor_t* TensorBase::getTacoTensorT() { return getStorage(); } +Literal TensorBase::getFillValue() const { + return content->tensorVar.getFill(); +} + static inline map getTensors(const IndexExpr& expr) { struct GetOperands : public IndexNotationVisitor { using IndexNotationVisitor::visit; @@ -705,6 +726,11 @@ bool equals(const TensorBase& a, const TensorBase& b) { return false; } + // Fill values must be the same + if (!equals(a.getFillValue(), b.getFillValue())) { + return false; + } + // Orders must be the same if (a.getOrder() != b.getOrder()) { return false; diff --git a/test/tests-api.cpp b/test/tests-api.cpp index da4f7082d..d1afee89c 100644 --- a/test/tests-api.cpp +++ b/test/tests-api.cpp @@ -121,6 +121,8 @@ TEST_P(apiget, api) { ASSERT_ARRAY_EQ(GetParam().getExpectedValues(), {(double*)storage.getValues().getData(), storage.getIndex().getSize()}); + + ASSERT_TRUE(equals(tensor.getFillValue(), Literal::zero(tensor.getComponentType()))); } TEST_P(apiwrb, api) { diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index b9e16bfda..b9d54509a 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -42,6 +42,8 @@ static TensorVar b("b", vectype, Format()); static TensorVar c("c", vectype, Format()); static TensorVar d("d", vectype, Format()); +static TensorVar fill_10("fillA", vectype, Format(), Literal((double) 10)); + static TensorVar w("w", vectype, dense); static TensorVar A("A", mattype, Format()); @@ -83,7 +85,7 @@ struct TestCase { TensorStorage getResult(TensorVar var, Format format) const { auto dimensions = getDimensions(var); - TensorStorage storage(type(), dimensions, format); + TensorStorage storage(type(), dimensions, format, var.getFill()); // TODO: Get rid of this and lower to use dimensions instead vector modeIndices(format.getOrder()); @@ -100,11 +102,12 @@ struct TestCase { static TensorStorage pack(Format format, const vector& dims, - const vector,double>>& components){ + const vector,double>>& components, + Literal fill){ size_t order = dims.size(); size_t num = components.size(); if (order == 0) { - TensorStorage storage = TensorStorage(type(), {}, format); + TensorStorage storage = TensorStorage(type(), {}, format, fill); Array array = makeArray(type(), 1); *((double*)array.getData()) = components[0].second; storage.setValues(array); @@ -123,18 +126,18 @@ struct TestCase { } values[i] = components[i].second; } - return taco::pack(type(), dims, format, coords, values.data()); + return taco::pack(type(), dims, format, coords, values.data(), fill); } } TensorStorage getArgument(TensorVar var, Format format) const { taco_iassert(contains(inputs, var)) << var; - return pack(format, getDimensions(var), inputs.at(var)); + return pack(format, getDimensions(var), inputs.at(var), var.getFill()); } TensorStorage getExpected(TensorVar var, Format format) const { taco_iassert(contains(expected, var)) << var; - return pack(format, getDimensions(var), expected.at(var)); + return pack(format, getDimensions(var), expected.at(var), var.getFill()); } }; @@ -202,7 +205,7 @@ map formatVars(const std::vector& vars, // Default format is dense in all dimensions format = Format(vector(var.getOrder(), dense)); } - formatted.insert({var, TensorVar(var.getName(), var.getType(), format)}); + formatted.insert({var, TensorVar(var.getName(), var.getType(), format, var.getFill())}); } return formatted; } @@ -1559,6 +1562,23 @@ TEST_STMT(vector_not, } ) +//TEST_STMT(vector_add_fill, +// forall(i, +// fill_10(i) = b(i) * c(i) +// ), +// Values( +// Formats({{fill_10, dense}, {b,sparse}, {c, sparse}}), +// Formats({{fill_10, sparse}, {b,sparse}, {c, sparse}}) +// ), +// { +// TestCase( +// {{b, {{{0}, 2.0}}}, +// {c, {{{0}, 3.0}, {{1}, 6.0}}}}, +// +// {{fill_10, {{{0}, 5.0}, {{1}, 6.0}}}}) +// } +//) + // Test tensorOps Op testOp("testOp", MulAdd(), BC_BD_CD()); @@ -1568,7 +1588,11 @@ TEST_STMT(testOp1, a(i) = testOp(b(i), c(i), d(i)) ), Values( - Formats({{a,sparse}, {b,sparse}, {c,sparse}, {d, sparse}}) + Formats({{a,sparse}, {b,sparse}, {c,sparse}, {d, sparse}}), + Formats({{a,dense}, {b,dense}, {c,dense}, {d, dense}}), + Formats({{a,dense}, {b,sparse}, {c,dense}, {d, sparse}}), + Formats({{a,dense}, {b,sparse}, {c,dense}, {d, dense}}), + Formats({{a,dense}, {b,sparse}, {c,sparse}, {d, sparse}}) ), { TestCase( @@ -1582,9 +1606,7 @@ TEST_STMT(testOp1, Op specialOp("specialOp", GeneralAdd(), BC_BD_CD(), {{{0,1}, MulRegionDef()}, {{0,2}, SubRegionDef()}}); - - -TEST_STMT(testSpecialOp, +TEST_STMT(lowerSpecialRegions1, forall(i, forall(j, A(i, j) = specialOp(B(i, j), C(i, j), D(i, j)) @@ -1592,8 +1614,8 @@ TEST_STMT(testSpecialOp, Values( Formats({{A, Format({dense,dense})}, {B, Format({dense,dense})}, {C, Format({dense,dense})}, {D, Format({dense,dense})}}), -// Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, -// {D, Format({dense,sparse})}}), + Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, + {D, Format({dense,sparse})}}), Formats({{A, Format({sparse,sparse})}, {B, Format({sparse,sparse})}, {C, Format({sparse,sparse})}, {D, Format({sparse,sparse})}}) ), @@ -1607,5 +1629,30 @@ TEST_STMT(testSpecialOp, } ) +Op compUnion("compUnion", GeneralAdd(), ComplementUnion()); +TEST_STMT(lowerCompUnion, + forall(i, + forall(j, + A(i, j) = compUnion(B(i, j), C(i, j), D(i, j)) + )), + Values( + Formats({{A, Format({dense,dense})}, {B, Format({dense,dense})}, {C, Format({dense,dense})}, + {D, Format({dense,dense})}}), + Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, + {D, Format({dense,sparse})}}), + Formats({{A, Format({sparse,sparse})}, {B, Format({sparse,sparse})}, {C, Format({sparse,sparse})}, + {D, Format({sparse,sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 2.0}, {{1, 1}, 3.0}, {{1, 2}, 2.0}, {{4, 3}, 4.0}}}, + {C, {{{0, 1}, 3.0}, {{2, 1}, 3.0}, {{2, 2}, 4.0}, {{4, 3}, 6.0}}}, + {D, {{{1, 2}, 1.0}, {{2, 1}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 5.0}}}}, + + {{A, {{{0, 1}, 5.0}, {{1, 2}, 3.0}, {{2, 1}, 7.0}, + {{2, 2}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 15.0}}}}) + } +) + }} diff --git a/test/tests-tensor.cpp b/test/tests-tensor.cpp index 3400e413b..32ad42d5f 100644 --- a/test/tests-tensor.cpp +++ b/test/tests-tensor.cpp @@ -7,6 +7,12 @@ using namespace taco; +template +void testFill(T fillVal) { + Tensor a({2,2}, fillVal); + ASSERT_TRUE(equals(a.getFillValue(), Literal((T) fillVal))); +} + TEST(tensor, double_scalar) { Tensor a(4.2); ASSERT_DOUBLE_EQ(4.2, a.begin()->second); @@ -75,6 +81,25 @@ TEST(tensor, duplicates_scalar) { ASSERT_TRUE(++val == a.end()); } +TEST(tensor, scalar_type_correct) { + Tensor a; + ASSERT_EQ(a.getComponentType(), Int32); +} + +TEST(tensor, non_zero_fill) { + testFill(3); + testFill(34762); + testFill((1 << 10)); + testFill((1 << 30)); + testFill((1ULL << 42)); + testFill((1 << 10)); + testFill(-1); + testFill((1ULL << 42)); + testFill(std::numeric_limits::min()); + testFill(std::numeric_limits::max()); + testFill(std::numeric_limits::infinity()); +} + TEST(tensor, transpose) { TensorData testData = TensorData({5, 3, 2}, { {{0,0,0}, 0.0}, From 008c385072fb02bbf4dbb5a5a7f9574f7e497591 Mon Sep 17 00:00:00 2001 From: Rawn Date: Sun, 12 Apr 2020 14:52:02 -0400 Subject: [PATCH 16/27] Moved tensorOp to its own Cpp file. Fixed bug with lowering not emitting explicit zero checks before bottommost loop. Got Masked BFS optimization working for pull bfs! --- include/taco/index_notation/tensor_operator.h | 61 ++++---------- include/taco/lower/merge_lattice.h | 9 +- src/index_notation/index_notation.cpp | 2 +- src/index_notation/tensor_operator.cpp | 82 +++++++++++++++++++ src/ir/simplify.cpp | 5 ++ src/lower/lowerer_impl.cpp | 57 +++++++------ src/lower/merge_lattice.cpp | 17 +++- test/op_factory.h | 13 +++ test/tests-lower.cpp | 18 ++++ 9 files changed, 191 insertions(+), 73 deletions(-) create mode 100644 src/index_notation/tensor_operator.cpp diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index 9b517cdaf..e70964ef2 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -21,47 +21,28 @@ using algebraImpl = TensorOpNode::algebraImpl; public: // Full construction Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}) - : name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), - properties(properties), regionDefinitions(specialDefinitions) { - } + std::map, opImpl> specialDefinitions = {}); Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}) - : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), - regionDefinitions(specialDefinitions) { - } + std::map, opImpl> specialDefinitions = {}); // Construct without specifying algebra Op(std::string name, opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}) - : Op(name, lowererFunc, nullptr, properties, specialDefinitions) { - } + std::map, opImpl> specialDefinitions = {}); Op(opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}) - : Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) { - } + std::map, opImpl> specialDefinitions = {}); // Construct without properties Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, - std::map, opImpl> specialDefinitions = {}) - : Op(name, lowererFunc, algebraFunc, {}, specialDefinitions) { - } + std::map, opImpl> specialDefinitions = {}); - Op(opImpl lowererFunc, algebraImpl algebraFunc, - std::map, opImpl> specialDefinitions = {}) - : Op(util::uniqueName("Op"), lowererFunc, algebraFunc, {}, specialDefinitions) { - } + Op(opImpl lowererFunc, algebraImpl algebraFunc, std::map, opImpl> specialDefinitions = {}); // Construct without algebra or properties - Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}) - : Op(name, lowererFunc, nullptr, specialDefinitions) { - } + Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}); - explicit Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}) - : Op(lowererFunc, nullptr, specialDefinitions) { - } + explicit Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}); template TensorOp operator()(IndexExprs&&... exprs) { @@ -82,25 +63,17 @@ using algebraImpl = TensorOpNode::algebraImpl; std::vector properties; std::map, opImpl> regionDefinitions; - IterationAlgebra inferAlgFromProperties(const std::vector& exprs) { - if(properties.empty()) { - return constructDefaultAlgebra(exprs); - } - return {}; - } - - // Constructs an algebra that iterates over the entire space - static IterationAlgebra constructDefaultAlgebra(const std::vector& exprs) { - if(exprs.empty()) return Region(); + IterationAlgebra inferAlgFromProperties(const std::vector& exprs); - IterationAlgebra tensorsRegions(exprs[0]); - for(size_t i = 1; i < exprs.size(); ++i) { - tensorsRegions = Union(tensorsRegions, exprs[i]); - } + // Constructs an algebra for iterating over the operator assuming the annihilator + // of the expression is the input to this function. + // Returns a pair where pair.first is the algebra for iteration and pair.second is + // the number of subregions iterated by the algebra. + std::pair constructAnnihilatorAlg(const std::vector& args, + Annihilator annihilator); - IterationAlgebra background = Complement(tensorsRegions); - return Union(tensorsRegions, background); - } + // Constructs an algebra that iterates over the entire space + static IterationAlgebra constructDefaultAlgebra(const std::vector& exprs); }; } diff --git a/include/taco/lower/merge_lattice.h b/include/taco/lower/merge_lattice.h index 731152ff5..49f0051bb 100644 --- a/include/taco/lower/merge_lattice.h +++ b/include/taco/lower/merge_lattice.h @@ -63,7 +63,7 @@ class MergeLattice { static std::vector removePointsThatLackFullIterators(const std::vector&); /// Returns true if we need to emit checks for explicit zeros in the lattice given. - static bool needExplicitZeroChecks(const MergeLattice& lattice); + bool needExplicitZeroChecks(); /** * Returns the sub-lattice rooted at the given merge point. @@ -101,6 +101,13 @@ class MergeLattice { */ bool exact() const; + /** + * True if any of the mode iterators in the lattice is a leaf iterator. + * + * This method checks both the iterators and locators since they both contain mode iterators. + */ + bool anyModeIteratorIsLeaf() const; + /** * Get a list of iterators that should be omitted at this merge point. */ diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 99d454ee7..7e1fa7dcf 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -2565,7 +2565,7 @@ struct Zero : public IndexNotationRewriterStrict { rewrittenArg = Literal::zero(arg.getDataType()); } - if(equals(annihilatorVal, rewrittenArg)) { + if(annihilatorVal.defined() && equals(annihilatorVal, rewrittenArg)) { expr = IndexExpr(); return; } diff --git a/src/index_notation/tensor_operator.cpp b/src/index_notation/tensor_operator.cpp new file mode 100644 index 000000000..379235963 --- /dev/null +++ b/src/index_notation/tensor_operator.cpp @@ -0,0 +1,82 @@ +#include "taco/index_notation/tensor_operator.h" + +namespace taco { + +// Full construction +Op::Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, + std::map, opImpl> specialDefinitions) + : name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), + properties(properties), regionDefinitions(specialDefinitions) { +} + +Op::Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, + std::map, opImpl> specialDefinitions) + : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), + regionDefinitions(specialDefinitions) { +} + +// Construct without specifying algebra +Op::Op(std::string name, opImpl lowererFunc, std::vector properties, + std::map, opImpl> specialDefinitions) + : Op(name, lowererFunc, nullptr, properties, specialDefinitions) { +} + +Op::Op(opImpl lowererFunc, std::vector properties, + std::map, opImpl> specialDefinitions) + : Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) { +} + +// Construct without properties +Op::Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, + std::map, opImpl> specialDefinitions) + : Op(name, lowererFunc, algebraFunc, {}, specialDefinitions) { +} + +Op::Op(opImpl lowererFunc, algebraImpl algebraFunc, + std::map, opImpl> specialDefinitions) : + Op(util::uniqueName("Op"), lowererFunc, algebraFunc, {}, specialDefinitions) { +} + +// Construct without algebra or properties +Op::Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions) + : Op(name, lowererFunc, nullptr, specialDefinitions) { +} + +Op::Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions) + : Op(lowererFunc, nullptr, specialDefinitions) { +} + +IterationAlgebra Op::inferAlgFromProperties(const std::vector& exprs) { + if(properties.empty()) { + return constructDefaultAlgebra(exprs); + } + + // Start with smallest regions first. So we first check for annihilator and positional annihilator + if(findProperty(properties).defined()) { + Literal annihilator = findProperty(properties).annihilator(); + + } + + return {}; +} + +// Constructs an algebra that iterates over the entire space +IterationAlgebra Op::constructDefaultAlgebra(const std::vector& exprs) { + if(exprs.empty()) return Region(); + + IterationAlgebra tensorsRegions(exprs[0]); + for(size_t i = 1; i < exprs.size(); ++i) { + tensorsRegions = Union(tensorsRegions, exprs[i]); + } + + IterationAlgebra background = Complement(tensorsRegions); + return Union(tensorsRegions, background); +} + +std::pair Op::constructAnnihilatorAlg(const std::vector &args, + taco::Annihilator annihilator) { + taco_iassert(args.size() > 1) << "Annihilator must be applied to operand with at least two arguments"; + +} + +} \ No newline at end of file diff --git a/src/ir/simplify.cpp b/src/ir/simplify.cpp index 06e27ef12..2614ef6c1 100644 --- a/src/ir/simplify.cpp +++ b/src/ir/simplify.cpp @@ -13,6 +13,11 @@ namespace taco { namespace ir { +template +Literal foldConstant() { + +} + struct ExpressionSimplifier : IRRewriter { using IRRewriter::visit; void visit(const Or* op) { diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 0bdf87a33..fe51c67c4 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -411,7 +411,6 @@ Stmt LowererImpl::lowerForall(Forall forall) // Emit a loop that iterates over over a single iterator (optimization) if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique()) { MergeLattice loopLattice = caseLattice.getLoopLattice(); -// taco_iassert(caseLattice.points().size() == 1) << "\n Lattice received was " << caseLattice; MergePoint point = loopLattice.points()[0]; Iterator iterator = loopLattice.iterators()[0]; @@ -1191,34 +1190,35 @@ Stmt LowererImpl::resolveCoordinate(std::vector mergers, ir::Expr coor } Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, + MergeLattice caseLattice, const std::set& reducedAccesses) { vector result; - if (hasNoForAlls(stmt) && MergeLattice::needExplicitZeroChecks(lattice)) { - // In the bottom most loop. - Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, lattice, reducedAccesses); + + if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { + // Can check value array of some tensor + Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); result.push_back(body); return Block::make(result); } // Emitting structural cases so unconditionally apply lattice optimizations. - lattice = lattice.getLoopLattice(); + MergeLattice loopLattice = caseLattice.getLoopLattice(); vector appenders; vector inserters; - tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + tie(appenders, inserters) = splitAppenderAndInserters(loopLattice.results()); - if (lattice.iterators().size() == 1) { + if (loopLattice.iterators().size() == 1) { // Just one iterator so no conditional - taco_iassert(!lattice.points()[0].isOmitter()); + taco_iassert(!loopLattice.points()[0].isOmitter()); Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, lattice, reducedAccesses); + appenders, loopLattice, reducedAccesses); result.push_back(body); } - else if (!lattice.points().empty()) { + else if (!loopLattice.points().empty()) { vector> cases; - for (MergePoint point : lattice.points()) { + for (MergePoint point : loopLattice.points()) { if(point.isOmitter()) { continue; @@ -1226,14 +1226,14 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I // Construct case expression vector coordComparisons = compareToResolvedCoordinate(point.rangers(), coordinate, coordinateVar); - vector omittedRegionIterators = lattice.retrieveRegionIteratorsToOmit(point); + vector omittedRegionIterators = loopLattice.retrieveRegionIteratorsToOmit(point); std::vector neqComparisons = compareToResolvedCoordinate(omittedRegionIterators, coordinate, coordinateVar); append(coordComparisons, neqComparisons); coordComparisons = filter(coordComparisons, [](const Expr& e) { return e.defined(); }); // Construct case body - IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); + IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, loopLattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, inserters, appenders, MergeLattice({point}), reducedAccesses); if (coordComparisons.empty()) { @@ -1244,7 +1244,7 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I } cases.push_back({taco::ir::conjunction(coordComparisons), body}); } - result.push_back(Case::make(cases, lattice.exact())); + result.push_back(Case::make(cases, loopLattice.exact())); } return Block::make(result); @@ -1290,16 +1290,24 @@ std::vector LowererImpl::constructInnerLoopCasePreamble(ir::Expr coord for(auto it : tensorIterators) { Access itAccess = iterators.modeAccess(it).getAccess(); itAccesses.push_back(itAccess); - valueComparisons.push_back(constructCheckForAccessZero(itAccess)); + if(it.isLeaf()) { + valueComparisons.push_back(constructCheckForAccessZero(itAccess)); + } else { + valueComparisons.push_back(Expr()); + } } // Construct isNonZero cases for(size_t i = 0; i < coordComparisons.size(); ++i) { Expr nonZeroCase; - if(coordComparisons[i].defined()) { + if(coordComparisons[i].defined() && valueComparisons[i].defined()) { nonZeroCase = conjunction({coordComparisons[i], valueComparisons[i]}); - } else { + } else if (valueComparisons[i].defined()) { nonZeroCase = valueComparisons[i]; + } else if (coordComparisons[i].defined()) { + nonZeroCase = coordComparisons[i]; + } else { + continue; } Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); Stmt declaration = VarDecl::make(caseName, nonZeroCase); @@ -1351,7 +1359,6 @@ vector LowererImpl::lowerCasesFromMap(map iteratorToCondit } // Construct case body - // TODO - Rawn "Exhaust" locators IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, inserters, appenders, MergeLattice({point}), reducedAccesses); @@ -1437,16 +1444,20 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, // Code of loop body statement Stmt body; - if (hasNoForAlls(stmt) && caseLattice.points().size() > 1) { + if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.points().size() > 1) { std::vector stmts; // Need to emit checks based on case lattice vector modeIterators = getModeIterators(combine(caseLattice.iterators(), caseLattice.locators())); std::map caseMap; for(auto it : modeIterators) { - Access itAccess = iterators.modeAccess(it).getAccess(); - Expr accessCase = constructCheckForAccessZero(itAccess); - caseMap.insert({it, accessCase}); + if(it.isLeaf()) { + // Only emit explicit 0 checks for leaf iterators since these are the only iterators can can access tensor + // values array + Access itAccess = iterators.modeAccess(it).getAccess(); + Expr accessCase = constructCheckForAccessZero(itAccess); + caseMap.insert({it, accessCase}); + } } std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); append(stmts, loweredCases); diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index a238ea8a6..e915dc164 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -289,6 +289,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } regionsToKeep.insert(regionToKeep); } + lattice = MergeLattice(lattice.points(), regionsToKeep); } @@ -894,7 +895,7 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr MergeLattice lattice = builder.build(forall.getStmt()); // Can't remove points if lattice contains omitters since we lose merge cases during lowering. - if(hasNoForAlls(forall.getStmt()) && needExplicitZeroChecks(lattice)) { + if(lattice.anyModeIteratorIsLeaf() && lattice.needExplicitZeroChecks()) { return lattice; } @@ -940,11 +941,11 @@ MergeLattice::removePointsWithIdenticalIterators(const std::vector& return result; } -bool MergeLattice::needExplicitZeroChecks(const MergeLattice &lattice) { - if(util::any(lattice.points(), [](const MergePoint& mp) {return mp.isOmitter();})) { +bool MergeLattice::needExplicitZeroChecks() { + if(util::any(points(), [](const MergePoint& mp) {return mp.isOmitter();})) { return true; } - return !lattice.getTensorRegionsToKeep().empty(); + return !getTensorRegionsToKeep().empty(); } MergeLattice MergeLattice::subLattice(MergePoint lp) const { @@ -1030,6 +1031,14 @@ bool MergeLattice::exact() const { return true; } +bool MergeLattice::anyModeIteratorIsLeaf() const { + if(points().empty()) { + return false; + } + vector latticeIters = util::combine(iterators(), locators()); + return util::any(latticeIters, [](const Iterator& it) {return it.isModeIterator() && it.isLeaf();}); +} + std::vector MergeLattice::retrieveRegionIteratorsToOmit(const MergePoint &point) const { vector omittedIterators; set pointRegion = point.tensorRegion(); diff --git a/test/op_factory.h b/test/op_factory.h index 2a46cc9e5..95471673d 100644 --- a/test/op_factory.h +++ b/test/op_factory.h @@ -101,6 +101,13 @@ struct unionEdge { } }; +struct BfsMaskAlg { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Intersect(Intersect(r[0], r[1]), Complement(r[2])); + } +}; + // Lowerers struct MulAdd { ir::Expr operator()(const std::vector &v) { @@ -127,6 +134,12 @@ struct GeneralAdd { } }; +struct BfsLower { + ir::Expr operator()(const std::vector &v) { + return ir::Mul::make(v[0], v[1]); + } +}; + // Special definitions struct MulRegionDef { ir::Expr operator()(const std::vector &v) { diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index b9d54509a..4d9438895 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -1654,5 +1654,23 @@ TEST_STMT(lowerCompUnion, } ) +Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); +TEST_STMT(bfsMask, + forall(i, + forall(j, + a(i) += bfsMaskOp(B(i, j), c(j), c(i)) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse})}, {c, Format({dense})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}}}}) + } +) + }} From 68d3431efd88a7bef726ffb3047499e63eaa08f1 Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 8 May 2020 04:52:19 -0400 Subject: [PATCH 17/27] Allowed code for any reductions. Fixed bugs. Added a trivial fill value inferer that only uses properties. --- include/taco/index_notation/index_notation.h | 11 +- .../index_notation/index_notation_nodes.h | 5 + include/taco/index_notation/properties.h | 18 +- .../taco/index_notation/property_pointers.h | 8 + include/taco/index_notation/tensor_operator.h | 6 +- include/taco/ir/ir.h | 20 +- include/taco/ir/ir_printer.h | 1 + include/taco/ir/ir_rewriter.h | 1 + include/taco/ir/ir_visitor.h | 3 + include/taco/lower/lowerer_impl.h | 4 + src/index_notation/index_notation.cpp | 205 ++++++++++++++++-- src/index_notation/index_notation_nodes.cpp | 2 +- src/index_notation/index_notation_printer.cpp | 8 + src/index_notation/iteration_algebra.cpp | 2 +- src/index_notation/properties.cpp | 76 +++++++ src/index_notation/property_pointers.cpp | 21 ++ src/index_notation/tensor_operator.cpp | 73 ++++++- src/ir/ir.cpp | 9 + src/ir/ir_generators.cpp | 2 +- src/ir/ir_generators.h | 2 +- src/ir/ir_printer.cpp | 7 +- src/ir/ir_rewriter.cpp | 4 + src/ir/ir_visitor.cpp | 3 + src/ir/simplify.cpp | 5 - src/lower/lowerer_impl.cpp | 62 ++++-- src/lower/merge_lattice.cpp | 1 + test/op_factory.h | 29 ++- test/tests-index_notation.cpp | 14 ++ test/tests-lower.cpp | 72 +++++- 29 files changed, 606 insertions(+), 68 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 8c62c4b2f..4f4251db1 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -978,7 +978,12 @@ std::vector getResults(IndexStmt stmt); /// Returns the input tensors to the index statement, in the order they appear. std::vector getArguments(IndexStmt stmt); -/// Returns the temporaries in the index statement, in the order they appear. +/// Returns true iff all of the loops over free variables come before all of the loops over +/// reduction variables. Therefore, this returns true if the reduction controlled by the loops +/// does not a scatter. +bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt); + + /// Returns the temporaries in the index statement, in the order they appear. std::vector getTemporaries(IndexStmt stmt); /// Returns the tensors in the index statement. @@ -1013,6 +1018,10 @@ IndexExpr zero(IndexExpr, const std::set& zeroed); /// zero and then propagating and removing zeroes. IndexStmt zero(IndexStmt, const std::set& zeroed); +/// Infers the fill value of the input expression by applying properties if possible. If unable +/// to successfully infer the fill value of the result, returns the empty IndexExpr +IndexExpr inferFill(IndexExpr); + /// Returns true if there are no forall nodes in the indexStmt. Used to check /// if the last loop is being lowered. bool hasNoForAlls(IndexStmt); diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 7e51364b0..906292a21 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -208,6 +208,11 @@ struct TensorOpNode : public IndexExprNode { static Datatype inferReturnType(opImpl f, const std::vector& inputs) { std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; std::vector exprs = util::map(inputs, getExprs); + + if(exprs.empty()) { + return taco::Datatype(); + } + return f(exprs).type(); } diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h index d8c8348f7..1504dd8d1 100644 --- a/include/taco/index_notation/properties.h +++ b/include/taco/index_notation/properties.h @@ -6,11 +6,13 @@ namespace taco { +class IndexExpr; + /// A class containing properties about an operation class Property : public util::IntrusivePtr { public: Property(); - Property(const PropertyPtr* p); + explicit Property(const PropertyPtr* p); bool equals(const Property& p) const; std::ostream& print(std::ostream&) const; @@ -22,9 +24,12 @@ std::ostream& operator<<(std::ostream&, const Property&); class Annihilator : public Property { public: explicit Annihilator(Literal); - Annihilator(const PropertyPtr*); + Annihilator(Literal, std::vector&); + explicit Annihilator(const PropertyPtr*); const Literal& annihilator() const; + const std::vector& positions() const; + IndexExpr annihilates(const std::vector&) const; typedef AnnihilatorPtr Ptr; }; @@ -33,9 +38,12 @@ class Annihilator : public Property { class Identity : public Property { public: explicit Identity(Literal); - Identity(const PropertyPtr*); + Identity(Literal, std::vector&); + explicit Identity(const PropertyPtr*); const Literal& identity() const; + const std::vector& positions() const; + IndexExpr simplify(const std::vector&) const; typedef IdentityPtr Ptr; }; @@ -44,7 +52,7 @@ class Identity : public Property { class Associative : public Property { public: Associative(); - Associative(const PropertyPtr*); + explicit Associative(const PropertyPtr*); typedef AssociativePtr Ptr; }; @@ -54,7 +62,7 @@ class Commutative : public Property { public: Commutative(); explicit Commutative(const std::vector&); - Commutative(const PropertyPtr*); + explicit Commutative(const PropertyPtr*); const std::vector& ordering() const; diff --git a/include/taco/index_notation/property_pointers.h b/include/taco/index_notation/property_pointers.h index 250730fc8..81bee0492 100644 --- a/include/taco/index_notation/property_pointers.h +++ b/include/taco/index_notation/property_pointers.h @@ -31,7 +31,11 @@ struct PropertyPtr : public util::Manageable, struct AnnihilatorPtr : public PropertyPtr { AnnihilatorPtr(); AnnihilatorPtr(Literal); + AnnihilatorPtr(Literal, std::vector&); + const Literal& annihilator() const; + const std::vector& positions() const; + virtual std::ostream& print(std::ostream& os) const; virtual bool equals(const PropertyPtr* p) const; @@ -44,7 +48,11 @@ struct IdentityPtr : public PropertyPtr { public: IdentityPtr(); IdentityPtr(Literal); + IdentityPtr(Literal, std::vector&); + const Literal& identity() const; + const std::vector& positions() const; + virtual std::ostream& print(std::ostream& os) const; virtual bool equals(const PropertyPtr* p) const; diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index e70964ef2..d3aabe87e 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -69,8 +69,10 @@ using algebraImpl = TensorOpNode::algebraImpl; // of the expression is the input to this function. // Returns a pair where pair.first is the algebra for iteration and pair.second is // the number of subregions iterated by the algebra. - std::pair constructAnnihilatorAlg(const std::vector& args, - Annihilator annihilator); + IterationAlgebra constructAnnihilatorAlg(const std::vector& args, Annihilator annihilator); + + IterationAlgebra constructIdentityAlg(const std::vector& args, Identity identity); + // Constructs an algebra that iterates over the entire space static IterationAlgebra constructDefaultAlgebra(const std::vector& exprs); diff --git a/include/taco/ir/ir.h b/include/taco/ir/ir.h index 12b6937cc..31f268f88 100644 --- a/include/taco/ir/ir.h +++ b/include/taco/ir/ir.h @@ -65,7 +65,8 @@ enum class IRNodeType { BlankLine, Print, GetProperty, - Break + Break, + Continue }; enum class TensorProperty { @@ -234,6 +235,16 @@ struct Literal : public ExprNode { return *static_cast(value.get()); } + Expr promote(Datatype dt) const { + taco_iassert(max_type(dt, type) == dt); + if(type == dt) { + return Literal::make(getTypedVal(), dt); + } + + + + } + TypedComponentVal getTypedVal() const { return *value; } @@ -725,6 +736,13 @@ struct Break : public StmtNode { static const IRNodeType _type_info = IRNodeType::Break; }; +/** Continues to the next iteration of the current loop */ +struct Continue : public StmtNode { + static Stmt make(); + + static const IRNodeType _type_info = IRNodeType::Continue; +}; + /** A print statement. * Takes in a printf-style format string and Exprs to pass * for the values. diff --git a/include/taco/ir/ir_printer.h b/include/taco/ir/ir_printer.h index 759d21ad3..46587c039 100644 --- a/include/taco/ir/ir_printer.h +++ b/include/taco/ir/ir_printer.h @@ -66,6 +66,7 @@ class IRPrinter : public IRVisitorStrict { virtual void visit(const Comment*); virtual void visit(const BlankLine*); virtual void visit(const Break*); + virtual void visit(const Continue*); virtual void visit(const Print*); virtual void visit(const GetProperty*); diff --git a/include/taco/ir/ir_rewriter.h b/include/taco/ir/ir_rewriter.h index efb9eaf89..1c6156cc4 100644 --- a/include/taco/ir/ir_rewriter.h +++ b/include/taco/ir/ir_rewriter.h @@ -66,6 +66,7 @@ class IRRewriter : public IRVisitorStrict { virtual void visit(const Comment* op); virtual void visit(const BlankLine* op); virtual void visit(const Break* op); + virtual void visit(const Continue* op); virtual void visit(const Print* op); virtual void visit(const GetProperty* op); }; diff --git a/include/taco/ir/ir_visitor.h b/include/taco/ir/ir_visitor.h index f6331035b..693331825 100644 --- a/include/taco/ir/ir_visitor.h +++ b/include/taco/ir/ir_visitor.h @@ -46,6 +46,7 @@ struct Free; struct Comment; struct BlankLine; struct Break; +struct Continue; struct Print; struct GetProperty; @@ -96,6 +97,7 @@ class IRVisitorStrict { virtual void visit(const Comment*) = 0; virtual void visit(const BlankLine*) = 0; virtual void visit(const Break*) = 0; + virtual void visit(const Continue*) = 0; virtual void visit(const Print*) = 0; virtual void visit(const GetProperty*) = 0; }; @@ -149,6 +151,7 @@ class IRVisitor : public IRVisitorStrict { virtual void visit(const Comment* op); virtual void visit(const BlankLine* op); virtual void visit(const Break* op); + virtual void visit(const Continue* op); virtual void visit(const Print* op); virtual void visit(const GetProperty* op); }; diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index 47768f4cf..e78e9f8bf 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -410,9 +410,13 @@ class LowererImpl : public util::Uncopyable { /// For now, we only check mode iterators. std::vector getModeIterators(const std::vector&); + /// Emit early exit + ir::Stmt emitEarlyExit(ir::Expr reductionExpr, std::vector&); + private: bool assemble; bool compute; + bool loopOrderAllowsShortCircuit = false; int markAssignsAtomicDepth = 0; ParallelUnit atomicParallelUnit; diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 7e1fa7dcf..c97c36c6a 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -2138,14 +2138,16 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) { // that's not a reduction vector topLevelReductions; IndexExpr rhs = node->rhs; + IndexExpr reductionOp; while (isa(rhs)) { Reduction reduction = to(rhs); topLevelReductions.push_back(reduction.getVar()); rhs = reduction.getExpr(); + reductionOp = reduction.getOp(); } if (rhs != node->rhs) { - stmt = Assignment(node->lhs, rhs, Add()); + stmt = Assignment(node->lhs, rhs, reductionOp); for (auto& i : util::reverse(topLevelReductions)) { stmt = forall(i, stmt); } @@ -2236,6 +2238,45 @@ vector getArguments(IndexStmt stmt) { return result; } +bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt) { + + struct LoopOrderGetter : IndexNotationVisitor { + + std::vector loopOrder; + std::set freeVars; + + using IndexNotationVisitor::visit; + void visit(const AssignmentNode *op) { + for(const auto& var : op->lhs.getIndexVars()) { + freeVars.insert(var); + } + IndexNotationVisitor::visit(op); + } + + void visit(const ForallNode *op) { + loopOrder.push_back(op->indexVar); + IndexNotationVisitor::visit(op); + } + }; + + + LoopOrderGetter getter; + getter.visit(stmt); + + bool seenReductionVar = false; + for(auto& var : getter.loopOrder) { + if(util::contains(getter.freeVars, var)) { + if(seenReductionVar) { + // A reduction loop came before a loop over a free var + return false; + } + } else { + seenReductionVar = true; + } + } + return true; +} + std::vector getTemporaries(IndexStmt stmt) { vector temporaries; bool firstAssignment = true; @@ -2550,7 +2591,6 @@ struct Zero : public IndexNotationRewriterStrict { bool rewritten = false; Annihilator annihilator = findProperty(op->properties); - Literal annihilatorVal = annihilator.defined()? annihilator.annihilator(): Literal(); // TODO: Check exhausted default against result default for(int argIdx = 0; argIdx < (int) op->args.size(); ++argIdx) { @@ -2565,34 +2605,27 @@ struct Zero : public IndexNotationRewriterStrict { rewrittenArg = Literal::zero(arg.getDataType()); } - if(annihilatorVal.defined() && equals(annihilatorVal, rewrittenArg)) { - expr = IndexExpr(); - return; - } - args.push_back(rewrittenArg); if (arg != rewrittenArg) { rewritten = true; } } - Identity identity = findProperty(op->properties); - Literal identityVal = identity.defined()? identity.identity(): Literal(); - - // If only one term is not the identity, replace expr with just that term - size_t nonIdentityTerms = 0; - IndexExpr nonIdentityTerm; - for(const auto& arg : args) { - if(!equals(identityVal, arg)) { - nonIdentityTerm = arg; - ++nonIdentityTerms; + if(annihilator.defined()) { + IndexExpr e = annihilator.annihilates(args); + if(e.defined()) { + expr = e; + return; } - if(nonIdentityTerms > 1) break; } - if(nonIdentityTerms == 1) { - expr = nonIdentityTerm; - return; + Identity identity = findProperty(op->properties); + if(identity.defined()) { + IndexExpr e = identity.simplify(args); + if(e.defined()) { + expr = e; + return; + } } if (rewritten) { @@ -2737,6 +2770,136 @@ IndexStmt zero(IndexStmt stmt, const std::set& zeroed) { return Zero(zeroed).rewrite(stmt); } +struct fillValueInferrer : IndexExprRewriterStrict { + public: + virtual void visit(const AccessNode* op) { + expr = op->tensorVar.getFill(); + }; + + virtual void visit(const LiteralNode* op) { + expr = op; + } + + virtual void visit(const NegNode* op) { + IndexExpr a = rewrite(op->a); + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + expr = IndexExpr(); + } + + virtual void visit(const AddNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType())) && isa(b)) { + expr = b; + return; + } + + if(equals(b, Literal::zero(b.getDataType())) && isa(a)) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const SubNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(b, Literal::zero(b.getDataType())) && isa(a)) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const MulNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + + if(equals(b, Literal::zero(b.getDataType()))) { + expr = b; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const DivNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const SqrtNode* op) { + IndexExpr a = rewrite(op->a); + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + expr = IndexExpr(); + } + + virtual void visit(const CastNode* op) { + expr = IndexExpr(); + } + + virtual void visit(const TensorOpNode* op) { + Annihilator annihilator = findProperty(op->properties); + if(annihilator.defined()) { + IndexExpr e = annihilator.annihilates(op->args); + if(e.defined()) { + expr = e; + return; + } + } + + Identity identity = findProperty(op->properties); + if(identity.defined()) { + IndexExpr e = identity.simplify(op->args); + if(e.defined()) { + expr = e; + return; + } + } + + expr = IndexExpr(); + } + + virtual void visit(const CallIntrinsicNode*) { + + } + + virtual void visit(const ReductionNode*) { + expr = IndexExpr(); + } + + virtual void visit(const IndexVarNode*) { + expr = IndexExpr(); + } + }; + + +IndexExpr inferFill(IndexExpr expr) { + return fillValueInferrer().rewrite(expr); +} + bool hasNoForAlls(IndexStmt stmt) { bool noForAlls = true; diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index a28e1f62d..105d521c9 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -54,7 +54,7 @@ TensorOpNode::TensorOpNode(std::string name, const std::vector& args, // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) : IndexExprNode(a.getDataType()), op(op), var(var), a(a) { - taco_iassert(isa(op.ptr)); + taco_iassert(isa(op.ptr) || isa(op.ptr)); } IndexVarNode::IndexVarNode(const std::string& name, const Datatype& type) diff --git a/src/index_notation/index_notation_printer.cpp b/src/index_notation/index_notation_printer.cpp index 7cdd8edcc..a95278e1e 100644 --- a/src/index_notation/index_notation_printer.cpp +++ b/src/index_notation/index_notation_printer.cpp @@ -189,6 +189,10 @@ void IndexNotationPrinter::visit(const ReductionNode* op) { void visit(const BinaryExprNode* node) { reductionName = "reduction(" + node->getOperatorString() + ")"; } + + void visit(const TensorOpNode* node) { + reductionName = node->name + "Reduce"; + } }; parentPrecedence = Precedence::REDUCTION; os << ReductionName().get(op->op) << "(" << op->var << ", "; @@ -208,6 +212,10 @@ void IndexNotationPrinter::visit(const AssignmentNode* op) { void visit(const BinaryExprNode* node) { operatorName = node->getOperatorString(); } + + void visit(const TensorOpNode* node) { + operatorName = node->name; + } }; op->lhs.accept(this); diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp index 7f56bb7cc..ebab586c9 100644 --- a/src/index_notation/iteration_algebra.cpp +++ b/src/index_notation/iteration_algebra.cpp @@ -6,7 +6,7 @@ namespace taco { // Iteration Algebra Definitions -IterationAlgebra::IterationAlgebra() : IterationAlgebra(new RegionNode(nullptr)) {} +IterationAlgebra::IterationAlgebra() : IterationAlgebra(nullptr) {} IterationAlgebra::IterationAlgebra(const IterationAlgebraNode* n) : util::IntrusivePtr(n) {} IterationAlgebra::IterationAlgebra(IndexExpr expr) : IterationAlgebra(new RegionNode(expr)) {} diff --git a/src/index_notation/properties.cpp b/src/index_notation/properties.cpp index a305fbaa4..aca434443 100644 --- a/src/index_notation/properties.cpp +++ b/src/index_notation/properties.cpp @@ -47,6 +47,9 @@ template<> Annihilator to(const Property& p) { Annihilator::Annihilator(Literal annihilator) : Annihilator(new AnnihilatorPtr(annihilator)) { } +Annihilator::Annihilator(Literal annihilator, std::vector &p) : Annihilator(new AnnihilatorPtr(annihilator, p)) { +} + Annihilator::Annihilator(const PropertyPtr* p) : Property(p) { } @@ -55,6 +58,32 @@ const Literal& Annihilator::annihilator() const { return getPtr(*this)->annihilator(); } +const std::vector & Annihilator::positions() const { + taco_iassert(defined()); + return getPtr(*this)->positions(); +} + +IndexExpr Annihilator::annihilates(const std::vector& exprs) const { + taco_iassert(defined()); + Literal a = annihilator(); + std::vector pos = positions(); + if (pos.empty()) { + for(int i = 0; i < (int)exprs.size(); ++i) { + pos.push_back(i); + } + } + + for(const auto& idx : pos) { + taco_uassert(idx < (int)exprs.size()) << "Not enough args in expression"; + if(::taco::equals(exprs[idx], a)) { + return a; + } + } + + // We could not simplify. + return IndexExpr(); +} + // Identity class definitions template<> bool isa(const Property& p) { return isa(p.ptr); @@ -71,11 +100,58 @@ Identity::Identity(Literal identity) : Identity(new IdentityPtr(identity)) { Identity::Identity(const PropertyPtr* p) : Property(p) { } +Identity::Identity(Literal identity, std::vector& positions) : Identity(new IdentityPtr(identity, positions)) { +} + const Literal& Identity::identity() const { taco_iassert(defined()); return getPtr(*this)->identity(); } +const std::vector& Identity::positions() const { + taco_iassert(defined()); + return getPtr(*this)->positions(); +} + +IndexExpr Identity::simplify(const std::vector& exprs) const { + // If only one term is not the identity, replace expr with just that term. + // If all terms are identity, replace with identity. + Literal identityVal = identity(); + size_t nonIdentityTermsChecked = 0; + IndexExpr nonIdentityTerm; + + std::vector pos = positions(); + if (pos.empty()) { + for(int i = 0; i < (int)exprs.size(); ++i) { + pos.push_back(i); + } + } + + + for(const auto& idx : pos) { + if(!::taco::equals(identityVal, exprs[idx])) { + nonIdentityTerm = exprs[idx]; + ++nonIdentityTermsChecked; + } + if(nonIdentityTermsChecked > 1) { + return IndexExpr(); + } + } + + size_t identityTermsChecked = pos.size() - nonIdentityTermsChecked; + if(nonIdentityTermsChecked == 1 && identityTermsChecked == (exprs.size() - 1)) { + // If we checked all exprs and all are the identity except one return that term + return nonIdentityTerm; + } + + if(identityTermsChecked == exprs.size()) { + // If we checked every expression and + return identityVal; + } + + return IndexExpr(); +} + // Associative class definitions template<> bool isa(const Property& p) { return isa(p.ptr); diff --git a/src/index_notation/property_pointers.cpp b/src/index_notation/property_pointers.cpp index 1086f1d2d..84daa2be2 100644 --- a/src/index_notation/property_pointers.cpp +++ b/src/index_notation/property_pointers.cpp @@ -6,10 +6,12 @@ namespace taco { struct AnnihilatorPtr::Content { Literal annihilator; + std::vector positions; }; struct IdentityPtr::Content { Literal identity; + std::vector positions; }; // Property pointer definitions @@ -34,12 +36,22 @@ AnnihilatorPtr::AnnihilatorPtr() : PropertyPtr(), content(nullptr) { AnnihilatorPtr::AnnihilatorPtr(Literal annihilator) : PropertyPtr(), content(new Content) { content->annihilator = annihilator; + content->positions = std::vector(); +} + +AnnihilatorPtr::AnnihilatorPtr(Literal annihilator, std::vector& pos) : PropertyPtr(), content(new Content) { + content->annihilator = annihilator; + content->positions = pos; } const Literal& AnnihilatorPtr::annihilator() const { return content->annihilator; } +const std::vector & AnnihilatorPtr::positions() const { + return content->positions; +} + std::ostream& AnnihilatorPtr::print(std::ostream& os) const { os << "Annihilator("; if (annihilator().defined()) { @@ -65,10 +77,19 @@ IdentityPtr::IdentityPtr(Literal identity) : PropertyPtr(), content(new Content) content->identity = identity; } +IdentityPtr::IdentityPtr(Literal identity, std::vector &p) : PropertyPtr(), content(new Content) { + content->identity; + content->positions = p; +} + const Literal& IdentityPtr::identity() const { return content->identity; } +const std::vector & IdentityPtr::positions() const { + return content->positions; +} + std::ostream& IdentityPtr::print(std::ostream& os) const { os << "Identity("; if (identity().defined()) { diff --git a/src/index_notation/tensor_operator.cpp b/src/index_notation/tensor_operator.cpp index 379235963..ec9edf8ae 100644 --- a/src/index_notation/tensor_operator.cpp +++ b/src/index_notation/tensor_operator.cpp @@ -53,11 +53,24 @@ IterationAlgebra Op::inferAlgFromProperties(const std::vector& exprs) // Start with smallest regions first. So we first check for annihilator and positional annihilator if(findProperty(properties).defined()) { - Literal annihilator = findProperty(properties).annihilator(); + Annihilator annihilator = findProperty(properties); + IterationAlgebra alg = constructAnnihilatorAlg(exprs, annihilator); + if(alg.defined()) { + return alg; + } + } + + // Idempotence here ... + if(findProperty(properties).defined()) { + Identity identity = findProperty(properties); + IterationAlgebra alg = constructIdentityAlg(exprs, identity); + if(alg.defined()) { + return alg; + } } - return {}; + return constructDefaultAlgebra(exprs); } // Constructs an algebra that iterates over the entire space @@ -73,10 +86,60 @@ IterationAlgebra Op::constructDefaultAlgebra(const std::vector& exprs return Union(tensorsRegions, background); } -std::pair Op::constructAnnihilatorAlg(const std::vector &args, - taco::Annihilator annihilator) { - taco_iassert(args.size() > 1) << "Annihilator must be applied to operand with at least two arguments"; +IterationAlgebra Op::constructAnnihilatorAlg(const std::vector &args, taco::Annihilator annihilator) { + if(args.size () < 2) { + return IterationAlgebra(); + } + + Literal annVal = annihilator.annihilator(); + std::vector toIntersect; + + if(annihilator.positions().empty()) { + for(IndexExpr arg : args) { + if(equals(inferFill(arg), annVal)) { + toIntersect.push_back(arg); + } + } + } else { + for(size_t idx : annihilator.positions()) { + if(equals(inferFill(args[idx]), annVal)) { + toIntersect.push_back(args[idx]); + } + } + } + + if(toIntersect.empty()) { + return IterationAlgebra(); + } + + IterationAlgebra alg = toIntersect[0]; + for(size_t i = 1; i < toIntersect.size(); ++i) { + alg = Intersect(alg, toIntersect[i]); + } + return alg; +} + +IterationAlgebra Op::constructIdentityAlg(const std::vector &args, taco::Identity identity) { + if(args.size() < 2) { + return IterationAlgebra(); + } + + Literal idntyVal = identity.identity(); + + if(identity.positions().empty()) { + for(IndexExpr arg : args) { + if(!equals(inferFill(arg), idntyVal)) { + return IterationAlgebra(); + } + } + } + + IterationAlgebra alg(args[0]); + for(size_t i = 1; i < args.size(); ++i) { + alg = Union(alg, args[i]); + } + return alg; } } \ No newline at end of file diff --git a/src/ir/ir.cpp b/src/ir/ir.cpp index 5eb6c9af3..5feaa3d34 100644 --- a/src/ir/ir.cpp +++ b/src/ir/ir.cpp @@ -183,6 +183,8 @@ std::complex Literal::getComplexValue() const { return 0.0; } + + template bool compare(const Literal* literal, double val) { return literal->getValue() == static_cast(val); } @@ -789,6 +791,11 @@ Stmt Break::make() { return new Break; } +// Continue +Stmt Continue::make() { + return new Continue; +} + // Print Stmt Print::make(std::string fmt, std::vector params) { Print* pr = new Print; @@ -950,6 +957,8 @@ template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const BlankLine*)this); } template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const Break*)this); } +template<> void StmtNode::accept(IRVisitorStrict *v) + const { v->visit((const Continue*)this); } template<> void StmtNode::accept(IRVisitorStrict *v) const { v->visit((const Print*)this); } template<> void ExprNode::accept(IRVisitorStrict *v) diff --git a/src/ir/ir_generators.cpp b/src/ir/ir_generators.cpp index 23fe5edf6..11f3a1fee 100644 --- a/src/ir/ir_generators.cpp +++ b/src/ir/ir_generators.cpp @@ -14,7 +14,7 @@ Stmt compoundStore(Expr a, Expr i, Expr val, bool use_atomics, ParallelUnit atom return Store::make(a, i, add, use_atomics, atomic_parallel_unit); } -Stmt compoundAssign(Expr a, Expr val, bool use_atomics, ParallelUnit atomic_parallel_unit) { +Stmt addAssign(Expr a, Expr val, bool use_atomics, ParallelUnit atomic_parallel_unit) { Expr add = (val.type().getKind() == Datatype::Bool) ? Or::make(a, val) : Add::make(a, val); return Assign::make(a, add, use_atomics, atomic_parallel_unit); diff --git a/src/ir/ir_generators.h b/src/ir/ir_generators.h index 5b7015ba5..7f5af7157 100644 --- a/src/ir/ir_generators.h +++ b/src/ir/ir_generators.h @@ -14,7 +14,7 @@ class Stmt; Stmt compoundStore(Expr a, Expr i, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); /// Generate `a += val;` -Stmt compoundAssign(Expr a, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); +Stmt addAssign(Expr a, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); /// Generate `exprs_0 && ... && exprs_n` Expr conjunction(std::vector exprs); diff --git a/src/ir/ir_printer.cpp b/src/ir/ir_printer.cpp index 6338430bd..ef4eae06d 100644 --- a/src/ir/ir_printer.cpp +++ b/src/ir/ir_printer.cpp @@ -561,7 +561,12 @@ void IRPrinter::visit(const BlankLine*) { void IRPrinter::visit(const Break*) { doIndent(); - stream << "continue;" << endl; // TODO: add continue statement + stream << "break;" << endl; +} + +void IRPrinter::visit(const Continue*) { + doIndent(); + stream << "continue;" << endl; } void IRPrinter::visit(const Print* op) { diff --git a/src/ir/ir_rewriter.cpp b/src/ir/ir_rewriter.cpp index 1a1c91f23..7eaa373e1 100644 --- a/src/ir/ir_rewriter.cpp +++ b/src/ir/ir_rewriter.cpp @@ -451,6 +451,10 @@ void IRRewriter::visit(const Break* op) { stmt = op; } +void IRRewriter::visit(const Continue* op) { + stmt = op; +} + void IRRewriter::visit(const Print* op) { vector params; bool paramsSame = true; diff --git a/src/ir/ir_visitor.cpp b/src/ir/ir_visitor.cpp index 8a1baf6cf..b2b3d9ccb 100644 --- a/src/ir/ir_visitor.cpp +++ b/src/ir/ir_visitor.cpp @@ -231,6 +231,9 @@ void IRVisitor::visit(const BlankLine*) { void IRVisitor::visit(const Break*) { } +void IRVisitor::visit(const Continue*) { +} + void IRVisitor::visit(const Print* op) { for (auto e: op->params) e.accept(this); diff --git a/src/ir/simplify.cpp b/src/ir/simplify.cpp index 2614ef6c1..06e27ef12 100644 --- a/src/ir/simplify.cpp +++ b/src/ir/simplify.cpp @@ -13,11 +13,6 @@ namespace taco { namespace ir { -template -Literal foldConstant() { - -} - struct ExpressionSimplifier : IRRewriter { using IRRewriter::visit; void visit(const Or* op) { diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index fe51c67c4..f8ceb3236 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -2,6 +2,7 @@ #include "taco/lower/lowerer_impl.h" #include "taco/index_notation/index_notation.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/index_notation/index_notation_visitor.h" #include "taco/ir/ir.h" @@ -113,6 +114,7 @@ LowererImpl::lower(IndexStmt stmt, string name, bool assemble, bool compute) this->compute = compute; definedIndexVarsOrdered = {}; definedIndexVars = {}; + loopOrderAllowsShortCircuit = allForFreeLoopsBeforeAllReductionLoops(stmt); // Create result and parameter variables vector results = getResults(stmt); @@ -271,8 +273,20 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) // TODO: we don't need to mark all assigns/stores just when scattering/reducing } else { - taco_iassert(isa(assignment.getOperator())); - return compoundAssign(var, rhs, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), atomicParallelUnit); + if (isa(assignment.getOperator())) { + return addAssign(var, rhs, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + } + taco_iassert(isa(assignment.getOperator())); + + TensorOp op = to(assignment.getOperator()); + Expr assignOp = op.getFunc()({var, rhs}); + Stmt assign = Assign::make(var, assignOp, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + + std::vector properties = op.getProperties(); + assign = Block::make(assign, emitEarlyExit(var, properties)); + return assign; } } // Assignments to tensor variables (non-scalar). @@ -285,7 +299,21 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) computeStmt = Store::make(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); } else { - computeStmt = compoundStore(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); + if (isa(assignment.getOperator())) { + computeStmt = compoundStore(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); + } else { + + taco_iassert(isa(assignment.getOperator())); + + TensorOp op = to(assignment.getOperator()); + Expr assignOp = op.getFunc()({Load::make(values, loc), rhs}); + computeStmt = Store::make(values, loc, assignOp, + markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + + std::vector properties = op.getProperties(); + computeStmt = Block::make(computeStmt, emitEarlyExit(Load::make(values, loc), properties)); + } } taco_iassert(computeStmt.defined()); return computeStmt; @@ -368,7 +396,7 @@ Stmt LowererImpl::lowerForall(Forall forall) if (isa(ir::simplify(iterBounds[0])) && ir::simplify(iterBounds[0]).as()->equalsScalar(0)) { guardCondition = maxGuard; } - ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Break::make()); + ir::Stmt guard = Block::make(IfThenElse::make(minGuard, Continue::make()), IfThenElse::make(maxGuard, Break::make())); recoverySteps.push_back(guard); } @@ -378,7 +406,7 @@ Stmt LowererImpl::lowerForall(Forall forall) // place underived guard if (emitUnderivedGuards && underivedBounds.count(varToRecover) && !provGraph.hasPosDescendant(varToRecover)) { Stmt guard = IfThenElse::make(Gte::make(indexVarToExprMap[varToRecover], underivedBounds[varToRecover][1]), - Break::make()); + Continue::make()); recoverySteps.push_back(guard); } } @@ -982,7 +1010,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, else { locateCoordVar = ir::Assign::make(coordVarUnknown, posVarUnknown); } - Stmt loopBody = ir::Block::make(compoundAssign(posVarUnknown, 1), locateCoordVar, loopToTrackUnderiveds); + Stmt loopBody = ir::Block::make(addAssign(posVarUnknown, 1), locateCoordVar, loopToTrackUnderiveds); if (posIteratorLevel.getParent().hasPosIter()) { // TODO: if level is unique or not loopToTrackUnderiveds = IfThenElse::make(loopcond, loopBody); } @@ -1733,6 +1761,14 @@ bool LowererImpl::generateComputeCode() const { return this->compute; } +Stmt LowererImpl::emitEarlyExit(Expr reductionExpr, std::vector& properties) { + if (loopOrderAllowsShortCircuit && findProperty(properties).defined()) { + Literal annh = findProperty(properties).annihilator(); + Expr isAnnihilator = ir::Eq::make(reductionExpr, lower(annh)); + return IfThenElse::make(isAnnihilator, Block::make(Break::make())); + } + return Stmt(); +} Expr LowererImpl::getTensorVar(TensorVar tensorVar) const { taco_iassert(util::contains(this->tensorVars, tensorVar)) << tensorVar; @@ -2226,7 +2262,7 @@ Stmt LowererImpl::reduceDuplicateCoordinates(Expr coordinate, // need a separate segend variable. segendVar = iterVar; if (alwaysReduce) { - result.push_back(compoundAssign(segendVar, 1)); + result.push_back(addAssign(segendVar, 1)); } } else { Expr segendInit = alwaysReduce ? ir::Add::make(iterVar, 1) : iterVar; @@ -2236,9 +2272,9 @@ Stmt LowererImpl::reduceDuplicateCoordinates(Expr coordinate, vector dedupStmts; if (reducedVal.defined()) { Expr partialVal = Load::make(tensorVals, segendVar); - dedupStmts.push_back(compoundAssign(reducedVal, partialVal)); + dedupStmts.push_back(addAssign(reducedVal, partialVal)); } - dedupStmts.push_back(compoundAssign(segendVar, 1)); + dedupStmts.push_back(addAssign(segendVar, 1)); Stmt dedupBody = Block::make(dedupStmts); ModeFunction posAccess = iterator.posAccess(segendVar, @@ -2389,7 +2425,7 @@ Stmt LowererImpl::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, Expr ivar = iterators[0].getIteratorVar(); if (iterators[0].isUnique()) { - return compoundAssign(ivar, 1); + return addAssign(ivar, 1); } // If iterator is over bottommost coordinate hierarchy level with @@ -2416,7 +2452,7 @@ Stmt LowererImpl::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, : ir::Cast::make(Eq::make(iterator.getCoordVar(), coordinate), ivar.type()); - result.push_back(compoundAssign(ivar, increment)); + result.push_back(addAssign(ivar, increment)); } else if (!iterator.isLeaf()) { result.push_back(Assign::make(ivar, iterator.getSegendVar())); } @@ -2428,7 +2464,7 @@ Stmt LowererImpl::codeToIncIteratorVars(Expr coordinate, IndexVar coordinateVar, bool isMerger = find(mergers.begin(), mergers.end(), iterator) != mergers.end(); if (isMerger) { Expr ivar = iterator.getIteratorVar(); - result.push_back(compoundAssign(ivar, 1)); + result.push_back(addAssign(ivar, 1)); } else { result.push_back(codeToLoadCoordinatesFromPosIterators(iterators, false)); @@ -2514,7 +2550,7 @@ Stmt LowererImpl::appendCoordinate(vector appenders, Expr coord) { } if (generateAssembleCode() || isLastAppender(appender)) { - appendStmts.push_back(compoundAssign(pos, 1)); + appendStmts.push_back(addAssign(pos, 1)); Stmt appendCode = Block::make(appendStmts); if (appenderChild.defined() && appenderChild.hasAppend()) { diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index e915dc164..ffe2fed07 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -267,6 +267,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } void visit(const TensorOpNode* expr) { + taco_iassert(expr->iterAlg.defined()) << "Algebra must be defined" << endl; lattice = build(expr->iterAlg); // Now we need to store regions that should be kept when applying optimizations. diff --git a/test/op_factory.h b/test/op_factory.h index 95471673d..832c2d53e 100644 --- a/test/op_factory.h +++ b/test/op_factory.h @@ -60,9 +60,9 @@ struct ComplementIntersect { struct IntersectGenDeMorgan { IterationAlgebra operator()(const std::vector& regions) { - IterationAlgebra unions; - for(const auto& region : regions) { - unions = Union(unions, Complement(region)); + IterationAlgebra unions = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + unions = Union(unions, Complement(regions[i])); } return Complement(unions); } @@ -104,7 +104,7 @@ struct unionEdge { struct BfsMaskAlg { IterationAlgebra operator()(const std::vector& regions) { std::vector r = regions; - return Intersect(Intersect(r[0], r[1]), Complement(r[2])); + return Intersect(r[0], Complement(r[1])); } }; @@ -134,9 +134,28 @@ struct GeneralAdd { } }; +struct MinImpl { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() >= 2) << "Min operator needs at least two operands"; + return ir::Min::make(v[0], v[1]); + } +}; + struct BfsLower { ir::Expr operator()(const std::vector &v) { - return ir::Mul::make(v[0], v[1]); + return v[0]; + } +}; + +struct OrImpl { + ir::Expr operator()(const std::vector &v) { + return ir::Or::make(v[0], v[1]); + } +}; + +struct AndImpl { + ir::Expr operator()(const std::vector &v) { + return ir::And::make(v[0], v[1]); } }; diff --git a/test/tests-index_notation.cpp b/test/tests-index_notation.cpp index 66c1effe1..5a219c487 100644 --- a/test/tests-index_notation.cpp +++ b/test/tests-index_notation.cpp @@ -1,5 +1,7 @@ #include "test.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation.h" +#include "op_factory.h" using namespace taco; @@ -202,3 +204,15 @@ INSTANTIATE_TEST_CASE_P(separate_reductions, concrete, tk += c(k))), forall(j, tj += b(j)))))); + + +Op scOr("Or", OrImpl(), {Annihilator((bool)1), Identity((bool)0)}); +Op scAnd("And", AndImpl(), {Annihilator((bool)0), Identity((bool)0)}); + +Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); +INSTANTIATE_TEST_CASE_P(tensorOpConcrete, concrete, + Values(ConcreteTest(a(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(B(i, j), c(j)), c(i))), + forall(i, + forall(j, + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) + ))))); \ No newline at end of file diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 4d9438895..30ced544d 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -121,8 +121,9 @@ struct TestCase { vector values(num); for (size_t i=0; i < components.size(); ++i) { auto& coordinates = components[i].first; + std::vector ordering = format.getModeOrdering(); for (size_t j=0; j < coordinates.size(); ++j) { - coords[j][i] = coordinates[j]; + coords[j][i] = coordinates[ordering[j]]; } values[i] = components[i].second; } @@ -1654,21 +1655,82 @@ TEST_STMT(lowerCompUnion, } ) +Op scOr("Or", OrImpl(), {Annihilator((double)1), Identity(Literal((double)0))}); +Op scAnd("And", AndImpl(), {Annihilator((double)0), Identity((double)1)}); Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); -TEST_STMT(bfsMask, + +TEST_STMT(BoolRing, forall(i, forall(j, - a(i) += bfsMaskOp(B(i, j), c(j), c(i)) + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) )), Values( Formats({{a, Format({dense})}, {B, Format({dense,sparse})}, {c, Format({dense})}}) ), { TestCase( - {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, + {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{1, 2}, 1.0}, {{3, 1}, 1.0}, {{4, 3}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}, {{3}, 1.0}}}}) + } +) + +TEST_STMT(BoolRing2, + forall(j, + forall(i, + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse}, {1, 0})}, {c, Format({sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{3, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}, {{3}, 1.0}}}}) + } +) + +TEST_STMT(BoolRing3, + forall(j, + forall(i, + Assignment(a(i), scAnd(B(i, j), c(j)), scOr()) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse}, {1, 0})}, {c, Format({sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{3, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, {c, {{{1}, 1.0}}}}, - {{a, {{{0}, 1.0}}}}) + {{a, {{{0}, 1.0}, {{1}, 1.0}, {{3}, 1.0}}}}) + } +) + +Op customMin("Min", MinImpl(), {Identity(std::numeric_limits::infinity()) }); +Op Plus("Plus", GeneralAdd(), {Annihilator(std::numeric_limits::infinity())}); + +static TensorVar a_inf("a", vectype, Format(), std::numeric_limits::infinity()); +static TensorVar c_inf("c", vectype, Format(), std::numeric_limits::infinity()); +static TensorVar B_inf("B", mattype, Format(), std::numeric_limits::infinity()); + +TEST_STMT(MinPlusRing, + forall(i, + forall(j, + Assignment(a_inf(i), Plus(B_inf(i, j), c_inf(j)), customMin()) + )), + Values( + Formats({{a_inf, Format({dense})}, {B_inf, Format({dense,sparse})}, {c_inf, Format({dense})}}) + ), + { + TestCase( + {{B_inf, {{{0, 1}, 3.0}, {{1, 1}, 4.0}, {{1, 2}, 2.0}, {{3,2}, 5.0}, {{4, 3}, 1.0}}}, + {c_inf, {{{1}, 1.0}, {{2}, 6.0}}}}, + + {{a_inf, {{{0}, 4.0}, {{1}, 5.0}, {{3}, 11.0}}}}) } ) From dbcbe5d68fff3c831cb630c782032fc7088c407d Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 8 May 2020 12:54:20 -0400 Subject: [PATCH 18/27] Reverted lower test framework packing --- test/tests-lower.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 30ced544d..ac1e8f866 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -123,7 +123,7 @@ struct TestCase { auto& coordinates = components[i].first; std::vector ordering = format.getModeOrdering(); for (size_t j=0; j < coordinates.size(); ++j) { - coords[j][i] = coordinates[ordering[j]]; + coords[j][i] = coordinates[j]; } values[i] = components[i].second; } @@ -1686,7 +1686,7 @@ TEST_STMT(BoolRing2, ), { TestCase( - {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{3, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, + {{B, {{{1, 0}, 1.0}, {{1, 1}, 1.0}, {{1, 3}, 1.0}, {{2, 1}, 1.0}, {{3, 4}, 1.0}}}, {c, {{{1}, 1.0}}}}, {{a, {{{0}, 1.0}, {{3}, 1.0}}}}) @@ -1703,7 +1703,7 @@ TEST_STMT(BoolRing3, ), { TestCase( - {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{3, 1}, 1.0}, {{1, 2}, 1.0}, {{4, 3}, 1.0}}}, + {{B, {{{1, 0}, 1.0}, {{1, 1}, 1.0}, {{1, 3}, 1.0}, {{2, 1}, 1.0}, {{3, 4}, 1.0}}}, {c, {{{1}, 1.0}}}}, {{a, {{{0}, 1.0}, {{1}, 1.0}, {{3}, 1.0}}}}) From 2a1bd2bd86d44e05d17729aef1329ed77f914f68 Mon Sep 17 00:00:00 2001 From: Rawn Date: Fri, 8 May 2020 22:59:36 -0400 Subject: [PATCH 19/27] Fixed bug with double assembly. Added test for boolean semi-ring and sparsifying a dense result using tensor ops --- src/lower/lowerer_impl.cpp | 20 +++++++---- src/lower/merge_lattice.cpp | 1 + test/op_factory.h | 1 + test/tests-lower.cpp | 18 ++++++++++ test/tests-scheduling-eval.cpp | 61 ++++++++++++++++++++++++++++++++++ 5 files changed, 95 insertions(+), 6 deletions(-) diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index f8ceb3236..99e918d56 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -1457,7 +1457,6 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, vector appenders, MergeLattice caseLattice, const set& reducedAccesses) { - Stmt initVals = resizeAndInitValues(appenders, reducedAccesses); // Inserter positions Stmt declInserterPosVars = declLocatePosVars(inserters); @@ -1470,9 +1469,10 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, captureNextLocatePos = false; } - // Code of loop body statement - Stmt body; if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.points().size() > 1) { + + // Code of loop body statement + // Explicit zero checks needed std::vector stmts; // Need to emit checks based on case lattice @@ -1487,13 +1487,21 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, caseMap.insert({it, accessCase}); } } + + // This will lower the body for each case to actually compute. Therefore, we don't need to resize assembly arrays std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); + append(stmts, loweredCases); - body = Block::make(stmts); - } else { - body = lower(stmt); + Stmt body = Block::make(stmts); + + return Block::make(declInserterPosVars, declLocatorPosVars, body); } + Stmt initVals = resizeAndInitValues(appenders, reducedAccesses); + + // Code of loop body statement + Stmt body = lower(stmt); + // Code to append coordinates Stmt appendCoords = appendCoordinate(appenders, coordinate); diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index ffe2fed07..138569db9 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -102,6 +102,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA void visit(const ComplementNode* node) { taco_iassert(isa(node->a)) << "Demorgan's rule must be applied before lowering."; lattice = build(node->a); + vector points = flipPoints(lattice.points()); // Otherwise, all tensors are sparse diff --git a/test/op_factory.h b/test/op_factory.h index 832c2d53e..c9aa6266c 100644 --- a/test/op_factory.h +++ b/test/op_factory.h @@ -1,6 +1,7 @@ #ifndef TACO_OP_FACTORY_H #define TACO_OP_FACTORY_H +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation.h" #include "taco/ir/ir.h" diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index ac1e8f866..4332ffc5c 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -1734,5 +1734,23 @@ TEST_STMT(MinPlusRing, } ) +Op sparsify("sparsify", identityFunc(), [](const std::vector& v) {return Union(v[0], Complement(v[1]));}); + +TEST_STMT(SparsifyTest, + forall(i, + a(i) = sparsify(b(i), i) + ), + Values( + Formats({{a, Format({sparse})}, {b, Format({dense})} }) + ), + { + TestCase( + {{b, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}}, + + {{a, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}}) + } +) + + }} diff --git a/test/tests-scheduling-eval.cpp b/test/tests-scheduling-eval.cpp index aca737ab7..c69871b29 100644 --- a/test/tests-scheduling-eval.cpp +++ b/test/tests-scheduling-eval.cpp @@ -8,6 +8,7 @@ #include "taco/index_notation/index_notation.h" #include "codegen/codegen.h" #include "taco/lower/lower.h" +#include "op_factory.h" using namespace taco; const IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); @@ -1718,4 +1719,64 @@ TEST(generate_figures, DISABLED_cpu) { source_file << source.str(); source_file.close(); } +} + +TEST(scheduling_eval, scheduledBoolRing) { + if (should_use_CUDA_codegen()) { + return; + } + int NUM_I = 102; + int NUM_J = 102; + float SPARSITY = .3; + + Tensor A("A", {NUM_I, NUM_J}, CSR); + Tensor x("x", {NUM_J}, {Dense}); + Tensor y("y", {NUM_I}, {Dense}); + + uint8_t one = 1; + uint8_t zero = 0; + + Op scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); + Op scAnd("And", AndImpl(), {Annihilator(one), Identity(zero)}); + Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); + + srand(120); + for (int i = 0; i < NUM_I; i++) { + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + A.insert({i, j}, one); + } + } + } + + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + x.insert({j}, one); + } else { + x.insert({j}, zero); + } + } + + x.pack(); + A.pack(); + + y(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), x(i))); + + IndexStmt stmt = y.getAssignment().concretize(); + stmt = scheduleSpMVCPU(stmt); + + //printToFile("spmv_cpu", stmt); + + y.compile(stmt); + y.assemble(); + y.compute(); + + Tensor expected("expected", {NUM_I}, {Dense}); + expected(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), x(i))); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, y); } \ No newline at end of file From d42b1f358616af37e4dbfdd526b2611bb4226f2e Mon Sep 17 00:00:00 2001 From: Rawn Date: Sat, 9 May 2020 21:19:14 -0400 Subject: [PATCH 20/27] Added test for push --- test/tests-scheduling-eval.cpp | 66 ++++++++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/test/tests-scheduling-eval.cpp b/test/tests-scheduling-eval.cpp index c69871b29..51cd10392 100644 --- a/test/tests-scheduling-eval.cpp +++ b/test/tests-scheduling-eval.cpp @@ -1721,7 +1721,7 @@ TEST(generate_figures, DISABLED_cpu) { } } -TEST(scheduling_eval, scheduledBoolRing) { +TEST(scheduling_eval, bfsPullScheduled) { if (should_use_CUDA_codegen()) { return; } @@ -1737,7 +1737,7 @@ TEST(scheduling_eval, scheduledBoolRing) { uint8_t zero = 0; Op scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); - Op scAnd("And", AndImpl(), {Annihilator(one), Identity(zero)}); + Op scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); srand(120); @@ -1779,4 +1779,66 @@ TEST(scheduling_eval, scheduledBoolRing) { expected.assemble(); expected.compute(); ASSERT_TENSOR_EQ(expected, y); +} + +TEST(scheduling_eval, bfsPushScheduled) { + if (should_use_CUDA_codegen()) { + return; + } + int NUM_I = 102; + int NUM_J = 102; + float SPARSITY = .3; + + Tensor A("A", {NUM_I, NUM_J}, CSC); + Tensor x("x", {NUM_J}, {compressed}); + Tensor y("y", {NUM_I}, {Dense}); + + uint8_t one = 1; + uint8_t zero = 0; + + Op scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); + Op scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); + Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); + + srand(120); + for (int i = 0; i < NUM_I; i++) { + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + A.insert({i, j}, one); + } + } + } + + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + x.insert({j}, one); + } else { + x.insert({j}, zero); + } + } + + x.pack(); + A.pack(); + + IndexExpr computeExpr = Reduction(scOr(), j, scAnd(A(i, j), x(j))); + + y(i) = computeExpr; + + IndexStmt stmt = y.getAssignment().concretize(); + stmt = stmt.reorder(i, j).parallelize(j, ParallelUnit::CPUThread, OutputRaceStrategy::Atomics); + + //printToFile("spmv_cpu", stmt); + + y.compile(stmt); + y.assemble(); + y.compute(); + + Tensor expected("expected", {NUM_I}, {Dense}); + expected(i) = computeExpr; + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, y); } \ No newline at end of file From cfde49af25b72ed70c2b961f6521bde11b13e4d8 Mon Sep 17 00:00:00 2001 From: Rawn Date: Sun, 31 Jan 2021 13:08:33 -0800 Subject: [PATCH 21/27] Rename some variables --- include/taco/index_notation/index_notation.h | 12 +-- .../index_notation/index_notation_nodes.h | 30 +++--- .../index_notation/index_notation_printer.h | 2 +- .../index_notation/index_notation_rewriter.h | 4 +- .../index_notation/index_notation_visitor.h | 8 +- include/taco/index_notation/tensor_operator.h | 49 +++++----- include/taco/lower/lowerer_impl.h | 2 +- src/index_notation/index_notation.cpp | 52 +++++------ src/index_notation/index_notation_nodes.cpp | 22 ++--- src/index_notation/index_notation_printer.cpp | 6 +- .../index_notation_rewriter.cpp | 8 +- src/index_notation/index_notation_visitor.cpp | 2 +- src/index_notation/tensor_operator.cpp | 50 +++++----- src/lower/expr_tools.cpp | 2 +- src/lower/lowerer_impl.cpp | 12 +-- src/lower/merge_lattice.cpp | 2 +- test/tests-index_notation.cpp | 6 +- test/tests-lower.cpp | 35 +++++-- test/tests-merge_lattice.cpp | 18 ++-- test/tests-scheduling-eval.cpp | 91 +++++-------------- 20 files changed, 191 insertions(+), 222 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 4f4251db1..a0151bd36 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -50,7 +50,7 @@ struct SubNode; struct MulNode; struct DivNode; struct CastNode; -struct TensorOpNode; +struct CallNode; struct CallIntrinsicNode; struct ReductionNode; struct IndexVarNode; @@ -406,11 +406,11 @@ class Cast : public IndexExpr { }; /// A call to an operator -class TensorOp: public IndexExpr { +class Call: public IndexExpr { public: - TensorOp() = default; - TensorOp(const TensorOpNode*); - TensorOp(const TensorOpNode*, std::string name); + Call() = default; + Call(const CallNode*); + Call(const CallNode*, std::string name); const std::vector& getArgs() const; const std::function&)> getFunc() const; @@ -420,7 +420,7 @@ class TensorOp: public IndexExpr { const std::map, std::function&)>> getDefs() const; const std::vector& getDefinedArgs() const; - typedef TensorOpNode Node; + typedef CallNode Node; private: std::string name; diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 906292a21..a730ac32f 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -175,20 +175,20 @@ struct CallIntrinsicNode : public IndexExprNode { std::vector args; }; -struct TensorOpNode : public IndexExprNode { - typedef std::function&)> opImpl; - typedef std::function&)> algebraImpl; +struct CallNode : public IndexExprNode { + typedef std::function&)> OpImpl; + typedef std::function&)> AlgebraImpl; - TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, - const IterationAlgebra& iterAlg, - const std::vector& properties, - const std::map, opImpl>& regionDefinitions, - const std::vector& definedRegions); + CallNode(std::string name, const std::vector& args, OpImpl lowerFunc, + const IterationAlgebra& iterAlg, + const std::vector& properties, + const std::map, OpImpl>& regionDefinitions, + const std::vector& definedRegions); - TensorOpNode(std::string name, const std::vector& args, opImpl lowerFunc, - const IterationAlgebra& iterAlg, - const std::vector& properties, - const std::map, opImpl>& regionDefinitions); + CallNode(std::string name, const std::vector& args, OpImpl lowerFunc, + const IterationAlgebra& iterAlg, + const std::vector& properties, + const std::map, OpImpl>& regionDefinitions); void accept(IndexExprVisitorStrict* v) const { v->visit(this); @@ -196,16 +196,16 @@ struct TensorOpNode : public IndexExprNode { std::string name; std::vector args; - opImpl defaultLowerFunc; + OpImpl defaultLowerFunc; IterationAlgebra iterAlg; std::vector properties; - std::map, opImpl> regionDefinitions; + std::map, OpImpl> regionDefinitions; // Needed to track which inputs have been exhausted so the lowerer can know which lower func to use std::vector definedRegions; private: - static Datatype inferReturnType(opImpl f, const std::vector& inputs) { + static Datatype inferReturnType(OpImpl f, const std::vector& inputs) { std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; std::vector exprs = util::map(inputs, getExprs); diff --git a/include/taco/index_notation/index_notation_printer.h b/include/taco/index_notation/index_notation_printer.h index 3aceaaa20..7c32f25b1 100644 --- a/include/taco/index_notation/index_notation_printer.h +++ b/include/taco/index_notation/index_notation_printer.h @@ -25,7 +25,7 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict { void visit(const MulNode*); void visit(const DivNode*); void visit(const CastNode*); - void visit(const TensorOpNode*); + void visit(const CallNode*); void visit(const CallIntrinsicNode*); void visit(const ReductionNode*); void visit(const IndexVarNode*); diff --git a/include/taco/index_notation/index_notation_rewriter.h b/include/taco/index_notation/index_notation_rewriter.h index 729349a58..caaa64773 100644 --- a/include/taco/index_notation/index_notation_rewriter.h +++ b/include/taco/index_notation/index_notation_rewriter.h @@ -32,7 +32,7 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict { virtual void visit(const MulNode* op) = 0; virtual void visit(const DivNode* op) = 0; virtual void visit(const CastNode* op) = 0; - virtual void visit(const TensorOpNode* op) = 0; + virtual void visit(const CallNode* op) = 0; virtual void visit(const CallIntrinsicNode* op) = 0; virtual void visit(const ReductionNode* op) = 0; virtual void visit(const IndexVarNode* op) = 0; @@ -95,7 +95,7 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict { virtual void visit(const MulNode* op); virtual void visit(const DivNode* op); virtual void visit(const CastNode* op); - virtual void visit(const TensorOpNode* op); + virtual void visit(const CallNode* op); virtual void visit(const CallIntrinsicNode* op); virtual void visit(const ReductionNode* op); virtual void visit(const IndexVarNode* op); diff --git a/include/taco/index_notation/index_notation_visitor.h b/include/taco/index_notation/index_notation_visitor.h index 2cbde9489..adc0e6787 100644 --- a/include/taco/index_notation/index_notation_visitor.h +++ b/include/taco/index_notation/index_notation_visitor.h @@ -20,7 +20,7 @@ struct MulNode; struct DivNode; struct SqrtNode; struct CastNode; -struct TensorOpNode; +struct CallNode; struct CallIntrinsicNode; struct UnaryExprNode; struct BinaryExprNode; @@ -52,7 +52,7 @@ class IndexExprVisitorStrict { virtual void visit(const DivNode*) = 0; virtual void visit(const SqrtNode*) = 0; virtual void visit(const CastNode*) = 0; - virtual void visit(const TensorOpNode*) = 0; + virtual void visit(const CallNode*) = 0; virtual void visit(const CallIntrinsicNode*) = 0; virtual void visit(const ReductionNode*) = 0; virtual void visit(const IndexVarNode*) = 0; @@ -100,7 +100,7 @@ class IndexNotationVisitor : public IndexNotationVisitorStrict { virtual void visit(const DivNode* node); virtual void visit(const SqrtNode* node); virtual void visit(const CastNode* node); - virtual void visit(const TensorOpNode* node); + virtual void visit(const CallNode* node); virtual void visit(const CallIntrinsicNode* node); virtual void visit(const UnaryExprNode* node); virtual void visit(const BinaryExprNode* node); @@ -170,7 +170,7 @@ class Matcher : public IndexNotationVisitor { RULE(MulNode) RULE(DivNode) RULE(CastNode) - RULE(TensorOpNode) + RULE(CallNode) RULE(CallIntrinsicNode) RULE(ReductionNode) diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index d3aabe87e..e129ae971 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -13,55 +13,56 @@ namespace taco { -class Op { - -using opImpl = TensorOpNode::opImpl; -using algebraImpl = TensorOpNode::algebraImpl; +class Func { +// TODO: RENAME +using OpImpl = CallNode::OpImpl; +using AlgebraImpl = CallNode::AlgebraImpl; +// TODO: Make this part of callNode and call. Add generateIterationAlgebra() and generateImplementation() functions public: // Full construction - Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}); + Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, + std::map, OpImpl> specialDefinitions = {}); - Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}); + Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, + std::map, OpImpl> specialDefinitions = {}); // Construct without specifying algebra - Op(std::string name, opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}); + Func(std::string name, OpImpl lowererFunc, std::vector properties, + std::map, OpImpl> specialDefinitions = {}); - Op(opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions = {}); + Func(OpImpl lowererFunc, std::vector properties, + std::map, OpImpl> specialDefinitions = {}); // Construct without properties - Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, - std::map, opImpl> specialDefinitions = {}); + Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, + std::map, OpImpl> specialDefinitions = {}); - Op(opImpl lowererFunc, algebraImpl algebraFunc, std::map, opImpl> specialDefinitions = {}); + Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::map, OpImpl> specialDefinitions = {}); // Construct without algebra or properties - Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}); + Func(std::string name, OpImpl lowererFunc, std::map, OpImpl> specialDefinitions = {}); - explicit Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions = {}); + explicit Func(OpImpl lowererFunc, std::map, OpImpl> specialDefinitions = {}); template - TensorOp operator()(IndexExprs&&... exprs) { + Call operator()(IndexExprs&&... exprs) { std::vector actualArgs{exprs...}; IterationAlgebra nodeAlgebra = algebraFunc == nullptr? inferAlgFromProperties(actualArgs): algebraFunc(actualArgs); - TensorOpNode* op = new TensorOpNode(name, actualArgs, lowererFunc, nodeAlgebra, properties, - regionDefinitions); + CallNode* op = new CallNode(name, actualArgs, lowererFunc, nodeAlgebra, properties, + regionDefinitions); - return TensorOp(op); + return Call(op); } private: std::string name; - opImpl lowererFunc; - algebraImpl algebraFunc; + OpImpl lowererFunc; + AlgebraImpl algebraFunc; std::vector properties; - std::map, opImpl> regionDefinitions; + std::map, OpImpl> regionDefinitions; IterationAlgebra inferAlgFromProperties(const std::vector& exprs); diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index e78e9f8bf..6053d3e92 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -220,7 +220,7 @@ class LowererImpl : public util::Uncopyable { virtual ir::Expr lowerIndexVar(IndexVar var); /// Lower a generic tensor operation expression - virtual ir::Expr lowerTensorOp(TensorOp op); + virtual ir::Expr lowerTensorOp(Call op); /// Lower a concrete index variable statement. ir::Stmt lower(IndexStmt stmt); diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index e77bb2a29..844314b9f 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -242,12 +242,12 @@ struct Equals : public IndexNotationVisitorStrict { eq = true; } - void visit(const TensorOpNode* anode) { - if (!isa(bExpr.ptr)) { + void visit(const CallNode* anode) { + if (!isa(bExpr.ptr)) { eq = false; return; } - auto bnode = to(bExpr.ptr); + auto bnode = to(bExpr.ptr); // Properties if (anode->properties.size() != bnode->properties.size()) { @@ -447,7 +447,7 @@ struct Equals : public IndexNotationVisitorStrict { eq = true; } - static bool checkRegionDefinitions(const TensorOpNode* anode, const TensorOpNode* bnode) { + static bool checkRegionDefinitions(const CallNode* anode, const CallNode* bnode) { // Check region definitions if (anode->regionDefinitions.size() != bnode->regionDefinitions.size()) { return false; @@ -479,14 +479,14 @@ struct Equals : public IndexNotationVisitorStrict { /// Checks if the iteration algebra structure is the same and the ordering of the index expressions /// nested under regions is the same for each op node. - static bool checkIterationAlg(const TensorOpNode* anode, const TensorOpNode* bnode) { + static bool checkIterationAlg(const CallNode* anode, const CallNode* bnode) { // Check IterationAlgebra structures if(!algStructureEqual(anode->iterAlg, bnode->iterAlg)) { return false; } struct OrderChecker : public IterationAlgebraVisitor { - explicit OrderChecker(const TensorOpNode* op) : op(op) {} + explicit OrderChecker(const CallNode* op) : op(op) {} std::vector& check() { op->iterAlg.accept(this); @@ -504,7 +504,7 @@ struct Equals : public IndexNotationVisitorStrict { } std::vector ordering; - const TensorOpNode* op; + const CallNode* op; }; std::vector aOrdering = OrderChecker(anode).check(); @@ -885,49 +885,49 @@ template <> Cast to(IndexExpr e) { return Cast(to(e.ptr)); } -// class TensorOp, most construction should happen from tensor_operator.h -TensorOp::TensorOp(const TensorOpNode* n) : IndexExpr(n) { +// class Call, most construction should happen from tensor_operator.h +Call::Call(const CallNode* n) : IndexExpr(n) { } -TensorOp::TensorOp(const TensorOpNode *n, std::string name) : IndexExpr(n), name(name) { +Call::Call(const CallNode *n, std::string name) : IndexExpr(n), name(name) { } -const std::vector& TensorOp::getArgs() const { +const std::vector& Call::getArgs() const { return getNode(*this)->args; } -const TensorOpNode::opImpl TensorOp::getFunc() const { +const CallNode::OpImpl Call::getFunc() const { return getNode(*this)->defaultLowerFunc; } -const IterationAlgebra& TensorOp::getAlgebra() const { +const IterationAlgebra& Call::getAlgebra() const { return getNode(*this)->iterAlg; } -const std::vector& TensorOp::getProperties() const { +const std::vector& Call::getProperties() const { return getNode(*this)->properties; } -const std::string TensorOp::getName() const { +const std::string Call::getName() const { return getNode(*this)->name; } -const std::map, TensorOpNode::opImpl> TensorOp::getDefs() const { +const std::map, CallNode::OpImpl> Call::getDefs() const { return getNode(*this)->regionDefinitions; } -const std::vector& TensorOp::getDefinedArgs() const { +const std::vector& Call::getDefinedArgs() const { return getNode(*this)->definedRegions; } -template <> bool isa(IndexExpr e) { - return isa(e.ptr); +template <> bool isa(IndexExpr e) { + return isa(e.ptr); } -template <> TensorOp to(IndexExpr e) { - taco_iassert(isa(e)); - return TensorOp(to(e.ptr)); +template <> Call to(IndexExpr e) { + taco_iassert(isa(e)); + return Call(to(e.ptr)); } // class CallIntrinsic @@ -2585,7 +2585,7 @@ struct Zero : public IndexNotationRewriterStrict { } } - void visit(const TensorOpNode* op) { + void visit(const CallNode* op) { std::vector args; std::vector rewrittenArgs; std::vector definedArgs; @@ -2632,8 +2632,8 @@ struct Zero : public IndexNotationRewriterStrict { if (rewritten) { const std::map subs = util::zipToMap(op->args, rewrittenArgs); IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); - expr = new TensorOpNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, - op->regionDefinitions, definedArgs); + expr = new CallNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, + op->regionDefinitions, definedArgs); } else { expr = op; @@ -2861,7 +2861,7 @@ struct fillValueInferrer : IndexExprRewriterStrict { expr = IndexExpr(); } - virtual void visit(const TensorOpNode* op) { + virtual void visit(const CallNode* op) { Annihilator annihilator = findProperty(op->properties); if(annihilator.defined()) { IndexExpr e = annihilator.annihilates(op->args); diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index 105d521c9..42a94bc98 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -29,18 +29,18 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, func(func), args(args) { } -// class TensorOpNode - TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl defaultLowerFunc, - const IterationAlgebra &iterAlg, const std::vector &properties, - const std::map, opImpl>& regionDefinitions) - : TensorOpNode(name, args, defaultLowerFunc, iterAlg, properties, regionDefinitions, definedIndices(args)){ +// class CallNode + CallNode::CallNode(std::string name, const std::vector& args, OpImpl defaultLowerFunc, + const IterationAlgebra &iterAlg, const std::vector &properties, + const std::map, OpImpl>& regionDefinitions) + : CallNode(name, args, defaultLowerFunc, iterAlg, properties, regionDefinitions, definedIndices(args)){ } -// class TensorOpNode -TensorOpNode::TensorOpNode(std::string name, const std::vector& args, opImpl defaultLowerFunc, - const IterationAlgebra &iterAlg, const std::vector &properties, - const std::map, opImpl>& regionDefinitions, - const std::vector& definedRegions) +// class CallNode +CallNode::CallNode(std::string name, const std::vector& args, OpImpl defaultLowerFunc, + const IterationAlgebra &iterAlg, const std::vector &properties, + const std::map, OpImpl>& regionDefinitions, + const std::vector& definedRegions) : IndexExprNode(inferReturnType(defaultLowerFunc, args)), name(name), args(args), defaultLowerFunc(defaultLowerFunc), iterAlg(applyDemorgan(iterAlg)), properties(properties), regionDefinitions(regionDefinitions), definedRegions(definedRegions) { @@ -54,7 +54,7 @@ TensorOpNode::TensorOpNode(std::string name, const std::vector& args, // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) : IndexExprNode(a.getDataType()), op(op), var(var), a(a) { - taco_iassert(isa(op.ptr) || isa(op.ptr)); + taco_iassert(isa(op.ptr) || isa(op.ptr)); } IndexVarNode::IndexVarNode(const std::string& name, const Datatype& type) diff --git a/src/index_notation/index_notation_printer.cpp b/src/index_notation/index_notation_printer.cpp index a95278e1e..a9c959d8c 100644 --- a/src/index_notation/index_notation_printer.cpp +++ b/src/index_notation/index_notation_printer.cpp @@ -157,7 +157,7 @@ static inline void acceptJoin(IndexNotationPrinter* printer, } } -void IndexNotationPrinter::visit(const TensorOpNode* op) { +void IndexNotationPrinter::visit(const CallNode* op) { parentPrecedence = Precedence::FUNC; os << op->name << "("; acceptJoin(this, os, op->args, ", "); @@ -190,7 +190,7 @@ void IndexNotationPrinter::visit(const ReductionNode* op) { reductionName = "reduction(" + node->getOperatorString() + ")"; } - void visit(const TensorOpNode* node) { + void visit(const CallNode* node) { reductionName = node->name + "Reduce"; } }; @@ -213,7 +213,7 @@ void IndexNotationPrinter::visit(const AssignmentNode* op) { operatorName = node->getOperatorString(); } - void visit(const TensorOpNode* node) { + void visit(const CallNode* node) { operatorName = node->name; } }; diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 5e206c4e0..b546662bc 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -107,7 +107,7 @@ void IndexNotationRewriter::visit(const CastNode* op) { } } -void IndexNotationRewriter::visit(const TensorOpNode* op) { +void IndexNotationRewriter::visit(const CallNode* op) { std::vector args; bool rewritten = false; for(auto& arg : op->args) { @@ -121,8 +121,8 @@ void IndexNotationRewriter::visit(const TensorOpNode* op) { if (rewritten) { const std::map subs = util::zipToMap(op->args, args); IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); - expr = new TensorOpNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, - op->regionDefinitions); + expr = new CallNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, + op->regionDefinitions); } else { expr = op; @@ -304,7 +304,7 @@ struct ReplaceRewriter : public IndexNotationRewriter { SUBSTITUTE_EXPR; } - void visit(const TensorOpNode* op) { + void visit(const CallNode* op) { SUBSTITUTE_EXPR; } diff --git a/src/index_notation/index_notation_visitor.cpp b/src/index_notation/index_notation_visitor.cpp index e954895f7..317c5b15c 100644 --- a/src/index_notation/index_notation_visitor.cpp +++ b/src/index_notation/index_notation_visitor.cpp @@ -69,7 +69,7 @@ void IndexNotationVisitor::visit(const CastNode* op) { op->a.accept(this); } -void IndexNotationVisitor::visit(const TensorOpNode* op) { +void IndexNotationVisitor::visit(const CallNode* op) { for (auto& arg : op->args) { arg.accept(this); } diff --git a/src/index_notation/tensor_operator.cpp b/src/index_notation/tensor_operator.cpp index ec9edf8ae..5e0939465 100644 --- a/src/index_notation/tensor_operator.cpp +++ b/src/index_notation/tensor_operator.cpp @@ -3,50 +3,50 @@ namespace taco { // Full construction -Op::Op(opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions) - : name(util::uniqueName("Op")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), +Func::Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, + std::map, OpImpl> specialDefinitions) + : name(util::uniqueName("Func")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } -Op::Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, std::vector properties, - std::map, opImpl> specialDefinitions) +Func::Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, + std::map, OpImpl> specialDefinitions) : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } // Construct without specifying algebra -Op::Op(std::string name, opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions) - : Op(name, lowererFunc, nullptr, properties, specialDefinitions) { +Func::Func(std::string name, OpImpl lowererFunc, std::vector properties, + std::map, OpImpl> specialDefinitions) + : Func(name, lowererFunc, nullptr, properties, specialDefinitions) { } -Op::Op(opImpl lowererFunc, std::vector properties, - std::map, opImpl> specialDefinitions) - : Op(util::uniqueName("Op"), lowererFunc, nullptr, properties, specialDefinitions) { +Func::Func(OpImpl lowererFunc, std::vector properties, + std::map, OpImpl> specialDefinitions) + : Func(util::uniqueName("Func"), lowererFunc, nullptr, properties, specialDefinitions) { } // Construct without properties -Op::Op(std::string name, opImpl lowererFunc, algebraImpl algebraFunc, - std::map, opImpl> specialDefinitions) - : Op(name, lowererFunc, algebraFunc, {}, specialDefinitions) { +Func::Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, + std::map, OpImpl> specialDefinitions) + : Func(name, lowererFunc, algebraFunc, {}, specialDefinitions) { } -Op::Op(opImpl lowererFunc, algebraImpl algebraFunc, - std::map, opImpl> specialDefinitions) : - Op(util::uniqueName("Op"), lowererFunc, algebraFunc, {}, specialDefinitions) { +Func::Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, + std::map, OpImpl> specialDefinitions) : + Func(util::uniqueName("Func"), lowererFunc, algebraFunc, {}, specialDefinitions) { } // Construct without algebra or properties -Op::Op(std::string name, opImpl lowererFunc, std::map, opImpl> specialDefinitions) - : Op(name, lowererFunc, nullptr, specialDefinitions) { +Func::Func(std::string name, OpImpl lowererFunc, std::map, OpImpl> specialDefinitions) + : Func(name, lowererFunc, nullptr, specialDefinitions) { } -Op::Op(opImpl lowererFunc, std::map, opImpl> specialDefinitions) - : Op(lowererFunc, nullptr, specialDefinitions) { +Func::Func(OpImpl lowererFunc, std::map, OpImpl> specialDefinitions) + : Func(lowererFunc, nullptr, specialDefinitions) { } -IterationAlgebra Op::inferAlgFromProperties(const std::vector& exprs) { +IterationAlgebra Func::inferAlgFromProperties(const std::vector& exprs) { if(properties.empty()) { return constructDefaultAlgebra(exprs); } @@ -74,7 +74,7 @@ IterationAlgebra Op::inferAlgFromProperties(const std::vector& exprs) } // Constructs an algebra that iterates over the entire space -IterationAlgebra Op::constructDefaultAlgebra(const std::vector& exprs) { +IterationAlgebra Func::constructDefaultAlgebra(const std::vector& exprs) { if(exprs.empty()) return Region(); IterationAlgebra tensorsRegions(exprs[0]); @@ -86,7 +86,7 @@ IterationAlgebra Op::constructDefaultAlgebra(const std::vector& exprs return Union(tensorsRegions, background); } -IterationAlgebra Op::constructAnnihilatorAlg(const std::vector &args, taco::Annihilator annihilator) { +IterationAlgebra Func::constructAnnihilatorAlg(const std::vector &args, taco::Annihilator annihilator) { if(args.size () < 2) { return IterationAlgebra(); } @@ -120,7 +120,7 @@ IterationAlgebra Op::constructAnnihilatorAlg(const std::vector &args, return alg; } -IterationAlgebra Op::constructIdentityAlg(const std::vector &args, taco::Identity identity) { +IterationAlgebra Func::constructIdentityAlg(const std::vector &args, taco::Identity identity) { if(args.size() < 2) { return IterationAlgebra(); } diff --git a/src/lower/expr_tools.cpp b/src/lower/expr_tools.cpp index 2d86bd966..74ea277f4 100644 --- a/src/lower/expr_tools.cpp +++ b/src/lower/expr_tools.cpp @@ -261,7 +261,7 @@ class SubExprVisitor : public IndexExprVisitorStrict { subExpr = binarySubExpr(op); } - void visit(const TensorOpNode* op) { + void visit(const CallNode* op) { taco_not_supported_yet; } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index e49bf34f0..bd3f0f981 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -57,7 +57,7 @@ class LowererImpl::Visitor : public IndexNotationVisitorStrict { void visit(const SqrtNode* node) { expr = impl->lowerSqrt(node); } void visit(const CastNode* node) { expr = impl->lowerCast(node); } void visit(const CallIntrinsicNode* node) { expr = impl->lowerCallIntrinsic(node); } - void visit(const TensorOpNode* node) { expr = impl->lowerTensorOp(node); } + void visit(const CallNode* node) { expr = impl->lowerTensorOp(node); } void visit(const ReductionNode* node) { taco_ierror << "Reduction nodes not supported in concrete index notation"; } @@ -277,9 +277,9 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) return addAssign(var, rhs, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), atomicParallelUnit); } - taco_iassert(isa(assignment.getOperator())); + taco_iassert(isa(assignment.getOperator())); - TensorOp op = to(assignment.getOperator()); + Call op = to(assignment.getOperator()); Expr assignOp = op.getFunc()({var, rhs}); Stmt assign = Assign::make(var, assignOp, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), atomicParallelUnit); @@ -303,9 +303,9 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) computeStmt = compoundStore(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); } else { - taco_iassert(isa(assignment.getOperator())); + taco_iassert(isa(assignment.getOperator())); - TensorOp op = to(assignment.getOperator()); + Call op = to(assignment.getOperator()); Expr assignOp = op.getFunc()({Load::make(values, loc), rhs}); computeStmt = Store::make(values, loc, assignOp, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), @@ -1735,7 +1735,7 @@ Expr LowererImpl::lowerCallIntrinsic(CallIntrinsic call) { } -Expr LowererImpl::lowerTensorOp(TensorOp op) { +Expr LowererImpl::lowerTensorOp(Call op) { auto definedArgs = op.getDefinedArgs(); std::vector args; diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 138569db9..1ea4ee0e7 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -267,7 +267,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA lattice = build(expr->a); } - void visit(const TensorOpNode* expr) { + void visit(const CallNode* expr) { taco_iassert(expr->iterAlg.defined()) << "Algebra must be defined" << endl; lattice = build(expr->iterAlg); diff --git a/test/tests-index_notation.cpp b/test/tests-index_notation.cpp index 5a219c487..282934246 100644 --- a/test/tests-index_notation.cpp +++ b/test/tests-index_notation.cpp @@ -206,10 +206,10 @@ INSTANTIATE_TEST_CASE_P(separate_reductions, concrete, -Op scOr("Or", OrImpl(), {Annihilator((bool)1), Identity((bool)0)}); -Op scAnd("And", AndImpl(), {Annihilator((bool)0), Identity((bool)0)}); +Func scOr("Or", OrImpl(), {Annihilator((bool)1), Identity((bool)0)}); +Func scAnd("And", AndImpl(), {Annihilator((bool)0), Identity((bool)0)}); -Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); +Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); INSTANTIATE_TEST_CASE_P(tensorOpConcrete, concrete, Values(ConcreteTest(a(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(B(i, j), c(j)), c(i))), forall(i, diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index 4332ffc5c..19e5fec8d 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -1582,7 +1582,7 @@ TEST_STMT(vector_not, // Test tensorOps -Op testOp("testOp", MulAdd(), BC_BD_CD()); +Func testOp("testOp", MulAdd(), BC_BD_CD()); TEST_STMT(testOp1, forall(i, @@ -1606,7 +1606,7 @@ TEST_STMT(testOp1, ) -Op specialOp("specialOp", GeneralAdd(), BC_BD_CD(), {{{0,1}, MulRegionDef()}, {{0,2}, SubRegionDef()}}); +Func specialOp("specialOp", GeneralAdd(), BC_BD_CD(), {{{0, 1}, MulRegionDef()}, {{0, 2}, SubRegionDef()}}); TEST_STMT(lowerSpecialRegions1, forall(i, forall(j, @@ -1630,7 +1630,7 @@ TEST_STMT(lowerSpecialRegions1, } ) -Op compUnion("compUnion", GeneralAdd(), ComplementUnion()); +Func compUnion("compUnion", GeneralAdd(), ComplementUnion()); TEST_STMT(lowerCompUnion, forall(i, forall(j, @@ -1655,9 +1655,9 @@ TEST_STMT(lowerCompUnion, } ) -Op scOr("Or", OrImpl(), {Annihilator((double)1), Identity(Literal((double)0))}); -Op scAnd("And", AndImpl(), {Annihilator((double)0), Identity((double)1)}); -Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); +Func scOr("Or", OrImpl(), {Annihilator((double)1), Identity(Literal((double)0))}); +Func scAnd("And", AndImpl(), {Annihilator((double)0), Identity((double)1)}); +Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); TEST_STMT(BoolRing, forall(i, @@ -1710,8 +1710,8 @@ TEST_STMT(BoolRing3, } ) -Op customMin("Min", MinImpl(), {Identity(std::numeric_limits::infinity()) }); -Op Plus("Plus", GeneralAdd(), {Annihilator(std::numeric_limits::infinity())}); +Func customMin("Min", MinImpl(), {Identity(std::numeric_limits::infinity()) }); +Func Plus("Plus", GeneralAdd(), {Annihilator(std::numeric_limits::infinity())}); static TensorVar a_inf("a", vectype, Format(), std::numeric_limits::infinity()); static TensorVar c_inf("c", vectype, Format(), std::numeric_limits::infinity()); @@ -1734,7 +1734,7 @@ TEST_STMT(MinPlusRing, } ) -Op sparsify("sparsify", identityFunc(), [](const std::vector& v) {return Union(v[0], Complement(v[1]));}); +Func sparsify("sparsify", identityFunc(), [](const std::vector& v) {return Union(v[0], Complement(v[1]));}); TEST_STMT(SparsifyTest, forall(i, @@ -1751,6 +1751,23 @@ TEST_STMT(SparsifyTest, } ) +Func xorOp("xor", GeneralAdd(), xorGen()); + +TEST_STMT(XorTest, + forall(i, + c(i) = xorOp(a(i), b(i)) + ), + Values( + Formats({{a, Format({sparse})}, {b, Format({sparse})}, {c, Format({sparse})} }) + ), + { + TestCase( + {{a, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}, + {b, {{{1}, 5.0}, {{2}, 4.0}, {{4}, 2.0}}}}, + + {{c, {{{0}, 3.0}, {{1}, 5.0}}}}) + } +) }} diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index e43898d10..97f88310b 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -615,8 +615,8 @@ INSTANTIATE_TEST_CASE_P(hashmap, merge_lattice, ) ); -Op intersectAdd("intersectAdd", GeneralAdd(), IntersectGen()); -Op intersectAddDeMorgan("intersectAddDeMorgan", GeneralAdd(), IntersectGenDeMorgan()); +Func intersectAdd("intersectAdd", GeneralAdd(), IntersectGen()); +Func intersectAddDeMorgan("intersectAddDeMorgan", GeneralAdd(), IntersectGenDeMorgan()); INSTANTIATE_TEST_CASE_P(deMorganIntersect, merge_lattice, Values( @@ -666,7 +666,7 @@ INSTANTIATE_TEST_CASE_P(deMorganIntersect, merge_lattice, ) ); -Op complementIntersect("complementIntersect", GeneralAdd(), ComplementIntersect()); +Func complementIntersect("complementIntersect", GeneralAdd(), ComplementIntersect()); INSTANTIATE_TEST_CASE_P(complementIntersect, merge_lattice, Values( @@ -776,7 +776,7 @@ INSTANTIATE_TEST_CASE_P(complementIntersect, merge_lattice, ); -Op complementUnion("complementUnion", GeneralAdd(), ComplementUnion()); +Func complementUnion("complementUnion", GeneralAdd(), ComplementUnion()); INSTANTIATE_TEST_CASE_P(complementUnion, merge_lattice, Values( Test(forall(i, rd = complementUnion(s1, s2)), @@ -908,7 +908,7 @@ INSTANTIATE_TEST_CASE_P(complementUnion, merge_lattice, ) ); -Op xorOp("xor", GeneralAdd(), xorGen()); +Func xorOp("xor", GeneralAdd(), xorGen()); INSTANTIATE_TEST_CASE_P(xorLattice, merge_lattice, Values(Test(forall(i, rd = xorOp(s1, s2)), MergeLattice({MergePoint({it(s1), it(s2)}, @@ -991,7 +991,7 @@ INSTANTIATE_TEST_CASE_P(xorLattice, merge_lattice, ) ); -Op identity("identity", identityFunc(), fullSpaceGen()); +Func identity("identity", identityFunc(), fullSpaceGen()); INSTANTIATE_TEST_CASE_P(singleCompUnion, merge_lattice, Values(Test(forall(i, rd = identity(s1)), MergeLattice({MergePoint({it(s1), i}, @@ -1017,9 +1017,9 @@ INSTANTIATE_TEST_CASE_P(singleCompUnion, merge_lattice, ) ); -Op emptyIdentity("emptyIdentity", identityFunc(), emptyGen()); -Op intersectEdgeCase("intersectEdgeCase", GeneralAdd(), intersectEdge()); -Op unionEdgeCase("unionEdgeCase", GeneralAdd(), unionEdge()); +Func emptyIdentity("emptyIdentity", identityFunc(), emptyGen()); +Func intersectEdgeCase("intersectEdgeCase", GeneralAdd(), intersectEdge()); +Func unionEdgeCase("unionEdgeCase", GeneralAdd(), unionEdge()); INSTANTIATE_TEST_CASE_P(edgeCases, merge_lattice, Values(Test(forall(i, rd = emptyIdentity(s1)), MergeLattice({MergePoint({it(s1)}, diff --git a/test/tests-scheduling-eval.cpp b/test/tests-scheduling-eval.cpp index 5bd1a41ef..aa3a3c483 100644 --- a/test/tests-scheduling-eval.cpp +++ b/test/tests-scheduling-eval.cpp @@ -27,7 +27,7 @@ void printToFile(string filename, IndexStmt stmt) { mkdir(file_path.c_str(), 0777); std::shared_ptr codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen); - ir::Stmt compute = lower(stmt, "compute", false, true); + ir::Stmt compute = lower(stmt, "compute", true, true); codegen->compile(compute, true); ofstream source_file; @@ -1730,16 +1730,18 @@ TEST(scheduling_eval, bfsPullScheduled) { int NUM_J = numVertices; float SPARSITY = .3; - Tensor A("A", {NUM_I, NUM_J}, CSR); - Tensor x("x", {NUM_J}, {Dense}); + Tensor A("A", {NUM_I, NUM_J}, CSC); + Tensor x("x", {NUM_J}, {Sparse}); + Tensor m("mask", {NUM_J}, {Dense}); Tensor y("y", {NUM_I}, {Dense}); + Tensor step("step"); uint16_t one = 1; uint16_t zero = 0; - Op scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); - Op scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); - Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); + Func scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); + Func scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); + Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); srand(120); for (int i = 0; i < NUM_I; i++) { @@ -1762,80 +1764,29 @@ TEST(scheduling_eval, bfsPullScheduled) { x.pack(); A.pack(); + m.pack(); - y(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), x(i))); - + y(i) = Reduction(scOr(), j, scAnd(A(i, j), x(j))); IndexStmt stmt = y.getAssignment().concretize(); - stmt = scheduleSpMVCPU(stmt); - - //printToFile("spmv_cpu", stmt); + stmt = stmt.reorder(i,j) + .parallelize(j, taco::ParallelUnit::CPUThread, taco::OutputRaceStrategy::Atomics); + printToFile("bfs_push", stmt); y.compile(stmt); y.assemble(); y.compute(); + Tensor s("s", {NUM_J}, {Sparse}); + Tensor d("d", {NUM_J}, {dense}); - Tensor expected("expected", {NUM_I}, {Dense}); - expected(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), x(i))); - expected.compile(); - expected.assemble(); - expected.compute(); - ASSERT_TENSOR_EQ(expected, y); -} - -TEST(scheduling_eval, bfsPushScheduled) { - if (should_use_CUDA_codegen()) { - return; - } - constexpr int numVertices = 30; - int NUM_I = numVertices; - int NUM_J = numVertices; - float SPARSITY = .3; - - Tensor A("A", {NUM_I, NUM_J}, CSC); - Tensor x("x", {NUM_J}, {compressed}); - Tensor y("y", {NUM_I}, {Dense}); - - int one = 1; - int zero = 0; - - Op scOr("Or", BitOrImpl(), {Annihilator(one), Identity(zero)}); - Op scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); - Op bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); - - srand(120); - for (int i = 0; i < NUM_I; i++) { - for (int j = 0; j < NUM_J; j++) { - float rand_float = (float)rand()/(float)(RAND_MAX); - if (rand_float < SPARSITY) { - A.insert({i, j}, one); - } - } - } - - for (int j = 0; j < NUM_J; j++) { - float rand_float = (float)rand()/(float)(RAND_MAX); - if (rand_float < SPARSITY) { - x.insert({j}, one); - } - } - - x.pack(); - A.pack(); - y(i) = Reduction(scOr(), j, scAnd(A(i, j), x(j))); - - IndexStmt stmt = y.getAssignment().concretize(); - stmt = stmt.reorder(i, j) - .parallelize(j, ParallelUnit::CPUThread, OutputRaceStrategy::Atomics); - - //printToFile("spmv_cpu", stmt); + Func sparsifyOp("sparsify", identityFunc(), ComplementUnion()); + s(i) = sparsifyOp(d(i), i); + IndexStmt sparsify = s.getAssignment().concretize(); + printToFile("sparsify", sparsify); - y.compile(stmt); - y.assemble(); - y.compute(); - Tensor expected("expected", {NUM_I}, {Dense}); - expected(i) = Reduction(scOr(), j, scAnd(A(i, j), x(j))); + Tensor expected("expected", {NUM_I}, {Dense}); + expected(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), m(i))); expected.compile(); expected.assemble(); expected.compute(); From 40364381af44d5f19a992ab9a3f6381aecea204e Mon Sep 17 00:00:00 2001 From: Rawn Date: Sat, 6 Feb 2021 19:35:35 -0800 Subject: [PATCH 22/27] Rename opImpl and algImpl --- include/taco/index_notation/tensor_operator.h | 38 +++++++++---------- src/index_notation/tensor_operator.cpp | 28 +++++++------- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h index e129ae971..fc5aa69a9 100644 --- a/include/taco/index_notation/tensor_operator.h +++ b/include/taco/index_notation/tensor_operator.h @@ -14,36 +14,36 @@ namespace taco { class Func { -// TODO: RENAME -using OpImpl = CallNode::OpImpl; -using AlgebraImpl = CallNode::AlgebraImpl; + +using FuncBodyGenerator = CallNode::OpImpl; +using FuncAlgebraGenerator = CallNode::AlgebraImpl; // TODO: Make this part of callNode and call. Add generateIterationAlgebra() and generateImplementation() functions public: // Full construction - Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, - std::map, OpImpl> specialDefinitions = {}); + Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); - Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, - std::map, OpImpl> specialDefinitions = {}); + Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); // Construct without specifying algebra - Func(std::string name, OpImpl lowererFunc, std::vector properties, - std::map, OpImpl> specialDefinitions = {}); + Func(std::string name, FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); - Func(OpImpl lowererFunc, std::vector properties, - std::map, OpImpl> specialDefinitions = {}); + Func(FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); // Construct without properties - Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, - std::map, OpImpl> specialDefinitions = {}); + Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions = {}); - Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::map, OpImpl> specialDefinitions = {}); + Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); // Construct without algebra or properties - Func(std::string name, OpImpl lowererFunc, std::map, OpImpl> specialDefinitions = {}); + Func(std::string name, FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); - explicit Func(OpImpl lowererFunc, std::map, OpImpl> specialDefinitions = {}); + explicit Func(FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); template Call operator()(IndexExprs&&... exprs) { @@ -59,10 +59,10 @@ using AlgebraImpl = CallNode::AlgebraImpl; private: std::string name; - OpImpl lowererFunc; - AlgebraImpl algebraFunc; + FuncBodyGenerator lowererFunc; + FuncAlgebraGenerator algebraFunc; std::vector properties; - std::map, OpImpl> regionDefinitions; + std::map, FuncBodyGenerator> regionDefinitions; IterationAlgebra inferAlgFromProperties(const std::vector& exprs); diff --git a/src/index_notation/tensor_operator.cpp b/src/index_notation/tensor_operator.cpp index 5e0939465..0526e552f 100644 --- a/src/index_notation/tensor_operator.cpp +++ b/src/index_notation/tensor_operator.cpp @@ -3,46 +3,46 @@ namespace taco { // Full construction -Func::Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, - std::map, OpImpl> specialDefinitions) +Func::Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) : name(util::uniqueName("Func")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } -Func::Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, std::vector properties, - std::map, OpImpl> specialDefinitions) +Func::Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), regionDefinitions(specialDefinitions) { } // Construct without specifying algebra -Func::Func(std::string name, OpImpl lowererFunc, std::vector properties, - std::map, OpImpl> specialDefinitions) +Func::Func(std::string name, FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) : Func(name, lowererFunc, nullptr, properties, specialDefinitions) { } -Func::Func(OpImpl lowererFunc, std::vector properties, - std::map, OpImpl> specialDefinitions) +Func::Func(FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) : Func(util::uniqueName("Func"), lowererFunc, nullptr, properties, specialDefinitions) { } // Construct without properties -Func::Func(std::string name, OpImpl lowererFunc, AlgebraImpl algebraFunc, - std::map, OpImpl> specialDefinitions) +Func::Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions) : Func(name, lowererFunc, algebraFunc, {}, specialDefinitions) { } -Func::Func(OpImpl lowererFunc, AlgebraImpl algebraFunc, - std::map, OpImpl> specialDefinitions) : +Func::Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions) : Func(util::uniqueName("Func"), lowererFunc, algebraFunc, {}, specialDefinitions) { } // Construct without algebra or properties -Func::Func(std::string name, OpImpl lowererFunc, std::map, OpImpl> specialDefinitions) +Func::Func(std::string name, FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions) : Func(name, lowererFunc, nullptr, specialDefinitions) { } -Func::Func(OpImpl lowererFunc, std::map, OpImpl> specialDefinitions) +Func::Func(FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions) : Func(lowererFunc, nullptr, specialDefinitions) { } From 607c6c149656389a5046f950cac2625be8faa297 Mon Sep 17 00:00:00 2001 From: Rawn Date: Sat, 6 Feb 2021 20:33:10 -0800 Subject: [PATCH 23/27] Fixes bug that occurred during merge. One of the constructors of TensorVar was merged incorrectly so all formats were dense inside the lattice --- src/index_notation/index_notation.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index b7400851d..949517b80 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -2139,7 +2139,7 @@ TensorVar::TensorVar(const Type& type, const Format& format, const Literal& fill } TensorVar::TensorVar(const string& name, const Type& type, const Format& format, const Literal& fill) - : TensorVar(-1, name, type, createDenseFormat(type), fill) { + : TensorVar(-1, name, type, format, fill) { } TensorVar::TensorVar(const int& id, const string& name, const Type& type, const Format& format, const Literal& fill) From e3c13f3cd1437901ab4bf2de92255ce496483f69 Mon Sep 17 00:00:00 2001 From: Rawn Date: Sat, 6 Feb 2021 22:52:21 -0800 Subject: [PATCH 24/27] Uses alloc_mem instead of malloc to allocate fill value. This caused the GPU backend to be completely non-functional --- include/taco/storage/storage.h | 9 +++------ include/taco/taco_tensor_t.h | 2 +- src/storage/pack.cpp | 2 +- src/storage/storage.cpp | 27 ++++++++------------------- src/taco_tensor_t.cpp | 15 ++++++++++++++- src/tensor.cpp | 10 +++++----- 6 files changed, 32 insertions(+), 33 deletions(-) diff --git a/include/taco/storage/storage.h b/include/taco/storage/storage.h index f14c3a7e4..a3bf987c3 100644 --- a/include/taco/storage/storage.h +++ b/include/taco/storage/storage.h @@ -52,13 +52,12 @@ class TensorStorage { /// Returns the value array that contains the tensor components. const Array& getValues() const; - /// Returns the fill array containing the tensor fill value. This is always - /// of size one. - const Array& getFill() const; - /// Returns the tensor component value array. Array getValues(); + /// Returns the full value attached to the tensor storage + Literal getFillValue(); + /// Returns the size of the storage in bytes. size_t getSizeInBytes(); @@ -71,8 +70,6 @@ class TensorStorage { /// Set the tensor component value array. void setValues(const Array& values); - /// Set the fill array. This should always be size 1 - void setFill(const Array& fill); private: struct Content; diff --git a/include/taco/taco_tensor_t.h b/include/taco/taco_tensor_t.h index d3f865b42..f1777510d 100644 --- a/include/taco/taco_tensor_t.h +++ b/include/taco/taco_tensor_t.h @@ -24,7 +24,7 @@ typedef struct taco_tensor_t { taco_tensor_t *init_taco_tensor_t(int32_t order, int32_t csize, int32_t* dimensions, int32_t* modeOrdering, - taco_mode_t* mode_types); + taco_mode_t* mode_types, void* fill_ptr); void deinit_taco_tensor_t(taco_tensor_t* t); diff --git a/src/storage/pack.cpp b/src/storage/pack.cpp index 852b230a0..3d0f69588 100644 --- a/src/storage/pack.cpp +++ b/src/storage/pack.cpp @@ -161,7 +161,7 @@ TensorStorage pack(Datatype componentType, } void* vals = malloc(maxSize * componentType.getNumBytes()); - const void* fillData = storage.getFill().getData(); + const void* fillData = storage.getFillValue().defined()? storage.getFillValue().getValPtr() : nullptr; int actual_size = packTensor(dimensions, coordinates, (char *) values, fillData, 0, numCoordinates, format.getModeFormats(), 0, &indices, (char *)vals, componentType, 0); diff --git a/src/storage/storage.cpp b/src/storage/storage.cpp index afceab2bd..9227420b2 100644 --- a/src/storage/storage.cpp +++ b/src/storage/storage.cpp @@ -27,8 +27,7 @@ struct TensorStorage::Content { Index index; Array values; - // Always an array of size 1 - Array fillValue; + Literal fillValue; Content(Datatype componentType, vector dimensions, Format format, Literal fill) : componentType(componentType), dimensions(dimensions), format(format), @@ -58,14 +57,10 @@ struct TensorStorage::Content { } } - void* fillData = malloc(componentType.getNumBytes()); - memcpy(fillData, fill.getValPtr(), componentType.getNumBytes()); - - fillValue = Array(componentType, fillData, 1, Array::Policy::Free); - + fillValue = fill; tensorData = init_taco_tensor_t(order, componentType.getNumBits(), dimensionsInt32.data(), modeOrdering.data(), - modeTypes.data()); + modeTypes.data(), fill.getValPtr()); } ~Content() { @@ -106,15 +101,14 @@ const Array& TensorStorage::getValues() const { return content->values; } -const Array& TensorStorage::getFill() const { - taco_iassert(content->fillValue.getSize() == 1); - return content->fillValue; -} - Array TensorStorage::getValues() { return content->values; } +Literal TensorStorage::getFillValue() { + return content->fillValue; +} + size_t TensorStorage::getSizeInBytes() { size_t indexSizeInBytes = 0; const auto& index = getIndex(); @@ -177,7 +171,7 @@ TensorStorage::operator struct taco_tensor_t*() const { } tensorData->vals = (uint8_t*)getValues().getData(); - tensorData->fill_value = (uint8_t*)getFill().getData(); + tensorData->fill_value = (uint8_t*) content->fillValue.getValPtr(); return content->tensorData; } @@ -190,11 +184,6 @@ void TensorStorage::setValues(const Array& values) { content->values = values; } -void TensorStorage::setFill(const Array &fill) { - taco_iassert(fill.getSize() == 1); - content->fillValue = fill; -} - bool equals(TensorStorage a, TensorStorage b) { return false; } diff --git a/src/taco_tensor_t.cpp b/src/taco_tensor_t.cpp index e37ea5b13..2337be2a9 100644 --- a/src/taco_tensor_t.cpp +++ b/src/taco_tensor_t.cpp @@ -26,9 +26,12 @@ void free_mem(void *ptr) { } } +// Note about fill: +// It is planned to allow the fill of a result to be null and for TACO to set this when it does compute. This is the +// case we currently expect the fill pointer to be null. taco_tensor_t* init_taco_tensor_t(int32_t order, int32_t csize, int32_t* dimensions, int32_t* modeOrdering, - taco_mode_t* mode_types) { + taco_mode_t* mode_types, void* fill_ptr) { taco_tensor_t* t = (taco_tensor_t *) alloc_mem(sizeof(taco_tensor_t)); t->order = order; t->dimensions = (int32_t *) alloc_mem(order * sizeof(int32_t)); @@ -37,6 +40,16 @@ taco_tensor_t* init_taco_tensor_t(int32_t order, int32_t csize, t->indices = (uint8_t ***) alloc_mem(order * sizeof(uint8_t***)); t->csize = csize; + int fill_bytes = csize / 8; + t->fill_value = (uint8_t*) alloc_mem(fill_bytes); + + if (fill_ptr) { + uint8_t* fill_inp = (uint8_t*) fill_ptr; + for (int i = 0; i < fill_bytes; ++i) { + t->fill_value[i] = fill_inp[i]; + } + } + for (int32_t i = 0; i < order; i++) { t->dimensions[i] = dimensions[i]; t->mode_ordering[i] = modeOrdering[i]; diff --git a/src/tensor.cpp b/src/tensor.cpp index 1450c90bb..c12e209c8 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -372,15 +372,16 @@ void TensorBase::pack() { std::vector bufferDim = {1}; std::vector bufferModeOrdering = {0}; std::vector bufferCoords(numCoordinates, 0); + + void* fillPtr = getStorage().getFillValue().defined()? getStorage().getFillValue().getValPtr() : nullptr; taco_tensor_t* bufferStorage = init_taco_tensor_t(1, csize, (int32_t*)bufferDim.data(), (int32_t*)bufferModeOrdering.data(), - (taco_mode_t*)bufferModeType.data()); + (taco_mode_t*)bufferModeType.data(), fillPtr); std::vector pos = {0, (int)numCoordinates}; bufferStorage->indices[0][0] = (uint8_t*)pos.data(); bufferStorage->indices[0][1] = (uint8_t*)bufferCoords.data(); bufferStorage->vals = (uint8_t*)content->coordinateBuffer->data(); - bufferStorage->fill_value = (uint8_t*)(getStorage().getFill().getData()); std::vector arguments = {content->storage, bufferStorage}; helperFuncs->callFuncPacked("pack", arguments.data()); @@ -440,18 +441,17 @@ void TensorBase::pack() { content->coordinateBuffer->clear(); content->coordinateBufferUsed = 0; - + void* fillPtr = getStorage().getFillValue().defined()? getStorage().getFillValue().getValPtr() : nullptr; std::vector bufferModeTypes(order, taco_mode_sparse); taco_tensor_t* bufferStorage = init_taco_tensor_t(order, csize, (int32_t*)dimensions.data(), (int32_t*)permutation.data(), - (taco_mode_t*)bufferModeTypes.data()); + (taco_mode_t*)bufferModeTypes.data(), fillPtr); std::vector pos = {0, (int)numCoordinates}; bufferStorage->indices[0][0] = (uint8_t*)pos.data(); for (int i = 0; i < order; ++i) { bufferStorage->indices[i][1] = (uint8_t*)coordinates[i].data(); } bufferStorage->vals = (uint8_t*)values; - bufferStorage->fill_value = (uint8_t*)(getStorage().getFill().getData()); // Pack nonzero components into required format std::vector arguments = {content->storage, bufferStorage}; From 05d3c93170c1f8fd6a68a6f5ac9b3d4cecb3f6a5 Mon Sep 17 00:00:00 2001 From: Rawn Date: Sun, 7 Feb 2021 00:35:35 -0800 Subject: [PATCH 25/27] Fixes lattice construction when scheduled index variables are used in an expression. Also, prevents generation of unnecessary merge loops being created at some loop levels --- src/lower/merge_lattice.cpp | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 1ea4ee0e7..969bdd6fa 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -155,8 +155,32 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationA } void visit(const IndexVarNode* varNode) { + // There are a few cases here... + + // 1) If var in the expression is the same as the var being lowered, we need to return a lattice + // with one point that iterates over the universe of the current dimension. + // Why: TACO needs to know if it needs to generate merge loops to deal with computing the current index var. + // Eg. b(i) + i where b is sparse. To preserve semantics, we need to merge the sparse iteration set of b + // with the implied dense space of i. + // Question: What if the user WANTS i to be 'sparse'? Just define a func where + is an intersection =) + // 2) The vars differ. This case actually has 2 subcases... + // a) The loop variable ('i' in this builder) is derived from the variable used in the expression + // ('var' defined below). In this case, return a mode iterator over the derived var ('i' in the builder) + // so taco can generate the correct merge loops for this level. + // b) The loop variable is not derived from the variable used in the expression. In this case, we just return + // an empty lattice as there is nothing that needs to be merged =) + // TODO: Add these cases to the test suite.... IndexVar var(varNode); - lattice = MergeLattice({MergePoint({iterators.modeIterator(var)}, {}, {})}); + taco_iassert(provGraph.isUnderived(var)); + if (var == i) { + lattice = MergeLattice({MergePoint({Iterator(var)}, {}, {})}); + } else { + if (provGraph.isDerivedFrom(i, var)) { + lattice = MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); + } else { + lattice = MergeLattice({}); + } + } } void visit(const AccessNode* access) From a76408ac458c9acb5bdb14446ff30b73d0d18ee9 Mon Sep 17 00:00:00 2001 From: Rawn Date: Wed, 17 Feb 2021 13:53:04 -0800 Subject: [PATCH 26/27] Make IndexVars inherit from IndexExpr again --- include/taco/index_notation/index_notation.h | 37 ++++++-------------- include/taco/lower/lowerer_impl.h | 2 +- src/lower/lowerer_impl.cpp | 4 +-- 3 files changed, 14 insertions(+), 29 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 27c082a86..76822bc09 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -463,31 +463,6 @@ class CallIntrinsic : public IndexExpr { typedef CallIntrinsicNode Node; }; - -/// Index variables are used to index into tensors in index expressions, and -/// they represent iteration over the tensor modes they index into. -/// Index variables can also be used in computation -class IndexVar : public IndexExpr { -public: - IndexVar(); - IndexVar(const std::string& name); - IndexVar(const std::string& name, const Datatype& type); - IndexVar(const IndexVarNode *); - - /// Returns the name of the index variable. - std::string getName() const; - - // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer - friend bool operator==(const IndexVar&, const IndexVar&); - friend bool operator<(const IndexVar&, const IndexVar&); - friend bool operator!=(const IndexVar&, const IndexVar&); - friend bool operator>=(const IndexVar&, const IndexVar&); - friend bool operator<=(const IndexVar&, const IndexVar&); - friend bool operator>(const IndexVar&, const IndexVar&); - - typedef IndexVarNode Node; -}; - std::ostream& operator<<(std::ostream&, const IndexVar&); /// Create calls to various intrinsics. @@ -928,17 +903,27 @@ class WindowedIndexVar : public util::Comparable, public Index /// Index variables are used to index into tensors in index expressions, and /// they represent iteration over the tensor modes they index into. -class IndexVar : public util::Comparable, public IndexVarInterface { +class IndexVar : public IndexExpr, public IndexVarInterface { + public: IndexVar(); ~IndexVar() = default; IndexVar(const std::string& name); + IndexVar(const std::string& name, const Datatype& type); + IndexVar(const IndexVarNode *); /// Returns the name of the index variable. std::string getName() const; + // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer friend bool operator==(const IndexVar&, const IndexVar&); friend bool operator<(const IndexVar&, const IndexVar&); + friend bool operator!=(const IndexVar&, const IndexVar&); + friend bool operator>=(const IndexVar&, const IndexVar&); + friend bool operator<=(const IndexVar&, const IndexVar&); + friend bool operator>(const IndexVar&, const IndexVar&); + + typedef IndexVarNode Node; /// Indexing into an IndexVar returns a window into it. WindowedIndexVar operator()(int lo, int hi); diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index a31aff58a..e6d5b939e 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -436,7 +436,7 @@ class LowererImpl : public util::Uncopyable { /// Emit early exit ir::Stmt emitEarlyExit(ir::Expr reductionExpr, std::vector&); -======= + /// Expression that returns the beginning of a window to iterate over /// in a compressed iterator. It is used when operating over windows of /// tensors, instead of the full tensor. diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 6a4512c97..6fc32467d 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -3096,7 +3096,7 @@ Expr LowererImpl::searchForStartOfWindowPosition(Iterator iterator, ir::Expr sta // for the beginning of the window. iterator.getWindowLowerBound(), }; - return Call::make("taco_binarySearchAfter", args, Datatype::UInt64); + return ir::Call::make("taco_binarySearchAfter", args, Datatype::UInt64); } Expr LowererImpl::searchForEndOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end) { @@ -3109,7 +3109,7 @@ Expr LowererImpl::searchForEndOfWindowPosition(Iterator iterator, ir::Expr start // for the end of the window. iterator.getWindowUpperBound(), }; - return Call::make("taco_binarySearchAfter", args, Datatype::UInt64); + return ir::Call::make("taco_binarySearchAfter", args, Datatype::UInt64); } Stmt LowererImpl::upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access) { From 4e27678ace702c9305b52e317e46ccd5b2e41ae9 Mon Sep 17 00:00:00 2001 From: Rohan Yadav Date: Wed, 17 Feb 2021 16:19:17 -0800 Subject: [PATCH 27/27] lower: fix a bug introduced by merging windowing and array algebra This commit fixes a bug caused by merging together the windowing and array algebra work. In particular, the newly introduced deep equality defined on `Access` types did not include additions for the newly added components used in windowing. --- include/taco/index_notation/index_notation.h | 5 +++ .../index_notation/index_notation_nodes.h | 6 ++++ src/index_notation/index_notation.cpp | 32 +++++++++++++++++ src/lower/mode_access.cpp | 35 +++---------------- 4 files changed, 48 insertions(+), 30 deletions(-) diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 76822bc09..640d87dd1 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -269,6 +269,11 @@ class Access : public IndexExpr { Assignment operator+=(const IndexExpr&); typedef AccessNode Node; + + // Equality and comparison are overridden on Access to perform a deep + // comparison of the access rather than a pointer check. + friend bool operator==(const Access& a, const Access& b); + friend bool operator<(const Access& a, const Access &b); }; diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 3df5a1d32..d3284f0ab 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -28,6 +28,12 @@ struct AccessWindow { friend bool operator==(const AccessWindow& a, const AccessWindow& b) { return a.lo == b.lo && a.hi == b.hi; } + friend bool operator<(const AccessWindow& a, const AccessWindow& b) { + if (a.lo != b.lo) { + return a.lo < b.lo; + } + return a.hi < b.hi; + } }; struct AccessNode : public IndexExprNode { diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index efb330668..6b3fd71b4 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -999,6 +999,38 @@ int Access::getWindowUpperBound(int mode) const { return getNode(*this)->windowedModes.at(mode).hi; } +bool operator==(const Access& a, const Access& b) { + // Short-circuit for when the Access pointers are the same. + if (getNode(a) == getNode(b)) { + return true; + } + if (a.getTensorVar() != b.getTensorVar()) { + return false; + } + if (a.getIndexVars() != b.getIndexVars()) { + return false; + } + if (getNode(a)->windowedModes != getNode(b)->windowedModes) { + return false; + } + return true; +} + +bool operator<(const Access& a, const Access& b) { + // First branch on tensorVar. + if(a.getTensorVar() != b.getTensorVar()) { + return a.getTensorVar() < b.getTensorVar(); + } + + // Then branch on the indexVars used in the access. + if (a.getIndexVars() != b.getIndexVars()) { + return a.getIndexVars() < b.getIndexVars(); + } + + // Lastly, branch on the windows. + return getNode(a)->windowedModes < getNode(b)->windowedModes; +} + static void check(Assignment assignment) { auto lhs = assignment.getLhs(); auto tensorVar = lhs.getTensorVar(); diff --git a/src/lower/mode_access.cpp b/src/lower/mode_access.cpp index fcb0d19c4..f682e915e 100644 --- a/src/lower/mode_access.cpp +++ b/src/lower/mode_access.cpp @@ -13,43 +13,18 @@ size_t ModeAccess::getModePos() const { return mode; } -static bool accessEqual(const Access& a, const Access& b) { - return a == b || - (a.getTensorVar() == b.getTensorVar() && a.getIndexVars() == b.getIndexVars()); -} - bool operator==(const ModeAccess& a, const ModeAccess& b) { - return accessEqual(a.getAccess(), b.getAccess()) && a.getModePos() == b.getModePos(); + return a.getAccess() == b.getAccess() && a.getModePos() == b.getModePos(); } bool operator<(const ModeAccess& a, const ModeAccess& b) { - - // fast path for when access pointers are equal - if(a.getAccess() == b.getAccess()) { + // First break on the mode position. + if (a.getModePos() != b.getModePos()) { return a.getModePos() < b.getModePos(); } - // First break on tensorVars - if(a.getAccess().getTensorVar() != b.getAccess().getTensorVar()) { - return a.getAccess().getTensorVar() < b.getAccess().getTensorVar(); - } - - // Then break on the indexVars used in the access - std::vector aVars = a.getAccess().getIndexVars(); - std::vector bVars = b.getAccess().getIndexVars(); - - if(aVars.size() != bVars.size()) { - return aVars.size() < bVars.size(); - } - - for(size_t i = 0; i < aVars.size(); ++i) { - if(aVars[i] != bVars[i]) { - return aVars[i] < bVars[i]; - } - } - - // Finally, break on the mode position - return a.getModePos() < b.getModePos(); + // Then, return a deep comparison of the underlying access. + return a.getAccess()