Skip to content

Commit

Permalink
Adds tempory insertion for SpMM
Browse files Browse the repository at this point in the history
Not a permanent solution as it's based on fragile pattern matchin. Better solution is described in comment for future work.
  • Loading branch information
fredrikbk committed May 29, 2019
1 parent 9c13afc commit b7aecb6
Show file tree
Hide file tree
Showing 3 changed files with 190 additions and 31 deletions.
6 changes: 6 additions & 0 deletions include/taco/index_notation/transformations.h
Original file line number Diff line number Diff line change
Expand Up @@ -140,5 +140,11 @@ IndexStmt parallelizeOuterLoop(IndexStmt stmt);
*/
IndexStmt reorderLoopsTopologically(IndexStmt stmt);

/**
* Insert where statements with temporaries into the following statements kinds:
* 1. The result is a is scattered into but does not support random insert.
*/
IndexStmt insertTemporaries(IndexStmt stmt);

}
#endif
138 changes: 134 additions & 4 deletions src/index_notation/transformations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "taco/index_notation/index_notation_rewriter.h"
#include "taco/index_notation/index_notation_nodes.h"
#include "taco/error/error_messages.h"
#include "taco/util/collections.h"

#include <iostream>
#include <taco/lower/iterator.h>
Expand All @@ -14,23 +15,28 @@ using namespace std;

namespace taco {


// class Transformation
Transformation::Transformation(Reorder reorder)
: transformation(new Reorder(reorder)) {
}


Transformation::Transformation(Precompute precompute)
: transformation(new Precompute(precompute)) {
}


Transformation::Transformation(Parallelize parallelize)
: transformation(new Parallelize(parallelize)) {
}


IndexStmt Transformation::apply(IndexStmt stmt, string* reason) const {
return transformation->apply(stmt, reason);
}


std::ostream& operator<<(std::ostream& os, const Transformation& t) {
t.transformation->print(os);
return os;
Expand All @@ -43,19 +49,23 @@ struct Reorder::Content {
IndexVar j;
};


Reorder::Reorder(IndexVar i, IndexVar j) : content(new Content) {
content->i = i;
content->j = j;
}


IndexVar Reorder::geti() const {
return content->i;
}


IndexVar Reorder::getj() const {
return content->j;
}


IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
INIT_REASON(reason);

Expand Down Expand Up @@ -111,10 +121,12 @@ IndexStmt Reorder::apply(IndexStmt stmt, string* reason) const {
return ReorderRewriter(*this, reason).reorder(stmt);
}


void Reorder::print(std::ostream& os) const {
os << "reorder(" << geti() << ", " << getj() << ")";
}


std::ostream& operator<<(std::ostream& os, const Reorder& reorder) {
reorder.print(os);
return os;
Expand All @@ -129,9 +141,11 @@ struct Precompute::Content {
TensorVar workspace;
};


Precompute::Precompute() : content(nullptr) {
}


Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw,
TensorVar workspace) : content(new Content) {
content->expr = expr;
Expand All @@ -140,22 +154,27 @@ Precompute::Precompute(IndexExpr expr, IndexVar i, IndexVar iw,
content->workspace = workspace;
}


IndexExpr Precompute::getExpr() const {
return content->expr;
}


IndexVar Precompute::geti() const {
return content->i;
}


IndexVar Precompute::getiw() const {
return content->iw;
}


TensorVar Precompute::getWorkspace() const {
return content->workspace;
}


static bool containsExpr(Assignment assignment, IndexExpr expr) {
struct ContainsVisitor : public IndexNotationVisitor {
using IndexNotationVisitor::visit;
Expand Down Expand Up @@ -192,6 +211,7 @@ static bool containsExpr(Assignment assignment, IndexExpr expr) {
return visitor.contains;
}


static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) {
Assignment assignment;
match(stmt,
Expand All @@ -205,6 +225,7 @@ static Assignment getAssignmentContainingExpr(IndexStmt stmt, IndexExpr expr) {
return assignment;
}


IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
INIT_REASON(reason);

Expand Down Expand Up @@ -247,15 +268,18 @@ IndexStmt Precompute::apply(IndexStmt stmt, std::string* reason) const {
return rewriter.rewrite(stmt);
}


void Precompute::print(std::ostream& os) const {
os << "precompute(" << getExpr() << ", " << geti() << ", "
<< getiw() << ", " << getWorkspace() << ")";
}


bool Precompute::defined() const {
return content != nullptr;
}


std::ostream& operator<<(std::ostream& os, const Precompute& precompute) {
precompute.print(os);
return os;
Expand All @@ -267,17 +291,21 @@ struct Parallelize::Content {
IndexVar i;
};


Parallelize::Parallelize() : content(nullptr) {
}


Parallelize::Parallelize(IndexVar i) : content(new Content) {
content->i = i;
}


IndexVar Parallelize::geti() const {
return content->i;
}


IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
INIT_REASON(reason);

Expand Down Expand Up @@ -340,10 +368,12 @@ IndexStmt Parallelize::apply(IndexStmt stmt, std::string* reason) const {
return rewritten;
}


void Parallelize::print(std::ostream& os) const {
os << "parallelize(" << geti() << ")";
}


std::ostream& operator<<(std::ostream& os, const Parallelize& parallelize) {
parallelize.print(os);
return os;
Expand Down Expand Up @@ -391,8 +421,10 @@ static vector<pair<IndexVar, bool>> varOrderFromTensorLevels(set<pair<IndexVar,
return varOrder;
}


// Takes in varOrders from many tensors and creates a map of dependencies between IndexVars
static map<IndexVar, set<IndexVar>> depsFromVarOrders(map<string, vector<pair<IndexVar, bool>>> varOrders) {
static map<IndexVar, set<IndexVar>>
depsFromVarOrders(map<string, vector<pair<IndexVar,bool>>> varOrders) {
map<IndexVar, set<IndexVar>> deps;
for (pair<string, vector<pair<IndexVar, bool>>> varOrderPair : varOrders) {
vector<pair<IndexVar, bool>> varOrder = varOrderPair.second;
Expand All @@ -411,9 +443,10 @@ static map<IndexVar, set<IndexVar>> depsFromVarOrders(map<string, vector<pair<In
return deps;
}

static
vector<IndexVar> topologicallySort(map<IndexVar,set<IndexVar>> tensorDeps,
vector<IndexVar> originalOrder){

static vector<IndexVar>
topologicallySort(map<IndexVar,set<IndexVar>> tensorDeps,
vector<IndexVar> originalOrder){
vector<IndexVar> sortedVars;
unsigned long countVars = originalOrder.size();
while (sortedVars.size() < countVars) {
Expand Down Expand Up @@ -448,6 +481,7 @@ vector<IndexVar> topologicallySort(map<IndexVar,set<IndexVar>> tensorDeps,
return sortedVars;
}


IndexStmt reorderLoopsTopologically(IndexStmt stmt) {
// Collect tensorLevelVars which stores the pairs of IndexVar and tensor
// level that each tensor is accessed at
Expand Down Expand Up @@ -554,4 +588,100 @@ IndexStmt reorderLoopsTopologically(IndexStmt stmt) {
return rewriter.rewrite(stmt);
}

static bool compare(std::vector<IndexVar> vars1, std::vector<IndexVar> vars2) {
return util::all(util::zip(vars1, vars2),
[](const pair<IndexVar,IndexVar>& v) {
return v.first == v.second;
});
}

// TODO Temporary function to insert workspaces into SpMM kernels
static IndexStmt optimizeSpMM(IndexStmt stmt) {
if (!isa<Forall>(stmt)) {
return stmt;
}
Forall foralli = to<Forall>(stmt);
IndexVar i = foralli.getIndexVar();

if (!isa<Forall>(foralli.getStmt())) {
return stmt;
}
Forall forallk = to<Forall>(foralli.getStmt());
IndexVar k = forallk.getIndexVar();

if (!isa<Forall>(forallk.getStmt())) {
return stmt;
}
Forall forallj = to<Forall>(forallk.getStmt());
IndexVar j = forallj.getIndexVar();

if (!isa<Assignment>(forallj.getStmt())) {
return stmt;
}
Assignment assignment = to<Assignment>(forallj.getStmt());

if (!isa<Mul>(assignment.getRhs())) {
return stmt;
}
Mul mul = to<Mul>(assignment.getRhs());

taco_iassert(isa<Access>(assignment.getLhs()));
if (!isa<Access>(mul.getA())) {
return stmt;
}
if (!isa<Access>(mul.getB())) {
return stmt;
}

Access Aaccess = to<Access>(assignment.getLhs());
Access Baccess = to<Access>(mul.getA());
Access Caccess = to<Access>(mul.getB());

if (!compare(Aaccess.getIndexVars(), {i,j}) ||
!compare(Baccess.getIndexVars(), {i,k}) ||
!compare(Caccess.getIndexVars(), {k,j})) {
return stmt;
}

TensorVar A = Aaccess.getTensorVar();
TensorVar B = Baccess.getTensorVar();
TensorVar C = Caccess.getTensorVar();

if (A.getFormat() != CSR ||
B.getFormat() != CSR ||
C.getFormat() != CSR) {
return stmt;
}

// It's an SpMM statement so return an optimized SpMM statement
TensorVar w("w", Type(Float64, {Dimension()}), dense);
return forall(i,
where(forall(j,
A(i,j) = w(j)),
forall(k,
forall(j,
w(j) += B(i,k) * C(k,j)))));
}

IndexStmt insertTemporaries(IndexStmt stmt)
{
IndexStmt spmm = optimizeSpMM(stmt);
if (spmm != stmt) {
return spmm;
}

// TODO Implement general workspacing when scattering into sparse result modes

// Result dimensions that are indexed by free variables dominated by a
// reduction variable are scattered into. If any of these are compressed
// then we introduce a dense workspace to scatter into instead. The where
// statement must push the reduction loop into the producer side, leaving
// only the free variable loops on the consumer side.

//vector<IndexVar> reductionVars = getReductionVars(stmt);
//...

return stmt;
}

}
Loading

0 comments on commit b7aecb6

Please sign in to comment.