diff --git a/include/taco.h b/include/taco.h index 016960ffc..9293b3414 100644 --- a/include/taco.h +++ b/include/taco.h @@ -3,6 +3,7 @@ #include "taco/tensor.h" #include "taco/format.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation.h" #endif diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 0646398b0..640d87dd1 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -9,7 +9,9 @@ #include #include #include +#include +#include "taco/util/name_generator.h" #include "taco/format.h" #include "taco/error.h" #include "taco/util/intrusive_ptr.h" @@ -22,6 +24,7 @@ #include "taco/ir_tags.h" #include "taco/lower/iterator.h" #include "taco/index_notation/provenance_graph.h" +#include "taco/index_notation/properties.h" namespace taco { @@ -38,6 +41,8 @@ class IndexExpr; class Assignment; class Access; +class IterationAlgebra; + struct AccessNode; struct AccessWindow; struct LiteralNode; @@ -48,8 +53,10 @@ struct SubNode; struct MulNode; struct DivNode; struct CastNode; +struct CallNode; struct CallIntrinsicNode; struct ReductionNode; +struct IndexVarNode; struct AssignmentNode; struct YieldNode; @@ -262,6 +269,11 @@ class Access : public IndexExpr { Assignment operator+=(const IndexExpr&); typedef AccessNode Node; + + // Equality and comparison are overridden on Access to perform a deep + // comparison of the access rather than a pointer check. + friend bool operator==(const Access& a, const Access& b); + friend bool operator<(const Access& a, const Access &b); }; @@ -289,11 +301,14 @@ class Literal : public IndexExpr { Literal(std::complex); Literal(std::complex); - static IndexExpr zero(Datatype); + static Literal zero(Datatype); /// Returns the literal value. template T getVal() const; + /// Returns an untyped pointer to the literal value + void* getValPtr(); + typedef LiteralNode Node; }; @@ -413,6 +428,26 @@ class Cast : public IndexExpr { typedef CastNode Node; }; +/// A call to an operator +class Call: public IndexExpr { +public: + Call() = default; + Call(const CallNode*); + Call(const CallNode*, std::string name); + + const std::vector& getArgs() const; + const std::function&)> getFunc() const; + const IterationAlgebra& getAlgebra() const; + const std::vector& getProperties() const; + const std::string getName() const; + const std::map, std::function&)>> getDefs() const; + const std::vector& getDefinedArgs() const; + + typedef CallNode Node; + +private: + std::string name; +}; /// A call to an intrinsic. /// ``` @@ -433,6 +468,8 @@ class CallIntrinsic : public IndexExpr { typedef CallIntrinsicNode Node; }; +std::ostream& operator<<(std::ostream&, const IndexVar&); + /// Create calls to various intrinsics. IndexExpr mod(IndexExpr, IndexExpr); IndexExpr abs(IndexExpr); @@ -871,17 +908,27 @@ class WindowedIndexVar : public util::Comparable, public Index /// Index variables are used to index into tensors in index expressions, and /// they represent iteration over the tensor modes they index into. -class IndexVar : public util::Comparable, public IndexVarInterface { +class IndexVar : public IndexExpr, public IndexVarInterface { + public: IndexVar(); ~IndexVar() = default; IndexVar(const std::string& name); + IndexVar(const std::string& name, const Datatype& type); + IndexVar(const IndexVarNode *); /// Returns the name of the index variable. std::string getName() const; + // Need these to overshadow the comparisons in for the IndexExpr instrusive pointer friend bool operator==(const IndexVar&, const IndexVar&); friend bool operator<(const IndexVar&, const IndexVar&); + friend bool operator!=(const IndexVar&, const IndexVar&); + friend bool operator>=(const IndexVar&, const IndexVar&); + friend bool operator<=(const IndexVar&, const IndexVar&); + friend bool operator>(const IndexVar&, const IndexVar&); + + typedef IndexVarNode Node; /// Indexing into an IndexVar returns a window into it. WindowedIndexVar operator()(int lo, int hi); @@ -927,11 +974,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector predicate); class TensorVar : public util::Comparable { public: TensorVar(); - TensorVar(const Type& type); - TensorVar(const std::string& name, const Type& type); - TensorVar(const Type& type, const Format& format); - TensorVar(const std::string& name, const Type& type, const Format& format); - TensorVar(const int &id, const std::string& name, const Type& type, const Format& format); + TensorVar(const Type& type, const Literal& fill = Literal()); + TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal()); + TensorVar(const Type& type, const Format& format, const Literal& fill = Literal()); + TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal()); + TensorVar(const int &id, const std::string& name, const Type& type, const Format& format, + const Literal& fill = Literal()); /// Returns the ID of the tensor variable. int getId() const; @@ -952,6 +1000,12 @@ class TensorVar : public util::Comparable { /// and execute it's expression. const Schedule& getSchedule() const; + /// Gets the fill value of the tensor variable. May be left undefined. + const Literal& getFill() const; + + /// Set the fill value of the tensor variable + void setFill(const Literal& fill); + /// Set the name of the tensor variable. void setName(std::string name); @@ -1008,7 +1062,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr); bool isReductionNotation(IndexStmt, std::string* reason=nullptr); /// Check whether the statement is in the concrete index notation dialect. -/// This means every index variable has a forall node, there are no reduction +/// This means every index variable has a forall node, each index variable used +/// for computation is under a forall node for that variable, there are no reduction /// nodes, and that every reduction variable use is nested inside a compound /// assignment statement. You can optionally pass in a pointer to a string /// that the reason why it is not concrete notation is printed to. @@ -1030,7 +1085,12 @@ std::vector getResults(IndexStmt stmt); /// Returns the input tensors to the index statement, in the order they appear. std::vector getArguments(IndexStmt stmt); -/// Returns the temporaries in the index statement, in the order they appear. +/// Returns true iff all of the loops over free variables come before all of the loops over +/// reduction variables. Therefore, this returns true if the reduction controlled by the loops +/// does not a scatter. +bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt); + + /// Returns the temporaries in the index statement, in the order they appear. std::vector getTemporaries(IndexStmt stmt); // [Olivia] @@ -1070,7 +1130,15 @@ IndexExpr zero(IndexExpr, const std::set& zeroed); /// zero and then propagating and removing zeroes. IndexStmt zero(IndexStmt, const std::set& zeroed); -/// Create an `other` tensor with the given name and format, +/// Infers the fill value of the input expression by applying properties if possible. If unable +/// to successfully infer the fill value of the result, returns the empty IndexExpr +IndexExpr inferFill(IndexExpr); + +/// Returns true if there are no forall nodes in the indexStmt. Used to check +/// if the last loop is being lowered. +bool hasNoForAlls(IndexStmt); + +/// Create an `other` tensor with the given name and format, /// and return tensor(indexVars) = other(indexVars) if otherIsOnRight, /// and otherwise returns other(indexVars) = tensor(indexVars). IndexStmt generatePackStmt(TensorVar tensor, diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index d18ee4a3d..d3284f0ab 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -3,13 +3,18 @@ #include #include +#include #include "taco/type.h" +#include "taco/util/collections.h" +#include "taco/util/comparable.h" #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_nodes_abstract.h" #include "taco/index_notation/index_notation_visitor.h" #include "taco/index_notation/intrinsic.h" #include "taco/util/strings.h" +#include "iteration_algebra.h" +#include "properties.h" namespace taco { @@ -23,6 +28,12 @@ struct AccessWindow { friend bool operator==(const AccessWindow& a, const AccessWindow& b) { return a.lo == b.lo && a.hi == b.hi; } + friend bool operator<(const AccessWindow& a, const AccessWindow& b) { + if (a.lo != b.lo) { + return a.lo < b.lo; + } + return a.hi < b.hi; + } }; struct AccessNode : public IndexExprNode { @@ -68,7 +79,6 @@ struct LiteralNode : public IndexExprNode { void* val; }; - struct UnaryExprNode : public IndexExprNode { IndexExpr a; @@ -188,6 +198,57 @@ struct CallIntrinsicNode : public IndexExprNode { std::vector args; }; +struct CallNode : public IndexExprNode { + typedef std::function&)> OpImpl; + typedef std::function&)> AlgebraImpl; + + CallNode(std::string name, const std::vector& args, OpImpl lowerFunc, + const IterationAlgebra& iterAlg, + const std::vector& properties, + const std::map, OpImpl>& regionDefinitions, + const std::vector& definedRegions); + + CallNode(std::string name, const std::vector& args, OpImpl lowerFunc, + const IterationAlgebra& iterAlg, + const std::vector& properties, + const std::map, OpImpl>& regionDefinitions); + + void accept(IndexExprVisitorStrict* v) const { + v->visit(this); + } + + std::string name; + std::vector args; + OpImpl defaultLowerFunc; + IterationAlgebra iterAlg; + std::vector properties; + std::map, OpImpl> regionDefinitions; + + // Needed to track which inputs have been exhausted so the lowerer can know which lower func to use + std::vector definedRegions; + +private: + static Datatype inferReturnType(OpImpl f, const std::vector& inputs) { + std::function getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); }; + std::vector exprs = util::map(inputs, getExprs); + + if(exprs.empty()) { + return taco::Datatype(); + } + + return f(exprs).type(); + } + + static std::vector definedIndices(std::vector args) { + std::vector v; + for(int i = 0; i < (int) args.size(); ++i) { + if(args[i].defined()) { + v.push_back(i); + } + } + return v; + } +}; struct ReductionNode : public IndexExprNode { ReductionNode(IndexExpr op, IndexVar var, IndexExpr a); @@ -202,6 +263,27 @@ struct ReductionNode : public IndexExprNode { IndexExpr a; }; +struct IndexVarNode : public IndexExprNode, public util::Comparable { + IndexVarNode() = delete; + IndexVarNode(const std::string& name, const Datatype& type); + + void accept(IndexExprVisitorStrict* v) const { + v->visit(this); + } + + std::string getName() const; + + friend bool operator==(const IndexVarNode& a, const IndexVarNode& b); + friend bool operator<(const IndexVarNode& a, const IndexVarNode& b); + +private: + struct Content; + std::shared_ptr content; +}; + +struct IndexVarNode::Content { + std::string name; +}; // Index Statements struct AssignmentNode : public IndexStmtNode { diff --git a/include/taco/index_notation/index_notation_printer.h b/include/taco/index_notation/index_notation_printer.h index 61b080fc0..7c32f25b1 100644 --- a/include/taco/index_notation/index_notation_printer.h +++ b/include/taco/index_notation/index_notation_printer.h @@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict { void visit(const MulNode*); void visit(const DivNode*); void visit(const CastNode*); + void visit(const CallNode*); void visit(const CallIntrinsicNode*); void visit(const ReductionNode*); + void visit(const IndexVarNode*); // Tensor Expressions void visit(const AssignmentNode*); diff --git a/include/taco/index_notation/index_notation_rewriter.h b/include/taco/index_notation/index_notation_rewriter.h index 3551aac5e..caaa64773 100644 --- a/include/taco/index_notation/index_notation_rewriter.h +++ b/include/taco/index_notation/index_notation_rewriter.h @@ -32,8 +32,10 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict { virtual void visit(const MulNode* op) = 0; virtual void visit(const DivNode* op) = 0; virtual void visit(const CastNode* op) = 0; + virtual void visit(const CallNode* op) = 0; virtual void visit(const CallIntrinsicNode* op) = 0; virtual void visit(const ReductionNode* op) = 0; + virtual void visit(const IndexVarNode* op) = 0; }; @@ -93,8 +95,10 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict { virtual void visit(const MulNode* op); virtual void visit(const DivNode* op); virtual void visit(const CastNode* op); + virtual void visit(const CallNode* op); virtual void visit(const CallIntrinsicNode* op); virtual void visit(const ReductionNode* op); + virtual void visit(const IndexVarNode* op); virtual void visit(const AssignmentNode* op); virtual void visit(const YieldNode* op); diff --git a/include/taco/index_notation/index_notation_visitor.h b/include/taco/index_notation/index_notation_visitor.h index 97a70adc2..adc0e6787 100644 --- a/include/taco/index_notation/index_notation_visitor.h +++ b/include/taco/index_notation/index_notation_visitor.h @@ -20,10 +20,12 @@ struct MulNode; struct DivNode; struct SqrtNode; struct CastNode; +struct CallNode; struct CallIntrinsicNode; struct UnaryExprNode; struct BinaryExprNode; struct ReductionNode; +struct IndexVarNode; struct AssignmentNode; struct YieldNode; @@ -50,8 +52,10 @@ class IndexExprVisitorStrict { virtual void visit(const DivNode*) = 0; virtual void visit(const SqrtNode*) = 0; virtual void visit(const CastNode*) = 0; + virtual void visit(const CallNode*) = 0; virtual void visit(const CallIntrinsicNode*) = 0; virtual void visit(const ReductionNode*) = 0; + virtual void visit(const IndexVarNode*) = 0; }; class IndexStmtVisitorStrict { @@ -96,10 +100,12 @@ class IndexNotationVisitor : public IndexNotationVisitorStrict { virtual void visit(const DivNode* node); virtual void visit(const SqrtNode* node); virtual void visit(const CastNode* node); + virtual void visit(const CallNode* node); virtual void visit(const CallIntrinsicNode* node); virtual void visit(const UnaryExprNode* node); virtual void visit(const BinaryExprNode* node); virtual void visit(const ReductionNode* node); + virtual void visit(const IndexVarNode* node); // Index Statments virtual void visit(const AssignmentNode* node); @@ -164,11 +170,13 @@ class Matcher : public IndexNotationVisitor { RULE(MulNode) RULE(DivNode) RULE(CastNode) + RULE(CallNode) RULE(CallIntrinsicNode) RULE(ReductionNode) RULE(BinaryExprNode) RULE(UnaryExprNode) + RULE(IndexVarNode) RULE(AssignmentNode) RULE(YieldNode) diff --git a/include/taco/index_notation/iteration_algebra.h b/include/taco/index_notation/iteration_algebra.h new file mode 100644 index 000000000..a3b8fa34e --- /dev/null +++ b/include/taco/index_notation/iteration_algebra.h @@ -0,0 +1,243 @@ +#ifndef TACO_ITERATION_ALGEBRA_H +#define TACO_ITERATION_ALGEBRA_H + +#include "taco/index_notation/index_notation.h" +#include "taco/util/uncopyable.h" +#include "taco/util/comparable.h" +#include "taco/util/intrusive_ptr.h" + +namespace taco { + +class IterationAlgebraVisitorStrict; +class IndexExpr; + +struct IterationAlgebraNode; +struct RegionNode; +struct ComplementNode; +struct IntersectNode; +struct UnionNode; + +/// The iteration algebra class describes a set expression composed of complements, intersections and unions on +/// IndexExprs to describe the spaces in a Venn Diagram where computation will occur. +/// This algebra is used to generate merge lattices to co-iterate over tensors in an expression. +class IterationAlgebra : public util::IntrusivePtr { +public: + IterationAlgebra(); + IterationAlgebra(const IterationAlgebraNode* n); + IterationAlgebra(IndexExpr expr); + + void accept(IterationAlgebraVisitorStrict* v) const; +}; + +std::ostream& operator<<(std::ostream&, const IterationAlgebra&); + +/// A region in an Iteration space. Given a Tensor A, this produces values everywhere the tensorVar or access is defined. +class Region: public IterationAlgebra { +public: + Region(); + Region(IndexExpr expr); + Region(const RegionNode*); +}; + +/// This complements an iteration space algebra expression. Thus, it will flip the segments that are produced and +/// omitted in the input segment. +/// Example: Given a segment A which produces values where A is defined and omits values outside of A, +/// complement(A) will not compute where A is defined but compute over the background of A. +class Complement: public IterationAlgebra { +public: + Complement(IterationAlgebra alg); + Complement(const ComplementNode* n); +}; + +/// This intersects two iteration space algebra expressions. This instructs taco to compute over areas where BOTH +/// set expressions produce values and ignore all other segments. +/// +/// Examples +/// +/// Given two tensors A and B: +/// Intersect(A, B) will produce values where both A and B are defined. An example of an operation with this property +/// is multiplication where both tensors are sparse over 0. +/// +/// Intersect(Complement(A), B) will produce values where only B is defined. This pattern can be useful for filtering +/// one tensor based on the values of another. +class Intersect: public IterationAlgebra { +public: + Intersect(IterationAlgebra, IterationAlgebra); + Intersect(const IterationAlgebraNode*); +}; + +/// This takes the union of two iteration space algebra expressions. This instructs taco to compute over areas where +/// either set expression produces a value or both set expressions produce a value and ignore all other segments. +/// +/// Examples +/// +/// Given two tensors A and B: +/// Union(A, B) will produce values where either A or B is defined. Addition is an example of a union operator. +/// +/// Union(Complement(A), B) will produce values wherever A is not defined. In the places A is not defined, the compiler +/// will replace the value of A in the indexExpression with the fill value of the tensor A. Likewise, when B is not +/// defined, the compiler will replace the value of B in the index expression with the fill value of B. +class Union: public IterationAlgebra { +public: + Union(IterationAlgebra, IterationAlgebra); + Union(const IterationAlgebraNode*); +}; + +/// A node in the iteration space algebra +struct IterationAlgebraNode: public util::Manageable, + private util::Uncopyable { +public: + IterationAlgebraNode() {} + + virtual ~IterationAlgebraNode() = default; + virtual void accept(IterationAlgebraVisitorStrict*) const = 0; +}; + +/// A binary node in the iteration space algebra. Used for Unions and Intersects +struct BinaryIterationAlgebraNode: public IterationAlgebraNode { + IterationAlgebra a; + IterationAlgebra b; +protected: + BinaryIterationAlgebraNode(IterationAlgebra a, IterationAlgebra b) : IterationAlgebraNode(), a(a), b(b) {} +}; + +/// A node which is wrapped by Region. @see Region +struct RegionNode: public IterationAlgebraNode { +public: + RegionNode() : IterationAlgebraNode() {} + RegionNode(IndexExpr expr) : IterationAlgebraNode(), expr_(expr) {} + void accept(IterationAlgebraVisitorStrict*) const; + const IndexExpr expr() const; +private: + IndexExpr expr_; +}; + +/// A node which is wrapped by Complement. @see Complement +struct ComplementNode: public IterationAlgebraNode { + IterationAlgebra a; +public: + ComplementNode(IterationAlgebra a) : IterationAlgebraNode(), a(a) {} + + void accept(IterationAlgebraVisitorStrict*) const; +}; + +/// A node which is wrapped by Intersect. @see Intersect +struct IntersectNode: public BinaryIterationAlgebraNode { +public: + IntersectNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} + + void accept(IterationAlgebraVisitorStrict*) const; + + const std::string algebraString() const; +}; + +/// A node which is wrapped by Union. @see Union +struct UnionNode: public BinaryIterationAlgebraNode { +public: + UnionNode(IterationAlgebra a, IterationAlgebra b) : BinaryIterationAlgebraNode(a, b) {} + + void accept(IterationAlgebraVisitorStrict*) const; + + const std::string algebraString() const; +}; + +/// Visits an iteration space algebra expression +class IterationAlgebraVisitorStrict { +public: + virtual ~IterationAlgebraVisitorStrict() {} + void visit(const IterationAlgebra& alg); + + virtual void visit(const RegionNode*) = 0; + virtual void visit(const ComplementNode*) = 0; + virtual void visit(const IntersectNode*) = 0; + virtual void visit(const UnionNode*) = 0; +}; + +// Default Iteration Algebra visitor +class IterationAlgebraVisitor : public IterationAlgebraVisitorStrict { +public: + virtual ~IterationAlgebraVisitor() {} + using IterationAlgebraVisitorStrict::visit; + + virtual void visit(const RegionNode* n); + virtual void visit(const ComplementNode*); + virtual void visit(const IntersectNode*); + virtual void visit(const UnionNode*); +}; + +/// Rewrites an iteration algebra expression +class IterationAlgebraRewriterStrict : public IterationAlgebraVisitorStrict { +public: + virtual ~IterationAlgebraRewriterStrict() {} + IterationAlgebra rewrite(IterationAlgebra); + +protected: + /// Assign new algebra in visit method to replace the algebra nodes visited + IterationAlgebra alg; + + using IterationAlgebraVisitorStrict::visit; + + virtual void visit(const RegionNode*) = 0; + virtual void visit(const ComplementNode*) = 0; + virtual void visit(const IntersectNode*) = 0; + virtual void visit(const UnionNode*) = 0; +}; + +class IterationAlgebraRewriter : public IterationAlgebraRewriterStrict { +public: + virtual ~IterationAlgebraRewriter() {} + +protected: + using IterationAlgebraRewriterStrict::visit; + + virtual void visit(const RegionNode* n); + virtual void visit(const ComplementNode*); + virtual void visit(const IntersectNode*); + virtual void visit(const UnionNode*); +}; + +/// Returns true if algebra e is of type E. +template +inline bool isa(const IterationAlgebraNode* e) { + return e != nullptr && dynamic_cast(e) != nullptr; +} + +/// Casts the algebraNode e to type E. +template +inline const E* to(const IterationAlgebraNode* e) { + taco_iassert(isa(e)) << + "Cannot convert " << typeid(e).name() << " to " << typeid(E).name(); + return static_cast(e); +} + +/// Return true if the iteration algebra is of the given subtype. The subtypes +/// are Region, Complement, Union and Intersect. +template bool isa(IterationAlgebra); + +/// Casts the iteration algebra to the given subtype. Assumes S is a subtype and +/// the subtypes are Region, Complement, Union and Intersect. +template SubType to(IterationAlgebra); + +/// Returns true if the structure of the iteration algebra is the same. +/// This means that intersections, unions, complements and regions appear +/// in the same places but the IndexExpressions these operations are applied +/// to are not necessarily the same. +bool algStructureEqual(const IterationAlgebra&, const IterationAlgebra&); + +/// Returns true if the iterations algebras passed in have the same structure +/// and the Index Expressions that they operate on are the same. +bool algEqual(const IterationAlgebra&, const IterationAlgebra&); + +/// Applies demorgan's laws to the algebra passed in and returns a new algebra +/// which describes the same space but with complements appearing only around region +/// nodes. +IterationAlgebra applyDemorgan(IterationAlgebra alg); + +/// Rewrites the algebra to replace the IndexExprs in the algebra with new index exprs as +/// specified by the input map. If the map does not contain an indexExpr, it is kept the +/// same as the input algebra. +IterationAlgebra replaceAlgIndexExprs(IterationAlgebra alg, const std::map&); +} + + +#endif // TACO_ITERATION_ALGEBRA_H diff --git a/include/taco/index_notation/iteration_algebra_printer.h b/include/taco/index_notation/iteration_algebra_printer.h new file mode 100644 index 000000000..6f97b6b49 --- /dev/null +++ b/include/taco/index_notation/iteration_algebra_printer.h @@ -0,0 +1,34 @@ +#ifndef TACO_ITERATION_ALGEBRA_PRINTER_H +#define TACO_ITERATION_ALGEBRA_PRINTER_H + +#include +#include "taco/index_notation/iteration_algebra.h" + +namespace taco { + +// Iteration Algebra Printer +class IterationAlgebraPrinter : IterationAlgebraVisitorStrict { +public: + IterationAlgebraPrinter(std::ostream& os); + void print(const IterationAlgebra& alg); + void visit(const RegionNode* n); + void visit(const ComplementNode* n); + void visit(const IntersectNode* n); + void visit(const UnionNode* n); + +private: + std::ostream& os; + enum class Precedence { + COMPLEMENT = 3, + INTERSECT = 4, + UNION = 5, + TOP = 20 + }; + + Precedence parentPrecedence; + + template + void visitBinary(Node n, Precedence precedence); +}; +} +#endif //TACO_ITERATION_ALGEBRA_PRINTER_H diff --git a/include/taco/index_notation/properties.h b/include/taco/index_notation/properties.h new file mode 100644 index 000000000..1504dd8d1 --- /dev/null +++ b/include/taco/index_notation/properties.h @@ -0,0 +1,92 @@ +#ifndef TACO_PROPERTIES_H +#define TACO_PROPERTIES_H + +#include "taco/index_notation/property_pointers.h" +#include "taco/util/intrusive_ptr.h" + +namespace taco { + +class IndexExpr; + +/// A class containing properties about an operation +class Property : public util::IntrusivePtr { +public: + Property(); + explicit Property(const PropertyPtr* p); + + bool equals(const Property& p) const; + std::ostream& print(std::ostream&) const; +}; + +std::ostream& operator<<(std::ostream&, const Property&); + +/// A class wrapping the annihilator property pointer +class Annihilator : public Property { +public: + explicit Annihilator(Literal); + Annihilator(Literal, std::vector&); + explicit Annihilator(const PropertyPtr*); + + const Literal& annihilator() const; + const std::vector& positions() const; + IndexExpr annihilates(const std::vector&) const; + + typedef AnnihilatorPtr Ptr; +}; + +/// A class wrapping an identity property pointer +class Identity : public Property { +public: + explicit Identity(Literal); + Identity(Literal, std::vector&); + explicit Identity(const PropertyPtr*); + + const Literal& identity() const; + const std::vector& positions() const; + IndexExpr simplify(const std::vector&) const; + + typedef IdentityPtr Ptr; +}; + +/// A class wrapping an associative property pointer +class Associative : public Property { +public: + Associative(); + explicit Associative(const PropertyPtr*); + + typedef AssociativePtr Ptr; +}; + +/// A class wrapping a commutative property pointer +class Commutative : public Property { +public: + Commutative(); + explicit Commutative(const std::vector&); + explicit Commutative(const PropertyPtr*); + + const std::vector& ordering() const; + + typedef CommutativePtr Ptr; +}; + +/// Returns true if property p is of type P. +template bool isa(const Property& p); + +/// Casts the Property p to type P. +template P to(const Property& p); + +/// Finds and returns the property of type P if it exists in the vector. If +/// the property does not exist, returns an undefined instance of the property +/// requested. +/// The vector of properties should not contain duplicates so this is sufficient. +template +inline const P findProperty(const std::vector &properties) { + for (const auto &p: properties) { + if (isa

(p)) return to

(p); + } + return P(nullptr); +} + +} + +#endif //TACO_PROPERTIES_H diff --git a/include/taco/index_notation/property_pointers.h b/include/taco/index_notation/property_pointers.h new file mode 100644 index 000000000..81bee0492 --- /dev/null +++ b/include/taco/index_notation/property_pointers.h @@ -0,0 +1,100 @@ +#ifndef TACO_PROPERTY_POINTERS_H +#define TACO_PROPERTY_POINTERS_H + +#include +#include +#include +#include +#include + +#include "taco/error.h" +#include "taco/util/comparable.h" + +namespace taco { + +class Literal; +struct PropertyPtr; + +/// A pointer to the property data. This will be wrapped in an auxillary class +/// to allow a user to create a vector of properties. Needed since properties +/// have different methods and data +struct PropertyPtr : public util::Manageable, + private util::Uncopyable { +public: + PropertyPtr(); + virtual ~PropertyPtr(); + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +/// Pointer class for annihilators +struct AnnihilatorPtr : public PropertyPtr { + AnnihilatorPtr(); + AnnihilatorPtr(Literal); + AnnihilatorPtr(Literal, std::vector&); + + const Literal& annihilator() const; + const std::vector& positions() const; + + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; + + struct Content; + std::shared_ptr content; +}; + +/// Pointer class for identities +struct IdentityPtr : public PropertyPtr { +public: + IdentityPtr(); + IdentityPtr(Literal); + IdentityPtr(Literal, std::vector&); + + const Literal& identity() const; + const std::vector& positions() const; + + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; + + struct Content; + std::shared_ptr content; +}; + +/// Pointer class for associativity +struct AssociativePtr : public PropertyPtr { + AssociativePtr(); + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +/// Pointer class for commutativity +struct CommutativePtr : public PropertyPtr { + CommutativePtr(); + CommutativePtr(const std::vector&); + const std::vector ordering_; + virtual std::ostream& print(std::ostream& os) const; + virtual bool equals(const PropertyPtr* p) const; +}; + +template +inline bool isa(const PropertyPtr* p) { + return p != nullptr && dynamic_cast(p) != nullptr; +} + +template +inline const P* to(const PropertyPtr* p) { + taco_iassert(isa

(p)) << + "Cannot convert " << typeid(p).name() << " to " << typeid(P).name();; + return static_cast(p); +} + +template +inline const typename P::Ptr* getPtr(const P& propertyPtr) { + taco_iassert(isa(propertyPtr.ptr)); + return static_cast(propertyPtr.ptr); +} + + +} + +#endif //TACO_PROPERTY_POINTERS_H diff --git a/include/taco/index_notation/tensor_operator.h b/include/taco/index_notation/tensor_operator.h new file mode 100644 index 000000000..fc5aa69a9 --- /dev/null +++ b/include/taco/index_notation/tensor_operator.h @@ -0,0 +1,85 @@ +#ifndef TACO_OPS_H +#define TACO_OPS_H + +#include +#include +#include + +#include "taco/ir/ir.h" +#include "taco/util/collections.h" +#include "taco/index_notation/properties.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/index_notation_nodes.h" + +namespace taco { + +class Func { + +using FuncBodyGenerator = CallNode::OpImpl; +using FuncAlgebraGenerator = CallNode::AlgebraImpl; + +// TODO: Make this part of callNode and call. Add generateIterationAlgebra() and generateImplementation() functions +public: + // Full construction + Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); + + Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); + + // Construct without specifying algebra + Func(std::string name, FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); + + Func(FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions = {}); + + // Construct without properties + Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions = {}); + + Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); + + // Construct without algebra or properties + Func(std::string name, FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); + + explicit Func(FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions = {}); + + template + Call operator()(IndexExprs&&... exprs) { + std::vector actualArgs{exprs...}; + + IterationAlgebra nodeAlgebra = algebraFunc == nullptr? inferAlgFromProperties(actualArgs): algebraFunc(actualArgs); + + CallNode* op = new CallNode(name, actualArgs, lowererFunc, nodeAlgebra, properties, + regionDefinitions); + + return Call(op); + } + +private: + std::string name; + FuncBodyGenerator lowererFunc; + FuncAlgebraGenerator algebraFunc; + std::vector properties; + std::map, FuncBodyGenerator> regionDefinitions; + + IterationAlgebra inferAlgFromProperties(const std::vector& exprs); + + // Constructs an algebra for iterating over the operator assuming the annihilator + // of the expression is the input to this function. + // Returns a pair where pair.first is the algebra for iteration and pair.second is + // the number of subregions iterated by the algebra. + IterationAlgebra constructAnnihilatorAlg(const std::vector& args, Annihilator annihilator); + + IterationAlgebra constructIdentityAlg(const std::vector& args, Identity identity); + + + // Constructs an algebra that iterates over the entire space + static IterationAlgebra constructDefaultAlgebra(const std::vector& exprs); +}; + +} +#endif //TACO_OPS_H + + diff --git a/include/taco/ir/ir.h b/include/taco/ir/ir.h index f852f26b1..f38b5e8ac 100644 --- a/include/taco/ir/ir.h +++ b/include/taco/ir/ir.h @@ -78,6 +78,7 @@ enum class TensorProperty { ModeTypes, Indices, Values, + FillValue, ValuesSize }; @@ -235,6 +236,16 @@ struct Literal : public ExprNode { return *static_cast(value.get()); } + Expr promote(Datatype dt) const { + taco_iassert(max_type(dt, type) == dt); + if(type == dt) { + return Literal::make(getTypedVal(), dt); + } + + + + } + TypedComponentVal getTypedVal() const { return *value; } diff --git a/include/taco/lower/iterator.h b/include/taco/lower/iterator.h index 8e054ff02..6d1bca2b0 100644 --- a/include/taco/lower/iterator.h +++ b/include/taco/lower/iterator.h @@ -56,7 +56,6 @@ class Iterator : public util::Comparable { /// Get the child of this iterator in its iterator list. const Iterator getChild() const; - /// Returns true if the iterator iterates over the dimension. bool isDimensionIterator() const; diff --git a/include/taco/lower/lowerer_impl.h b/include/taco/lower/lowerer_impl.h index 52fd91165..e6d5b939e 100644 --- a/include/taco/lower/lowerer_impl.h +++ b/include/taco/lower/lowerer_impl.h @@ -78,6 +78,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -87,6 +88,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -97,6 +99,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -107,6 +110,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -114,6 +118,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, std::set reducedAccesses, ir::Stmt recoveryStmt); @@ -177,6 +182,7 @@ class LowererImpl : public util::Uncopyable { std::vector locaters, std::vector inserters, std::vector appenders, + MergeLattice caseLattice, const std::set& reducedAccesses); @@ -222,6 +228,11 @@ class LowererImpl : public util::Uncopyable { /// Lower an intrinsic function call expression. virtual ir::Expr lowerCallIntrinsic(CallIntrinsic call); + /// Lower an IndexVar expression + virtual ir::Expr lowerIndexVar(IndexVar var); + + /// Lower a generic tensor operation expression + virtual ir::Expr lowerTensorOp(Call op); /// Lower a concrete index variable statement. ir::Stmt lower(IndexStmt stmt); @@ -229,7 +240,6 @@ class LowererImpl : public util::Uncopyable { /// Lower a concrete index variable expression. ir::Expr lower(IndexExpr expr); - /// Check whether the lowerer should generate code to assemble result indices. bool generateAssembleCode() const; @@ -318,7 +328,7 @@ class LowererImpl : public util::Uncopyable { * Generate code to zero-initialize values array in range * [begin * size, (begin + 1) * size). */ - ir::Stmt zeroInitValues(ir::Expr tensor, ir::Expr begin, ir::Expr size); + ir::Stmt initValues(ir::Expr tensor, ir::Expr initVal, ir::Expr begin, ir::Expr size); /// Declare position variables and initialize them with a locate. ir::Stmt declLocatePosVars(std::vector iterators); @@ -378,6 +388,55 @@ class LowererImpl : public util::Uncopyable { /// Expression that evaluates to true if none of the iterators are exhausted ir::Expr checkThatNoneAreExhausted(std::vector iterators); + /// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. + /// Will emit checks for explicit zeros for each mode iterator and each locator in the lattice. + ir::Stmt lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, + MergeLattice lattice, const std::set& reducedAccesses); + + /// Constructs cases comparing the coordVar for each iterator to the resolved coordinate. + /// Returns a vector where coordComparisons[i] corresponds to a case for iters[i] + /// If no case can be formed for a given iterator, an undefined expr is appended where a case would normally be. + template + std::vector compareToResolvedCoordinate(const std::vector& iters, ir::Expr resolvedCoordinate, + IndexVar coordinateVar) { + std::vector coordComparisons; + + for (Iterator iterator : iters) { + if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && + provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { + coordComparisons.push_back(C::make(iterator.getCoordVar(), resolvedCoordinate)); + } else { + coordComparisons.push_back(ir::Expr()); + } + } + + return coordComparisons; + } + + /// Makes the preamble of booleans used in case checks for the inner most loop of the computations + /// The iterator to condition map contains the name of the boolean indicating if each corresponding mode iterator + /// and each locator is non-zero. This function populates this map so the caller can user the boolean names to emit + /// checks for each lattice point. + std::vector constructInnerLoopCasePreamble(ir::Expr coordinate, IndexVar coordinateVar, + MergeLattice lattice, + std::map& iteratorToConditionMap); + + /// Lowers merge cases in the lattice using a map to know what expr to emit for each iterator in the lattice. + /// The map must be of iterators to exprs of boolean types + std::vector lowerCasesFromMap(std::map iteratorToCondition, + ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, + const std::set& reducedAccesses); + + /// Constructs an expression which checks if this access is "zero" + ir::Expr constructCheckForAccessZero(Access); + + /// Filters out a list of iterators and returns those the lowerer should explicitly check for zeros. + /// For now, we only check mode iterators. + std::vector getModeIterators(const std::vector&); + + /// Emit early exit + ir::Stmt emitEarlyExit(ir::Expr reductionExpr, std::vector&); + /// Expression that returns the beginning of a window to iterate over /// in a compressed iterator. It is used when operating over windows of /// tensors, instead of the full tensor. @@ -407,6 +466,7 @@ class LowererImpl : public util::Uncopyable { private: bool assemble; bool compute; + bool loopOrderAllowsShortCircuit = false; int markAssignsAtomicDepth = 0; ParallelUnit atomicParallelUnit; diff --git a/include/taco/lower/merge_lattice.h b/include/taco/lower/merge_lattice.h index 9f1592d5f..49f0051bb 100644 --- a/include/taco/lower/merge_lattice.h +++ b/include/taco/lower/merge_lattice.h @@ -49,6 +49,22 @@ class MergeLattice { static MergeLattice make(Forall forall, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, std::map whereTempsToResult = {}); + + /** + * Removes lattice points whose iterators are identical to the iterators of an earlier point, since we have + * already iterated over this sub-space. + */ + static std::vector removePointsWithIdenticalIterators(const std::vector&); + + /** + * remove lattice points that lack any of the full iterators of the first point, since when a full iterator exhausts + * we have iterated over the whole space. + */ + static std::vector removePointsThatLackFullIterators(const std::vector&); + + /// Returns true if we need to emit checks for explicit zeros in the lattice given. + bool needExplicitZeroChecks(); + /** * Returns the sub-lattice rooted at the given merge point. */ @@ -64,6 +80,11 @@ class MergeLattice { */ const std::vector& iterators() const; + /** + * Retrieve all the locators in this lattice. + */ + const std::vector& locators() const; + /** * Returns iterators that have been exhausted prior to the merge point. */ @@ -80,8 +101,34 @@ class MergeLattice { */ bool exact() const; + /** + * True if any of the mode iterators in the lattice is a leaf iterator. + * + * This method checks both the iterators and locators since they both contain mode iterators. + */ + bool anyModeIteratorIsLeaf() const; + + /** + * Get a list of iterators that should be omitted at this merge point. + */ + std::vector retrieveRegionIteratorsToOmit(const MergePoint& point) const; + + /** + * Returns a set of sets of tensor iterators. A merge point with a tensorRegion in this set should not + * be removed from the lattice. + * + * Needed so that special regions are kept when applying optimizations that remove merge points. + */ + std::set> getTensorRegionsToKeep() const; + + /** + * Removes points from the lattice that would duplicate iteration over the input tensors. + */ + MergeLattice getLoopLattice() const; + private: std::vector points_; + std::set> regionsToKeep; public: /** @@ -89,7 +136,7 @@ class MergeLattice { * is primarily intended for testing purposes and most construction should * happen through `MergeLattice::make`. */ - MergeLattice(std::vector points); + MergeLattice(std::vector points, std::set> regionsToKeep = {}); }; std::ostream& operator<<(std::ostream&, const MergeLattice&); @@ -146,6 +193,16 @@ class MergePoint { */ const std::vector& results() const; + /** + * Returns the iterators that iterate over or locate into tensors + */ + const std::set tensorRegion() const; + + /** + * Returns true if this merge point may leave out the tensors it iterates + */ + bool isOmitter() const; + private: struct Content; std::shared_ptr content_; @@ -158,7 +215,8 @@ class MergePoint { */ MergePoint(const std::vector& iterators, const std::vector& locators, - const std::vector& results); + const std::vector& results, + bool omitPoint = false); }; std::ostream& operator<<(std::ostream&, const MergePoint&); diff --git a/include/taco/storage/pack.h b/include/taco/storage/pack.h index 437a441ee..e39536c56 100644 --- a/include/taco/storage/pack.h +++ b/include/taco/storage/pack.h @@ -17,6 +17,8 @@ namespace taco { +class Literal; + namespace ir { class Stmt; } @@ -25,12 +27,13 @@ TensorStorage pack(Datatype datatype, const std::vector& dimensions, const Format& format, const std::vector& coordinates, - const void* values); - + const void* values, + const Literal& fill); template TensorStorage pack(std::vector dimensions, Format format, - const std::vector,V>>& components){ + const std::vector,V>>& components, + const Literal& fill){ size_t order = dimensions.size(); size_t nnz = components.size(); @@ -45,7 +48,7 @@ TensorStorage pack(std::vector dimensions, Format format, } } - return pack(type(), dimensions, format, coordinates, values.data()); + return pack(type(), dimensions, format, coordinates, values.data(), fill); } } diff --git a/include/taco/storage/storage.h b/include/taco/storage/storage.h index 61116a00d..a3bf987c3 100644 --- a/include/taco/storage/storage.h +++ b/include/taco/storage/storage.h @@ -18,6 +18,7 @@ class Type; class Datatype; class Index; class Array; +class Literal; /// Storage for a tensor object. Tensor storage consists of a value array that /// contains the tensor values and one index per mode. The type of each @@ -28,7 +29,7 @@ class TensorStorage { /// Construct tensor storage for the given format. TensorStorage(Datatype componentType, const std::vector& dimensions, - Format format); + Format format, Literal fillVal); /// Returns the tensor storage format. const Format& getFormat() const; @@ -54,6 +55,9 @@ class TensorStorage { /// Returns the tensor component value array. Array getValues(); + /// Returns the full value attached to the tensor storage + Literal getFillValue(); + /// Returns the size of the storage in bytes. size_t getSizeInBytes(); @@ -66,6 +70,7 @@ class TensorStorage { /// Set the tensor component value array. void setValues(const Array& values); + private: struct Content; std::shared_ptr content; diff --git a/include/taco/taco_tensor_t.h b/include/taco/taco_tensor_t.h index 20d78bb51..f1777510d 100644 --- a/include/taco/taco_tensor_t.h +++ b/include/taco/taco_tensor_t.h @@ -18,12 +18,13 @@ typedef struct taco_tensor_t { taco_mode_t* mode_types; // mode storage types uint8_t*** indices; // tensor index data (per mode) uint8_t* vals; // tensor values + uint8_t* fill_value; // tensor fill value int32_t vals_size; // values array size } taco_tensor_t; taco_tensor_t *init_taco_tensor_t(int32_t order, int32_t csize, - int32_t* dimensions, int32_t* modeOrdering, - taco_mode_t* mode_types); + int32_t* dimensions, int32_t* modeOrdering, + taco_mode_t* mode_types, void* fill_ptr); void deinit_taco_tensor_t(taco_tensor_t* t); diff --git a/include/taco/tensor.h b/include/taco/tensor.h index 25186c815..3be1f200d 100644 --- a/include/taco/tensor.h +++ b/include/taco/tensor.h @@ -62,19 +62,27 @@ class TensorBase { /// Create a tensor with the given dimensions. The format defaults to sparse /// in every mode. TensorBase(Datatype ctype, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed); + ModeFormat modeType = ModeFormat::compressed, Literal fill = Literal()); /// Create a tensor with the given dimensions and format. - TensorBase(Datatype ctype, std::vector dimensions, Format format); + TensorBase(Datatype ctype, std::vector dimensions, Format format, Literal fill = Literal()); /// Create a tensor with the given data type, dimensions and format. The /// format defaults to sparse in every mode. - TensorBase(std::string name, Datatype ctype, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed); - + TensorBase(std::string name, Datatype ctype, std::vector dimensions, + ModeFormat modeType = ModeFormat::compressed, Literal fill = Literal()); + + /// Create a tensor with the given dimensions and fill value. The format + /// defaults to sparse in every mode. + TensorBase(Datatype ctype, std::vector dimensions, Literal fill); + + /// Create a tensor with the given data type, dimensions and fill value. The + /// format defaults to sparse in every mode. + TensorBase(std::string name, Datatype ctype, std::vector dimensions, Literal fill); + /// Create a tensor with the given data type, dimensions and format. TensorBase(std::string name, Datatype ctype, std::vector dimensions, - Format format); + Format format, Literal fill = Literal()); /* --- Metadata Methods --- */ @@ -465,8 +473,10 @@ class TensorBase { /// Get the taco_tensor_t representation of this tensor. taco_tensor_t* getTacoTensorT(); - /* --- Friend Functions --- */ + /// Get the fill value of this tensor. + Literal getFillValue() const; + /* --- Friend Functions --- */ /// True iff two tensors have the same type and the same values. friend bool equals(const TensorBase&, const TensorBase&); @@ -562,18 +572,29 @@ class Tensor : public TensorBase { /// Create a tensor with the given dimensions. The format defaults to sparse /// in every mode. - Tensor(std::vector dimensions, ModeFormat modeType = ModeFormat::compressed); + Tensor(std::vector dimensions, ModeFormat modeType = ModeFormat::compressed, + CType fill = CType()); /// Create a tensor with the given dimensions and format - Tensor(std::vector dimensions, Format format); + Tensor(std::vector dimensions, Format format, CType fill = CType()); /// Create a tensor with the given name, dimensions and format. The format /// defaults to sparse in every mode. - Tensor(std::string name, std::vector dimensions, - ModeFormat modeType = ModeFormat::compressed); + Tensor(std::string name, std::vector dimensions, + ModeFormat modeType = ModeFormat::compressed, + CType fill = CType()); + + /// Create a tensor with the given dimensions and fill value. The format + /// defaults to sparse in every mode. + Tensor(std::vector dimensions, CType fill); + + /// Create a tensor with the given data type, dimensions and fill value. The + /// format defaults to sparse in every mode. + Tensor(std::string name, std::vector dimensions, CType fill); /// Create a tensor with the given name, dimensions and format - Tensor(std::string name, std::vector dimensions, Format format); + Tensor(std::string name, std::vector dimensions, Format format, + CType fill = CType()); /// Create a tensor from a TensorBase instance. The Tensor and TensorBase /// objects will reference the same underlying tensor so it is a shallow copy. @@ -893,10 +914,10 @@ struct TensorBase::Content { unsigned int uniqueId; Content(std::string name, Datatype dataType, const std::vector& dimensions, - Format format) + Format format, Literal fill) : dataType(dataType), dimensions(dimensions), - storage(TensorStorage(dataType, dimensions, format)), - tensorVar(TensorVar(util::getUniqueId(), name, Type(dataType,convert(dimensions)),format)) { + storage(TensorStorage(dataType, dimensions, format, fill)), + tensorVar(TensorVar(util::getUniqueId(), name, Type(dataType,convert(dimensions)),format, fill)) { uniqueId = tensorVar.getId(); } }; @@ -1060,21 +1081,30 @@ template Tensor::Tensor(CType value) : TensorBase(value) {} template -Tensor::Tensor(std::vector dimensions, ModeFormat modeType) - : TensorBase(type(), dimensions, modeType) {} +Tensor::Tensor(std::vector dimensions, ModeFormat modeType, CType fill) + : TensorBase(type(), dimensions, modeType, fill) {} + +template +Tensor::Tensor(std::vector dimensions, Format format, CType fill) + : TensorBase(type(), dimensions, format, fill) {} + +template +Tensor::Tensor(std::string name, std::vector dimensions, + ModeFormat modeType, CType fill) + : TensorBase(name, type(), dimensions, modeType, fill) {} template -Tensor::Tensor(std::vector dimensions, Format format) - : TensorBase(type(), dimensions, format) {} +Tensor::Tensor(std::vector dimensions, CType fill) + : TensorBase(type(), dimensions, fill) {} template -Tensor::Tensor(std::string name, std::vector dimensions, - ModeFormat modeType) - : TensorBase(name, type(), dimensions, modeType) {} +Tensor::Tensor(std::string name, std::vector dimensions, CType fill) + : TensorBase(name, type(), dimensions, fill) {} template -Tensor::Tensor(std::string name, std::vector dimensions, Format format) - : TensorBase(name, type(), dimensions, format) {} +Tensor::Tensor(std::string name, std::vector dimensions, Format format, + CType fill) + : TensorBase(name, type(), dimensions, format, fill) {} template Tensor::Tensor(const TensorBase& tensor) : TensorBase(tensor) { diff --git a/include/taco/util/collections.h b/include/taco/util/collections.h index 6edddf380..e0a07ba78 100644 --- a/include/taco/util/collections.h +++ b/include/taco/util/collections.h @@ -97,6 +97,20 @@ std::vector remove(const std::vector& vector, return result; } +template +std::vector removeDuplicates(const std::vector& vector) { + std::set seen; + std::vector result; + + for(const V& v: vector) { + if(!contains(seen, v)) { + seen.insert(v); + result.push_back(v); + } + } + return result; +} + template std::vector filter(const std::vector& vector, T test) { std::vector result; @@ -119,6 +133,16 @@ size_t count(const std::vector& vector, T test) { return count; } +template +std::map zipToMap(const std::vector& keys, const std::vector& values) { + std::map result; + size_t limit = std::min(keys.size(), values.size()); + for(size_t i = 0; i < limit; ++i) { + result.insert({keys[i], values[i]}); + } + return result; +} + /** * Split the vector into two vectors where elements in the first pass the test * and elements in the second do not. @@ -149,9 +173,9 @@ bool all(const C& collection, T test) { return true; } -template -bool any(const std::vector& vector, T test) { - for (auto& element : vector) { +template +bool any(const C& collection, T test) { + for (auto& element : collection) { if (test(element)) { return true; } diff --git a/include/taco/util/functions.h b/include/taco/util/functions.h new file mode 100644 index 000000000..1893147ae --- /dev/null +++ b/include/taco/util/functions.h @@ -0,0 +1,21 @@ +#ifndef TACO_FUNCTIONAL_H +#define TACO_FUNCTIONAL_H + +#include + +namespace taco { +namespace util { + +template +Fnptr functorAddress(std::function f) { + return *f.template target(); +} + +template +bool targetPtrEqual(std::function f, std::function g) { + return functorAddress(f) != nullptr && functorAddress(f) == functorAddress(g); +} + +} +} +#endif //TACO_FUNCTIONAL_H diff --git a/src/codegen/codegen.cpp b/src/codegen/codegen.cpp index f0c09d98a..a59ac419f 100644 --- a/src/codegen/codegen.cpp +++ b/src/codegen/codegen.cpp @@ -244,6 +244,10 @@ string CodeGen::unpackTensorProperty(string varname, const GetProperty* op, } else if (op->property == TensorProperty::ValuesSize) { ret << "int " << varname << " = " << tensor->name << "->vals_size;\n"; return ret.str(); + } else if (op->property == TensorProperty::FillValue) { + ret << printType(tensor->type, false) << " " << varname << " = "; + ret << "*((" <type, true) << ")(" << tensor->name << "->fill_value));\n"; + return ret.str(); } string tp; @@ -281,6 +285,8 @@ string CodeGen::packTensorProperty(string varname, Expr tnsr, } else if (property == TensorProperty::ValuesSize) { ret << tensor->name << "->vals_size = " << varname << ";\n"; return ret.str(); + } else if (property == TensorProperty::FillValue) { + return ""; } string tp; diff --git a/src/codegen/codegen_c.cpp b/src/codegen/codegen_c.cpp index f48f34f2a..123877110 100644 --- a/src/codegen/codegen_c.cpp +++ b/src/codegen/codegen_c.cpp @@ -48,6 +48,7 @@ const string cHeaders = " taco_mode_t* mode_types; // mode storage types\n" " uint8_t*** indices; // tensor index data (per mode)\n" " uint8_t* vals; // tensor values\n" + " uint8_t* fill_value; // tensor fill value\n" " int32_t vals_size; // values array size\n" "} taco_tensor_t;\n" "#endif\n" diff --git a/src/codegen/codegen_cuda.cpp b/src/codegen/codegen_cuda.cpp index 971410b98..dcbad4551 100644 --- a/src/codegen/codegen_cuda.cpp +++ b/src/codegen/codegen_cuda.cpp @@ -49,6 +49,7 @@ const string cHeaders = " taco_mode_t* mode_types; // mode storage types\n" " uint8_t*** indices; // tensor index data (per mode)\n" " uint8_t* vals; // tensor values\n" + " uint8_t* fill_value; // tensor fill value\n" " int32_t vals_size; // values array size\n" "} taco_tensor_t;\n" "#endif\n" diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index 4b090d84b..6b3fd71b4 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -14,6 +14,7 @@ #include "taco/type.h" #include "taco/format.h" +#include "taco/index_notation/properties.h" #include "taco/index_notation/intrinsic.h" #include "taco/index_notation/schedule.h" #include "taco/index_notation/transformations.h" @@ -21,13 +22,13 @@ #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_printer.h" #include "taco/ir/ir.h" -#include "taco/lower/lower.h" #include "taco/codegen/module.h" #include "taco/util/name_generator.h" #include "taco/util/scopedmap.h" #include "taco/util/strings.h" #include "taco/util/collections.h" +#include "taco/util/functions.h" using namespace std; @@ -105,6 +106,71 @@ std::ostream& operator<<(std::ostream& os, const IndexExpr& expr) { return os; } +static bool checkRegionDefinitions(const CallNode* anode, const CallNode* bnode) { + // Check region definitions + if (anode->regionDefinitions.size() != bnode->regionDefinitions.size()) { + return false; + } + + auto& aDefs = anode->regionDefinitions; + auto& bDefs = bnode->regionDefinitions; + for (auto itA = aDefs.begin(), itB = bDefs.begin(); itA != aDefs.end(); ++itA, ++itB) { + if(itA->first != itB->first) { + return false; + } + + std::vector aArgs; + std::vector bArgs; + for(int idx : itA->first) { + taco_iassert((size_t)idx < anode->args.size()); // We already know anode->args.size == bnode->args.size + aArgs.push_back(anode->args[idx]); + bArgs.push_back(bnode->args[idx]); + } + + // TODO lower and check IR + if(!util::targetPtrEqual(itA->second, itB->second)) { + return false; + } + } + + return true; +} + +/// Checks if the iteration algebra structure is the same and the ordering of the index expressions +/// nested under regions is the same for each op node. +static bool checkIterationAlg(const CallNode* anode, const CallNode* bnode) { + // Check IterationAlgebra structures + if(!algStructureEqual(anode->iterAlg, bnode->iterAlg)) { + return false; + } + + struct OrderChecker : public IterationAlgebraVisitor { + explicit OrderChecker(const CallNode* op) : op(op) {} + + std::vector& check() { + op->iterAlg.accept(this); + return ordering; + } + + using IterationAlgebraVisitor::visit; + + void visit(const RegionNode* region) { + const IndexExpr& e = region->expr(); + auto it = std::find(op->args.begin(), op->args.end(), e); + taco_iassert(it != op->args.end()) << "Iteration algebra region expressions must be in arguments"; + size_t loc = it - op->args.begin(); + ordering.push_back(loc); + } + + std::vector ordering; + const CallNode* op; + }; + + std::vector aOrdering = OrderChecker(anode).check(); + std::vector bOrdering = OrderChecker(bnode).check(); + return aOrdering == bOrdering; +} + struct Isomorphic : public IndexNotationVisitorStrict { bool eq = false; IndexExpr bExpr; @@ -165,6 +231,21 @@ struct Isomorphic : public IndexNotationVisitorStrict { using IndexNotationVisitorStrict::visit; + void visit(const IndexVarNode* anode) { + if(!isa(bExpr.ptr)) { + eq = false; + return; + } + + auto bnode = to(bExpr.ptr); + if(anode != bnode) { + eq = false; + return; + } + + eq = true; + } + void visit(const AccessNode* anode) { if (!isa(bExpr.ptr)) { eq = false; @@ -412,6 +493,69 @@ struct Isomorphic : public IndexNotationVisitorStrict { } eq = true; } + + void visit(const CallNode* anode) { + if (!isa(bExpr.ptr)) { + eq = false; + return; + } + auto bnode = to(bExpr.ptr); + + // Properties + if (anode->properties.size() != bnode->properties.size()) { + eq = false; + return; + } + + for(const auto& a_prop : anode->properties) { + bool found = false; + for(const auto& b_prop : bnode->properties) { + if(a_prop.equals(b_prop)) { + found = true; + break; + } + } + if (!found) { + eq = false; + return; + } + } + + // Exhausted regions + if (anode->definedRegions != bnode->definedRegions) { + eq = false; + return; + } + + // Lower function + // TODO: For now just check that the function pointers are the same. + if(!util::targetPtrEqual(anode->defaultLowerFunc, bnode->defaultLowerFunc)) { + eq = false; + return; + } + + // Check arguments + if (anode->args.size() != bnode->args.size()) { + eq = false; + return; + } + + for (size_t i = 0; i < anode->args.size(); ++i) { + if (!check(anode->args[i], bnode->args[i])) { + eq = false; + return; + } + } + + // Algebra + if (!checkIterationAlg(anode, bnode)) { + eq = false; + return; + } + + // Special definitions + eq = checkRegionDefinitions(anode, bnode); + } }; bool isomorphic(IndexExpr a, IndexExpr b) { @@ -453,6 +597,21 @@ struct Equals : public IndexNotationVisitorStrict { using IndexNotationVisitorStrict::visit; + void visit(const IndexVarNode* anode) { + if(!isa(bExpr.ptr)) { + eq = false; + return; + } + + auto bnode = to(bExpr.ptr); + if(anode != bnode) { + eq = false; + return; + } + + eq = true; + } + void visit(const AccessNode* anode) { if (!isa(bExpr.ptr)) { eq = false; @@ -555,6 +714,69 @@ struct Equals : public IndexNotationVisitorStrict { eq = true; } + void visit(const CallNode* anode) { + if (!isa(bExpr.ptr)) { + eq = false; + return; + } + auto bnode = to(bExpr.ptr); + + // Properties + if (anode->properties.size() != bnode->properties.size()) { + eq = false; + return; + } + + for(const auto& a_prop : anode->properties) { + bool found = false; + for(const auto& b_prop : bnode->properties) { + if(a_prop.equals(b_prop)) { + found = true; + break; + } + } + if (!found) { + eq = false; + return; + } + } + + // Exhausted regions + if (anode->definedRegions != bnode->definedRegions) { + eq = false; + return; + } + + // Lower function + // TODO: For now just check that the function pointers are the same. + if(!util::targetPtrEqual(anode->defaultLowerFunc, bnode->defaultLowerFunc)) { + eq = false; + return; + } + + // Check arguments + if (anode->args.size() != bnode->args.size()) { + eq = false; + return; + } + + for (size_t i = 0; i < anode->args.size(); ++i) { + if (!equals(anode->args[i], bnode->args[i])) { + eq = false; + return; + } + } + + // Algebra + if (!checkIterationAlg(anode, bnode)) { + eq = false; + return; + } + + // Special definitions + eq = checkRegionDefinitions(anode, bnode); + } + void visit(const CallIntrinsicNode* anode) { if (!isa(bExpr.ptr)) { eq = false; @@ -777,6 +999,38 @@ int Access::getWindowUpperBound(int mode) const { return getNode(*this)->windowedModes.at(mode).hi; } +bool operator==(const Access& a, const Access& b) { + // Short-circuit for when the Access pointers are the same. + if (getNode(a) == getNode(b)) { + return true; + } + if (a.getTensorVar() != b.getTensorVar()) { + return false; + } + if (a.getIndexVars() != b.getIndexVars()) { + return false; + } + if (getNode(a)->windowedModes != getNode(b)->windowedModes) { + return false; + } + return true; +} + +bool operator<(const Access& a, const Access& b) { + // First branch on tensorVar. + if(a.getTensorVar() != b.getTensorVar()) { + return a.getTensorVar() < b.getTensorVar(); + } + + // Then branch on the indexVars used in the access. + if (a.getIndexVars() != b.getIndexVars()) { + return a.getIndexVars() < b.getIndexVars(); + } + + // Lastly, branch on the windows. + return getNode(a)->windowedModes < getNode(b)->windowedModes; +} + static void check(Assignment assignment) { auto lhs = assignment.getLhs(); auto tensorVar = lhs.getTensorVar(); @@ -887,7 +1141,7 @@ Literal::Literal(std::complex val) : Literal(new LiteralNode(val)) { Literal::Literal(std::complex val) : Literal(new LiteralNode(val)) { } -IndexExpr Literal::zero(Datatype type) { +Literal Literal::zero(Datatype type) { switch (type.getKind()) { case Datatype::Bool: return Literal(false); case Datatype::UInt8: return Literal(uint8_t(0)); @@ -905,7 +1159,7 @@ IndexExpr Literal::zero(Datatype type) { default: taco_ierror << "unsupported type"; }; - return IndexExpr(); + return Literal(); } template T Literal::getVal() const { @@ -928,6 +1182,10 @@ template double Literal::getVal() const; template std::complex Literal::getVal() const; template std::complex Literal::getVal() const; +void* Literal::getValPtr() { + return getNode(*this)->val; +} + template <> bool isa(IndexExpr e) { return isa(e.ptr); } @@ -1112,6 +1370,50 @@ template <> Cast to(IndexExpr e) { return Cast(to(e.ptr)); } +// class Call, most construction should happen from tensor_operator.h +Call::Call(const CallNode* n) : IndexExpr(n) { +} + +Call::Call(const CallNode *n, std::string name) : IndexExpr(n), name(name) { +} + +const std::vector& Call::getArgs() const { + return getNode(*this)->args; +} + +const CallNode::OpImpl Call::getFunc() const { + return getNode(*this)->defaultLowerFunc; +} + +const IterationAlgebra& Call::getAlgebra() const { + return getNode(*this)->iterAlg; +} + +const std::vector& Call::getProperties() const { + return getNode(*this)->properties; +} + +const std::string Call::getName() const { + return getNode(*this)->name; +} + +const std::map, CallNode::OpImpl> Call::getDefs() const { + return getNode(*this)->regionDefinitions; +} + +const std::vector& Call::getDefinedArgs() const { + return getNode(*this)->definedRegions; +} + + +template <> bool isa(IndexExpr e) { + return isa(e.ptr); +} + +template <> Call to(IndexExpr e) { + taco_iassert(isa(e)); + return Call(to(e.ptr)); +} // class CallIntrinsic CallIntrinsic::CallIntrinsic(const CallIntrinsicNode* n) : IndexExpr(n) { @@ -1471,9 +1773,9 @@ IndexStmt IndexStmt::pos(IndexVar i, IndexVar ipos, Access access) const { // check access is correct ProvenanceGraph provGraph = ProvenanceGraph(*this); vector underivedParentAncestors = provGraph.getUnderivedAncestors(i); - int max_mode = 0; + size_t max_mode = 0; for (IndexVar underived : underivedParentAncestors) { - int mode_index = 0; // which of the access index vars match? + size_t mode_index = 0; // which of the access index vars match? for (auto var : access.getIndexVars()) { if (var == underived) { break; @@ -1826,12 +2128,23 @@ template <> SuchThat to(IndexStmt s) { // class IndexVar IndexVar::IndexVar() : IndexVar(util::uniqueName('i')) {} -IndexVar::IndexVar(const std::string& name) : content(new Content) { - content->name = name; +IndexVar::IndexVar(const std::string& name) : IndexVar(name, Datatype::Int32) {} + +IndexVar::IndexVar(const std::string& name, const Datatype& type) : IndexVar(new IndexVarNode(name, type)) {} + +IndexVar::IndexVar(const IndexVarNode* n) : IndexExpr(n) {} + +template <> bool isa(IndexExpr e) { + return isa(e.ptr); +} + +template <> IndexVar to(IndexExpr e) { + taco_iassert(isa(e)); + return IndexVar(to(e.ptr)); } std::string IndexVar::getName() const { - return content->name; + return getNode(*this)->getName(); } WindowedIndexVar IndexVar::operator()(int lo, int hi) { @@ -1839,11 +2152,27 @@ WindowedIndexVar IndexVar::operator()(int lo, int hi) { } bool operator==(const IndexVar& a, const IndexVar& b) { - return a.content == b.content; + return *getNode(a) == *getNode(b); } bool operator<(const IndexVar& a, const IndexVar& b) { - return a.content < b.content; + return *getNode(a) < *getNode(b); +} + +bool operator!=(const IndexVar& a , const IndexVar& b) { + return *getNode(a) != *getNode(b); +} + +bool operator>=(const IndexVar& a, const IndexVar& b) { + return *getNode(a) >= *getNode(b); +} + +bool operator<=(const IndexVar& a, const IndexVar& b) { + return *getNode(a) <= *getNode(b); +} + +bool operator>(const IndexVar& a , const IndexVar& b) { + return *getNode(a) > *getNode(b); } std::ostream& operator<<(std::ostream& os, const std::shared_ptr& var) { @@ -1889,6 +2218,7 @@ struct TensorVar::Content { Type type; Format format; Schedule schedule; + Literal fill; }; TensorVar::TensorVar() : content(nullptr) { @@ -1898,28 +2228,29 @@ static Format createDenseFormat(const Type& type) { return Format(vector(type.getOrder(), ModeFormat(Dense))); } -TensorVar::TensorVar(const Type& type) -: TensorVar(type, createDenseFormat(type)) { +TensorVar::TensorVar(const Type& type, const Literal& fill) +: TensorVar(type, createDenseFormat(type), fill) { } -TensorVar::TensorVar(const std::string& name, const Type& type) -: TensorVar(-1, name, type, createDenseFormat(type)) { +TensorVar::TensorVar(const std::string& name, const Type& type, const Literal& fill) +: TensorVar(-1, name, type, createDenseFormat(type), fill) { } -TensorVar::TensorVar(const Type& type, const Format& format) - : TensorVar(-1, util::uniqueName('A'), type, format) { +TensorVar::TensorVar(const Type& type, const Format& format, const Literal& fill) + : TensorVar(-1, util::uniqueName('A'), type, format, fill) { } -TensorVar::TensorVar(const string& name, const Type& type, const Format& format) - : TensorVar(-1, name, type, format) { +TensorVar::TensorVar(const string& name, const Type& type, const Format& format, const Literal& fill) + : TensorVar(-1, name, type, format, fill) { } -TensorVar::TensorVar(const int& id, const string& name, const Type& type, const Format& format) +TensorVar::TensorVar(const int& id, const string& name, const Type& type, const Format& format, const Literal& fill) : content(new Content) { content->id = id; content->name = name; content->type = type; content->format = format; + content->fill = fill.defined()? fill : Literal::zero(type.getDataType()); } int TensorVar::getId() const { @@ -1959,6 +2290,14 @@ const Schedule& TensorVar::getSchedule() const { return content->schedule; } +const Literal& TensorVar::getFill() const { + return content->fill; +} + +void TensorVar::setFill(const Literal &fill) { + content->fill = fill; +} + void TensorVar::setName(std::string name) { content->name = name; } @@ -2177,13 +2516,22 @@ bool isConcreteNotation(IndexStmt stmt, std::string* reason) { std::function([&](const AccessNode* op) { for (auto& var : op->indexVars) { // non underived variables may appear in temporaries, but we don't check these - if (!boundVars.contains(var) && provGraph.isUnderived(var) && (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { + if (!boundVars.contains(var) && provGraph.isUnderived(var) && + (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { *reason = "all variables in concrete notation must be bound by a " "forall statement"; isConcrete = false; } } }), + std::function([&](const IndexVarNode* op) { + IndexVar var(op); + if (!boundVars.contains(var) && provGraph.isUnderived(var) && + (provGraph.isFullyDerived(var) || !provGraph.isRecoverable(var, definedVars))) { + *reason = "index variables used in compute statements must be nested under a forall"; + isConcrete = false; + } + }), std::function([&](const WhereNode* op, Matcher* ctx) { bool alreadyInProducer = inWhereProducer; inWhereProducer = true; @@ -2342,14 +2690,16 @@ IndexStmt makeConcreteNotation(IndexStmt stmt) { // that's not a reduction vector topLevelReductions; IndexExpr rhs = node->rhs; + IndexExpr reductionOp; while (isa(rhs)) { Reduction reduction = to(rhs); topLevelReductions.push_back(reduction.getVar()); rhs = reduction.getExpr(); + reductionOp = reduction.getOp(); } if (rhs != node->rhs) { - stmt = Assignment(node->lhs, rhs, Add()); + stmt = Assignment(node->lhs, rhs, reductionOp); for (auto& i : util::reverse(topLevelReductions)) { stmt = forall(i, stmt); } @@ -2440,6 +2790,46 @@ vector getArguments(IndexStmt stmt) { return result; } +bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt) { + + struct LoopOrderGetter : IndexNotationVisitor { + + std::vector loopOrder; + std::set freeVars; + + using IndexNotationVisitor::visit; + + void visit(const AssignmentNode *op) { + for (const auto &var : op->lhs.getIndexVars()) { + freeVars.insert(var); + } + IndexNotationVisitor::visit(op); + } + + void visit(const ForallNode *op) { + loopOrder.push_back(op->indexVar); + IndexNotationVisitor::visit(op); + } + }; + + + LoopOrderGetter getter; + getter.visit(stmt); + + bool seenReductionVar = false; + for (auto &var : getter.loopOrder) { + if (util::contains(getter.freeVars, var)) { + if (seenReductionVar) { + // A reduction loop came before a loop over a free var + return false; + } + } else { + seenReductionVar = true; + } + } + return true; + } + std::map getTemporaryLocations(IndexStmt stmt) { map temporaryLocs; Forall f = Forall(); @@ -2657,6 +3047,10 @@ struct Zero : public IndexNotationRewriterStrict { expr = op; } + void visit(const IndexVarNode* op) { + expr = op; + } + template IndexExpr visitUnaryOp(const T *op) { IndexExpr a = rewrite(op->a); @@ -2760,6 +3154,62 @@ struct Zero : public IndexNotationRewriterStrict { } } + void visit(const CallNode* op) { + std::vector args; + std::vector rewrittenArgs; + std::vector definedArgs; + bool rewritten = false; + + Annihilator annihilator = findProperty(op->properties); + + // TODO: Check exhausted default against result default + for(int argIdx = 0; argIdx < (int) op->args.size(); ++argIdx) { + IndexExpr arg = op->args[argIdx]; + IndexExpr rewrittenArg = rewrite(arg); + rewrittenArgs.push_back(rewrittenArg); + + if (rewrittenArg.defined()) { + definedArgs.push_back(argIdx); + } else { + // TODO: fill value instead of 0 + rewrittenArg = Literal::zero(arg.getDataType()); + } + + args.push_back(rewrittenArg); + if (arg != rewrittenArg) { + rewritten = true; + } + } + + if(annihilator.defined()) { + IndexExpr e = annihilator.annihilates(args); + if(e.defined()) { + expr = e; + return; + } + } + + Identity identity = findProperty(op->properties); + if(identity.defined()) { + IndexExpr e = identity.simplify(args); + if(e.defined()) { + expr = e; + return; + } + } + + if (rewritten) { + const std::map subs = util::zipToMap(op->args, rewrittenArgs); + IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); + expr = new CallNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, + op->regionDefinitions, definedArgs); + } + else { + expr = op; + } + + } + void visit(const CallIntrinsicNode* op) { std::vector args; std::vector zeroArgs; @@ -2890,6 +3340,150 @@ IndexStmt zero(IndexStmt stmt, const std::set& zeroed) { return Zero(zeroed).rewrite(stmt); } +// Attempts to infer the fill value of a given expression. If we cannot infer the value, an empty expression +// is returned +struct fillValueInferrer : IndexExprRewriterStrict { + public: + virtual void visit(const AccessNode* op) { + expr = op->tensorVar.getFill(); + }; + + virtual void visit(const LiteralNode* op) { + expr = op; + } + + virtual void visit(const NegNode* op) { + IndexExpr a = rewrite(op->a); + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + expr = IndexExpr(); + } + + virtual void visit(const AddNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType())) && isa(b)) { + expr = b; + return; + } + + if(equals(b, Literal::zero(b.getDataType())) && isa(a)) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const SubNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(b, Literal::zero(b.getDataType())) && isa(a)) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const MulNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + + if(equals(b, Literal::zero(b.getDataType()))) { + expr = b; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const DivNode* op) { + IndexExpr a = rewrite(op->a); + IndexExpr b = rewrite(op->b); + + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + + expr = IndexExpr(); + } + + virtual void visit(const SqrtNode* op) { + IndexExpr a = rewrite(op->a); + if(equals(a, Literal::zero(a.getDataType()))) { + expr = a; + return; + } + expr = IndexExpr(); + } + + virtual void visit(const CastNode* op) { + expr = IndexExpr(); + } + + virtual void visit(const CallNode* op) { + Annihilator annihilator = findProperty(op->properties); + if(annihilator.defined()) { + IndexExpr e = annihilator.annihilates(op->args); + if(e.defined()) { + expr = e; + return; + } + } + + Identity identity = findProperty(op->properties); + if(identity.defined()) { + IndexExpr e = identity.simplify(op->args); + if(e.defined()) { + expr = e; + return; + } + } + + expr = IndexExpr(); + } + + virtual void visit(const CallIntrinsicNode*) { + // TODO Implement or remove this + taco_not_supported_yet; + } + + virtual void visit(const ReductionNode*) { + expr = IndexExpr(); + } + + virtual void visit(const IndexVarNode*) { + expr = IndexExpr(); + } + }; + + +IndexExpr inferFill(IndexExpr expr) { + return fillValueInferrer().rewrite(expr); +} + +bool hasNoForAlls(IndexStmt stmt) { + + bool noForAlls = true; + match(stmt, + std::function([&](const ForallNode* op) { + noForAlls = false; + }) + ); + return noForAlls; +} + IndexStmt generatePackStmt(TensorVar tensor, std::string otherName, Format otherFormat, std::vector indexVars, @@ -2922,5 +3516,4 @@ IndexStmt generatePackCOOStmt(TensorVar tensor, return generatePackStmt(tensor, tensorName + "_COO", bufferFormat, indexVars, otherIsOnRight); } - } diff --git a/src/index_notation/index_notation_nodes.cpp b/src/index_notation/index_notation_nodes.cpp index bc67e9218..42a94bc98 100644 --- a/src/index_notation/index_notation_nodes.cpp +++ b/src/index_notation/index_notation_nodes.cpp @@ -29,11 +29,54 @@ CallIntrinsicNode::CallIntrinsicNode(const std::shared_ptr& func, func(func), args(args) { } +// class CallNode + CallNode::CallNode(std::string name, const std::vector& args, OpImpl defaultLowerFunc, + const IterationAlgebra &iterAlg, const std::vector &properties, + const std::map, OpImpl>& regionDefinitions) + : CallNode(name, args, defaultLowerFunc, iterAlg, properties, regionDefinitions, definedIndices(args)){ + } + +// class CallNode +CallNode::CallNode(std::string name, const std::vector& args, OpImpl defaultLowerFunc, + const IterationAlgebra &iterAlg, const std::vector &properties, + const std::map, OpImpl>& regionDefinitions, + const std::vector& definedRegions) + : IndexExprNode(inferReturnType(defaultLowerFunc, args)), name(name), args(args), defaultLowerFunc(defaultLowerFunc), + iterAlg(applyDemorgan(iterAlg)), properties(properties), regionDefinitions(regionDefinitions), + definedRegions(definedRegions) { + + taco_iassert(defaultLowerFunc != nullptr); + for (const auto& pair: regionDefinitions) { + taco_iassert(args.size() >= pair.first.size()); + } +} // class ReductionNode ReductionNode::ReductionNode(IndexExpr op, IndexVar var, IndexExpr a) : IndexExprNode(a.getDataType()), op(op), var(var), a(a) { - taco_iassert(isa(op.ptr)); + taco_iassert(isa(op.ptr) || isa(op.ptr)); +} + +IndexVarNode::IndexVarNode(const std::string& name, const Datatype& type) + : IndexExprNode(type), content(new Content) { + + if (!type.isInt() && !type.isUInt()) { + taco_not_supported_yet << ". IndexVars must be integral type."; + } + + content->name = name; +} + +std::string IndexVarNode::getName() const { + return content->name; +} + +bool operator==(const IndexVarNode& a, const IndexVarNode& b) { + return a.content->name == b.content->name; +} + +bool operator<(const IndexVarNode& a, const IndexVarNode& b) { + return a.content->name < b.content->name; } } diff --git a/src/index_notation/index_notation_printer.cpp b/src/index_notation/index_notation_printer.cpp index ba633731d..f603a9b70 100644 --- a/src/index_notation/index_notation_printer.cpp +++ b/src/index_notation/index_notation_printer.cpp @@ -25,6 +25,10 @@ void IndexNotationPrinter::visit(const AccessNode* op) { } } +void IndexNotationPrinter::visit(const IndexVarNode* op) { + os << op->getName(); +} + void IndexNotationPrinter::visit(const LiteralNode* op) { switch (op->getDataType().getKind()) { case Datatype::Bool: @@ -158,6 +162,13 @@ static inline void acceptJoin(IndexNotationPrinter* printer, } } +void IndexNotationPrinter::visit(const CallNode* op) { + parentPrecedence = Precedence::FUNC; + os << op->name << "("; + acceptJoin(this, os, op->args, ", "); + os << ")"; +} + void IndexNotationPrinter::visit(const CallIntrinsicNode* op) { parentPrecedence = Precedence::FUNC; os << op->func->getName(); @@ -183,6 +194,10 @@ void IndexNotationPrinter::visit(const ReductionNode* op) { void visit(const BinaryExprNode* node) { reductionName = "reduction(" + node->getOperatorString() + ")"; } + + void visit(const CallNode* node) { + reductionName = node->name + "Reduce"; + } }; parentPrecedence = Precedence::REDUCTION; os << ReductionName().get(op->op) << "(" << op->var << ", "; @@ -202,6 +217,10 @@ void IndexNotationPrinter::visit(const AssignmentNode* op) { void visit(const BinaryExprNode* node) { operatorName = node->getOperatorString(); } + + void visit(const CallNode* node) { + operatorName = node->name; + } }; op->lhs.accept(this); diff --git a/src/index_notation/index_notation_rewriter.cpp b/src/index_notation/index_notation_rewriter.cpp index 35d111457..32ad92ab0 100644 --- a/src/index_notation/index_notation_rewriter.cpp +++ b/src/index_notation/index_notation_rewriter.cpp @@ -42,6 +42,10 @@ void IndexNotationRewriter::visit(const AccessNode* op) { expr = op; } +void IndexNotationRewriter::visit(const IndexVarNode* op) { + expr = op; +} + template IndexExpr visitUnaryOp(const T *op, IndexNotationRewriter *rw) { IndexExpr a = rw->rewrite(op->a); @@ -103,6 +107,28 @@ void IndexNotationRewriter::visit(const CastNode* op) { } } +void IndexNotationRewriter::visit(const CallNode* op) { + std::vector args; + bool rewritten = false; + for(auto& arg : op->args) { + IndexExpr rewrittenArg = rewrite(arg); + args.push_back(rewrittenArg); + if (arg != rewrittenArg) { + rewritten = true; + } + } + + if (rewritten) { + const std::map subs = util::zipToMap(op->args, args); + IterationAlgebra newAlg = replaceAlgIndexExprs(op->iterAlg, subs); + expr = new CallNode(op->name, args, op->defaultLowerFunc, newAlg, op->properties, + op->regionDefinitions); + } + else { + expr = op; + } +} + void IndexNotationRewriter::visit(const CallIntrinsicNode* op) { std::vector args; bool rewritten = false; @@ -246,6 +272,10 @@ struct ReplaceRewriter : public IndexNotationRewriter { SUBSTITUTE_EXPR; } + void visit(const IndexVarNode* op) { + SUBSTITUTE_EXPR; + } + void visit(const LiteralNode* op) { SUBSTITUTE_EXPR; } @@ -274,6 +304,14 @@ struct ReplaceRewriter : public IndexNotationRewriter { SUBSTITUTE_EXPR; } + void visit(const CallNode* op) { + SUBSTITUTE_EXPR; + } + + void visit(const CallIntrinsicNode* op) { + SUBSTITUTE_EXPR; + } + void visit(const ReductionNode* op) { SUBSTITUTE_EXPR; } @@ -334,6 +372,16 @@ struct ReplaceIndexVars : public IndexNotationRewriter { } } + void visit(const IndexVarNode* op) { + + IndexVar var(op); + if(util::contains(substitutions, var)) { + expr = substitutions.at(var); + } else { + expr = var; + } + } + // TODO: Replace in assignments }; diff --git a/src/index_notation/index_notation_visitor.cpp b/src/index_notation/index_notation_visitor.cpp index ec77d99ba..317c5b15c 100644 --- a/src/index_notation/index_notation_visitor.cpp +++ b/src/index_notation/index_notation_visitor.cpp @@ -35,6 +35,9 @@ IndexNotationVisitor::~IndexNotationVisitor() { void IndexNotationVisitor::visit(const AccessNode* op) { } +void IndexNotationVisitor::visit(const IndexVarNode *op) { +} + void IndexNotationVisitor::visit(const LiteralNode* op) { } @@ -66,6 +69,12 @@ void IndexNotationVisitor::visit(const CastNode* op) { op->a.accept(this); } +void IndexNotationVisitor::visit(const CallNode* op) { + for (auto& arg : op->args) { + arg.accept(this); + } +} + void IndexNotationVisitor::visit(const CallIntrinsicNode* op) { for (auto& arg : op->args) { arg.accept(this); diff --git a/src/index_notation/iteration_algebra.cpp b/src/index_notation/iteration_algebra.cpp new file mode 100644 index 000000000..ebab586c9 --- /dev/null +++ b/src/index_notation/iteration_algebra.cpp @@ -0,0 +1,324 @@ +#include "taco/util/collections.h" +#include "taco/index_notation/iteration_algebra.h" +#include "taco/index_notation/iteration_algebra_printer.h" + +namespace taco { + +// Iteration Algebra Definitions + +IterationAlgebra::IterationAlgebra() : IterationAlgebra(nullptr) {} +IterationAlgebra::IterationAlgebra(const IterationAlgebraNode* n) : util::IntrusivePtr(n) {} +IterationAlgebra::IterationAlgebra(IndexExpr expr) : IterationAlgebra(new RegionNode(expr)) {} + +void IterationAlgebra::accept(IterationAlgebraVisitorStrict *v) const { + ptr->accept(v); +} + +std::ostream& operator<<(std::ostream& os, const IterationAlgebra& algebra) { + if(!algebra.defined()) return os << "{}"; + IterationAlgebraPrinter printer(os); + printer.print(algebra); + return os; +} + +// Definitions for Iteration Algebra + +// Region +Region::Region() : IterationAlgebra(new RegionNode) {} +Region::Region(IndexExpr expr) : IterationAlgebra(expr) {} +Region::Region(const taco::RegionNode *n) : IterationAlgebra(n) {} + +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Region to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Region(to(alg.ptr)); +} + +// Complement +Complement::Complement(const ComplementNode* n): IterationAlgebra(n) { +} + +Complement::Complement(IterationAlgebra alg) : Complement(new ComplementNode(alg)) { +} + + +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Complement to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Complement(to(alg.ptr)); +} + +// Intersect +Intersect::Intersect(IterationAlgebra a, IterationAlgebra b) : Intersect(new IntersectNode(a, b)) {} +Intersect::Intersect(const IterationAlgebraNode* n) : IterationAlgebra(n) {} + +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Intersect to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Intersect(to(alg.ptr)); +} + +// Union +Union::Union(IterationAlgebra a, IterationAlgebra b) : Union(new UnionNode(a, b)) {} +Union::Union(const IterationAlgebraNode* n) : IterationAlgebra(n) {} + +template <> bool isa(IterationAlgebra alg) { + return isa(alg.ptr); +} + +template <> Union to(IterationAlgebra alg) { + taco_iassert(isa(alg)); + return Union(to(alg.ptr)); +} + + +// Node method definitions start here: + +// Definitions for RegionNode +void RegionNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +const IndexExpr RegionNode::expr() const { + return expr_; +} + +// Definitions for ComplementNode +void ComplementNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +// Definitions for IntersectNode +void IntersectNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +const std::string IntersectNode::algebraString() const { + return "*"; +} + +// Definitions for UnionNode +void UnionNode::accept(IterationAlgebraVisitorStrict *v) const { + v->visit(this); +} + +const std::string UnionNode::algebraString() const { + return "U"; +} + +// Visitor definitions start here: + +// IterationAlgebraVisitorStrict definitions +void IterationAlgebraVisitorStrict::visit(const IterationAlgebra &alg) { + alg.accept(this); +} + +// Default IterationAlgebraVisitor definitions +void IterationAlgebraVisitor::visit(const RegionNode *n) { +} + +void IterationAlgebraVisitor::visit(const ComplementNode *n) { + n->a.accept(this); +} + +void IterationAlgebraVisitor::visit(const IntersectNode *n) { + n->a.accept(this); + n->b.accept(this); +} + +void IterationAlgebraVisitor::visit(const UnionNode *n) { + n->a.accept(this); + n->b.accept(this); +} + +// IterationAlgebraRewriter definitions start here: +IterationAlgebra IterationAlgebraRewriterStrict::rewrite(IterationAlgebra iter_alg) { + if(iter_alg.defined()) { + iter_alg.accept(this); + iter_alg = alg; + } + else { + iter_alg = IterationAlgebra(); + } + + alg = IterationAlgebra(); + return iter_alg; +} + +// Default IterationAlgebraRewriter definitions +void IterationAlgebraRewriter::visit(const RegionNode *n) { + alg = n; +} + +void IterationAlgebraRewriter::visit(const ComplementNode *n) { + IterationAlgebra a = rewrite(n->a); + if(n-> a == a) { + alg = n; + } else { + alg = new ComplementNode(a); + } +} + +void IterationAlgebraRewriter::visit(const IntersectNode *n) { + IterationAlgebra a = rewrite(n->a); + IterationAlgebra b = rewrite(n->b); + + if(n->a == a && n->b == b) { + alg = n; + } else { + alg = new IntersectNode(a, b); + } +} + +void IterationAlgebraRewriter::visit(const UnionNode *n) { + IterationAlgebra a = rewrite(n->a); + IterationAlgebra b = rewrite(n->b); + + if(n->a == a && n->b == b) { + alg = n; + } else { + alg = new UnionNode(a, b); + } +} + +struct AlgComparer : public IterationAlgebraVisitorStrict { + + bool eq = false; + IterationAlgebra bAlg; + bool checkIndexExprs; + + explicit AlgComparer(bool checkIndexExprs) : checkIndexExprs(checkIndexExprs) { + } + + bool compare(const IterationAlgebra& a, const IterationAlgebra& b) { + bAlg = b; + a.accept(this); + return eq; + } + + void visit(const RegionNode* node) { + if(!isa(bAlg.ptr)) { + eq = false; + return; + } + + auto bnode = to(bAlg.ptr); + if (checkIndexExprs && !equals(node->expr(), bnode->expr())) { + eq = false; + return; + } + eq = true; + } + + void visit(const ComplementNode* node) { + if (!isa(bAlg.ptr)) { + eq = false; + return; + } + + auto bNode = to(bAlg.ptr); + eq = AlgComparer(checkIndexExprs).compare(node->a, bNode->a); + } + + template + bool binaryCheck(const T* anode, IterationAlgebra b) { + if (!isa(b.ptr)) { + return false; + } + auto bnode = to(b.ptr); + return AlgComparer(checkIndexExprs).compare(anode->a, bnode->a) && + AlgComparer(checkIndexExprs).compare(anode->b, bnode->b); + } + + + void visit(const IntersectNode* node) { + eq = binaryCheck(node, bAlg); + } + + void visit(const UnionNode* node) { + eq = binaryCheck(node, bAlg); + } + +}; + +bool algStructureEqual(const IterationAlgebra& a, const IterationAlgebra& b) { + return AlgComparer(false).compare(a, b); +} + +bool algEqual(const IterationAlgebra& a, const IterationAlgebra& b) { + return AlgComparer(true).compare(a, b); +} + +class DeMorganApplier : public IterationAlgebraRewriterStrict { + + void visit(const RegionNode* n) { + alg = Complement(n); + } + + void visit(const ComplementNode* n) { + alg = applyDemorgan(n->a); + } + + template + IterationAlgebra binaryVisit(Node n) { + IterationAlgebra a = applyDemorgan(Complement(n->a)); + IterationAlgebra b = applyDemorgan(Complement(n->b)); + return new ComplementedNode(a, b); + } + + void visit(const IntersectNode* n) { + alg = binaryVisit(n); + } + + void visit(const UnionNode* n) { + alg = binaryVisit(n); + } +}; + +struct DeMorganDispatcher : public IterationAlgebraRewriter { + + using IterationAlgebraRewriter::visit; + + void visit(const ComplementNode *n) { + alg = DeMorganApplier().rewrite(n->a); + } +}; + +IterationAlgebra applyDemorgan(IterationAlgebra alg) { + return DeMorganDispatcher().rewrite(alg); +} + +class IndexExprReplacer : public IterationAlgebraRewriter { + +public: + IndexExprReplacer(const std::map& substitutions) : substitutions(substitutions) { + } + +private: + using IterationAlgebraRewriter::visit; + + void visit(const RegionNode* node) { + if (util::contains(substitutions, node->expr())) { + alg = new RegionNode(substitutions.at(node->expr())); + return; + } + alg = node; + } + + const std::map substitutions; +}; + +IterationAlgebra replaceAlgIndexExprs(IterationAlgebra alg, const std::map& substitutions) { + return IndexExprReplacer(substitutions).rewrite(alg); +} + +} \ No newline at end of file diff --git a/src/index_notation/iteration_algebra_printer.cpp b/src/index_notation/iteration_algebra_printer.cpp new file mode 100644 index 000000000..d582b1f5b --- /dev/null +++ b/src/index_notation/iteration_algebra_printer.cpp @@ -0,0 +1,55 @@ +#include "taco/index_notation/iteration_algebra_printer.h" + +namespace taco { + +// Iteration Algebra Printer +IterationAlgebraPrinter::IterationAlgebraPrinter(std::ostream& os) : os(os) {} + +void IterationAlgebraPrinter::print(const IterationAlgebra& alg) { + parentPrecedence = Precedence::TOP; + alg.accept(this); +} + +void IterationAlgebraPrinter::visit(const RegionNode* n) { + os << n->expr(); +} + +void IterationAlgebraPrinter::visit(const ComplementNode* n) { + Precedence precedence = Precedence::COMPLEMENT; + bool parenthesize = precedence > parentPrecedence; + parentPrecedence = precedence; + os << "~"; + if (parenthesize) { + os << "("; + } + n->a.accept(this); + if (parenthesize) { + os << ")"; + } +} + +void IterationAlgebraPrinter::visit(const IntersectNode* n) { + visitBinary(n, Precedence::INTERSECT); +} + +void IterationAlgebraPrinter::visit(const UnionNode* n) { + visitBinary(n, Precedence::UNION); +} + +template +void IterationAlgebraPrinter::visitBinary(Node n, Precedence precedence) { + bool parenthesize = precedence > parentPrecedence; + if (parenthesize) { + os << "("; + } + parentPrecedence = precedence; + n->a.accept(this); + os << " " << n->algebraString() << " "; + parentPrecedence = precedence; + n->b.accept(this); + if (parenthesize) { + os << ")"; + } +} + +} \ No newline at end of file diff --git a/src/index_notation/properties.cpp b/src/index_notation/properties.cpp new file mode 100644 index 000000000..aca434443 --- /dev/null +++ b/src/index_notation/properties.cpp @@ -0,0 +1,194 @@ +#include "taco/index_notation/properties.h" +#include "taco/index_notation/index_notation.h" + +namespace taco { + +// Property class definitions +Property::Property() : util::IntrusivePtr(nullptr) { +} + +Property::Property(const PropertyPtr* p) : util::IntrusivePtr(p) { +} + +bool Property::equals(const Property &p) const { + if(!defined() && !p.defined()) { + return true; + } + + if(defined() && p.defined()) { + return ptr->equals(p.ptr); + } + + return false; +} + +std::ostream & Property::print(std::ostream& os) const { + if(!defined()) { + os << "Property(undef)"; + return os; + } + return ptr->print(os); +} + +std::ostream& operator<<(std::ostream& os, const Property& p) { + return p.print(os); +} + +// Annihilator class definitions +template<> bool isa(const Property& p) { + return isa(p.ptr); +} + +template<> Annihilator to(const Property& p) { + taco_iassert(isa(p)); + return Annihilator(to(p.ptr)); +} + +Annihilator::Annihilator(Literal annihilator) : Annihilator(new AnnihilatorPtr(annihilator)) { +} + +Annihilator::Annihilator(Literal annihilator, std::vector &p) : Annihilator(new AnnihilatorPtr(annihilator, p)) { +} + +Annihilator::Annihilator(const PropertyPtr* p) : Property(p) { +} + +const Literal& Annihilator::annihilator() const { + taco_iassert(defined()); + return getPtr(*this)->annihilator(); +} + +const std::vector & Annihilator::positions() const { + taco_iassert(defined()); + return getPtr(*this)->positions(); +} + +IndexExpr Annihilator::annihilates(const std::vector& exprs) const { + taco_iassert(defined()); + Literal a = annihilator(); + std::vector pos = positions(); + if (pos.empty()) { + for(int i = 0; i < (int)exprs.size(); ++i) { + pos.push_back(i); + } + } + + for(const auto& idx : pos) { + taco_uassert(idx < (int)exprs.size()) << "Not enough args in expression"; + if(::taco::equals(exprs[idx], a)) { + return a; + } + } + + // We could not simplify. + return IndexExpr(); +} + +// Identity class definitions +template<> bool isa(const Property& p) { + return isa(p.ptr); +} + +template<> Identity to(const Property& p) { + taco_iassert(isa(p)); + return Identity(to(p.ptr)); +} + +Identity::Identity(Literal identity) : Identity(new IdentityPtr(identity)) { +} + +Identity::Identity(const PropertyPtr* p) : Property(p) { +} + +Identity::Identity(Literal identity, std::vector& positions) : Identity(new IdentityPtr(identity, positions)) { +} + +const Literal& Identity::identity() const { + taco_iassert(defined()); + return getPtr(*this)->identity(); +} + +const std::vector& Identity::positions() const { + taco_iassert(defined()); + return getPtr(*this)->positions(); +} + +IndexExpr Identity::simplify(const std::vector& exprs) const { + // If only one term is not the identity, replace expr with just that term. + // If all terms are identity, replace with identity. + Literal identityVal = identity(); + size_t nonIdentityTermsChecked = 0; + IndexExpr nonIdentityTerm; + + std::vector pos = positions(); + if (pos.empty()) { + for(int i = 0; i < (int)exprs.size(); ++i) { + pos.push_back(i); + } + } + + + for(const auto& idx : pos) { + if(!::taco::equals(identityVal, exprs[idx])) { + nonIdentityTerm = exprs[idx]; + ++nonIdentityTermsChecked; + } + if(nonIdentityTermsChecked > 1) { + return IndexExpr(); + } + } + + size_t identityTermsChecked = pos.size() - nonIdentityTermsChecked; + if(nonIdentityTermsChecked == 1 && identityTermsChecked == (exprs.size() - 1)) { + // If we checked all exprs and all are the identity except one return that term + return nonIdentityTerm; + } + + if(identityTermsChecked == exprs.size()) { + // If we checked every expression and + return identityVal; + } + + return IndexExpr(); +} + +// Associative class definitions +template<> bool isa(const Property& p) { + return isa(p.ptr); +} + +template<> Associative to(const Property& p) { + taco_iassert(isa(p)); + return Associative(to(p.ptr)); +} + +Associative::Associative() : Associative(new AssociativePtr) { +} + +Associative::Associative(const PropertyPtr* p) : Property(p) { +} + +// Commutative class definitions +template<> bool isa(const Property& p) { + return isa(p.ptr); +} + +template<> Commutative to(const Property& p) { + taco_iassert(isa(p)); + return Commutative(to(p.ptr)); +} + +Commutative::Commutative() : Commutative(new CommutativePtr) { +} + +Commutative::Commutative(const std::vector& ordering) : Commutative(new CommutativePtr(ordering)) { +} + +Commutative::Commutative(const PropertyPtr* p) : Property(p) { +} + +const std::vector & Commutative::ordering() const { + return getPtr(*this)->ordering_; +} + +} \ No newline at end of file diff --git a/src/index_notation/property_pointers.cpp b/src/index_notation/property_pointers.cpp new file mode 100644 index 000000000..84daa2be2 --- /dev/null +++ b/src/index_notation/property_pointers.cpp @@ -0,0 +1,142 @@ +#include "taco/index_notation/property_pointers.h" +#include "taco/index_notation/index_notation.h" +#include "taco/util/strings.h" + +namespace taco { + +struct AnnihilatorPtr::Content { + Literal annihilator; + std::vector positions; +}; + +struct IdentityPtr::Content { + Literal identity; + std::vector positions; +}; + +// Property pointer definitions +PropertyPtr::PropertyPtr() { +} + +PropertyPtr::~PropertyPtr() { +} + +std::ostream& PropertyPtr::print(std::ostream& os) const { + os << "Property()"; + return os; +} + +bool PropertyPtr::equals(const PropertyPtr* p) const { + return this == p; +} + +// Annihilator pointer definitions +AnnihilatorPtr::AnnihilatorPtr() : PropertyPtr(), content(nullptr) { +} + +AnnihilatorPtr::AnnihilatorPtr(Literal annihilator) : PropertyPtr(), content(new Content) { + content->annihilator = annihilator; + content->positions = std::vector(); +} + +AnnihilatorPtr::AnnihilatorPtr(Literal annihilator, std::vector& pos) : PropertyPtr(), content(new Content) { + content->annihilator = annihilator; + content->positions = pos; +} + +const Literal& AnnihilatorPtr::annihilator() const { + return content->annihilator; +} + +const std::vector & AnnihilatorPtr::positions() const { + return content->positions; +} + +std::ostream& AnnihilatorPtr::print(std::ostream& os) const { + os << "Annihilator("; + if (annihilator().defined()) { + os << annihilator(); + } else { + os << "undef"; + } + os << ")"; + return os; +} + +bool AnnihilatorPtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const AnnihilatorPtr* a = to(p); + return ::taco::equals(annihilator(), a->annihilator()); +} + +// Identity pointer definitions +IdentityPtr::IdentityPtr() : PropertyPtr(), content(nullptr) { +} + +IdentityPtr::IdentityPtr(Literal identity) : PropertyPtr(), content(new Content) { + content->identity = identity; +} + +IdentityPtr::IdentityPtr(Literal identity, std::vector &p) : PropertyPtr(), content(new Content) { + content->identity; + content->positions = p; +} + +const Literal& IdentityPtr::identity() const { + return content->identity; +} + +const std::vector & IdentityPtr::positions() const { + return content->positions; +} + +std::ostream& IdentityPtr::print(std::ostream& os) const { + os << "Identity("; + if (identity().defined()) { + os << identity(); + } else { + os << "undef"; + } + os << ")"; + return os; +} + +bool IdentityPtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const IdentityPtr* idnty = to(p); + return ::taco::equals(identity(), idnty->identity()); +} + +// Associative pointer definitions +AssociativePtr::AssociativePtr() : PropertyPtr() { +} + +std::ostream& AssociativePtr::print(std::ostream& os) const { + os << "Associative()"; + return os; +} + +bool AssociativePtr::equals(const PropertyPtr* p) const { + return isa(p); +} + +// CommutativePtr definitions +CommutativePtr::CommutativePtr() : PropertyPtr() { +} + +CommutativePtr::CommutativePtr(const std::vector& ordering) : ordering_(ordering) { +} + +std::ostream& CommutativePtr::print(std::ostream& os) const { + os << "Commutative("; + os << "{" << util::join(ordering_) << "})"; + return os; +} + +bool CommutativePtr::equals(const PropertyPtr* p) const { + if(!isa(p)) return false; + const CommutativePtr* idnty = to(p); + return ordering_ == idnty->ordering_; +} + +} diff --git a/src/index_notation/tensor_operator.cpp b/src/index_notation/tensor_operator.cpp new file mode 100644 index 000000000..0526e552f --- /dev/null +++ b/src/index_notation/tensor_operator.cpp @@ -0,0 +1,145 @@ +#include "taco/index_notation/tensor_operator.h" + +namespace taco { + +// Full construction +Func::Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) + : name(util::uniqueName("Func")), lowererFunc(lowererFunc), algebraFunc(algebraFunc), + properties(properties), regionDefinitions(specialDefinitions) { +} + +Func::Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) + : name(name), lowererFunc(lowererFunc), algebraFunc(algebraFunc), properties(properties), + regionDefinitions(specialDefinitions) { +} + +// Construct without specifying algebra +Func::Func(std::string name, FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) + : Func(name, lowererFunc, nullptr, properties, specialDefinitions) { +} + +Func::Func(FuncBodyGenerator lowererFunc, std::vector properties, + std::map, FuncBodyGenerator> specialDefinitions) + : Func(util::uniqueName("Func"), lowererFunc, nullptr, properties, specialDefinitions) { +} + +// Construct without properties +Func::Func(std::string name, FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions) + : Func(name, lowererFunc, algebraFunc, {}, specialDefinitions) { +} + +Func::Func(FuncBodyGenerator lowererFunc, FuncAlgebraGenerator algebraFunc, + std::map, FuncBodyGenerator> specialDefinitions) : + Func(util::uniqueName("Func"), lowererFunc, algebraFunc, {}, specialDefinitions) { +} + +// Construct without algebra or properties +Func::Func(std::string name, FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions) + : Func(name, lowererFunc, nullptr, specialDefinitions) { +} + +Func::Func(FuncBodyGenerator lowererFunc, std::map, FuncBodyGenerator> specialDefinitions) + : Func(lowererFunc, nullptr, specialDefinitions) { +} + +IterationAlgebra Func::inferAlgFromProperties(const std::vector& exprs) { + if(properties.empty()) { + return constructDefaultAlgebra(exprs); + } + + // Start with smallest regions first. So we first check for annihilator and positional annihilator + if(findProperty(properties).defined()) { + Annihilator annihilator = findProperty(properties); + IterationAlgebra alg = constructAnnihilatorAlg(exprs, annihilator); + if(alg.defined()) { + return alg; + } + } + + // Idempotence here ... + + if(findProperty(properties).defined()) { + Identity identity = findProperty(properties); + IterationAlgebra alg = constructIdentityAlg(exprs, identity); + if(alg.defined()) { + return alg; + } + } + + return constructDefaultAlgebra(exprs); +} + +// Constructs an algebra that iterates over the entire space +IterationAlgebra Func::constructDefaultAlgebra(const std::vector& exprs) { + if(exprs.empty()) return Region(); + + IterationAlgebra tensorsRegions(exprs[0]); + for(size_t i = 1; i < exprs.size(); ++i) { + tensorsRegions = Union(tensorsRegions, exprs[i]); + } + + IterationAlgebra background = Complement(tensorsRegions); + return Union(tensorsRegions, background); +} + +IterationAlgebra Func::constructAnnihilatorAlg(const std::vector &args, taco::Annihilator annihilator) { + if(args.size () < 2) { + return IterationAlgebra(); + } + + Literal annVal = annihilator.annihilator(); + std::vector toIntersect; + + if(annihilator.positions().empty()) { + for(IndexExpr arg : args) { + if(equals(inferFill(arg), annVal)) { + toIntersect.push_back(arg); + } + } + } else { + for(size_t idx : annihilator.positions()) { + if(equals(inferFill(args[idx]), annVal)) { + toIntersect.push_back(args[idx]); + } + } + } + + if(toIntersect.empty()) { + return IterationAlgebra(); + } + + IterationAlgebra alg = toIntersect[0]; + for(size_t i = 1; i < toIntersect.size(); ++i) { + alg = Intersect(alg, toIntersect[i]); + } + + return alg; +} + +IterationAlgebra Func::constructIdentityAlg(const std::vector &args, taco::Identity identity) { + if(args.size() < 2) { + return IterationAlgebra(); + } + + Literal idntyVal = identity.identity(); + + if(identity.positions().empty()) { + for(IndexExpr arg : args) { + if(!equals(inferFill(arg), idntyVal)) { + return IterationAlgebra(); + } + } + } + + IterationAlgebra alg(args[0]); + for(size_t i = 1; i < args.size(); ++i) { + alg = Union(alg, args[i]); + } + return alg; +} + +} \ No newline at end of file diff --git a/src/index_notation/transformations.cpp b/src/index_notation/transformations.cpp index 5310455f6..d7f73329b 100644 --- a/src/index_notation/transformations.cpp +++ b/src/index_notation/transformations.cpp @@ -236,9 +236,10 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const { TensorVar ws = precompute.getWorkspace(); IndexExpr e = precompute.getExpr(); IndexVar iw = precompute.getiw(); + const std::map index_var_substitutions = {{i,iw}}; IndexStmt consumer = forall(i, replace(s, {{e, ws(i)}})); - IndexStmt producer = forall(iw, ws(iw) = replace(e, {{i,iw}})); + IndexStmt producer = forall(iw, ws(iw) = replace(e, index_var_substitutions)); Where where(consumer, producer); stmt = where; diff --git a/src/ir/ir.cpp b/src/ir/ir.cpp index 2623b27cd..bf0106c99 100644 --- a/src/ir/ir.cpp +++ b/src/ir/ir.cpp @@ -183,6 +183,8 @@ std::complex Literal::getComplexValue() const { return 0.0; } + + template bool compare(const Literal* literal, double val) { return literal->getValue() == static_cast(val); } @@ -870,6 +872,9 @@ Expr GetProperty::make(Expr tensor, TensorProperty property, int mode) { case TensorProperty::ValuesSize: gp->name = tensorVar->name + "_vals_size"; break; + case TensorProperty::FillValue: + gp->name = tensorVar->name + "_fill_value"; + break; } return gp; diff --git a/src/ir/ir_generators.cpp b/src/ir/ir_generators.cpp index 23fe5edf6..11f3a1fee 100644 --- a/src/ir/ir_generators.cpp +++ b/src/ir/ir_generators.cpp @@ -14,7 +14,7 @@ Stmt compoundStore(Expr a, Expr i, Expr val, bool use_atomics, ParallelUnit atom return Store::make(a, i, add, use_atomics, atomic_parallel_unit); } -Stmt compoundAssign(Expr a, Expr val, bool use_atomics, ParallelUnit atomic_parallel_unit) { +Stmt addAssign(Expr a, Expr val, bool use_atomics, ParallelUnit atomic_parallel_unit) { Expr add = (val.type().getKind() == Datatype::Bool) ? Or::make(a, val) : Add::make(a, val); return Assign::make(a, add, use_atomics, atomic_parallel_unit); diff --git a/src/ir/ir_generators.h b/src/ir/ir_generators.h index 5b7015ba5..7f5af7157 100644 --- a/src/ir/ir_generators.h +++ b/src/ir/ir_generators.h @@ -14,7 +14,7 @@ class Stmt; Stmt compoundStore(Expr a, Expr i, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); /// Generate `a += val;` -Stmt compoundAssign(Expr a, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); +Stmt addAssign(Expr a, Expr val, bool use_atomics=false, ParallelUnit atomic_parallel_unit=ParallelUnit::NotParallel); /// Generate `exprs_0 && ... && exprs_n` Expr conjunction(std::vector exprs); diff --git a/src/lower/expr_tools.cpp b/src/lower/expr_tools.cpp index ded5c53dd..74ea277f4 100644 --- a/src/lower/expr_tools.cpp +++ b/src/lower/expr_tools.cpp @@ -189,6 +189,15 @@ class SubExprVisitor : public IndexExprVisitorStrict { using IndexExprVisitorStrict::visit; + void visit(const IndexVarNode* op) { + IndexVar var(op); + if (util::contains(vars, var)) { + subExpr = op; + return; + } + subExpr = IndexExpr(); + } + void visit(const AccessNode* op) { // If any variable is in the set of index variables, then the expression // has not been emitted at a previous level, so we keep it. @@ -252,6 +261,10 @@ class SubExprVisitor : public IndexExprVisitorStrict { subExpr = binarySubExpr(op); } + void visit(const CallNode* op) { + taco_not_supported_yet; + } + void visit(const CastNode* op) { taco_not_supported_yet; } diff --git a/src/lower/iterator.cpp b/src/lower/iterator.cpp index aad82a3b3..fbf352929 100644 --- a/src/lower/iterator.cpp +++ b/src/lower/iterator.cpp @@ -48,12 +48,12 @@ Iterator::Iterator(std::shared_ptr content) : content(content) { Iterator::Iterator(IndexVar indexVar, bool isFull) : content(new Content) { content->indexVar = indexVar; - content->coordVar = Var::make(indexVar.getName(), Int()); - content->posVar = Var::make(indexVar.getName() + "_pos", Int()); + content->coordVar = Var::make(indexVar.getName(), indexVar.getDataType()); + content->posVar = Var::make(indexVar.getName() + "_pos", indexVar.getDataType()); if (!isFull) { - content->beginVar = Var::make(indexVar.getName() + "_begin", Int()); - content->endVar = Var::make(indexVar.getName() + "_end", Int()); + content->beginVar = Var::make(indexVar.getName() + "_begin", indexVar.getDataType()); + content->endVar = Var::make(indexVar.getName() + "_end", indexVar.getDataType()); } } @@ -79,12 +79,12 @@ Iterator::Iterator(IndexVar indexVar, Expr tensor, Mode mode, Iterator parent, if (useNameForPos) { posNamePrefix = name; } - content->posVar = Var::make(name, Int()); - content->endVar = Var::make("p" + modeName + "_end", Int()); - content->beginVar = Var::make("p" + modeName + "_begin", Int()); + content->posVar = Var::make(name, indexVar.getDataType()); + content->endVar = Var::make("p" + modeName + "_end", indexVar.getDataType()); + content->beginVar = Var::make("p" + modeName + "_begin", indexVar.getDataType()); - content->coordVar = Var::make(name, Int()); - content->segendVar = Var::make(modeName + "_segend", Int()); + content->coordVar = Var::make(name, indexVar.getDataType()); + content->segendVar = Var::make(modeName + "_segend", indexVar.getDataType()); content->validVar = Var::make("v" + modeName, Bool); } @@ -394,7 +394,6 @@ struct Iterators::Content { map modeIterators; }; - Iterators::Iterators() : content(new Content) { @@ -431,10 +430,12 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) { ProvenanceGraph provGraph = ProvenanceGraph(stmt); set underivedAdded; + set computeVars; // Create dimension iterators match(stmt, function([&](auto n, auto m) { - content->modeIterators.insert({n->indexVar, Iterator(n->indexVar, !provGraph.hasCoordBounds(n->indexVar) && provGraph.isCoordVariable(n->indexVar))}); + content->modeIterators.insert({n->indexVar, Iterator(n->indexVar, !provGraph.hasCoordBounds(n->indexVar) + && provGraph.isCoordVariable(n->indexVar))}); for (const IndexVar& underived : provGraph.getUnderivedAncestors(n->indexVar)) { if (!underivedAdded.count(underived)) { content->modeIterators.insert({underived, underived}); @@ -442,6 +443,9 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) } } m->match(n->stmt); + }), + function([&](const IndexVarNode* var) { + }) ); @@ -455,7 +459,7 @@ Iterators::Iterators(IndexStmt stmt, const map& tensorVars) }), function([&](auto n, auto m) { m->match(n->rhs); - m->match(n->lhs); + m->match(n->lhs); }) ); @@ -534,7 +538,7 @@ Iterator Iterators::levelIterator(ModeAccess modeAccess) const taco_iassert(content != nullptr); taco_iassert(util::contains(content->levelIterators, modeAccess)) << "Cannot find " << modeAccess << " in " - << util::join(content->levelIterators); + << util::join(content->levelIterators) << "\n" << modeAccess.getAccess(); return content->levelIterators.at(modeAccess); } diff --git a/src/lower/lower.cpp b/src/lower/lower.cpp index e24406543..ac4613410 100644 --- a/src/lower/lower.cpp +++ b/src/lower/lower.cpp @@ -57,7 +57,7 @@ ir::Stmt lower(IndexStmt stmt, std::string name, // if (!messages.empty()) { // std::cerr << "Verifier messages:\n" << messages << "\n"; // } - + return lowered; } diff --git a/src/lower/lowerer_impl.cpp b/src/lower/lowerer_impl.cpp index 25a3593ab..6fc32467d 100644 --- a/src/lower/lowerer_impl.cpp +++ b/src/lower/lowerer_impl.cpp @@ -2,6 +2,7 @@ #include "taco/lower/lowerer_impl.h" #include "taco/index_notation/index_notation.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation_nodes.h" #include "taco/index_notation/index_notation_visitor.h" #include "taco/ir/ir.h" @@ -56,9 +57,11 @@ class LowererImpl::Visitor : public IndexNotationVisitorStrict { void visit(const SqrtNode* node) { expr = impl->lowerSqrt(node); } void visit(const CastNode* node) { expr = impl->lowerCast(node); } void visit(const CallIntrinsicNode* node) { expr = impl->lowerCallIntrinsic(node); } + void visit(const CallNode* node) { expr = impl->lowerTensorOp(node); } void visit(const ReductionNode* node) { taco_ierror << "Reduction nodes not supported in concrete index notation"; } + void visit(const IndexVarNode* node) { expr = impl->lowerIndexVar(node); } }; LowererImpl::LowererImpl() : visitor(new Visitor(this)) { @@ -112,6 +115,7 @@ LowererImpl::lower(IndexStmt stmt, string name, this->compute = compute; definedIndexVarsOrdered = {}; definedIndexVars = {}; + loopOrderAllowsShortCircuit = allForFreeLoopsBeforeAllReductionLoops(stmt); // Create result and parameter variables vector results = getResults(stmt); @@ -243,7 +247,6 @@ LowererImpl::lower(IndexStmt stmt, string name, } } } - // Allocate and initialize append and insert mode indices Stmt initializeResults = initResultArrays(resultAccesses, inputAccesses, reducedAccesses); @@ -293,8 +296,20 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) return Assign::make(var, rhs); } else { - taco_iassert(isa(assignment.getOperator())); - return compoundAssign(var, rhs, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), atomicParallelUnit); + if (isa(assignment.getOperator())) { + return addAssign(var, rhs, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + } + taco_iassert(isa(assignment.getOperator())); + + Call op = to(assignment.getOperator()); + Expr assignOp = op.getFunc()({var, rhs}); + Stmt assign = Assign::make(var, assignOp, markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + + std::vector properties = op.getProperties(); + assign = Block::make(assign, emitEarlyExit(var, properties)); + return assign; } } // Assignments to tensor variables (non-scalar). @@ -306,7 +321,21 @@ Stmt LowererImpl::lowerAssignment(Assignment assignment) computeStmt = Store::make(values, loc, rhs); } else { - computeStmt = compoundStore(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); + if (isa(assignment.getOperator())) { + computeStmt = compoundStore(values, loc, rhs, markAssignsAtomicDepth > 0, atomicParallelUnit); + } else { + + taco_iassert(isa(assignment.getOperator())); + + Call op = to(assignment.getOperator()); + Expr assignOp = op.getFunc()({Load::make(values, loc), rhs}); + computeStmt = Store::make(values, loc, assignOp, + markAssignsAtomicDepth > 0 && !util::contains(whereTemps, result), + atomicParallelUnit); + + std::vector properties = op.getProperties(); + computeStmt = Block::make(computeStmt, emitEarlyExit(Load::make(values, loc), properties)); + } } taco_iassert(computeStmt.defined()); } @@ -422,6 +451,7 @@ Stmt LowererImpl::lowerForall(Forall forall) if (isa(ir::simplify(iterBounds[0])) && ir::simplify(iterBounds[0]).as()->equalsScalar(0)) { guardCondition = maxGuard; } + ir::Stmt guard = ir::IfThenElse::make(guardCondition, ir::Continue::make()); recoverySteps.push_back(guard); } @@ -487,7 +517,7 @@ Stmt LowererImpl::lowerForall(Forall forall) parallelUnitSizes[forall.getParallelUnit()] = ir::Sub::make(bounds[1], bounds[0]); } - MergeLattice lattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult); + MergeLattice caseLattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult); vector resultAccesses; set reducedAccesses; std::tie(resultAccesses, reducedAccesses) = getResultAccesses(forall); @@ -506,11 +536,11 @@ Stmt LowererImpl::lowerForall(Forall forall) Stmt loops; // Emit a loop that iterates over over a single iterator (optimization) - if (lattice.iterators().size() == 1 && lattice.iterators()[0].isUnique()) { - taco_iassert(lattice.points().size() == 1); + if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique()) { + MergeLattice loopLattice = caseLattice.getLoopLattice(); - MergePoint point = lattice.points()[0]; - Iterator iterator = lattice.iterators()[0]; + MergePoint point = loopLattice.points()[0]; + Iterator iterator = loopLattice.iterators()[0]; vector locators = point.locators(); vector appenders; @@ -554,21 +584,21 @@ Stmt LowererImpl::lowerForall(Forall forall) } if (!isWhereProducer && hasPosDescendant && underivedAncestors.size() > 1 && provGraph.isPosVariable(iterator.getIndexVar()) && posDescendant == forall.getIndexVar()) { - loops = lowerForallFusedPosition(forall, iterator, locators, - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallFusedPosition(forall, iterator, locators, inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } else if (canAccelWithSparseIteration) { - loops = lowerForallDenseAcceleration(forall, locators, inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallDenseAcceleration(forall, locators, inserters, appenders, caseLattice, reducedAccesses, recoveryStmt); } // Emit dimension coordinate iteration loop else if (iterator.isDimensionIterator()) { - loops = lowerForallDimension(forall, point.locators(), - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallDimension(forall, point.locators(), inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } // Emit position iteration loop else if (iterator.hasPosIter()) { - loops = lowerForallPosition(forall, iterator, locators, - inserters, appenders, reducedAccesses, recoveryStmt); + loops = lowerForallPosition(forall, iterator, locators, inserters, appenders, caseLattice, + reducedAccesses, recoveryStmt); } // Emit coordinate iteration loop else { @@ -581,7 +611,7 @@ Stmt LowererImpl::lowerForall(Forall forall) else { std::vector underivedAncestors = provGraph.getUnderivedAncestors(forall.getIndexVar()); taco_iassert(underivedAncestors.size() == 1); // TODO: add support for fused coordinate of pos loop - loops = lowerMergeLattice(lattice, underivedAncestors[0], + loops = lowerMergeLattice(caseLattice, underivedAncestors[0], forall.getStmt(), reducedAccesses); } // taco_iassert(loops.defined()); @@ -897,6 +927,7 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -907,8 +938,8 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, atomicParallelUnit = forall.getParallelUnit(); } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, + appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -940,6 +971,7 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -968,7 +1000,7 @@ Stmt LowererImpl::lowerForallDimension(Forall forall, } Stmt declareVar = VarDecl::make(coordinate, Load::make(indexList, loopVar)); - Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); Stmt resetGuard = ir::Store::make(bitGuard, coordinate, ir::Literal::make(false), markAssignsAtomicDepth > 0, atomicParallelUnit); body = Block::make(declareVar, body, resetGuard); @@ -999,6 +1031,7 @@ Stmt LowererImpl::lowerForallCoordinate(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { taco_not_supported_yet; @@ -1009,6 +1042,7 @@ Stmt LowererImpl::lowerForallPosition(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -1035,8 +1069,7 @@ Stmt LowererImpl::lowerForallPosition(Forall forall, Iterator iterator, markAssignsAtomicDepth++; } - Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + Stmt body = lowerForallBody(coordinate, forall.getStmt(), locators, inserters, appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1111,6 +1144,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, set reducedAccesses, ir::Stmt recoveryStmt) { @@ -1189,7 +1223,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, else { locateCoordVar = ir::Assign::make(coordVarUnknown, posVarUnknown); } - Stmt loopBody = ir::Block::make(compoundAssign(posVarUnknown, 1), locateCoordVar, loopToTrackUnderiveds); + Stmt loopBody = ir::Block::make(addAssign(posVarUnknown, 1), locateCoordVar, loopToTrackUnderiveds); if (posIteratorLevel.getParent().hasPosIter()) { // TODO: if level is unique or not loopToTrackUnderiveds = IfThenElse::make(loopcond, loopBody); } @@ -1205,7 +1239,7 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, } Stmt body = lowerForallBody(coordinate, forall.getStmt(), - locators, inserters, appenders, reducedAccesses); + locators, inserters, appenders, caseLattice, reducedAccesses); if (forall.getParallelUnit() != ParallelUnit::NotParallel && forall.getOutputRaceStrategy() == OutputRaceStrategy::Atomics) { markAssignsAtomicDepth--; @@ -1271,31 +1305,34 @@ Stmt LowererImpl::lowerForallFusedPosition(Forall forall, Iterator iterator, } -Stmt LowererImpl::lowerMergeLattice(MergeLattice lattice, IndexVar coordinateVar, +Stmt LowererImpl::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinateVar, IndexStmt statement, const std::set& reducedAccesses) { + // Lower merge lattice always gets called from lowerForAll. So we want loop lattice + MergeLattice loopLattice = caseLattice.getLoopLattice(); + Expr coordinate = getCoordinateVar(coordinateVar); - vector appenders = filter(lattice.results(), + vector appenders = filter(loopLattice.results(), [](Iterator it){return it.hasAppend();}); - vector mergers = lattice.points()[0].mergers(); - Stmt iteratorVarInits = codeToInitializeIteratorVars(lattice.iterators(), lattice.points()[0].rangers(), mergers, coordinate, coordinateVar); + vector mergers = loopLattice.points()[0].mergers(); + Stmt iteratorVarInits = codeToInitializeIteratorVars(loopLattice.iterators(), loopLattice.points()[0].rangers(), mergers, coordinate, coordinateVar); // if modeiteratornonmerger then will be declared in codeToInitializeIteratorVars auto modeIteratorsNonMergers = - filter(lattice.points()[0].iterators(), [mergers](Iterator it){ + filter(loopLattice.points()[0].iterators(), [mergers](Iterator it){ bool isMerger = find(mergers.begin(), mergers.end(), it) != mergers.end(); return it.isDimensionIterator() && !isMerger; }); bool resolvedCoordDeclared = !modeIteratorsNonMergers.empty(); vector mergeLoopsVec; - for (MergePoint point : lattice.points()) { + for (MergePoint point : loopLattice.points()) { // Each iteration of this loop generates a while loop for one of the merge // points in the merge lattice. - IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point,lattice)); - MergeLattice sublattice = lattice.subLattice(point); + IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice)); + MergeLattice sublattice = caseLattice.subLattice(point); Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared); mergeLoopsVec.push_back(mergeLoop); } @@ -1403,58 +1440,245 @@ Stmt LowererImpl::resolveCoordinate(std::vector mergers, ir::Expr coor } Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, - MergeLattice lattice, + MergeLattice caseLattice, const std::set& reducedAccesses) { vector result; + if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) { + // Can check value array of some tensor + Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses); + result.push_back(body); + return Block::make(result); + } + + // Emitting structural cases so unconditionally apply lattice optimizations. + MergeLattice loopLattice = caseLattice.getLoopLattice(); + vector appenders; vector inserters; - tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + tie(appenders, inserters) = splitAppenderAndInserters(loopLattice.results()); - // Just one iterator so no conditionals - if (lattice.iterators().size() == 1) { - Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, reducedAccesses); + if (loopLattice.iterators().size() == 1) { + // Just one iterator so no conditional + taco_iassert(!loopLattice.points()[0].isOmitter()); + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, loopLattice, reducedAccesses); result.push_back(body); } - else { + else if (!loopLattice.points().empty()) { vector> cases; - for (MergePoint point : lattice.points()) { + for (MergePoint point : loopLattice.points()) { - // Construct case expression - vector coordComparisons; - for (Iterator iterator : point.rangers()) { - if (!(provGraph.isCoordVariable(iterator.getIndexVar()) && provGraph.isDerivedFrom(iterator.getIndexVar(), coordinateVar))) { - coordComparisons.push_back(Eq::make(iterator.getCoordVar(), coordinate)); - } + if(point.isOmitter()) { + continue; } + // Construct case expression + vector coordComparisons = compareToResolvedCoordinate(point.rangers(), coordinate, coordinateVar); + vector omittedRegionIterators = loopLattice.retrieveRegionIteratorsToOmit(point); + std::vector neqComparisons = compareToResolvedCoordinate(omittedRegionIterators, coordinate, coordinateVar); + append(coordComparisons, neqComparisons); + + coordComparisons = filter(coordComparisons, [](const Expr& e) { return e.defined(); }); + // Construct case body - IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); + IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, loopLattice)); Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, - inserters, appenders, reducedAccesses); + inserters, appenders, MergeLattice({point}), reducedAccesses); if (coordComparisons.empty()) { Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, - appenders, reducedAccesses); + appenders, MergeLattice({point}), reducedAccesses); result.push_back(body); break; } cases.push_back({taco::ir::conjunction(coordComparisons), body}); } - result.push_back(Case::make(cases, lattice.exact())); + result.push_back(Case::make(cases, loopLattice.exact())); } return Block::make(result); } +ir::Expr LowererImpl::constructCheckForAccessZero(Access access) { + Expr tensorValue = lower(access); + IndexExpr zeroVal = Literal::zero(tensorValue.type()); //TODO ARRAY Generalize + return Neq::make(tensorValue, lower(zeroVal)); +} + +std::vector LowererImpl::getModeIterators(const std::vector& iters) { + // For now only check mode iterators. + return filter(iters, [](const Iterator& it){return it.isModeIterator();}); +} + +std::vector LowererImpl::constructInnerLoopCasePreamble(ir::Expr coordinate, IndexVar coordinateVar, + MergeLattice lattice, + map& iteratorToConditionMap) { + vector result; + + // First, get mode iterator coordinate comparisons + std::vector modeIterators = getModeIterators(lattice.iterators()); + vector coordComparisonsForModeIters = compareToResolvedCoordinate(modeIterators, coordinate, coordinateVar); + + std::vector modeItersWithIndexCases; + std::vector coordComparisons; + for(size_t i = 0; i < coordComparisonsForModeIters.size(); ++i) { + Expr expr = coordComparisonsForModeIters[i]; + if (expr.defined()) { + modeItersWithIndexCases.push_back(modeIterators[i]); + coordComparisons.push_back(expr); + } + } + + // Construct tensor iterators with modeIterators first then locate iterators to keep a mapping between vector indices + vector tensorIterators = combine(modeItersWithIndexCases, lattice.locators()); + tensorIterators = getModeIterators(tensorIterators); + + // Get value comparisons for all tensor iterators + vector itAccesses; + vector valueComparisons; + for(auto it : tensorIterators) { + Access itAccess = iterators.modeAccess(it).getAccess(); + itAccesses.push_back(itAccess); + if(it.isLeaf()) { + valueComparisons.push_back(constructCheckForAccessZero(itAccess)); + } else { + valueComparisons.push_back(Expr()); + } + } + + // Construct isNonZero cases + for(size_t i = 0; i < coordComparisons.size(); ++i) { + Expr nonZeroCase; + if(coordComparisons[i].defined() && valueComparisons[i].defined()) { + nonZeroCase = conjunction({coordComparisons[i], valueComparisons[i]}); + } else if (valueComparisons[i].defined()) { + nonZeroCase = valueComparisons[i]; + } else if (coordComparisons[i].defined()) { + nonZeroCase = coordComparisons[i]; + } else { + continue; + } + Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + Stmt declaration = VarDecl::make(caseName, nonZeroCase); + result.push_back(declaration); + iteratorToConditionMap[tensorIterators[i]] = caseName; + } + + for(size_t i = modeItersWithIndexCases.size(); i < valueComparisons.size(); ++i) { + Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool); + Stmt declaration = VarDecl::make(caseName, valueComparisons[i]); + result.push_back(declaration); + iteratorToConditionMap[tensorIterators[i]] = caseName; + } + + return result; +} + +vector LowererImpl::lowerCasesFromMap(map iteratorToCondition, + ir::Expr coordinate, IndexStmt stmt, const MergeLattice& lattice, + const std::set& reducedAccesses) { + + vector appenders; + vector inserters; + tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + + std::vector result; + vector> cases; + for (MergePoint point : lattice.points()) { + + if(point.isOmitter()) { + continue; + } + + // Construct case expression + vector isNonZeroComparisions; + for(auto& it : combine(point.rangers(), point.locators())) { + if(util::contains(iteratorToCondition, it)) { + taco_iassert(iteratorToCondition.at(it).type() == taco::Bool) << "Map must have boolean types"; + isNonZeroComparisions.push_back(iteratorToCondition.at(it)); + } + } + + function getNegatedComparison = [&](const Iterator& it) {return ir::Neg::make(iteratorToCondition.at(it));}; + vector omittedRegionIterators = lattice.retrieveRegionIteratorsToOmit(point); + for(auto& it : omittedRegionIterators) { + if(util::contains(iteratorToCondition, it)) { + isNonZeroComparisions.push_back(ir::Neg::make(iteratorToCondition.at(it))); + } + } + + // Construct case body + IndexStmt zeroedStmt = zero(stmt, getExhaustedAccesses(point, lattice)); + Stmt body = lowerForallBody(coordinate, zeroedStmt, {}, + inserters, appenders, MergeLattice({point}), reducedAccesses); + if (isNonZeroComparisions.empty()) { + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, MergeLattice({point}), reducedAccesses); + result.push_back(body); + break; + } + cases.push_back({taco::ir::conjunction(isNonZeroComparisions), body}); + } + + vector inputs = combine(lattice.iterators(), lattice.locators()); + inputs = getModeIterators(inputs); + + if(!lattice.exact() && util::any(inserters, [](Iterator it){return it.isFull();}) && hasNoForAlls(stmt) + && any(inputs, [](Iterator it){return it.isFull();})) { + // Currently, if the lattice is not exact, the output is full and any of the inputs are full, we initialize + // the result tensor + vector stmts; + for(auto& it : inserters) { + if(it.isFull()) { + Access access = iterators.modeAccess(it).getAccess(); + IndexStmt initStmt = Assignment(access, Literal::zero(access.getDataType())); + Stmt initialization = lowerForallBody(coordinate, initStmt, {}, inserters, + appenders, MergeLattice({}), reducedAccesses); + stmts.push_back(initialization); + } + } + Stmt backgroundInit = Block::make(stmts); + cases.push_back({Expr((bool) true), backgroundInit}); + result.push_back(Case::make(cases, true)); + } else { + result.push_back(Case::make(cases, lattice.exact())); + } + return result; +} + +/// Lowers a merge lattice to cases assuming there are no more loops to be emitted in stmt. +Stmt LowererImpl::lowerMergeCasesWithExplicitZeroChecks(ir::Expr coordinate, IndexVar coordinateVar, IndexStmt stmt, + MergeLattice lattice, const std::set& reducedAccesses) { + + vector result; + if (lattice.points().size() == 1 && lattice.iterators().size() == 1) { + // Just one iterator so no conditional + vector appenders; + vector inserters; + tie(appenders, inserters) = splitAppenderAndInserters(lattice.results()); + taco_iassert(!lattice.points()[0].isOmitter()); + Stmt body = lowerForallBody(coordinate, stmt, {}, inserters, + appenders, lattice, reducedAccesses); + result.push_back(body); + } else if (!lattice.points().empty()) { + map iteratorToConditionMap; + + vector preamble = constructInnerLoopCasePreamble(coordinate, coordinateVar, lattice, iteratorToConditionMap); + util::append(result, preamble); + vector cases = lowerCasesFromMap(iteratorToConditionMap, coordinate, stmt, lattice, reducedAccesses); + util::append(result, cases); + } + + return Block::make(result); +} Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, vector locators, vector inserters, vector appenders, + MergeLattice caseLattice, const set& reducedAccesses) { - Stmt initVals = resizeAndInitValues(appenders, reducedAccesses); // Inserter positions Stmt declInserterPosVars = declLocatePosVars(inserters); @@ -1467,6 +1691,36 @@ Stmt LowererImpl::lowerForallBody(Expr coordinate, IndexStmt stmt, captureNextLocatePos = false; } + if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.points().size() > 1) { + + // Code of loop body statement + // Explicit zero checks needed + std::vector stmts; + + // Need to emit checks based on case lattice + vector modeIterators = getModeIterators(combine(caseLattice.iterators(), caseLattice.locators())); + std::map caseMap; + for(auto it : modeIterators) { + if(it.isLeaf()) { + // Only emit explicit 0 checks for leaf iterators since these are the only iterators can can access tensor + // values array + Access itAccess = iterators.modeAccess(it).getAccess(); + Expr accessCase = constructCheckForAccessZero(itAccess); + caseMap.insert({it, accessCase}); + } + } + + // This will lower the body for each case to actually compute. Therefore, we don't need to resize assembly arrays + std::vector loweredCases = lowerCasesFromMap(caseMap, coordinate, stmt, caseLattice, reducedAccesses); + + append(stmts, loweredCases); + Stmt body = Block::make(stmts); + + return Block::make(declInserterPosVars, declLocatorPosVars, body); + } + + Stmt initVals = resizeAndInitValues(appenders, reducedAccesses); + // Code of loop body statement Stmt body = lower(stmt); @@ -1787,6 +2041,12 @@ Expr LowererImpl::lowerAccess(Access access) { } } +Expr LowererImpl::lowerIndexVar(IndexVar var) { + taco_iassert(util::contains(indexVarToExprMap, var)); + taco_iassert(provGraph.isRecoverable(var, definedIndexVars)); + return indexVarToExprMap.at(var); +} + Expr LowererImpl::lowerLiteral(Literal literal) { switch (literal.getDataType().getKind()) { @@ -1849,6 +2109,7 @@ Expr LowererImpl::lowerSub(Sub sub) { Expr LowererImpl::lowerMul(Mul mul) { + const IndexExpr t = mul.getA(); Expr a = lower(mul.getA()); Expr b = lower(mul.getB()); return (mul.getDataType().getKind() == Datatype::Bool) @@ -1880,6 +2141,25 @@ Expr LowererImpl::lowerCallIntrinsic(CallIntrinsic call) { } +Expr LowererImpl::lowerTensorOp(Call op) { + auto definedArgs = op.getDefinedArgs(); + std::vector args; + + if(util::contains(op.getDefs(), definedArgs)) { + auto lowerFunc = op.getDefs().at(definedArgs); + for (auto& argIdx : definedArgs) { + args.push_back(lower(op.getArgs()[argIdx])); + } + return lowerFunc(args); + } + + for(const auto& arg : op.getArgs()) { + args.push_back(lower(arg)); + } + + return op.getFunc()(args); +} + Stmt LowererImpl::lower(IndexStmt stmt) { return visitor->lower(stmt); } @@ -1899,6 +2179,14 @@ bool LowererImpl::generateComputeCode() const { return this->compute; } +Stmt LowererImpl::emitEarlyExit(Expr reductionExpr, std::vector& properties) { + if (loopOrderAllowsShortCircuit && findProperty(properties).defined()) { + Literal annh = findProperty(properties).annihilator(); + Expr isAnnihilator = ir::Eq::make(reductionExpr, lower(annh)); + return IfThenElse::make(isAnnihilator, Block::make(Break::make())); + } + return Stmt(); +} Expr LowererImpl::getTensorVar(TensorVar tensorVar) const { taco_iassert(util::contains(this->tensorVars, tensorVar)) << tensorVar; @@ -2032,6 +2320,7 @@ Stmt LowererImpl::initResultArrays(vector writes, taco_iassert(!iterators.empty()); Expr tensor = getTensorVar(write.getTensorVar()); + Expr fill = GetProperty::make(tensor, TensorProperty::FillValue); Expr valuesArr = GetProperty::make(tensor, TensorProperty::Values); bool clearValuesAllocation = false; @@ -2112,7 +2401,7 @@ Stmt LowererImpl::initResultArrays(vector writes, // iteration of all the iterators is not full. We can check this by seeing if we can recover a // full iterator from our set of iterators. Expr size = generateAssembleCode() ? getCapacityVar(tensor) : parentSize; - result.push_back(zeroInitValues(tensor, 0, size)); + result.push_back(initValues(tensor, fill, 0, size)); } } return result.empty() ? Stmt() : Block::blanks(result); @@ -2208,6 +2497,7 @@ Stmt LowererImpl::initResultArrays(IndexVar var, vector writes, vector result; for (auto& write : writes) { Expr tensor = getTensorVar(write.getTensorVar()); + Expr fill = GetProperty::make(tensor, TensorProperty::FillValue); Expr values = GetProperty::make(tensor, TensorProperty::Values); vector iterators = getIteratorsFrom(var, getIterators(write)); @@ -2267,7 +2557,7 @@ Stmt LowererImpl::initResultArrays(IndexVar var, vector writes, util::contains(reducedAccesses, write)) { // Zero-initialize values array if might not assign to every element // in values array during compute - result.push_back(zeroInitValues(tensor, resultParentPos, stride)); + result.push_back(initValues(tensor, fill, resultParentPos, stride)); } } } @@ -2313,18 +2603,20 @@ Stmt LowererImpl::resizeAndInitValues(const std::vector& appenders, } -Stmt LowererImpl::zeroInitValues(Expr tensor, Expr begin, Expr size) { +Stmt LowererImpl::initValues(Expr tensor, Expr initVal, Expr begin, Expr size) { Expr lower = simplify(ir::Mul::make(begin, size)); Expr upper = simplify(ir::Mul::make(ir::Add::make(begin, 1), size)); Expr p = Var::make("p" + util::toString(tensor), Int()); Expr values = GetProperty::make(tensor, TensorProperty::Values); - Stmt zeroInit = Store::make(values, p, ir::Literal::zero(tensor.type())); + Stmt zeroInit = Store::make(values, p, initVal); LoopKind parallel = (isa(size) && to(size)->getIntValue() < (1 << 10)) ? LoopKind::Serial : LoopKind::Static_Chunked; if (should_use_CUDA_codegen() && util::contains(parallelUnitSizes, ParallelUnit::GPUBlock)) { return ir::VarDecl::make(ir::Var::make("status", Int()), - ir::Call::make("cudaMemset", {values, ir::Literal::make(0, Int()), ir::Mul::make(ir::Sub::make(upper, lower), ir::Literal::make(values.type().getNumBytes()))}, Int())); + ir::Call::make("cudaMemset", {values, ir::Literal::make(0, Int()), + ir::Mul::make(ir::Sub::make(upper, lower), + ir::Literal::make(values.type().getNumBytes()))}, Int())); } return For::make(p, lower, upper, 1, zeroInit, parallel); } @@ -2402,7 +2694,7 @@ Stmt LowererImpl::reduceDuplicateCoordinates(Expr coordinate, // need a separate segend variable. segendVar = iterVar; if (alwaysReduce) { - result.push_back(compoundAssign(segendVar, 1)); + result.push_back(addAssign(segendVar, 1)); } } else { Expr segendInit = alwaysReduce ? ir::Add::make(iterVar, 1) : iterVar; @@ -2412,9 +2704,9 @@ Stmt LowererImpl::reduceDuplicateCoordinates(Expr coordinate, vector dedupStmts; if (reducedVal.defined()) { Expr partialVal = Load::make(tensorVals, segendVar); - dedupStmts.push_back(compoundAssign(reducedVal, partialVal)); + dedupStmts.push_back(addAssign(reducedVal, partialVal)); } - dedupStmts.push_back(compoundAssign(segendVar, 1)); + dedupStmts.push_back(addAssign(segendVar, 1)); Stmt dedupBody = Block::make(dedupStmts); ModeFunction posAccess = iterator.posAccess(segendVar, @@ -2465,7 +2757,7 @@ Stmt LowererImpl::codeToInitializeIteratorVar(Iterator iterator, vector appenders, Expr coord) { } if (generateAssembleCode() || isLastAppender(appender)) { - appendStmts.push_back(compoundAssign(pos, 1)); + appendStmts.push_back(addAssign(pos, 1)); Stmt appendCode = Block::make(appendStmts); if (appenderChild.defined() && appenderChild.hasAppend()) { @@ -2804,7 +3096,7 @@ Expr LowererImpl::searchForStartOfWindowPosition(Iterator iterator, ir::Expr sta // for the beginning of the window. iterator.getWindowLowerBound(), }; - return Call::make("taco_binarySearchAfter", args, Datatype::UInt64); + return ir::Call::make("taco_binarySearchAfter", args, Datatype::UInt64); } Expr LowererImpl::searchForEndOfWindowPosition(Iterator iterator, ir::Expr start, ir::Expr end) { @@ -2817,7 +3109,7 @@ Expr LowererImpl::searchForEndOfWindowPosition(Iterator iterator, ir::Expr start // for the end of the window. iterator.getWindowUpperBound(), }; - return Call::make("taco_binarySearchAfter", args, Datatype::UInt64); + return ir::Call::make("taco_binarySearchAfter", args, Datatype::UInt64); } Stmt LowererImpl::upperBoundGuardForWindowPosition(Iterator iterator, ir::Expr access) { diff --git a/src/lower/merge_lattice.cpp b/src/lower/merge_lattice.cpp index 1032cc822..969bdd6fa 100644 --- a/src/lower/merge_lattice.cpp +++ b/src/lower/merge_lattice.cpp @@ -12,16 +12,19 @@ #include "mode_access.h" #include "taco/util/collections.h" #include "taco/util/strings.h" +#include "taco/index_notation/iteration_algebra.h" #include "taco/util/scopedmap.h" using namespace std; namespace taco { -class MergeLatticeBuilder : public IndexNotationVisitorStrict { +class MergeLatticeBuilder : public IndexNotationVisitorStrict, public IterationAlgebraVisitorStrict { public: - MergeLatticeBuilder(IndexVar i, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, std::map whereTempsToResult = {}) - : i(i), iterators(iterators), provGraph(provGraph), definedIndexVars(definedIndexVars), whereTempsToResult(whereTempsToResult) {} + MergeLatticeBuilder(IndexVar i, Iterators iterators, ProvenanceGraph provGraph, std::set definedIndexVars, + std::map whereTempsToResult = {}) + : i(i), iterators(iterators), provGraph(provGraph), definedIndexVars(definedIndexVars), + whereTempsToResult(whereTempsToResult) {} MergeLattice build(IndexStmt stmt) { stmt.accept(this); @@ -37,6 +40,13 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { return l; } + MergeLattice build(IterationAlgebra alg) { + alg.accept(this); + MergeLattice l = lattice; + lattice = MergeLattice({}); + return l; + } + Iterator getIterator(Access access, IndexVar accessVar) { // must have matching underived ancestor map accessUnderivedAncestorsToLoc; @@ -73,13 +83,114 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { std::set definedIndexVars; map latticesOfTemporaries; std::map whereTempsToResult; + map seenMergePoints; MergeLattice modeIterationLattice() { return MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); } + void visit(const RegionNode* node) { + if(!node->expr().defined()) { + // Region is empty so return empty lattice + lattice = MergeLattice({}); + return; + } + + lattice = build(node->expr()); + } + + void visit(const ComplementNode* node) { + taco_iassert(isa(node->a)) << "Demorgan's rule must be applied before lowering."; + lattice = build(node->a); + + vector points = flipPoints(lattice.points()); + + // Otherwise, all tensors are sparse + points = includeMissingProducerPoints(points); + + // Add dimension point + Iterator dimIter = iterators.modeIterator(i); + points = includeDimensionIterator(points, dimIter); + + bool needsDimPoint = true; + for(const auto& point: points) { + if(point.locators().empty() && point.iterators().size() == 1 && point.iterators()[0] == dimIter) { + needsDimPoint = false; + break; + } + } + + if(needsDimPoint) { + points.push_back(MergePoint({dimIter}, {}, {})); + } + + lattice = MergeLattice(points); + } + + void visit(const IntersectNode* node) { + MergeLattice a = build(node->a); + MergeLattice b = build(node->b); + + if (a.points().size() > 0 && b.points().size() > 0) { + lattice = intersectLattices(a, b); + } else { + // If any side of an intersection is empty, the entire intersection must be empty + lattice = MergeLattice({}); + } + } + + void visit(const UnionNode* node) { + MergeLattice a = build(node->a); + MergeLattice b = build(node->b); + if (a.points().size() > 0 && b.points().size() > 0) { + lattice = unionLattices(a, b); + } + // Scalar operands + else if (a.points().size() > 0) { + lattice = a; + } + else if (b.points().size() > 0) { + lattice = b; + } + } + + void visit(const IndexVarNode* varNode) { + // There are a few cases here... + + // 1) If var in the expression is the same as the var being lowered, we need to return a lattice + // with one point that iterates over the universe of the current dimension. + // Why: TACO needs to know if it needs to generate merge loops to deal with computing the current index var. + // Eg. b(i) + i where b is sparse. To preserve semantics, we need to merge the sparse iteration set of b + // with the implied dense space of i. + // Question: What if the user WANTS i to be 'sparse'? Just define a func where + is an intersection =) + // 2) The vars differ. This case actually has 2 subcases... + // a) The loop variable ('i' in this builder) is derived from the variable used in the expression + // ('var' defined below). In this case, return a mode iterator over the derived var ('i' in the builder) + // so taco can generate the correct merge loops for this level. + // b) The loop variable is not derived from the variable used in the expression. In this case, we just return + // an empty lattice as there is nothing that needs to be merged =) + // TODO: Add these cases to the test suite.... + IndexVar var(varNode); + taco_iassert(provGraph.isUnderived(var)); + if (var == i) { + lattice = MergeLattice({MergePoint({Iterator(var)}, {}, {})}); + } else { + if (provGraph.isDerivedFrom(i, var)) { + lattice = MergeLattice({MergePoint({iterators.modeIterator(i)}, {}, {})}); + } else { + lattice = MergeLattice({}); + } + } + } + void visit(const AccessNode* access) { + // TODO: Case where Access is used in computation but not iteration algebra + if(seenMergePoints.find(access) != seenMergePoints.end()) { + lattice = MergeLattice({seenMergePoints.at(access)}); + return; + } + if (util::contains(latticesOfTemporaries, access->tensorVar)) { // If the accessed tensor variable is a temporary with an associated merge // lattice then we return that lattice. @@ -142,6 +253,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { : MergePoint(pointIterators, {}, {}); lattice = MergeLattice({point}); } + + seenMergePoints.insert({access, lattice.points()[0]}); } void visit(const LiteralNode* node) { @@ -155,63 +268,19 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } void visit(const AddNode* node) { - MergeLattice a = build(node->a); - MergeLattice b = build(node->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = unionLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new UnionNode(Region(node->a), Region(node->b))); } void visit(const SubNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = unionLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new UnionNode(Region(expr->a), Region(expr->b))); } void visit(const MulNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = intersectLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new IntersectNode(Region(expr->a), Region(expr->b))); } void visit(const DivNode* expr) { - MergeLattice a = build(expr->a); - MergeLattice b = build(expr->b); - if (a.points().size() > 0 && b.points().size() > 0) { - lattice = intersectLattices(a, b); - } - // Scalar operands - else if (a.points().size() > 0) { - lattice = a; - } - else if (b.points().size() > 0) { - lattice = b; - } + lattice = build(new IntersectNode(Region(expr->a), Region(expr->b))); } void visit(const SqrtNode* expr) { @@ -222,6 +291,34 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { lattice = build(expr->a); } + void visit(const CallNode* expr) { + taco_iassert(expr->iterAlg.defined()) << "Algebra must be defined" << endl; + lattice = build(expr->iterAlg); + + // Now we need to store regions that should be kept when applying optimizations. + // Can't remove regions described by special regions since the lowerer must emit checks for those in + // all cases. + const auto regionDefs = expr->regionDefinitions; + const vector inputs = expr->args; + set> regionsToKeep; + + for(auto& it : regionDefs) { + vector region = it.first; + set regionToKeep; + for(auto idx : region) { + match(inputs[idx], + function([&](const AccessNode* n) { + set tensorRegion = seenMergePoints.at(n).tensorRegion(); + regionToKeep.insert(tensorRegion.begin(), tensorRegion.end()); + }) + ); + } + regionsToKeep.insert(regionToKeep); + } + + lattice = MergeLattice(lattice.points(), regionsToKeep); + } + void visit(const CallIntrinsicNode* expr) { const auto zeroPreservingArgsSets = expr->func->zeroPreservingArgs(expr->args); @@ -292,9 +389,10 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { vector points; for (auto &point : lattice.points()) { points.push_back(MergePoint(point.iterators(), point.locators(), - vector(resultIterators.begin(), resultIterators.end()))); + vector(resultIterators.begin(), resultIterators.end()), + point.isOmitter())); } - lattice = MergeLattice(points); + lattice = MergeLattice(points, lattice.getTensorRegionsToKeep()); } } @@ -331,6 +429,98 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { taco_not_supported_yet; } + vector + enumerateChildrenPoints(const MergePoint& point, const map, MergePoint>& originalPoints, + set>& seen) { + set pointIters(point.iterators().begin(), point.iterators().end()); + set pointLocs(point.locators().begin(), point.locators().end()); + + set regions = point.tensorRegion(); + set currentRegion = point.tensorRegion(); + + vector result; + for(const auto& tensorIt: regions) { + currentRegion.erase(tensorIt); + + if(util::contains(seen, currentRegion)) { + currentRegion.insert(tensorIt); + continue; + } + + if(util::contains(originalPoints, currentRegion)) { + result.push_back(originalPoints.at(currentRegion)); + } + else if(!currentRegion.empty()){ + MergePoint mp({}, {}, {}); + for(const auto& it: currentRegion) { + mp = unionPoints(mp, seenMergePoints.at(iterators.modeAccess(it).getAccess())); + } + + vector newIters; + vector newLocators = mp.locators(); + for(const auto& it: mp.iterators()) { + if(util::contains(pointLocs, it)) { + newLocators.push_back(it); + } + else { + newIters.push_back(it); + } + } + + result.push_back(MergePoint(newIters, newLocators, point.results())); + } + + seen.insert(currentRegion); + currentRegion.insert(tensorIt); + } + return result; + } + + vector + includeMissingProducerPoints(const vector& points) { + if(points.empty()) return points; + + map, MergePoint> originalPoints; + set> seen; + for(const auto& point: points) { + originalPoints.insert({point.tensorRegion(), point}); + } + + vector frontier = {points[0]}; + vector exactLattice; + + while(!frontier.empty()) { + vector nextFrontier; + for (const auto &frontierPoint: frontier) { + exactLattice.push_back(frontierPoint); + util::append(nextFrontier, enumerateChildrenPoints(frontierPoint, originalPoints, seen)); + } + + frontier = nextFrontier; + } + + return exactLattice; + } + + static vector + includeDimensionIterator(const vector& points, const Iterator& dimIter) { + vector results; + for (auto& point : points) { + vector iterators = point.iterators(); + if (!any(iterators, [](Iterator it){ return it.isDimensionIterator(); })) { + taco_iassert(point.iterators().size() > 0); + results.push_back(MergePoint(combine(iterators, {dimIter}), + point.locators(), + point.results(), + point.isOmitter())); + } + else { + results.push_back(point); + } + } + return results; + } + /** * The intersection of two lattices is the result of merging all the * combinations of merge points from the two lattices. @@ -351,7 +541,26 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } } - return MergeLattice(points); + // Correctness: ensures that points produced on BOTH the left and the + // right lattices are produced in the final intersection. + // Needed since some subPoints may omit leading to erroneous + // omit intersection points. + points = correctPointTypesAfterIntersect(left.points(), right.points(), points); + + // Correctness: Deduplicate regions that are described by multiple lattice + // points and resolves conflicts arising between omitters and + // producers + points = removeDuplicatedTensorRegions(points, true); + + // Optimization: Removed a subLattice of points if the entire subLattice is + // made of only omitters + // points = removeUnnecessaryOmitterPoints(points); + + set> toKeep = left.getTensorRegionsToKeep(); + set> toKeepRight = right.getTensorRegionsToKeep(); + + toKeep.insert(toKeepRight.begin(), toKeepRight.end()); + return MergeLattice(points, toKeep); } /** @@ -375,6 +584,26 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // Append the merge points of b util::append(points, right.points()); + struct pointSort { + bool operator()(const MergePoint& a, const MergePoint& b) { + size_t left_size = a.iterators().size() + a.locators().size(); + size_t right_size = b.iterators().size() + b.locators().size(); + return left_size > right_size; + } + } pointSorter; + + std::sort(points.begin(), points.end(), pointSorter); + + // Correctness: This ensures that points omitted on BOTH the left and the + // right lattices are omitted in the Union. Needed since some + // subpoints may produce leading to erroneous producer regions + points = correctPointTypesAfterUnion(left.points(), right.points(), points); + + + // Correctness: Deduplicate regions that are described by multiple lattice + // points and resolves conflicts arising between omitters and + // producers + points = removeDuplicatedTensorRegions(points, false); // Optimization: insert a dimension iterator if one of the iterators in the // iterate set is not ordered. @@ -384,17 +613,14 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // are subsets of some other iterator. points = moveLocateSubsetIteratorsToLocateSet(points); - // Optimization: remove lattice points that lack any of the full iterators - // of the first point, since when a full iterator exhausts we - // have iterated over the whole space. - points = removePointsThatLackFullIterators(points); + // Optimization: Removes a subLattice of points if the entire subLattice is + // made of only omitters + // points = removeUnnecessaryOmitterPoints(points); + set> toKeep = left.getTensorRegionsToKeep(); + set> toKeepRight = right.getTensorRegionsToKeep(); - // Optimization: remove lattice points whose iterators are identical to the - // iterators of an earlier point, since we have already iterated - // over this sub-space. - points = removePointsWithIdenticalIterators(points); - - return MergeLattice(points); + toKeep.insert(toKeepRight.begin(), toKeepRight.end()); + return MergeLattice(points, toKeep); } /** @@ -428,7 +654,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { vector results = combine(left.results(), right.results()); - return MergePoint(iterators, locators, results); + return MergePoint(iterators, locators, results, left.isOmitter() || right.isOmitter()); } /** @@ -446,7 +672,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { // Remove duplicate iterators. iterators = deduplicateDimensionIterators(iterators); - return MergePoint(iterators, locaters, results); + return MergePoint(iterators, locaters, results, left.isOmitter() && right.isOmitter()); } static bool locateFromLeft(MergeLattice left, MergeLattice right) @@ -480,7 +706,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } static vector - insertDimensionIteratorIfNotOrdered(vector points) + insertDimensionIteratorIfNotOrdered(const vector& points) { vector results; for (auto& point : points) { @@ -491,7 +717,8 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { Iterator dimension(iterators[0].getIndexVar()); results.push_back(MergePoint(combine(iterators, {dimension}), point.locators(), - point.results())); + point.results(), + point.isOmitter())); } else { results.push_back(point); @@ -501,7 +728,7 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { } static vector - moveLocateSubsetIteratorsToLocateSet(vector points) + moveLocateSubsetIteratorsToLocateSet(const vector& points) { vector full = filter(points[0].iterators(), [](Iterator it){ return it.isFull(); }); @@ -524,74 +751,156 @@ class MergeLatticeBuilder : public IndexNotationVisitorStrict { }); result.push_back(MergePoint(iterators, combine(point.locators(), locators), - point.results())); + point.results(), + point.isOmitter())); } return result; } - static vector - removePointsThatLackFullIterators(vector points) + static vector + deduplicateDimensionIterators(const vector& iterators) { - vector result; - vector fullIterators = filter(points[0].iterators(), - [](Iterator it){return it.isFull();}); - for (auto& point : points) { - bool missingFullIterator = false; - for (auto& fullIterator : fullIterators) { - if (!util::contains(point.iterators(), fullIterator)) { - missingFullIterator = true; - break; + vector deduplicates; + + // Remove all but one of the dense iterators, which are all the same. + bool dimensionIteratorFound = false; + for (auto& iterator : iterators) { + if (iterator.isDimensionIterator()) { + if (!dimensionIteratorFound) { + deduplicates.push_back(iterator); + dimensionIteratorFound = true; } } - if (!missingFullIterator) { - result.push_back(point); + else { + deduplicates.push_back(iterator); + } + } + return deduplicates; + } + + static vector + flipPoints(const vector& points) { + vector flippedPoints; + for(const auto& mp: points) { + MergePoint flippedPoint(mp.iterators(), mp.locators(), mp.results(), !mp.isOmitter()); + flippedPoints.push_back(flippedPoint); + } + return flippedPoints; + } + + static set> + getProducerOrOmitterRegions(const std::vector& points, bool getOmitters) { + set> result; + + for(const auto& point: points) { + if(point.isOmitter() == getOmitters) { + set region = point.tensorRegion(); + result.insert(region); } } return result; } static vector - removePointsWithIdenticalIterators(vector points) - { + correctPointTypes(const vector& left, const vector& right, + const vector& points, bool preserveOmit) { vector result; - set> iteratorSets; + set> leftSet = getProducerOrOmitterRegions(left, preserveOmit); + set> rightSet = getProducerOrOmitterRegions(right, preserveOmit); + for (auto& point : points) { - set iteratorSet(point.iterators().begin(), - point.iterators().end()); - if (util::contains(iteratorSets, iteratorSet)) { + set iteratorSet = point.tensorRegion(); + + MergePoint newPoint = point; + if(util::contains(leftSet, iteratorSet) && util::contains(rightSet, iteratorSet)) { + // Both regions produce/omit, so we ensure that this is preserved + newPoint = MergePoint(point.iterators(), point.locators(), point.results(), preserveOmit); + } + result.push_back(newPoint); + } + + return result; + } + + static vector + correctPointTypesAfterIntersect(const vector& left, const vector& right, + const vector& points) { + return correctPointTypes(left, right, points, false); + } + + static vector + correctPointTypesAfterUnion(const vector& left, const vector& right, + const vector& points) { + return correctPointTypes(left, right, points, true); + } + + static vector + removeDuplicatedTensorRegions(const vector& points, bool preserveOmitters) { + + set> producerRegions = getProducerOrOmitterRegions(points, false); + set> omitterRegions = getProducerOrOmitterRegions(points, true); + + vector result; + + set> regionSets; + for (auto& point : points) { + set region = point.tensorRegion(); + + if(util::contains(regionSets, region)) { continue; } - result.push_back(point); - iteratorSets.insert(iteratorSet); + + MergePoint p = point; + if(util::contains(producerRegions, region) && util::contains(omitterRegions, region)) { + // If a region is marked as both produce and omit resolve the ambiguity based on the preserve + // omitters flag. + p = MergePoint(point.iterators(), point.locators(), point.results(), preserveOmitters); + } + + result.push_back(p); + regionSets.insert(region); } + return result; } - static vector - deduplicateDimensionIterators(const vector& iterators) - { - vector deduplicates; + static vector + removeUnnecessaryOmitterPoints(const vector& points) { + vector filteredPoints; - // Remove all but one of the dense iterators, which are all the same. - bool dimensionIteratorFound = false; - for (auto& iterator : iterators) { - if (iterator.isDimensionIterator()) { - if (!dimensionIteratorFound) { - deduplicates.push_back(iterator); - dimensionIteratorFound = true; + MergeLattice l(points); + set> removed; + + for(const auto& point : points) { + + if(util::contains(removed, point.tensorRegion())) { + continue; + } + + MergeLattice subLattice = l.subLattice(point); + + if(util::all(subLattice.points(), [](const MergePoint p) {return p.isOmitter();})) { + for(const auto& p : subLattice.points()) { + removed.insert(p.tensorRegion()); } } - else { - deduplicates.push_back(iterator); + } + + for(const auto& point : points) { + if(!util::contains(removed, point.tensorRegion())) { + filteredPoints.push_back(point); } } - return deduplicates; + + return filteredPoints; } + }; // class MergeLattice -MergeLattice::MergeLattice(vector points) : points_(points) +MergeLattice::MergeLattice(vector points, set> regionsToKeep) : points_(points), + regionsToKeep(regionsToKeep) { } @@ -599,6 +908,7 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr { // Can emit merge lattice once underived ancestor can be recovered IndexVar indexVar = forall.getIndexVar(); + MergeLatticeBuilder builder(indexVar, iterators, provGraph, definedIndexVars, whereTempsToResult); vector underivedAncestors = provGraph.getUnderivedAncestors(indexVar); @@ -609,7 +919,59 @@ MergeLattice MergeLattice::make(Forall forall, Iterators iterators, ProvenanceGr } MergeLattice lattice = builder.build(forall.getStmt()); - return lattice; + + // Can't remove points if lattice contains omitters since we lose merge cases during lowering. + if(lattice.anyModeIteratorIsLeaf() && lattice.needExplicitZeroChecks()) { + return lattice; + } + + // Loop lattice and case lattice are identical so simplify here + return lattice.getLoopLattice(); +} + +std::vector +MergeLattice::removePointsThatLackFullIterators(const std::vector& points) +{ + vector result; + vector fullIterators = filter(points[0].iterators(), + [](Iterator it){return it.isFull();}); + for (auto& point : points) { + bool missingFullIterator = false; + for (auto& fullIterator : fullIterators) { + if (!util::contains(point.iterators(), fullIterator)) { + missingFullIterator = true; + break; + } + } + if (!missingFullIterator) { + result.push_back(point); + } + } + return result; +} + +std::vector +MergeLattice::removePointsWithIdenticalIterators(const std::vector& points) +{ + vector result; + set> producerIteratorSets; + for (auto& point : points) { + set iteratorSet(point.iterators().begin(), + point.iterators().end()); + if (util::contains(producerIteratorSets, iteratorSet)) { + continue; + } + result.push_back(point); + producerIteratorSets.insert(iteratorSet); + } + return result; +} + +bool MergeLattice::needExplicitZeroChecks() { + if(util::any(points(), [](const MergePoint& mp) {return mp.isOmitter();})) { + return true; + } + return !getTensorRegionsToKeep().empty(); } MergeLattice MergeLattice::subLattice(MergePoint lp) const { @@ -639,12 +1001,20 @@ const vector& MergeLattice::iterators() const { return points()[0].iterators(); } +const vector& MergeLattice::locators() const { + // The iterators merged by a lattice are those merged by the first point + taco_iassert(points().size() > 0) << "No merge points in the merge lattice"; + return points()[0].locators(); +} + set MergeLattice::exhausted(MergePoint point) { - set notExhaustedIters(point.iterators().begin(), - point.iterators().end()); + set notExhaustedIters = point.tensorRegion(); set exhausted; - for (auto& iterator : iterators()) { + vector modeIterators = combine(iterators(), locators()); + modeIterators = filter(modeIterators, [](const Iterator& it) {return it.isModeIterator();}); + + for (auto& iterator : modeIterators) { if (!util::contains(notExhaustedIters, iterator)) { exhausted.insert(iterator); } @@ -662,6 +1032,11 @@ bool MergeLattice::exact() const { // A lattice is full if any merge point iterates over only full iterators // or if each sparse iterator is uniquely iterated by some lattice point. set uniquelyMergedIterators; + + if (util::any(points(), [](const MergePoint& m) {return m.isOmitter();})) { + return false; + } + for (auto& point : this->points()) { if (all(point.iterators(), [](Iterator it) {return it.isFull();})) { return true; @@ -682,6 +1057,57 @@ bool MergeLattice::exact() const { return true; } +bool MergeLattice::anyModeIteratorIsLeaf() const { + if(points().empty()) { + return false; + } + vector latticeIters = util::combine(iterators(), locators()); + return util::any(latticeIters, [](const Iterator& it) {return it.isModeIterator() && it.isLeaf();}); +} + +std::vector MergeLattice::retrieveRegionIteratorsToOmit(const MergePoint &point) const { + vector omittedIterators; + set pointRegion = point.tensorRegion(); + set seen; + const size_t levelOfPoint = pointRegion.size(); + + if(point.isOmitter()) { + seen = set(pointRegion.begin(), pointRegion.end()); + omittedIterators = vector(seen.begin(), seen.end()); + } + + // Look at all points above + for(const auto& mp: points()) { + if((mp.tensorRegion().size() > levelOfPoint) && mp.isOmitter()) { + // Grab the omitted tensors + set parentRegion = mp.tensorRegion(); + std::vector parentItersToOmit; + set_difference(parentRegion.begin(), parentRegion.end(), + pointRegion.begin(), pointRegion.end(), + back_inserter(parentItersToOmit)); + + // Add iterators not in present point to the iterators to omit + for(const auto& it : parentItersToOmit) { + if(!util::contains(seen, it)) { + seen.insert(it); + omittedIterators.push_back(it); + } + } + } + } + + return omittedIterators; +} + +set> MergeLattice::getTensorRegionsToKeep() const { + return regionsToKeep; +} + +MergeLattice MergeLattice::getLoopLattice() const { + std::vector p = removePointsThatLackFullIterators(points()); + return removePointsWithIdenticalIterators(p); +} + ostream& operator<<(ostream& os, const MergeLattice& ml) { return os << util::join(ml.points(), ", "); } @@ -710,19 +1136,22 @@ struct MergePoint::Content { std::vector iterators; std::vector locators; std::vector results; + bool omitPoint; }; MergePoint::MergePoint(const vector& iterators, const vector& locators, - const vector& results) : content_(new Content) { + const vector& results, + bool omitPoint) : content_(new Content) { taco_uassert(all(iterators, [](Iterator it){ return it.hasLocate() || it.isOrdered(); })) << "Merge points do not support iterators that do not have locate and " << "that are not ordered."; - content_->iterators = iterators; - content_->locators = locators; - content_->results = results; + content_->iterators = util::removeDuplicates(iterators); + content_->locators = util::removeDuplicates(locators); + content_->results = util::removeDuplicates(results); + content_->omitPoint = omitPoint; } const vector& MergePoint::iterators() const { @@ -790,6 +1219,18 @@ const std::vector& MergePoint::results() const { return content_->results; } +const std::set MergePoint::tensorRegion() const { + std::vector iterators = filter(content_->iterators, + [](Iterator it) {return !it.isDimensionIterator();}); + + append(iterators, content_->locators); + return set(iterators.begin(), iterators.end()); +} + +bool MergePoint::isOmitter() const { + return content_->omitPoint; +} + ostream& operator<<(ostream& os, const MergePoint& mlp) { os << "["; os << util::join(mlp.iterators(), ", "); @@ -801,6 +1242,14 @@ ostream& operator<<(ostream& os, const MergePoint& mlp) { os << "|"; if (mlp.results().size() > 0) os << " "; os << util::join(mlp.results(), ", "); + + os << "|"; + if(mlp.isOmitter()) { + os << " O "; + } else { + os << " P "; + } + os << "]"; return os; } @@ -821,6 +1270,7 @@ bool operator==(const MergePoint& a, const MergePoint& b) { if (!compare(a.iterators(), b.iterators())) return false; if (!compare(a.locators(), b.locators())) return false; if (!compare(a.results(), b.results())) return false; + if ((a.isOmitter() != b.isOmitter())) return false; return true; } diff --git a/src/lower/mode_access.cpp b/src/lower/mode_access.cpp index 8aea151bc..f682e915e 100644 --- a/src/lower/mode_access.cpp +++ b/src/lower/mode_access.cpp @@ -18,10 +18,13 @@ bool operator==(const ModeAccess& a, const ModeAccess& b) { } bool operator<(const ModeAccess& a, const ModeAccess& b) { - if (a.getAccess() == b.getAccess()) { + // First break on the mode position. + if (a.getModePos() != b.getModePos()) { return a.getModePos() < b.getModePos(); } - return a.getAccess() < b.getAccess(); + + // Then, return a deep comparison of the underlying access. + return a.getAccess() & dimensions, const vector& coords, - char* vals, + char* vals, const void* fill, size_t begin, size_t end, const vector& modeTypes, size_t i, std::vector>* indices, @@ -119,16 +122,19 @@ TensorStorage pack(Datatype componentType, const std::vector& dimensions, const Format& format, const std::vector& coordinates, - const void * values) { + const void * values, + const Literal& fill) { + taco_iassert(dimensions.size() == (size_t)format.getOrder()); taco_iassert(coordinates.size() == (size_t)format.getOrder()); taco_iassert(sameSize(coordinates)); taco_iassert(dimensions.size() > 0) << "Scalar packing not supported"; + taco_iassert(fill.getDataType() == componentType) << "Component type must match value type"; size_t order = dimensions.size(); size_t numCoordinates = coordinates[0].size(); - TensorStorage storage(componentType, dimensions, format); + TensorStorage storage(componentType, dimensions, format, fill); // Create vectors to store pointers to indices/index sizes vector> indices; @@ -155,7 +161,8 @@ TensorStorage pack(Datatype componentType, } void* vals = malloc(maxSize * componentType.getNumBytes()); - int actual_size = packTensor(dimensions, coordinates, (char *) values, 0, + const void* fillData = storage.getFillValue().defined()? storage.getFillValue().getValPtr() : nullptr; + int actual_size = packTensor(dimensions, coordinates, (char *) values, fillData, 0, numCoordinates, format.getModeFormats(), 0, &indices, (char *)vals, componentType, 0); vals = realloc(vals, actual_size); diff --git a/src/storage/storage.cpp b/src/storage/storage.cpp index aeb216054..9227420b2 100644 --- a/src/storage/storage.cpp +++ b/src/storage/storage.cpp @@ -10,6 +10,7 @@ #include "taco/storage/index.h" #include "taco/storage/array.h" #include "taco/util/strings.h" +#include "taco/index_notation/index_notation.h" using namespace std; @@ -26,11 +27,14 @@ struct TensorStorage::Content { Index index; Array values; - Content(Datatype componentType, vector dimensions, Format format) + Literal fillValue; + + Content(Datatype componentType, vector dimensions, Format format, Literal fill) : componentType(componentType), dimensions(dimensions), format(format), index(format) { int order = (int)dimensions.size(); + taco_iassert(fill.getDataType() == componentType) << "Fill value must be of same type as data array"; taco_iassert(order <= INT_MAX && componentType.getNumBits() <= INT_MAX); taco_uassert(order == format.getOrder()) << "The number of format mode types (" << format.getOrder() << ") " << @@ -53,9 +57,10 @@ struct TensorStorage::Content { } } + fillValue = fill; tensorData = init_taco_tensor_t(order, componentType.getNumBits(), - dimensionsInt32.data(), modeOrdering.data(), - modeTypes.data()); + dimensionsInt32.data(), modeOrdering.data(), + modeTypes.data(), fill.getValPtr()); } ~Content() { @@ -63,9 +68,9 @@ struct TensorStorage::Content { } }; -TensorStorage::TensorStorage(Datatype componentType, - const vector& dimensions, Format format) - : content(new Content(componentType, dimensions, format)) { +TensorStorage::TensorStorage(Datatype componentType, const vector& dimensions, + Format format, Literal fillVal) + : content(new Content(componentType, dimensions, format, fillVal)) { } const Format& TensorStorage::getFormat() const { @@ -100,6 +105,10 @@ Array TensorStorage::getValues() { return content->values; } +Literal TensorStorage::getFillValue() { + return content->fillValue; +} + size_t TensorStorage::getSizeInBytes() { size_t indexSizeInBytes = 0; const auto& index = getIndex(); @@ -162,6 +171,7 @@ TensorStorage::operator struct taco_tensor_t*() const { } tensorData->vals = (uint8_t*)getValues().getData(); + tensorData->fill_value = (uint8_t*) content->fillValue.getValPtr(); return content->tensorData; } diff --git a/src/taco_tensor_t.cpp b/src/taco_tensor_t.cpp index e37ea5b13..2337be2a9 100644 --- a/src/taco_tensor_t.cpp +++ b/src/taco_tensor_t.cpp @@ -26,9 +26,12 @@ void free_mem(void *ptr) { } } +// Note about fill: +// It is planned to allow the fill of a result to be null and for TACO to set this when it does compute. This is the +// case we currently expect the fill pointer to be null. taco_tensor_t* init_taco_tensor_t(int32_t order, int32_t csize, int32_t* dimensions, int32_t* modeOrdering, - taco_mode_t* mode_types) { + taco_mode_t* mode_types, void* fill_ptr) { taco_tensor_t* t = (taco_tensor_t *) alloc_mem(sizeof(taco_tensor_t)); t->order = order; t->dimensions = (int32_t *) alloc_mem(order * sizeof(int32_t)); @@ -37,6 +40,16 @@ taco_tensor_t* init_taco_tensor_t(int32_t order, int32_t csize, t->indices = (uint8_t ***) alloc_mem(order * sizeof(uint8_t***)); t->csize = csize; + int fill_bytes = csize / 8; + t->fill_value = (uint8_t*) alloc_mem(fill_bytes); + + if (fill_ptr) { + uint8_t* fill_inp = (uint8_t*) fill_ptr; + for (int i = 0; i < fill_bytes; ++i) { + t->fill_value[i] = fill_inp[i]; + } + } + for (int32_t i = 0; i < order; i++) { t->dimensions[i] = dimensions[i]; t->mode_ordering[i] = modeOrdering[i]; diff --git a/src/tensor.cpp b/src/tensor.cpp index ce2e4190d..59b095527 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -57,23 +57,31 @@ TensorBase::TensorBase(Datatype ctype) } TensorBase::TensorBase(std::string name, Datatype ctype) - : TensorBase(name, ctype, {}, Format()) { + : TensorBase(name, ctype, {}, Format(), Literal::zero(ctype)) { } -TensorBase::TensorBase(Datatype ctype, vector dimensions, - ModeFormat modeType) - : TensorBase(util::uniqueName('A'), ctype, dimensions, - std::vector(dimensions.size(), modeType)) { +TensorBase::TensorBase(Datatype ctype, vector dimensions, + ModeFormat modeType, Literal fill) + : TensorBase(util::uniqueName('A'), ctype, dimensions, + std::vector(dimensions.size(), modeType), fill) { } -TensorBase::TensorBase(Datatype ctype, vector dimensions, Format format) - : TensorBase(util::uniqueName('A'), ctype, dimensions, format) { +TensorBase::TensorBase(Datatype ctype, vector dimensions, Format format, Literal fill) + : TensorBase(util::uniqueName('A'), ctype, dimensions, format, fill) { } -TensorBase::TensorBase(std::string name, Datatype ctype, - std::vector dimensions, ModeFormat modeType) - : TensorBase(name, ctype, dimensions, - std::vector(dimensions.size(), modeType)) { +TensorBase::TensorBase(std::string name, Datatype ctype, + std::vector dimensions, ModeFormat modeType, Literal fill) + : TensorBase(name, ctype, dimensions, + std::vector(dimensions.size(), modeType), fill) { +} + +TensorBase::TensorBase(Datatype ctype, std::vector dimensions, Literal fill) + : TensorBase(ctype, dimensions, ModeFormat::compressed, fill) { +} + +TensorBase::TensorBase(std::string name, Datatype ctype, std::vector dimensions, Literal fill) + : TensorBase(name, ctype, dimensions, ModeFormat::compressed, fill) { } static Format initFormat(Format format) { @@ -102,12 +110,19 @@ static Format initFormat(Format format) { } TensorBase::TensorBase(string name, Datatype ctype, vector dimensions, - Format format) - : content(new Content(name, ctype, dimensions, initFormat(format))) { + Format format, Literal fill) { + + // Default fill to zero since undefined. This is done since we need the ctype to initialize the + // fill and we can't use this inside the default arguments. + fill = fill.defined()? fill : Literal::zero(ctype); + content = shared_ptr(new Content(name, ctype, dimensions, initFormat(format), fill)); + taco_uassert((size_t)format.getOrder() == dimensions.size()) << "The number of format mode types (" << format.getOrder() << ") " << "must match the tensor order (" << dimensions.size() << ")."; + taco_uassert(ctype == fill.getDataType()) << "Fill value must be of the same type as the tensor."; + content->allocSize = 1 << 20; vector modeIndices(format.getOrder()); @@ -357,12 +372,15 @@ void TensorBase::pack() { std::vector bufferDim = {1}; std::vector bufferModeOrdering = {0}; std::vector bufferCoords(numCoordinates, 0); + + void* fillPtr = getStorage().getFillValue().defined()? getStorage().getFillValue().getValPtr() : nullptr; taco_tensor_t* bufferStorage = init_taco_tensor_t(1, csize, (int32_t*)bufferDim.data(), (int32_t*)bufferModeOrdering.data(), - (taco_mode_t*)bufferModeType.data()); + (taco_mode_t*)bufferModeType.data(), fillPtr); std::vector pos = {0, (int)numCoordinates}; bufferStorage->indices[0][0] = (uint8_t*)pos.data(); bufferStorage->indices[0][1] = (uint8_t*)bufferCoords.data(); + bufferStorage->vals = (uint8_t*)content->coordinateBuffer->data(); std::vector arguments = {content->storage, bufferStorage}; @@ -423,11 +441,11 @@ void TensorBase::pack() { content->coordinateBuffer->clear(); content->coordinateBufferUsed = 0; - + void* fillPtr = getStorage().getFillValue().defined()? getStorage().getFillValue().getValPtr() : nullptr; std::vector bufferModeTypes(order, taco_mode_sparse); taco_tensor_t* bufferStorage = init_taco_tensor_t(order, csize, (int32_t*)dimensions.data(), (int32_t*)permutation.data(), - (taco_mode_t*)bufferModeTypes.data()); + (taco_mode_t*)bufferModeTypes.data(), fillPtr); std::vector pos = {0, (int)numCoordinates}; bufferStorage->indices[0][0] = (uint8_t*)pos.data(); for (int i = 0; i < order; ++i) { @@ -606,6 +624,7 @@ void TensorBase::compile() { stmt = parallelizeOuterLoop(stmt); compile(stmt, content->assembleWhileCompute); } + void TensorBase::compile(taco::IndexStmt stmt, bool assembleWhileCompute) { if (!needsCompile()) { return; @@ -642,6 +661,11 @@ taco_tensor_t* TensorBase::getTacoTensorT() { return getStorage(); } + +Literal TensorBase::getFillValue() const { + return content->tensorVar.getFill(); +} + void TensorBase::syncValues() { if (content->needsPack) { pack(); @@ -1032,6 +1056,11 @@ bool equals(const TensorBase& a, const TensorBase& b) { return false; } + // Fill values must be the same + if (!equals(a.getFillValue(), b.getFillValue())) { + return false; + } + // Orders must be the same if (a.getOrder() != b.getOrder()) { return false; diff --git a/test/op_factory.h b/test/op_factory.h new file mode 100644 index 000000000..1871528fa --- /dev/null +++ b/test/op_factory.h @@ -0,0 +1,186 @@ +#ifndef TACO_OP_FACTORY_H +#define TACO_OP_FACTORY_H + +#include "taco/index_notation/tensor_operator.h" +#include "taco/index_notation/index_notation.h" +#include "taco/ir/ir.h" + + +namespace taco { + +// Algebras +struct BC_BD_CD { + IterationAlgebra operator()(const std::vector &v) { + IterationAlgebra r1 = Intersect(v[0], v[1]); + IterationAlgebra r2 = Intersect(v[0], v[2]); + IterationAlgebra r3 = Intersect(v[1], v[2]); + + IterationAlgebra omit = Complement(Intersect(Intersect(v[0], v[1]), v[2])); + return Intersect(Union(Union(r1, r2), r3), omit); + } +}; + +struct ComplementUnion { + IterationAlgebra operator()(const std::vector& regions) { + taco_iassert(regions.size() >= 2); + IterationAlgebra unions = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + unions = Union(unions, regions[i]); + } + return unions; + } +}; + +struct IntersectGen { + IterationAlgebra operator()(const std::vector& regions) { + if (regions.size() < 2) { + return IterationAlgebra(); + } + + IterationAlgebra intersections = regions[0]; + for(size_t i = 1; i < regions.size(); ++i) { + intersections = Intersect(intersections, regions[i]); + } + return intersections; + } +}; + +struct ComplementIntersect { + IterationAlgebra operator()(const std::vector& regions) { + if (regions.size() < 2) { + return IterationAlgebra(); + } + + IterationAlgebra intersections = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + intersections = Intersect(intersections, regions[i]); + } + return intersections; + } +}; + +struct IntersectGenDeMorgan { + IterationAlgebra operator()(const std::vector& regions) { + IterationAlgebra unions = Complement(regions[0]); + for(size_t i = 1; i < regions.size(); ++i) { + unions = Union(unions, Complement(regions[i])); + } + return Complement(unions); + } +}; + +struct xorGen { + IterationAlgebra operator()(const std::vector& regions) { + IterationAlgebra noIntersect = Complement(Intersect(regions[0], regions[1])); + return Intersect(noIntersect, Union(regions[0], regions[1])); + } +}; + +struct fullSpaceGen { + IterationAlgebra operator()(const std::vector& regions) { + return Union(Complement(regions[0]), regions[0]); + } +}; + +struct emptyGen { + IterationAlgebra operator()(const std::vector& regions) { + return Intersect(Complement(regions[0]), regions[0]); + } +}; + +struct intersectEdge { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Intersect(Complement(Intersect(r[0], r[1])), Intersect(r[0], r[1])); + } +}; + +struct unionEdge { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Union(Complement(Intersect(r[0], r[1])), Intersect(r[0], r[1])); + } +}; + +struct BfsMaskAlg { + IterationAlgebra operator()(const std::vector& regions) { + std::vector r = regions; + return Intersect(r[0], Complement(r[1])); + } +}; + +// Lowerers +struct MulAdd { + ir::Expr operator()(const std::vector &v) { + return ir::Add::make(ir::Mul::make(v[0], v[1]), v[2]); + } +}; + +struct identityFunc { + ir::Expr operator()(const std::vector &v) { + return v[0]; + } +}; + +struct GeneralAdd { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() >= 2) << "Add operator needs at least two operands"; + ir::Expr add = ir::Add::make(v[0], v[1]); + + for (size_t idx = 2; idx < v.size(); ++idx) { + add = ir::Add::make(add, v[idx]); + } + + return add; + } +}; + +struct MinImpl { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() >= 2) << "Min operator needs at least two operands"; + return ir::Min::make(v[0], v[1]); + } +}; + +struct BfsLower { + ir::Expr operator()(const std::vector &v) { + return v[0]; + } +}; + +struct OrImpl { + ir::Expr operator()(const std::vector &v) { + return ir::Or::make(v[0], v[1]); + } +}; + +struct BitOrImpl { + ir::Expr operator()(const std::vector &v) { + return ir::BitOr::make(v[0], v[1]); + } +}; + +struct AndImpl { + ir::Expr operator()(const std::vector &v) { + return ir::And::make(v[0], v[1]); + } +}; + +// Special definitions +struct MulRegionDef { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() == 2) << "Add operator needs at least two operands"; + return ir::Mul::make(v[0], v[1]); + } +}; + +struct SubRegionDef { + ir::Expr operator()(const std::vector &v) { + taco_iassert(v.size() == 2) << "Sub def needs two operands"; + return ir::Sub::make(v[1], v[0]); + } +}; + + +} +#endif //TACO_OP_FACTORY_H diff --git a/test/test-iteration_algebra.cpp b/test/test-iteration_algebra.cpp new file mode 100644 index 000000000..b95513881 --- /dev/null +++ b/test/test-iteration_algebra.cpp @@ -0,0 +1,225 @@ +#include "test.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/iteration_algebra.h" + +using namespace taco; + +const IndexVar i("i"), j("j"), k("k"); + +Type vec(type(), {3}); +TensorVar v1("v1", vec), v2("v2", vec), v3("v3", vec); + +TEST(iteration_algebra, region) { + Access access = v1(i); + IterationAlgebra alg = access; + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Region region = to(alg); + const RegionNode* node = to(alg.ptr); + ASSERT_FALSE(isa(alg)); + ASSERT_EQ(access, node->expr()); +} + +TEST(iteration_algebra, Complement) { + IndexExpr expr = v1(i) + v2(i); + IterationAlgebra alg = Complement(expr); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Complement complement = to(alg); + const ComplementNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(expr, n->a)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r = to(n->a.ptr); + ASSERT_EQ(expr, r->expr()); +} + +TEST(iteration_algebra, Union) { + IndexExpr exprA = v1(i); + IndexExpr exprB = v2(i); + IterationAlgebra alg = Union(exprA, exprB); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Union u = to(alg); + const UnionNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(exprA, n->a)); + ASSERT_TRUE(algEqual(exprB, n->b)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r1 = to(n->a.ptr); + ASSERT_TRUE(isa(n->b.ptr)); + const RegionNode* r2 = to(n->b.ptr); + + ASSERT_EQ(exprA, r1->expr()); + ASSERT_EQ(exprB, r2->expr()); +} + +TEST(iteration_algebra, Intersect) { + IndexExpr exprA = v1(i); + IndexExpr exprB = v2(i); + IterationAlgebra alg = Intersect(exprA, exprB); + ASSERT_TRUE(isa(alg)); + ASSERT_TRUE(isa(alg.ptr)); + Intersect i = to(alg); + const IntersectNode* n = to(alg.ptr); + ASSERT_FALSE(isa(n)); + ASSERT_TRUE(algEqual(exprA, n->a)); + ASSERT_TRUE(algEqual(exprB, n->b)); + + ASSERT_TRUE(isa(n->a.ptr)); + const RegionNode* r1 = to(n->a.ptr); + ASSERT_TRUE(isa(n->b.ptr)); + const RegionNode* r2 = to(n->b.ptr); + + ASSERT_EQ(exprA, r1->expr()); + ASSERT_EQ(exprB, r2->expr()); +} + +TEST(iteration_algebra, comparatorRegion) { + IterationAlgebra alg1(v1(i)); + IterationAlgebra alg2(v2(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algEqual(alg1, alg1)); + ASSERT_TRUE(algStructureEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorComplement) { + IterationAlgebra alg1 = Complement(v2(i)); + IterationAlgebra alg2 = Complement(v3(j)); +// ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorIntersect) { + IterationAlgebra alg1 = Intersect(v1(i), v2(i)); + IterationAlgebra alg2 = Intersect(v1(j), v3(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorUnion) { + IterationAlgebra alg1 = Union(v1(i), v2(i)); + IterationAlgebra alg2 = Union(v1(j), v3(j)); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, comparatorMix) { + IterationAlgebra alg1 = Union(Intersect(v1(i), v2(i)), Complement(v3(i))); + IterationAlgebra alg2 = Union(Intersect(v1(j), v2(j)), Complement(v3(j))); + ASSERT_TRUE(algStructureEqual(alg1, alg2)); + ASSERT_FALSE(algEqual(alg1, alg2)); + + ASSERT_TRUE(algStructureEqual(alg1, alg1)); + ASSERT_TRUE(algEqual(alg1, alg1)); +} + +TEST(iteration_algebra, deMorganRegion) { + IterationAlgebra alg(v1(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, deMorganComplement) { + IterationAlgebra alg = Complement(v1(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, DeMorganNestedComplements) { + IterationAlgebra alg = v1(i); + for(int cnt = 0; cnt < 10; ++cnt) { + if(cnt % 2 == 0) { + IterationAlgebra simplified = applyDemorgan(alg); + IterationAlgebra expectedEven = v1(i); + ASSERT_TRUE(algEqual(expectedEven, simplified)); + } + else { + IterationAlgebra simplified = applyDemorgan(alg); + IterationAlgebra expectedOdd = Complement(v1(i)); + ASSERT_TRUE(algEqual(expectedOdd, simplified)); + } + alg = Complement(alg); + } +} + +TEST(iteration_algebra, deMorganIntersect) { + IterationAlgebra alg = Intersect(v1(i), v2(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, deMorganUnion) { + IterationAlgebra alg = Union(v1(i), v2(i)); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, UnionComplement) { + IterationAlgebra alg = Union(v1(i), Complement(v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_TRUE(algEqual(alg, simplified)); + ASSERT_TRUE(algStructureEqual(alg, simplified)); +} + +TEST(iteration_algebra, flipUnionToIntersect) { + IterationAlgebra alg = Complement(Union(v1(i), v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_FALSE(algEqual(alg, simplified)); + ASSERT_FALSE(algStructureEqual(alg, simplified)); + + IterationAlgebra expected = Intersect(Complement(v1(i)), Complement(v2(i))); + ASSERT_TRUE(algEqual(simplified, expected)); + ASSERT_TRUE(algStructureEqual(simplified, expected)); +} + +TEST(iteration_algebra, flipIntersectToUnion) { + IterationAlgebra alg = Complement(Intersect(v1(i), v2(i))); + IterationAlgebra simplified = applyDemorgan(alg); + + ASSERT_FALSE(algEqual(alg, simplified)); + ASSERT_FALSE(algStructureEqual(alg, simplified)); + + IterationAlgebra expected = Union(Complement(v1(i)), Complement(v2(i))); + ASSERT_TRUE(algEqual(simplified, expected)); + ASSERT_TRUE(algStructureEqual(simplified, expected)); +} + +TEST(iteration_algebra, hiddenIntersect) { + IterationAlgebra alg = Complement(Union(Complement(v1(i)), Complement(v2(i)))); + IterationAlgebra simplified = applyDemorgan(alg); + + IterationAlgebra expected = Intersect(v1(i), v2(i)); + ASSERT_TRUE(algEqual(expected, simplified)); +} + +TEST(iteration_algebra, hiddenUnion) { + IterationAlgebra alg = Complement(Intersect(Complement(v1(i)), Complement(v2(i)))); + IterationAlgebra simplified = applyDemorgan(alg); + + IterationAlgebra expected = Union(v1(i), v2(i)); + ASSERT_TRUE(algEqual(expected, simplified)); +} \ No newline at end of file diff --git a/test/test_properties.cpp b/test/test_properties.cpp new file mode 100644 index 000000000..8e17af052 --- /dev/null +++ b/test/test_properties.cpp @@ -0,0 +1,108 @@ +#include "test.h" +#include "taco/index_notation/index_notation.h" +#include "taco/index_notation/properties.h" + +using namespace taco; + +TEST(properties, annihilator) { + Literal z(0); + Annihilator a(nullptr); + Annihilator zero(z); + ASSERT_FALSE(a.defined()); + ASSERT_TRUE(zero.defined()); + + ASSERT_EQ(zero.annihilator(), z); + ASSERT_TRUE(equals(zero.annihilator(), z)); + + ASSERT_TRUE(zero.equals(Annihilator(Literal(0)))); + ASSERT_FALSE(zero.equals(a)); +} + +TEST(properties, identity) { + Literal z(0); + Identity a(nullptr); + Identity zero(z); + ASSERT_FALSE(a.defined()); + ASSERT_TRUE(zero.defined()); + + ASSERT_EQ(zero.identity(), z); + ASSERT_TRUE(equals(zero.identity(), Literal(0))); + + ASSERT_TRUE(zero.equals(Identity(Literal(0)))); + ASSERT_FALSE(zero.equals(a)); +} + +TEST(properties, associative) { + Associative a; + Associative undef(nullptr); + + ASSERT_TRUE(a.equals(a)); + ASSERT_FALSE(a.equals(undef)); + ASSERT_TRUE(a.defined()); + ASSERT_FALSE(undef.defined()); +} + +TEST(properties, commutative) { + Commutative com; + Commutative specific({0, 1}); + Commutative specific2({1, 2}); + Commutative undef(nullptr); + + ASSERT_TRUE(specific.defined()); + ASSERT_TRUE(com.defined()); + ASSERT_FALSE(undef.defined()); + + ASSERT_NE(specific.ordering(), specific2.ordering()); + ASSERT_EQ(specific.ordering(), std::vector({0, 1})); + ASSERT_TRUE(specific.equals(specific)); +} + +TEST(properties, property_conversion) { + Property annh = Annihilator(10); + Property identity = Identity(40); + Property assoc = Associative(); + Property com = Commutative({0,1,2}); + + ASSERT_TRUE(isa(annh)); + ASSERT_FALSE(isa(annh)); + Annihilator a = to(annh); + ASSERT_TRUE(equals(a.annihilator(), Literal(10))); + + ASSERT_TRUE(isa(identity)); + ASSERT_FALSE(isa(identity)); + Identity idnty = to(identity); + ASSERT_TRUE(equals(idnty.identity(), Literal(40))); + + ASSERT_TRUE(isa(assoc)); + ASSERT_FALSE(isa(assoc)); + Associative assc = to(assoc); + ASSERT_TRUE(assc.defined()); + + ASSERT_TRUE(isa(com)); + ASSERT_FALSE(isa(com)); + Commutative comm = to(com); + ASSERT_EQ(comm.ordering(), std::vector({0,1,2})); +} + +TEST(properties, findProperty) { + Annihilator a(10); + Identity i(10); + Associative as; + Commutative c({0, 1}); + + std::vector properties({a, i, as, c}); + ASSERT_TRUE(a.equals(findProperty(properties))); + ASSERT_TRUE(i.equals(findProperty(properties))); + ASSERT_TRUE(as.equals(findProperty(properties))); + ASSERT_TRUE(c.equals(findProperty(properties))); + + std::vector partialProperties({a, c}); + ASSERT_FALSE(i.equals(findProperty(partialProperties))); + ASSERT_FALSE(as.equals(findProperty(partialProperties))); + + ASSERT_FALSE(findProperty(partialProperties).defined()); + ASSERT_FALSE(findProperty(partialProperties).defined()); + + ASSERT_TRUE(properties[0].equals(Annihilator(10))); + ASSERT_FALSE(properties[0].equals(i)); +} \ No newline at end of file diff --git a/test/tests-api.cpp b/test/tests-api.cpp index da4f7082d..d1afee89c 100644 --- a/test/tests-api.cpp +++ b/test/tests-api.cpp @@ -121,6 +121,8 @@ TEST_P(apiget, api) { ASSERT_ARRAY_EQ(GetParam().getExpectedValues(), {(double*)storage.getValues().getData(), storage.getIndex().getSize()}); + + ASSERT_TRUE(equals(tensor.getFillValue(), Literal::zero(tensor.getComponentType()))); } TEST_P(apiwrb, api) { diff --git a/test/tests-expr.cpp b/test/tests-expr.cpp index 9877e7c5c..537615800 100644 --- a/test/tests-expr.cpp +++ b/test/tests-expr.cpp @@ -160,3 +160,45 @@ TEST(expr, redefine) { a.evaluate(); ASSERT_EQ(a.begin()->second, 42.0); } + +TEST(expr, indexVarSimple) { + Tensor a("actual", {3,3}, Dense); + a(i, j) = i + j; + a.evaluate(); + + Tensor expected("expected", a.getDimensions(), Dense); + + for(int i = 0; i < a.getDimensions()[0]; ++i) { + for(int j = 0; j < a.getDimensions()[1]; ++j) { + expected.insert({i, j}, i+j); + } + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} + +TEST(expr, indexVarMix) { + Tensor a("actual", {3, 3}, dense); + Tensor b("input", {3, 3}, compressed); + + Tensor expected("expected", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + a.evaluate(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} \ No newline at end of file diff --git a/test/tests-index_notation.cpp b/test/tests-index_notation.cpp index 9b08b86c8..d49af7dd3 100644 --- a/test/tests-index_notation.cpp +++ b/test/tests-index_notation.cpp @@ -1,5 +1,7 @@ #include "test.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/index_notation.h" +#include "op_factory.h" using namespace taco; @@ -252,3 +254,16 @@ INSTANTIATE_TEST_CASE_P(separate_reductions, concrete, forall(k, tk += c(k))), forall(j, tj += b(j)))))); + + + +Func scOr("Or", OrImpl(), {Annihilator((bool)1), Identity((bool)0)}); +Func scAnd("And", AndImpl(), {Annihilator((bool)0), Identity((bool)0)}); + +Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); +INSTANTIATE_TEST_CASE_P(tensorOpConcrete, concrete, + Values(ConcreteTest(a(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(B(i, j), c(j)), c(i))), + forall(i, + forall(j, + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) + ))))); \ No newline at end of file diff --git a/test/tests-indexexpr.cpp b/test/tests-indexexpr.cpp index 72d28c5b4..cdca2b2eb 100644 --- a/test/tests-indexexpr.cpp +++ b/test/tests-indexexpr.cpp @@ -73,3 +73,12 @@ TEST(indexexpr, div) { ASSERT_TRUE(equals(div.getA(), b(i))); ASSERT_TRUE(equals(div.getB(), Literal(2))); } + +TEST(indexexpr, indexvar) { + IndexExpr expr = i; + ASSERT_TRUE(isa(expr)); + ASSERT_TRUE(isa(expr.ptr)); + IndexVar var = to(expr); + ASSERT_EQ(type(), var.getDataType()); + ASSERT_EQ("i", var.getName()); +} diff --git a/test/tests-lower.cpp b/test/tests-lower.cpp index e876b7f0c..19e5fec8d 100644 --- a/test/tests-lower.cpp +++ b/test/tests-lower.cpp @@ -8,6 +8,7 @@ #include "taco/index_notation/index_notation.h" #include "taco/index_notation/index_notation_rewriter.h" #include "taco/index_notation/index_notation_nodes.h" +#include "taco/index_notation/tensor_operator.h" #include "taco/index_notation/kernel.h" #include "taco/codegen/module.h" #include "taco/storage/storage.h" @@ -16,6 +17,8 @@ #include "taco/format.h" #include "taco/util/strings.h" +#include "op_factory.h" + namespace taco { namespace test { @@ -39,6 +42,8 @@ static TensorVar b("b", vectype, Format()); static TensorVar c("c", vectype, Format()); static TensorVar d("d", vectype, Format()); +static TensorVar fill_10("fillA", vectype, Format(), Literal((double) 10)); + static TensorVar w("w", vectype, dense); static TensorVar A("A", mattype, Format()); @@ -80,7 +85,7 @@ struct TestCase { TensorStorage getResult(TensorVar var, Format format) const { auto dimensions = getDimensions(var); - TensorStorage storage(type(), dimensions, format); + TensorStorage storage(type(), dimensions, format, var.getFill()); // TODO: Get rid of this and lower to use dimensions instead vector modeIndices(format.getOrder()); @@ -97,11 +102,12 @@ struct TestCase { static TensorStorage pack(Format format, const vector& dims, - const vector,double>>& components){ + const vector,double>>& components, + Literal fill){ size_t order = dims.size(); size_t num = components.size(); if (order == 0) { - TensorStorage storage = TensorStorage(type(), {}, format); + TensorStorage storage = TensorStorage(type(), {}, format, fill); Array array = makeArray(type(), 1); *((double*)array.getData()) = components[0].second; storage.setValues(array); @@ -115,23 +121,24 @@ struct TestCase { vector values(num); for (size_t i=0; i < components.size(); ++i) { auto& coordinates = components[i].first; + std::vector ordering = format.getModeOrdering(); for (size_t j=0; j < coordinates.size(); ++j) { coords[j][i] = coordinates[j]; } values[i] = components[i].second; } - return taco::pack(type(), dims, format, coords, values.data()); + return taco::pack(type(), dims, format, coords, values.data(), fill); } } TensorStorage getArgument(TensorVar var, Format format) const { taco_iassert(contains(inputs, var)) << var; - return pack(format, getDimensions(var), inputs.at(var)); + return pack(format, getDimensions(var), inputs.at(var), var.getFill()); } TensorStorage getExpected(TensorVar var, Format format) const { taco_iassert(contains(expected, var)) << var; - return pack(format, getDimensions(var), expected.at(var)); + return pack(format, getDimensions(var), expected.at(var), var.getFill()); } }; @@ -199,7 +206,7 @@ map formatVars(const std::vector& vars, // Default format is dense in all dimensions format = Format(vector(var.getOrder(), dense)); } - formatted.insert({var, TensorVar(var.getName(), var.getType(), format)}); + formatted.insert({var, TensorVar(var.getName(), var.getType(), format, var.getFill())}); } return formatted; } @@ -1556,4 +1563,211 @@ TEST_STMT(vector_not, } ) +//TEST_STMT(vector_add_fill, +// forall(i, +// fill_10(i) = b(i) * c(i) +// ), +// Values( +// Formats({{fill_10, dense}, {b,sparse}, {c, sparse}}), +// Formats({{fill_10, sparse}, {b,sparse}, {c, sparse}}) +// ), +// { +// TestCase( +// {{b, {{{0}, 2.0}}}, +// {c, {{{0}, 3.0}, {{1}, 6.0}}}}, +// +// {{fill_10, {{{0}, 5.0}, {{1}, 6.0}}}}) +// } +//) + +// Test tensorOps + +Func testOp("testOp", MulAdd(), BC_BD_CD()); + +TEST_STMT(testOp1, + forall(i, + a(i) = testOp(b(i), c(i), d(i)) + ), + Values( + Formats({{a,sparse}, {b,sparse}, {c,sparse}, {d, sparse}}), + Formats({{a,dense}, {b,dense}, {c,dense}, {d, dense}}), + Formats({{a,dense}, {b,sparse}, {c,dense}, {d, sparse}}), + Formats({{a,dense}, {b,sparse}, {c,dense}, {d, dense}}), + Formats({{a,dense}, {b,sparse}, {c,sparse}, {d, sparse}}) + ), + { + TestCase( + {{b, {{{0}, 2.0}, {{1}, 2.0}, {{4}, 4.0}}}, + {c, {{{0}, 3.0}, {{2}, 3.0}, {{4}, 6.0}}}, + {d, {{{1}, 1.0}, {{2}, 4.0}, {{4}, 5.0}}}}, + + {{a, {{{0}, 6.0}, {{1}, 1.0}, {{2}, 4.0}}}}) + } +) + + +Func specialOp("specialOp", GeneralAdd(), BC_BD_CD(), {{{0, 1}, MulRegionDef()}, {{0, 2}, SubRegionDef()}}); +TEST_STMT(lowerSpecialRegions1, + forall(i, + forall(j, + A(i, j) = specialOp(B(i, j), C(i, j), D(i, j)) + )), + Values( + Formats({{A, Format({dense,dense})}, {B, Format({dense,dense})}, {C, Format({dense,dense})}, + {D, Format({dense,dense})}}), + Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, + {D, Format({dense,sparse})}}), + Formats({{A, Format({sparse,sparse})}, {B, Format({sparse,sparse})}, {C, Format({sparse,sparse})}, + {D, Format({sparse,sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 2.0}, {{1, 1}, 3.0}, {{1, 2}, 2.0}, {{4, 3}, 4.0}}}, + {C, {{{0, 1}, 3.0}, {{2, 1}, 3.0}, {{2, 2}, 4.0}, {{4, 3}, 6.0}}}, + {D, {{{1, 2}, 1.0}, {{2, 1}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 5.0}}}}, + + {{A, {{{0, 1}, 6.0}, {{1, 2}, -1.0}, {{2, 1}, 7.0}}}}) + } +) + +Func compUnion("compUnion", GeneralAdd(), ComplementUnion()); +TEST_STMT(lowerCompUnion, + forall(i, + forall(j, + A(i, j) = compUnion(B(i, j), C(i, j), D(i, j)) + )), + Values( + Formats({{A, Format({dense,dense})}, {B, Format({dense,dense})}, {C, Format({dense,dense})}, + {D, Format({dense,dense})}}), + Formats({{A, Format({dense,sparse})}, {B, Format({dense,sparse})}, {C, Format({dense,sparse})}, + {D, Format({dense,sparse})}}), + Formats({{A, Format({sparse,sparse})}, {B, Format({sparse,sparse})}, {C, Format({sparse,sparse})}, + {D, Format({sparse,sparse})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 2.0}, {{1, 1}, 3.0}, {{1, 2}, 2.0}, {{4, 3}, 4.0}}}, + {C, {{{0, 1}, 3.0}, {{2, 1}, 3.0}, {{2, 2}, 4.0}, {{4, 3}, 6.0}}}, + {D, {{{1, 2}, 1.0}, {{2, 1}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 5.0}}}}, + + {{A, {{{0, 1}, 5.0}, {{1, 2}, 3.0}, {{2, 1}, 7.0}, + {{2, 2}, 4.0}, {{3, 3}, 5.0}, {{4, 3}, 15.0}}}}) + } +) + +Func scOr("Or", OrImpl(), {Annihilator((double)1), Identity(Literal((double)0))}); +Func scAnd("And", AndImpl(), {Annihilator((double)0), Identity((double)1)}); +Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); + +TEST_STMT(BoolRing, + forall(i, + forall(j, + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse})}, {c, Format({dense})}}) + ), + { + TestCase( + {{B, {{{0, 1}, 1.0}, {{1, 1}, 1.0}, {{1, 2}, 1.0}, {{3, 1}, 1.0}, {{4, 3}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}, {{3}, 1.0}}}}) + } +) + +TEST_STMT(BoolRing2, + forall(j, + forall(i, + Assignment(a(i), bfsMaskOp(scAnd(B(i, j), c(j)), c(i)), scOr()) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse}, {1, 0})}, {c, Format({sparse})}}) + ), + { + TestCase( + {{B, {{{1, 0}, 1.0}, {{1, 1}, 1.0}, {{1, 3}, 1.0}, {{2, 1}, 1.0}, {{3, 4}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}, {{3}, 1.0}}}}) + } +) + +TEST_STMT(BoolRing3, + forall(j, + forall(i, + Assignment(a(i), scAnd(B(i, j), c(j)), scOr()) + )), + Values( + Formats({{a, Format({dense})}, {B, Format({dense,sparse}, {1, 0})}, {c, Format({sparse})}}) + ), + { + TestCase( + {{B, {{{1, 0}, 1.0}, {{1, 1}, 1.0}, {{1, 3}, 1.0}, {{2, 1}, 1.0}, {{3, 4}, 1.0}}}, + {c, {{{1}, 1.0}}}}, + + {{a, {{{0}, 1.0}, {{1}, 1.0}, {{3}, 1.0}}}}) + } +) + +Func customMin("Min", MinImpl(), {Identity(std::numeric_limits::infinity()) }); +Func Plus("Plus", GeneralAdd(), {Annihilator(std::numeric_limits::infinity())}); + +static TensorVar a_inf("a", vectype, Format(), std::numeric_limits::infinity()); +static TensorVar c_inf("c", vectype, Format(), std::numeric_limits::infinity()); +static TensorVar B_inf("B", mattype, Format(), std::numeric_limits::infinity()); + +TEST_STMT(MinPlusRing, + forall(i, + forall(j, + Assignment(a_inf(i), Plus(B_inf(i, j), c_inf(j)), customMin()) + )), + Values( + Formats({{a_inf, Format({dense})}, {B_inf, Format({dense,sparse})}, {c_inf, Format({dense})}}) + ), + { + TestCase( + {{B_inf, {{{0, 1}, 3.0}, {{1, 1}, 4.0}, {{1, 2}, 2.0}, {{3,2}, 5.0}, {{4, 3}, 1.0}}}, + {c_inf, {{{1}, 1.0}, {{2}, 6.0}}}}, + + {{a_inf, {{{0}, 4.0}, {{1}, 5.0}, {{3}, 11.0}}}}) + } +) + +Func sparsify("sparsify", identityFunc(), [](const std::vector& v) {return Union(v[0], Complement(v[1]));}); + +TEST_STMT(SparsifyTest, + forall(i, + a(i) = sparsify(b(i), i) + ), + Values( + Formats({{a, Format({sparse})}, {b, Format({dense})} }) + ), + { + TestCase( + {{b, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}}, + + {{a, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}}) + } +) + +Func xorOp("xor", GeneralAdd(), xorGen()); + +TEST_STMT(XorTest, + forall(i, + c(i) = xorOp(a(i), b(i)) + ), + Values( + Formats({{a, Format({sparse})}, {b, Format({sparse})}, {c, Format({sparse})} }) + ), + { + TestCase( + {{a, {{{0}, 3.0}, {{2}, 4.0}, {{4}, 2.0}}}, + {b, {{{1}, 5.0}, {{2}, 4.0}, {{4}, 2.0}}}}, + + {{c, {{{0}, 3.0}, {{1}, 5.0}}}}) + } +) + + }} diff --git a/test/tests-merge_lattice.cpp b/test/tests-merge_lattice.cpp index 346ffcf4a..7fabd2a48 100644 --- a/test/tests-merge_lattice.cpp +++ b/test/tests-merge_lattice.cpp @@ -9,6 +9,8 @@ #include "lower/mode_access.h" #include "taco/ir/ir.h" #include "taco/lower/mode_format_impl.h" +#include "taco/index_notation/tensor_operator.h" +#include "op_factory.h" using namespace std; @@ -275,6 +277,30 @@ INSTANTIATE_TEST_CASE_P(add, merge_lattice, {it(rd)}) }) ), + Test(forall(i, rd = s1 + (s2 + s3)), + MergeLattice({MergePoint({it(s1), it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1)}, + {}, + {it(rd)}), + MergePoint({it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s3)}, + {}, + {it(rd)}) + }) + ), Test(forall(i, rd = d1 + s2), MergeLattice({MergePoint({i, it(s2)}, {it(d1)}, @@ -589,6 +615,462 @@ INSTANTIATE_TEST_CASE_P(hashmap, merge_lattice, ) ); +Func intersectAdd("intersectAdd", GeneralAdd(), IntersectGen()); +Func intersectAddDeMorgan("intersectAddDeMorgan", GeneralAdd(), IntersectGenDeMorgan()); + +INSTANTIATE_TEST_CASE_P(deMorganIntersect, merge_lattice, + Values( + Test(forall(i, rd = intersectAdd(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAdd(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(h1, h2)), + MergeLattice({MergePoint({it(h1)}, + {it(h2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, h1)), + MergeLattice({MergePoint({it(h1)}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = intersectAddDeMorgan(d1, h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}) + }) + ) + + ) +); + +Func complementIntersect("complementIntersect", GeneralAdd(), ComplementIntersect()); + +INSTANTIATE_TEST_CASE_P(complementIntersect, merge_lattice, + Values( + Test(forall(i, rd = complementIntersect(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true), + + MergePoint({it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(s1, d1)), + MergeLattice({MergePoint({it(s1), i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({{it(s1)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, h2)), + MergeLattice({MergePoint({it(h2)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({{it(h2)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({{it(s1)}, + {}, + {it(rd)}}) + }) + ), + Test(forall(i, rd = complementIntersect(s1, s2, s3)), + MergeLattice({MergePoint({it(s1), it(s2), it(s3)}, + {}, + {it(rd)}, + true), + MergePoint({it(s2), it(s3)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, h1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {it(h1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(h1, d1, s1)), + MergeLattice({MergePoint({it(s1)}, + {it(h1), it(d1)}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementIntersect(d1, d2, d3)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2), it(d3)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d2), it(d3)}, + {it(rd)}) + }) + ) + + ) +); + + +Func complementUnion("complementUnion", GeneralAdd(), ComplementUnion()); +INSTANTIATE_TEST_CASE_P(complementUnion, merge_lattice, + Values( + Test(forall(i, rd = complementUnion(s1, s2)), + MergeLattice({MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}), + MergePoint({i}, + {it(d2)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(s1, d1)), + MergeLattice({MergePoint({it(s1), i}, + {it(d1)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(d1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(d1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(h1, h2)), + MergeLattice({MergePoint({i}, + {it(h1), it(h2)}, + {it(rd)}), + MergePoint({i}, + {it(h2)}, + {it(rd)}), + MergePoint({i}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(h1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(h1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}), + MergePoint({i}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = complementUnion(s1, s2, s3)), + MergeLattice({MergePoint({it(s1), i, it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({i, it(s2), it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), i, it(s3)}, + {}, + {it(rd)}), + MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i, it(s3)}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}, + true), + MergePoint({i}, + {}, + {it(rd)}) + }) + + ) + + ) +); + +Func xorOp("xor", GeneralAdd(), xorGen()); +INSTANTIATE_TEST_CASE_P(xorLattice, merge_lattice, + Values(Test(forall(i, rd = xorOp(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true), + MergePoint({it(s1)}, + {}, + {it(rd)}), + MergePoint({it(s2)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(d1, d2)), + MergeLattice({MergePoint({i}, + {it(d1), it(d2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({i}, + {it(d2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, h2)), + MergeLattice({MergePoint({i}, + {it(h1), it(h2)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i}, + {it(h2)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(d1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(d1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(d1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, s1)), + MergeLattice({MergePoint({i, it(s1)}, + {it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i, it(s1)}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = xorOp(h1, d1)), + MergeLattice({MergePoint({i}, + {it(d1), it(h1)}, + {it(rd)}, + true), + MergePoint({i}, + {it(h1)}, + {it(rd)}), + MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ) + ) +); + +Func identity("identity", identityFunc(), fullSpaceGen()); +INSTANTIATE_TEST_CASE_P(singleCompUnion, merge_lattice, + Values(Test(forall(i, rd = identity(s1)), + MergeLattice({MergePoint({it(s1), i}, + {}, + {it(rd)}), + MergePoint({i}, + {}, + {it(rd)}) + }) + ), + Test(forall(i, rd = identity(d1)), + MergeLattice({MergePoint({i}, + {it(d1)}, + {it(rd)}) + }) + ), + Test(forall(i, rd = identity(h1)), + MergeLattice({MergePoint({i}, + {it(h1)}, + {it(rd)}) + }) + ) + ) +); + +Func emptyIdentity("emptyIdentity", identityFunc(), emptyGen()); +Func intersectEdgeCase("intersectEdgeCase", GeneralAdd(), intersectEdge()); +Func unionEdgeCase("unionEdgeCase", GeneralAdd(), unionEdge()); +INSTANTIATE_TEST_CASE_P(edgeCases, merge_lattice, + Values(Test(forall(i, rd = emptyIdentity(s1)), + MergeLattice({MergePoint({it(s1)}, + {}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = emptyIdentity(d1)), + MergeLattice({MergePoint({i}, + {it(d1)}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = emptyIdentity(h1)), + MergeLattice({MergePoint({it(h1)}, + {it(h1)}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = intersectEdgeCase(s1, s2)), + MergeLattice({MergePoint({it(s1), it(s2)}, + {}, + {it(rd)}, + true) + }) + ), + Test(forall(i, rd = unionEdgeCase(s1, s2)), + MergeLattice({MergePoint({it(s1), i, it(s2)}, + {}, + {it(rd)}), + MergePoint({it(s1), i}, + {}, + {it(rd)}), + MergePoint({i, it(s2)}, + {}, + {it(rd)}), + MergePoint({i}, + {}, + {it(rd)}) + }) + ) + + ) +); + + + + IndexVar i1, i2; TEST(merge_lattice, split) { diff --git a/test/tests-scheduling-eval.cpp b/test/tests-scheduling-eval.cpp index eeaf720ca..0c9a15531 100644 --- a/test/tests-scheduling-eval.cpp +++ b/test/tests-scheduling-eval.cpp @@ -8,6 +8,7 @@ #include "taco/index_notation/index_notation.h" #include "codegen/codegen.h" #include "taco/lower/lower.h" +#include "op_factory.h" using namespace taco; const IndexVar i("i"), j("j"), k("k"), l("l"), m("m"), n("n"); @@ -26,7 +27,7 @@ void printToFile(string filename, IndexStmt stmt) { mkdir(file_path.c_str(), 0777); std::shared_ptr codegen = ir::CodeGen::init_default(source, ir::CodeGen::ImplementationGen); - ir::Stmt compute = lower(stmt, "compute", false, true); + ir::Stmt compute = lower(stmt, "compute", true, true); codegen->compile(compute, true); ofstream source_file; @@ -1172,6 +1173,38 @@ TEST(scheduling_eval, mttkrpGPU) { ASSERT_TENSOR_EQ(expected, A); } +TEST(scheduling_eval, indexVarSplit) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, compressed); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + IndexVar j0("j0"), j1("j1"); + stmt = stmt.split(j, j0, j1, 2); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} + TEST(generate_evaluation_files, DISABLED_cpu) { if (should_use_CUDA_codegen()) { return; @@ -1731,3 +1764,75 @@ TEST(generate_figures, DISABLED_cpu) { source_file.close(); } } + +TEST(scheduling_eval, bfsPullScheduled) { + if (should_use_CUDA_codegen()) { + return; + } + constexpr int numVertices = 10; + int NUM_I = numVertices; + int NUM_J = numVertices; + float SPARSITY = .3; + + Tensor A("A", {NUM_I, NUM_J}, CSC); + Tensor x("x", {NUM_J}, {Sparse}); + Tensor m("mask", {NUM_J}, {Dense}); + Tensor y("y", {NUM_I}, {Dense}); + Tensor step("step"); + + uint16_t one = 1; + uint16_t zero = 0; + + Func scOr("Or", OrImpl(), {Annihilator(one), Identity(zero)}); + Func scAnd("And", AndImpl(), {Annihilator(zero), Identity(one)}); + Func bfsMaskOp("bfsMask", BfsLower(), BfsMaskAlg()); + + srand(120); + for (int i = 0; i < NUM_I; i++) { + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + A.insert({i, j}, one); + } + } + } + + for (int j = 0; j < NUM_J; j++) { + float rand_float = (float)rand()/(float)(RAND_MAX); + if (rand_float < SPARSITY) { + x.insert({j}, one); + } else { + x.insert({j}, zero); + } + } + + x.pack(); + A.pack(); + m.pack(); + + y(i) = Reduction(scOr(), j, scAnd(A(i, j), x(j))); + IndexStmt stmt = y.getAssignment().concretize(); + + stmt = stmt.reorder(i,j) + .parallelize(j, taco::ParallelUnit::CPUThread, taco::OutputRaceStrategy::Atomics); + printToFile("bfs_push", stmt); + y.compile(stmt); + y.assemble(); + y.compute(); + + Tensor s("s", {NUM_J}, {Sparse}); + Tensor d("d", {NUM_J}, {dense}); + + Func sparsifyOp("sparsify", identityFunc(), ComplementUnion()); + s(i) = sparsifyOp(d(i), i); + IndexStmt sparsify = s.getAssignment().concretize(); + printToFile("sparsify", sparsify); + + + Tensor expected("expected", {NUM_I}, {Dense}); + expected(i) = Reduction(scOr(), j, bfsMaskOp(scAnd(A(i, j), x(j)), m(i))); + expected.compile(); + expected.assemble(); + expected.compute(); + ASSERT_TENSOR_EQ(expected, y); +} \ No newline at end of file diff --git a/test/tests-scheduling.cpp b/test/tests-scheduling.cpp index 93d5d58bc..56398a86e 100644 --- a/test/tests-scheduling.cpp +++ b/test/tests-scheduling.cpp @@ -858,3 +858,66 @@ TEST(scheduling_eval_test, spmv_fuse) { expected.compute(); ASSERT_TENSOR_EQ(expected, y); } + +TEST(scheduling_eval_test, indexVarSplit) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, compressed); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + IndexVar j0("j0"), j1("j1"); + stmt = stmt.split(j, j0, j1, 2); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} + +TEST(scheduling_eval_test, indexVarReorder) { + + Tensor a("A", {4, 4}, dense); + Tensor b("B", {4, 4}, dense); + + Tensor expected("C", a.getDimensions(), Dense); + const int n = a.getDimensions()[0]; + const int m = a.getDimensions()[1]; + + for(int i = 0; i < n; ++i) { + b.insert({i, i}, 2); + } + b.pack(); + + a(i, j) = b(i, j) * (i * m + j); + IndexStmt stmt = a.getAssignment().concretize(); + stmt = stmt.reorder(i, j); + + a.compile(stmt); + a.assemble(); + a.compute(); + + for(int i = 0; i < n; ++i) { + int flattened_idx = i * m + i; + expected.insert({i, i}, 2 * flattened_idx); + } + expected.pack(); + + ASSERT_TENSOR_EQ(expected, a); +} \ No newline at end of file diff --git a/test/tests-tensor.cpp b/test/tests-tensor.cpp index 1601c805a..c31ea2855 100644 --- a/test/tests-tensor.cpp +++ b/test/tests-tensor.cpp @@ -10,6 +10,12 @@ using namespace taco; +template +void testFill(T fillVal) { + Tensor a({2,2}, fillVal); + ASSERT_TRUE(equals(a.getFillValue(), Literal((T) fillVal))); +} + TEST(tensor, double_scalar) { Tensor a(4.2); ASSERT_DOUBLE_EQ(4.2, a.begin()->second); @@ -82,6 +88,25 @@ TEST(tensor, duplicates_scalar) { ASSERT_TRUE(++val == a.end()); } +TEST(tensor, scalar_type_correct) { + Tensor a; + ASSERT_EQ(a.getComponentType(), Int32); +} + +TEST(tensor, non_zero_fill) { + testFill(3); + testFill(34762); + testFill((1 << 10)); + testFill((1 << 30)); + testFill((1ULL << 42)); + testFill((1 << 10)); + testFill(-1); + testFill((1ULL << 42)); + testFill(std::numeric_limits::min()); + testFill(std::numeric_limits::max()); + testFill(std::numeric_limits::infinity()); +} + TEST(tensor, transpose) { TensorData testData = TensorData({5, 3, 2}, { {{0,0,0}, 0.0},