Skip to content

Commit

Permalink
Merge pull request #412 from RawnH/array_algebra
Browse files Browse the repository at this point in the history
Array algebra
  • Loading branch information
stephenchouca authored Feb 18, 2021
2 parents bfdaa71 + b1f7f88 commit 7d84d5d
Show file tree
Hide file tree
Showing 61 changed files with 5,124 additions and 309 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ if(NOT EXISTS "${TACO_PROJECT_DIR}/python_bindings/pybind11/CMakeLists.txt")
endif()

if(PYTHON)
add_subdirectory(python_bindings)
message("-- Will build Python extension")
add_definitions(-DPYTHON)
endif(PYTHON)
Expand Down
1 change: 1 addition & 0 deletions include/taco.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include "taco/tensor.h"
#include "taco/format.h"
#include "taco/index_notation/tensor_operator.h"
#include "taco/index_notation/index_notation.h"

#endif
88 changes: 78 additions & 10 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
#include <set>
#include <map>
#include <utility>
#include <functional>

#include "taco/util/name_generator.h"
#include "taco/format.h"
#include "taco/error.h"
#include "taco/util/intrusive_ptr.h"
Expand All @@ -22,6 +24,7 @@
#include "taco/ir_tags.h"
#include "taco/lower/iterator.h"
#include "taco/index_notation/provenance_graph.h"
#include "taco/index_notation/properties.h"

namespace taco {

Expand All @@ -38,6 +41,8 @@ class IndexExpr;
class Assignment;
class Access;

class IterationAlgebra;

struct AccessNode;
struct AccessWindow;
struct LiteralNode;
Expand All @@ -48,8 +53,10 @@ struct SubNode;
struct MulNode;
struct DivNode;
struct CastNode;
struct CallNode;
struct CallIntrinsicNode;
struct ReductionNode;
struct IndexVarNode;

struct AssignmentNode;
struct YieldNode;
Expand Down Expand Up @@ -262,6 +269,11 @@ class Access : public IndexExpr {
Assignment operator+=(const IndexExpr&);

typedef AccessNode Node;

// Equality and comparison are overridden on Access to perform a deep
// comparison of the access rather than a pointer check.
friend bool operator==(const Access& a, const Access& b);
friend bool operator<(const Access& a, const Access &b);
};


Expand Down Expand Up @@ -289,11 +301,14 @@ class Literal : public IndexExpr {
Literal(std::complex<float>);
Literal(std::complex<double>);

static IndexExpr zero(Datatype);
static Literal zero(Datatype);

/// Returns the literal value.
template <typename T> T getVal() const;

/// Returns an untyped pointer to the literal value
void* getValPtr();

typedef LiteralNode Node;
};

Expand Down Expand Up @@ -413,6 +428,26 @@ class Cast : public IndexExpr {
typedef CastNode Node;
};

/// A call to an operator
class Call: public IndexExpr {
public:
Call() = default;
Call(const CallNode*);
Call(const CallNode*, std::string name);

const std::vector<IndexExpr>& getArgs() const;
const std::function<ir::Expr(const std::vector<ir::Expr>&)> getFunc() const;
const IterationAlgebra& getAlgebra() const;
const std::vector<Property>& getProperties() const;
const std::string getName() const;
const std::map<std::vector<int>, std::function<ir::Expr(const std::vector<ir::Expr>&)>> getDefs() const;
const std::vector<int>& getDefinedArgs() const;

typedef CallNode Node;

private:
std::string name;
};

/// A call to an intrinsic.
/// ```
Expand All @@ -433,6 +468,8 @@ class CallIntrinsic : public IndexExpr {
typedef CallIntrinsicNode Node;
};

std::ostream& operator<<(std::ostream&, const IndexVar&);

/// Create calls to various intrinsics.
IndexExpr mod(IndexExpr, IndexExpr);
IndexExpr abs(IndexExpr);
Expand Down Expand Up @@ -871,17 +908,27 @@ class WindowedIndexVar : public util::Comparable<WindowedIndexVar>, public Index

/// Index variables are used to index into tensors in index expressions, and
/// they represent iteration over the tensor modes they index into.
class IndexVar : public util::Comparable<IndexVar>, public IndexVarInterface {
class IndexVar : public IndexExpr, public IndexVarInterface {

public:
IndexVar();
~IndexVar() = default;
IndexVar(const std::string& name);
IndexVar(const std::string& name, const Datatype& type);
IndexVar(const IndexVarNode *);

/// Returns the name of the index variable.
std::string getName() const;

// Need these to overshadow the comparisons in for the IndexExpr instrusive pointer
friend bool operator==(const IndexVar&, const IndexVar&);
friend bool operator<(const IndexVar&, const IndexVar&);
friend bool operator!=(const IndexVar&, const IndexVar&);
friend bool operator>=(const IndexVar&, const IndexVar&);
friend bool operator<=(const IndexVar&, const IndexVar&);
friend bool operator>(const IndexVar&, const IndexVar&);

typedef IndexVarNode Node;

/// Indexing into an IndexVar returns a window into it.
WindowedIndexVar operator()(int lo, int hi);
Expand Down Expand Up @@ -927,11 +974,12 @@ SuchThat suchthat(IndexStmt stmt, std::vector<IndexVarRel> predicate);
class TensorVar : public util::Comparable<TensorVar> {
public:
TensorVar();
TensorVar(const Type& type);
TensorVar(const std::string& name, const Type& type);
TensorVar(const Type& type, const Format& format);
TensorVar(const std::string& name, const Type& type, const Format& format);
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format);
TensorVar(const Type& type, const Literal& fill = Literal());
TensorVar(const std::string& name, const Type& type, const Literal& fill = Literal());
TensorVar(const Type& type, const Format& format, const Literal& fill = Literal());
TensorVar(const std::string& name, const Type& type, const Format& format, const Literal& fill = Literal());
TensorVar(const int &id, const std::string& name, const Type& type, const Format& format,
const Literal& fill = Literal());

/// Returns the ID of the tensor variable.
int getId() const;
Expand All @@ -952,6 +1000,12 @@ class TensorVar : public util::Comparable<TensorVar> {
/// and execute it's expression.
const Schedule& getSchedule() const;

/// Gets the fill value of the tensor variable. May be left undefined.
const Literal& getFill() const;

/// Set the fill value of the tensor variable
void setFill(const Literal& fill);

/// Set the name of the tensor variable.
void setName(std::string name);

Expand Down Expand Up @@ -1008,7 +1062,8 @@ bool isEinsumNotation(IndexStmt, std::string* reason=nullptr);
bool isReductionNotation(IndexStmt, std::string* reason=nullptr);

/// Check whether the statement is in the concrete index notation dialect.
/// This means every index variable has a forall node, there are no reduction
/// This means every index variable has a forall node, each index variable used
/// for computation is under a forall node for that variable, there are no reduction
/// nodes, and that every reduction variable use is nested inside a compound
/// assignment statement. You can optionally pass in a pointer to a string
/// that the reason why it is not concrete notation is printed to.
Expand All @@ -1030,7 +1085,12 @@ std::vector<TensorVar> getResults(IndexStmt stmt);
/// Returns the input tensors to the index statement, in the order they appear.
std::vector<TensorVar> getArguments(IndexStmt stmt);

/// Returns the temporaries in the index statement, in the order they appear.
/// Returns true iff all of the loops over free variables come before all of the loops over
/// reduction variables. Therefore, this returns true if the reduction controlled by the loops
/// does not a scatter.
bool allForFreeLoopsBeforeAllReductionLoops(IndexStmt stmt);

/// Returns the temporaries in the index statement, in the order they appear.
std::vector<TensorVar> getTemporaries(IndexStmt stmt);

// [Olivia]
Expand Down Expand Up @@ -1070,7 +1130,15 @@ IndexExpr zero(IndexExpr, const std::set<Access>& zeroed);
/// zero and then propagating and removing zeroes.
IndexStmt zero(IndexStmt, const std::set<Access>& zeroed);

/// Create an `other` tensor with the given name and format,
/// Infers the fill value of the input expression by applying properties if possible. If unable
/// to successfully infer the fill value of the result, returns the empty IndexExpr
IndexExpr inferFill(IndexExpr);

/// Returns true if there are no forall nodes in the indexStmt. Used to check
/// if the last loop is being lowered.
bool hasNoForAlls(IndexStmt);

/// Create an `other` tensor with the given name and format,
/// and return tensor(indexVars) = other(indexVars) if otherIsOnRight,
/// and otherwise returns other(indexVars) = tensor(indexVars).
IndexStmt generatePackStmt(TensorVar tensor,
Expand Down
84 changes: 83 additions & 1 deletion include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,18 @@

#include <vector>
#include <memory>
#include <numeric>

#include "taco/type.h"
#include "taco/util/collections.h"
#include "taco/util/comparable.h"
#include "taco/index_notation/index_notation.h"
#include "taco/index_notation/index_notation_nodes_abstract.h"
#include "taco/index_notation/index_notation_visitor.h"
#include "taco/index_notation/intrinsic.h"
#include "taco/util/strings.h"
#include "iteration_algebra.h"
#include "properties.h"

namespace taco {

Expand All @@ -23,6 +28,12 @@ struct AccessWindow {
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
return a.lo == b.lo && a.hi == b.hi;
}
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
if (a.lo != b.lo) {
return a.lo < b.lo;
}
return a.hi < b.hi;
}
};

struct AccessNode : public IndexExprNode {
Expand Down Expand Up @@ -68,7 +79,6 @@ struct LiteralNode : public IndexExprNode {
void* val;
};


struct UnaryExprNode : public IndexExprNode {
IndexExpr a;

Expand Down Expand Up @@ -188,6 +198,57 @@ struct CallIntrinsicNode : public IndexExprNode {
std::vector<IndexExpr> args;
};

struct CallNode : public IndexExprNode {
typedef std::function<ir::Expr(const std::vector<ir::Expr>&)> OpImpl;
typedef std::function<IterationAlgebra(const std::vector<IndexExpr>&)> AlgebraImpl;

CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
const IterationAlgebra& iterAlg,
const std::vector<Property>& properties,
const std::map<std::vector<int>, OpImpl>& regionDefinitions,
const std::vector<int>& definedRegions);

CallNode(std::string name, const std::vector<IndexExpr>& args, OpImpl lowerFunc,
const IterationAlgebra& iterAlg,
const std::vector<Property>& properties,
const std::map<std::vector<int>, OpImpl>& regionDefinitions);

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

std::string name;
std::vector<IndexExpr> args;
OpImpl defaultLowerFunc;
IterationAlgebra iterAlg;
std::vector<Property> properties;
std::map<std::vector<int>, OpImpl> regionDefinitions;

// Needed to track which inputs have been exhausted so the lowerer can know which lower func to use
std::vector<int> definedRegions;

private:
static Datatype inferReturnType(OpImpl f, const std::vector<IndexExpr>& inputs) {
std::function<ir::Expr(IndexExpr)> getExprs = [](IndexExpr arg) { return ir::Var::make("t", arg.getDataType()); };
std::vector<ir::Expr> exprs = util::map(inputs, getExprs);

if(exprs.empty()) {
return taco::Datatype();
}

return f(exprs).type();
}

static std::vector<int> definedIndices(std::vector<IndexExpr> args) {
std::vector<int> v;
for(int i = 0; i < (int) args.size(); ++i) {
if(args[i].defined()) {
v.push_back(i);
}
}
return v;
}
};

struct ReductionNode : public IndexExprNode {
ReductionNode(IndexExpr op, IndexVar var, IndexExpr a);
Expand All @@ -202,6 +263,27 @@ struct ReductionNode : public IndexExprNode {
IndexExpr a;
};

struct IndexVarNode : public IndexExprNode, public util::Comparable<IndexVarNode> {
IndexVarNode() = delete;
IndexVarNode(const std::string& name, const Datatype& type);

void accept(IndexExprVisitorStrict* v) const {
v->visit(this);
}

std::string getName() const;

friend bool operator==(const IndexVarNode& a, const IndexVarNode& b);
friend bool operator<(const IndexVarNode& a, const IndexVarNode& b);

private:
struct Content;
std::shared_ptr<Content> content;
};

struct IndexVarNode::Content {
std::string name;
};

// Index Statements
struct AssignmentNode : public IndexStmtNode {
Expand Down
2 changes: 2 additions & 0 deletions include/taco/index_notation/index_notation_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ class IndexNotationPrinter : public IndexNotationVisitorStrict {
void visit(const MulNode*);
void visit(const DivNode*);
void visit(const CastNode*);
void visit(const CallNode*);
void visit(const CallIntrinsicNode*);
void visit(const ReductionNode*);
void visit(const IndexVarNode*);

// Tensor Expressions
void visit(const AssignmentNode*);
Expand Down
4 changes: 4 additions & 0 deletions include/taco/index_notation/index_notation_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,10 @@ class IndexExprRewriterStrict : public IndexExprVisitorStrict {
virtual void visit(const MulNode* op) = 0;
virtual void visit(const DivNode* op) = 0;
virtual void visit(const CastNode* op) = 0;
virtual void visit(const CallNode* op) = 0;
virtual void visit(const CallIntrinsicNode* op) = 0;
virtual void visit(const ReductionNode* op) = 0;
virtual void visit(const IndexVarNode* op) = 0;
};


Expand Down Expand Up @@ -93,8 +95,10 @@ class IndexNotationRewriter : public IndexNotationRewriterStrict {
virtual void visit(const MulNode* op);
virtual void visit(const DivNode* op);
virtual void visit(const CastNode* op);
virtual void visit(const CallNode* op);
virtual void visit(const CallIntrinsicNode* op);
virtual void visit(const ReductionNode* op);
virtual void visit(const IndexVarNode* op);

virtual void visit(const AssignmentNode* op);
virtual void visit(const YieldNode* op);
Expand Down
Loading

0 comments on commit 7d84d5d

Please sign in to comment.