Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalizes TACO Merge lattice construction #390

Merged
merged 34 commits into from
Feb 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
535d777
IndexVars can be used in computation. Need to add more tests and allo…
rawnhenry Jan 13, 2020
f9c676d
Merge branch 'scheduling-language' of https://github.com/tensor-compi…
rawnhenry Jan 13, 2020
3c4560f
Added tests
rawnhenry Jan 28, 2020
42e7f99
Moved output of indexVar closer to class definition. Added test for C…
rawnhenry Feb 4, 2020
3cd7944
Added printer for iteration algebra and refactored some code
rawnhenry Feb 4, 2020
e459616
Added include to iteration_algebra_printer.h
rawnhenry Feb 4, 2020
a65a362
Adds basic functionality to MergeLattices and lowerer for general cod…
rawnhenry Feb 21, 2020
9788524
Started adding front end for defining new operators
rawnhenry Feb 28, 2020
9144c88
Rework user facing API
rawnhenry Feb 29, 2020
4f82ccb
Added TensorOpNode and made tests for new iteration algebra functions
rawnhenry Mar 5, 2020
2316c38
Reformat some code and fixed some bugs in lowering. Added one test fo…
rawnhenry Mar 5, 2020
c75b227
Added some missing files to git
rawnhenry Mar 5, 2020
67a4e6b
Redesign of property class. Properties are now wrapped for functions …
rawnhenry Mar 7, 2020
70f6c65
Merge and fix conflicts
rawnhenry Mar 7, 2020
6371d49
Bug fixes. Added more tests for new lattice machinery. Fixed issues w…
rawnhenry Mar 22, 2020
895bed9
Added more tests for lattice construction. Moved code that applies la…
rawnhenry Mar 23, 2020
09b5040
Added fill value to tensor. Differentiated between case and loop latt…
rawnhenry Apr 9, 2020
008c385
Moved tensorOp to its own Cpp file. Fixed bug with lowering not emitt…
rawnhenry Apr 12, 2020
68d3431
Allowed code for any reductions. Fixed bugs. Added a trivial fill val…
rawnhenry May 8, 2020
dbcbe5d
Reverted lower test framework packing
rawnhenry May 8, 2020
2a1bd2b
Fixed bug with double assembly. Added test for boolean semi-ring and …
rawnhenry May 9, 2020
d42b1f3
Added test for push
rawnhenry May 10, 2020
2d53d70
Merge from master to fix OMP issue
rawnhenry May 10, 2020
cfde49a
Rename some variables
rawnhenry Jan 31, 2021
5b23d14
Attempt to merge master into this branch. There is a bug in merge lat…
rawnhenry Feb 1, 2021
4036438
Rename opImpl and algImpl
rawnhenry Feb 7, 2021
607c6c1
Fixes bug that occurred during merge. One of the constructors of Tens…
rawnhenry Feb 7, 2021
e3c13f3
Uses alloc_mem instead of malloc to allocate fill value. This caused …
rawnhenry Feb 7, 2021
05d3c93
Fixes lattice construction when scheduled index variables are used in…
rawnhenry Feb 7, 2021
19466ea
Merge branch 'master' into array_algebra
rawnhenry Feb 7, 2021
8dfb79e
Merge branch 'master' of https://github.com/tensor-compiler/taco into…
rawnhenry Feb 17, 2021
a76408a
Make IndexVars inherit from IndexExpr again
rawnhenry Feb 17, 2021
4e27678
lower: fix a bug introduced by merging windowing and array algebra
rohany Feb 18, 2021
b1f7f88
Merge pull request #1 from rohany/windowing-array-algebra-fix
rawnhenry Feb 18, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/taco.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "taco/tensor.h"
#include "taco/format.h"
#include "taco/index_notation/tensor_operator.h"
#include "taco/index_notation/index_notation.h"

#endif
88 changes: 78 additions & 10 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include <set>
#include <map>
#include <utility>
#include <functional>

#include "taco/util/name_generator.h"
#include "taco/format.h"
#include "taco/error.h"
#include "taco/util/intrusive_ptr.h"
Expand All @@ -22,6 +24,7 @@
#include "taco/ir_tags.h"
#include "taco/lower/iterator.h"
#include "taco/index_notation/provenance_graph.h"
#include "taco/index_notation/properties.h"

namespace taco {

Expand All @@ -38,6 +41,8 @@ class IndexExpr;
class Assignment;
class Access;

class IterationAlgebra;

struct AccessNode;
struct AccessWindow;
struct LiteralNode;
Expand All @@ -48,8 +53,10 @@ struct SubNode;
struct MulNode;
struct DivNode;
struct CastNode;
struct CallNode;
struct CallIntrinsicNode;
struct ReductionNode;
struct IndexVarNode;

struct AssignmentNode;
struct YieldNode;
Expand Down Expand Up @@ -262,6 +269,11 @@ class Access : public IndexExpr {
Assignment operator+=(const IndexExpr&);

typedef AccessNode Node;

// Equality and comparison are overridden on Access to perform a deep
// comparison of the access rather than a pointer check.
friend bool operator==(const Access& a, const Access& b);
friend bool operator<(const Access& a, const Access &b);
};


Expand Down Expand Up @@ -289,11 +301,14 @@ class Literal : public IndexExpr {
Literal(std::complex<float>);
Literal(std::complex<double>);

static IndexExpr zero(Datatype);
static Literal zero(Datatype);

/// Returns the literal value.
template <typename T> T getVal() const;

/// Returns an untyped pointer to the literal value
void* getValPtr();

typedef LiteralNode Node;
};

Expand Down Expand Up @@ -413,6 +428,26 @@ class Cast : public IndexExpr {
typedef CastNode Node;
};

/// A call to an operator
class Call: public IndexExpr {
public:
Call() = default;
Call(const CallNode*);
Call(const CallNode*, std::string name);

const std::vector<IndexExpr>& getArgs() const;
const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc() const;
const IterationAlgebra& getAlgebra() const;
const std::vector<Property>& getProperties() const;
const std::string getName() const;
const std::map<std::vector<int>, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs() const;
const std::vector<int>& getDefinedArgs() const;

typedef CallNode Node;

private:
std::string name;
};

/// A call to an intrinsic.
/// ```
Expand All @@ -433,6 +468,8 @@ class CallIntrinsic : public IndexExpr {
typedef CallIntrinsicNode Node;
};

std::ostream& operator<<(std::ostream&, const IndexVar&);

/// Create calls to various intrinsics.
IndexExpr mod(IndexExpr, IndexExpr);
IndexExpr abs(IndexExpr);
Expand Down Expand Up @@ -871,17 +908,27 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
class IndexVar : public IndexExpr, public IndexVarInterface {

public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);
IndexVar(const std::string& name, const Datatype& type);
IndexVar(const IndexVarNode *);

/// Returns the name of the index variable.
std::string getName() const;

// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);
friend bool operator!=(const IndexVar&, const IndexVar&);
friend bool operator>=(const IndexVar&, const IndexVar&);
friend bool operator<=(const IndexVar&, const IndexVar&);
friend bool operator>(const IndexVar&, const IndexVar&);

typedef IndexVarNode Node;

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi);
Expand Down Expand Up @@ -927,11 +974,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
class TensorVar : public util::Comparable<TensorVar> {
public:
TensorVar();
TensorVar(const Type& type);
TensorVar(const std::string& name, const Type& type);
TensorVar(const Type& type, const Format& format);
TensorVar(const std::string& name, const Type& type, const Format& format);
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format);
TensorVar(const Type& type, const Literal& fill = Literal());
TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal());
TensorVar(const Type& type, const Format& format, const Literal& fill = Literal());
TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format,
const Literal& fill = Literal());

/// Returns the ID of the tensor variable.
int getId() const;
Expand All @@ -952,6 +1000,12 @@ class TensorVar : public util::Comparable<TensorVar> {
/// and execute it's expression.
const Schedule& getSchedule() const;

/// Gets the fill value of the tensor variable. May be left undefined.
const Literal& getFill() const;

/// Set the fill value of the tensor variable
void setFill(const Literal& fill);

/// Set the name of the tensor variable.
void setName(std::string name);

Expand Down Expand Up @@ -1008,7 +1062,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);

/// Check whether the statement is in the concrete index notation dialect.
/// This means every index variable has a forall node, there are no reduction
/// This means every index variable has a forall node, each index variable used
/// for computation is under a forall node for that variable, there are no reduction
/// nodes, and that every reduction variable use is nested inside a compound
/// assignment statement. You can optionally pass in a pointer to a string
/// that the reason why it is not concrete notation is printed to.
Expand All @@ -1030,7 +1085,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
/// Returns the input tensors to the index statement, in the order they appear.
std::vector<TensorVar> getArguments(IndexStmt stmt);

/// Returns the temporaries in the index statement, in the order they appear.
/// Returns true iff all of the loops over free variables come before all of the loops over
/// reduction variables. Therefore, this returns true if the reduction controlled by the loops
/// does not a scatter.
bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt);

/// Returns the temporaries in the index statement, in the order they appear.
std::vector<TensorVar> getTemporaries(IndexStmt stmt);

// [Olivia]
Expand Down Expand Up @@ -1070,7 +1130,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
/// zero and then propagating and removing zeroes.
IndexStmt zero(IndexStmt, const std::set<Access>& zeroed);

/// Create an `other` tensor with the given name and format,
/// Infers the fill value of the input expression by applying properties if possible. If unable
/// to successfully infer the fill value of the result, returns the empty IndexExpr
IndexExpr inferFill(IndexExpr);

/// Returns true if there are no forall nodes in the indexStmt. Used to check
/// if the last loop is being lowered.
bool hasNoForAlls(IndexStmt);

/// Create an `other` tensor with the given name and format,
/// and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
/// and otherwise returns other(indexVars) = tensor(indexVars).
IndexStmt generatePackStmt(TensorVar tensor,
Expand Down
84 changes: 83 additions & 1 deletion include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

#include <vector>
#include <memory>
#include <numeric>

#include "taco/type.h"
#include "taco/util/collections.h"
#include "taco/util/comparable.h"
#include "taco/index_notation/index_notation.h"
#include "taco/index_notation/index_notation_nodes_abstract.h"
#include "taco/index_notation/index_notation_visitor.h"
#include "taco/index_notation/intrinsic.h"
#include "taco/util/strings.h"
#include "iteration_algebra.h"
#include "properties.h"

namespace taco {

Expand All @@ -23,6 +28,12 @@ struct AccessWindow {
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
return a.lo == b.lo && a.hi == b.hi;
}
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
if (a.lo != b.lo) {
return a.lo < b.lo;
}
return a.hi < b.hi;
}
};

struct AccessNode : public IndexExprNode {
Expand Down Expand Up @@ -68,7 +79,6 @@ struct LiteralNode : public IndexExprNode {
void* val;
};


struct UnaryExprNode : public IndexExprNode {
IndexExpr a;

Expand Down Expand Up @@ -188,6 +198,57 @@ struct CallIntrinsicNode : public IndexExprNode {
std::vector<IndexExpr> args;
};

struct CallNode : public IndexExprNode {
typedef std::function<ir::Expr(const std::vector<ir::Expr>&)> OpImpl;
typedef std::function<IterationAlgebra(const std::vector<IndexExpr>&)> AlgebraImpl;

CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
const IterationAlgebra& iterAlg,
const std::vector<Property>& properties,
const std::map<std::vector<int>, OpImpl>& regionDefinitions,
const std::vector<int>& definedRegions);

CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
const IterationAlgebra& iterAlg,
const std::vector<Property>& properties,
const std::map<std::vector<int>, OpImpl>& regionDefinitions);

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

std::string name;
std::vector<IndexExpr> args;
OpImpl defaultLowerFunc;
IterationAlgebra iterAlg;
std::vector<Property> properties;
std::map<std::vector<int>, OpImpl> regionDefinitions;

// Needed to track which inputs have been exhausted so the lowerer can know which lower func to use
std::vector<int> definedRegions;

private:
static Datatype inferReturnType(OpImpl f, const std::vector<IndexExpr>& inputs) {
std::function<ir::Expr(IndexExpr)> getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); };
std::vector<ir::Expr> exprs = util::map(inputs, getExprs);

if(exprs.empty()) {
return taco::Datatype();
}

return f(exprs).type();
}

static std::vector<int> definedIndices(std::vector<IndexExpr> args) {
std::vector<int> v;
for(int i = 0; i < (int) args.size(); ++i) {
if(args[i].defined()) {
v.push_back(i);
}
}
return v;
}
};

struct ReductionNode : public IndexExprNode {
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
Expand All @@ -202,6 +263,27 @@ struct ReductionNode : public IndexExprNode {
IndexExpr a;
};

struct IndexVarNode : public IndexExprNode, public util::Comparable<IndexVarNode> {
IndexVarNode() = delete;
IndexVarNode(const std::string& name, const Datatype& type);

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

std::string getName() const;

friend bool operator==(const IndexVarNode& a, const IndexVarNode& b);
friend bool operator<(const IndexVarNode& a, const IndexVarNode& b);

private:
struct Content;
std::shared_ptr<Content> content;
};

struct IndexVarNode::Content {
std::string name;
};

// Index Statements
struct AssignmentNode : public IndexStmtNode {
Expand Down
2 changes: 2 additions & 0 deletions include/taco/index_notation/index_notation_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
void visit(const MulNode*);
void visit(const DivNode*);
void visit(const CastNode*);
void visit(const CallNode*);
void visit(const CallIntrinsicNode*);
void visit(const ReductionNode*);
void visit(const IndexVarNode*);

// Tensor Expressions
void visit(const AssignmentNode*);
Expand Down
4 changes: 4 additions & 0 deletions include/taco/index_notation/index_notation_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict {
virtual void visit(const MulNode* op) = 0;
virtual void visit(const DivNode* op) = 0;
virtual void visit(const CastNode* op) = 0;
virtual void visit(const CallNode* op) = 0;
virtual void visit(const CallIntrinsicNode* op) = 0;
virtual void visit(const ReductionNode* op) = 0;
virtual void visit(const IndexVarNode* op) = 0;
};


Expand Down Expand Up @@ -93,8 +95,10 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict {
virtual void visit(const MulNode* op);
virtual void visit(const DivNode* op);
virtual void visit(const CastNode* op);
virtual void visit(const CallNode* op);
virtual void visit(const CallIntrinsicNode* op);
virtual void visit(const ReductionNode* op);
virtual void visit(const IndexVarNode* op);

virtual void visit(const AssignmentNode* op);
virtual void visit(const YieldNode* op);
Expand Down
Loading