Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

*: trying out fixing the code generation #449

Draft
wants to merge 2 commits into
base: array_algebra
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/taco/lower/lowerer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,7 @@ class LowererImpl : public util::Uncopyable {
private:
bool assemble;
bool compute;
std::string funcname;
bool loopOrderAllowsShortCircuit = false;

int markAssignsAtomicDepth = 0;
Expand Down
6 changes: 5 additions & 1 deletion include/taco/util/scopedset.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <ostream>

#include "taco/error.h"
#include "taco/util/strings.h"

namespace taco {
namespace util {
Expand Down Expand Up @@ -56,7 +57,10 @@ class ScopedSet {
}

friend std::ostream& operator<<(std::ostream& os, ScopedSet<Key> sset) {
os << "ScopedSet:" << std::endl;
os << "ScopedSet: " << std::endl;
for (auto& s : sset.scopes) {
os << "scope: " << util::join(s) << std::endl;
}
return os;
}

Expand Down
109 changes: 95 additions & 14 deletions src/lower/lowerer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ LowererImpl::lower(IndexStmt stmt, string name,
{
this->assemble = assemble;
this->compute = compute;
this->funcname = name;
definedIndexVarsOrdered = {};
definedIndexVars = {};
loopOrderAllowsShortCircuit = allForFreeLoopsBeforeAllReductionLoops(stmt);
Expand Down Expand Up @@ -568,6 +569,9 @@ Stmt LowererImpl::lowerForall(Forall forall)
}

MergeLattice caseLattice = MergeLattice::make(forall, iterators, provGraph, definedIndexVars, whereTempsToResult);
// std::cout << "case lattice: " << forall.getIndexVar() << " " << caseLattice << std::endl;
// std::cout << "merge lattice: " << forall.getIndexVar() << " " << caseLattice.getLoopLattice() << std::endl;

vector<Access> resultAccesses;
set<Access> reducedAccesses;
std::tie(resultAccesses, reducedAccesses) = getResultAccesses(forall);
Expand All @@ -586,7 +590,7 @@ Stmt LowererImpl::lowerForall(Forall forall)

Stmt loops;
// Emit a loop that iterates over over a single iterator (optimization)
if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique()) {
if (caseLattice.iterators().size() == 1 && caseLattice.iterators()[0].isUnique() && false) {
MergeLattice loopLattice = caseLattice.getLoopLattice();

MergePoint point = loopLattice.points()[0];
Expand Down Expand Up @@ -664,10 +668,11 @@ Stmt LowererImpl::lowerForall(Forall forall)
loops = lowerMergeLattice(caseLattice, underivedAncestors[0],
forall.getStmt(), reducedAccesses);
}
// taco_iassert(loops.defined());
taco_iassert(loops.defined());

// std::cout << "LOOPS " << loops << std::endl;
if (!generateComputeCode() && !hasStores(loops)) {
// If assembly loop does not modify output arrays, then it can be safely
// If assembly loop does not modify output arrays, then it can be safely
// omitted.
loops = Stmt();
}
Expand Down Expand Up @@ -1386,11 +1391,14 @@ Stmt LowererImpl::lowerMergeLattice(MergeLattice caseLattice, IndexVar coordinat
bool resolvedCoordDeclared = !modeIteratorsNonMergers.empty();

vector<Stmt> mergeLoopsVec;
std::cout << "Lattice: " << caseLattice << std::endl;
for (MergePoint point : loopLattice.points()) {
// Each iteration of this loop generates a while loop for one of the merge
// points in the merge lattice.
IndexStmt zeroedStmt = zero(statement, getExhaustedAccesses(point, caseLattice));
std::cout << "Var: " << coordinateVar << " Merge Point: " << point << " Statement: " << statement << " Zeroed: " << zeroedStmt << std::endl;
MergeLattice sublattice = caseLattice.subLattice(point);
// std::cout << "sublattice: " << sublattice << std::endl;
Stmt mergeLoop = lowerMergePoint(sublattice, coordinate, coordinateVar, zeroedStmt, reducedAccesses, resolvedCoordDeclared);
mergeLoopsVec.push_back(mergeLoop);
}
Expand Down Expand Up @@ -1565,21 +1573,29 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I
{
vector<Stmt> result;

if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) {
// Can check value array of some tensor
Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses);
result.push_back(body);
return Block::make(result);
}
// if (caseLattice.anyModeIteratorIsLeaf() && caseLattice.needExplicitZeroChecks()) {
// // Can check value array of some tensor
// Stmt body = lowerMergeCasesWithExplicitZeroChecks(coordinate, coordinateVar, stmt, caseLattice, reducedAccesses);
// result.push_back(body);
// return Block::make(result);
// }

// Emitting structural cases so unconditionally apply lattice optimizations.
MergeLattice loopLattice = caseLattice.getLoopLattice();
// MergeLattice loopLattice = caseLattice.getLoopLattice();
MergeLattice loopLattice = caseLattice;

// std::cout << "LoopLattice: " << loopLattice << " CaseLattice: " << caseLattice << std::endl;
std::cout << " CaseLattice: " << caseLattice << std::endl;

vector<Iterator> appenders;
vector<Iterator> inserters;
tie(appenders, inserters) = splitAppenderAndInserters(loopLattice.results());

if (loopLattice.iterators().size() == 1) {
auto skip = true;
// auto skip = !(this->funcname == "compute");
// std::cout << "skip value " << stmt << " " << skip << std::endl;

if (loopLattice.iterators().size() == 1 && !loopLattice.points()[0].isOmitter() && skip) {
// Just one iterator so no conditional
taco_iassert(!loopLattice.points()[0].isOmitter());
Stmt body = lowerForallBody(coordinate, stmt, {}, inserters,
Expand All @@ -1588,21 +1604,83 @@ Stmt LowererImpl::lowerMergeCases(ir::Expr coordinate, IndexVar coordinateVar, I
}
else if (!loopLattice.points().empty()) {
vector<pair<Expr,Stmt>> cases;
// std::cout << "Accessible iterators: " << this->accessibleIterators << std::endl;
for (MergePoint point : loopLattice.points()) {
// std::cout << "In the loop lattice lowering phase " << point << std::endl;

struct ReadyTensors : IndexNotationVisitor {
std::set<TensorVar> readyTensors;
std::set<IndexVar> definedIndexVars;

void visit(const AccessNode* node) {
bool ready = true;
for (auto& ivar : node->indexVars) {
if (!util::contains(definedIndexVars, ivar)) {
ready = false;
break;
}
}
if (ready) {
readyTensors.insert(node->tensorVar);
}
}
};

// if(point.isOmitter()) {
// continue;
// }
auto rt = ReadyTensors();
rt.definedIndexVars = this->definedIndexVars;
stmt.accept(&rt);

std::cout << "readyTensors: " << util::join(rt.readyTensors) << std::endl;

auto skipPoint = false;
for (auto& rl : loopLattice.points().front().locators()) {

auto rlit = rl.getTensor();
TensorVar rltv;
for (auto& kv : this->tensorVars) {
if (kv.second == rlit) {
rltv = kv.first;
break;
}
}

if (util::contains(rt.readyTensors, rltv)) {
std::cout << "Not considering tensorvar: " << rltv << std::endl;
continue;
}

if (!util::contains(point.locators(), rl)) {
std::cout << "skipping point: " << point << std::endl;
skipPoint = true;
}
}

if (skipPoint) continue;

if(point.isOmitter() && hasNoForAlls(stmt)) {
// std::cout << "omitting point: " << point << std::endl;
continue;
}

// Construct case expression
vector<Expr> coordComparisons = compareToResolvedCoordinate<Eq>(point.rangers(), coordinate, coordinateVar);
vector<Iterator> omittedRegionIterators = loopLattice.retrieveRegionIteratorsToOmit(point);
if (!point.isOmitter()) {
// omittedRegionIterators = filter(omittedRegionIterators, [](const Iterator& it) {
// auto iterTensorVar = it.getTensor();
//
//// auto tensor
//// this->tensorVars
//// this->tens
// return false;
// });
std::vector <Expr> neqComparisons = compareToResolvedCoordinate<Neq>(omittedRegionIterators, coordinate,
coordinateVar);
append(coordComparisons, neqComparisons);
}

// std::cout << util::join(coordComparisons) << std::endl;

coordComparisons = filter(coordComparisons, [](const Expr& e) { return e.defined(); });

// Construct case body
Expand Down Expand Up @@ -1683,13 +1761,16 @@ std::vector<ir::Stmt> LowererImpl::constructInnerLoopCasePreamble(ir::Expr coord
continue;
}
Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool);
taco_iassert(nonZeroCase.defined());
Stmt declaration = VarDecl::make(caseName, nonZeroCase);
result.push_back(declaration);
iteratorToConditionMap[tensorIterators[i]] = caseName;
}

for(size_t i = modeItersWithIndexCases.size(); i < valueComparisons.size(); ++i) {
if (!valueComparisons[i].defined()) continue;
Expr caseName = Var::make(itAccesses[i].getTensorVar().getName() + "_isNonZero", taco::Bool);
taco_iassert(valueComparisons[i].defined());
Stmt declaration = VarDecl::make(caseName, valueComparisons[i]);
result.push_back(declaration);
iteratorToConditionMap[tensorIterators[i]] = caseName;
Expand Down
9 changes: 9 additions & 0 deletions test/op_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct IntersectGenDeMorgan {

struct xorGen {
IterationAlgebra operator()(const std::vector<IndexExpr>& regions) {
// return Intersect(regions[0], regions[1]);
IterationAlgebra noIntersect = Complement(Intersect(regions[0], regions[1]));
return Intersect(noIntersect, Union(regions[0], regions[1]));
}
Expand Down Expand Up @@ -124,6 +125,14 @@ struct identityFunc {

struct GeneralAdd {
ir::Expr operator()(const std::vector<ir::Expr> &v) {
// return ir::Literal::make(int(v.size()));

// if (!v.size()) {
// return 0;
// }

// return 1;

taco_iassert(v.size() >= 2) << "Add operator needs at least two operands";
ir::Expr add = ir::Add::make(v[0], v[1]);

Expand Down
38 changes: 38 additions & 0 deletions test/tests-lower.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1853,4 +1853,42 @@ TEST_STMT(XorTestOrder2,
}
)

TEST(weija, weija) {
auto dim = 2;
// Tensor<int> A("A", {dim, dim}, {Dense, Sparse});
Tensor<int> A("A", {dim, dim}, CSC);

// Tensor<int> A("A", {dim, dim}, {Dense, Sparse});
Tensor<int> Z("Z", {dim}, {Sparse});
Tensor<int> B("B", {dim, dim}, {Dense, Dense});
// Tensor<int> B("B", {dim, dim}, {Dense, Sparse});
Tensor<int> C("C", {dim, dim}, {Dense, Dense});
Tensor<int> D("D", {dim, dim}, {Dense, Dense});
Tensor<int> E("E", {dim, dim}, {Dense, Dense});

A.insert({0, 0}, 1); A.insert({1, 1}, 1); A.pack();
B.insert({0, 0}, 1); B.insert({0, 1}, 1); B.pack();
Z.insert({0}, 1); Z.insert({1}, 1); Z.pack();

IndexVar i("i"), j("j"), k("k"), l("l"), m("m");
C(i, j) = xorOp(A(i, j), Z(j));
// C(i, j) = A(i, j) * Z(i);
auto stmt =C.getAssignment().concretize().reorder({j, i});
std::cout << stmt << std::endl;
C.compile(stmt);
// C(i, j) = xorOp(A(i, k), B(k, j));
// E(i, l) = xorOp(xorOp(A(i, k), B(k, j)), D(j, l));
// C(i, j) = xorOp(A(i, j), B(i, j));
// C(i, j) = A(i, j) * B(j, i);
// C(i, j) = xorOp(A(i, j), B(j, i));
// C.compile(C.getAssignment().concretize().reorder({i, j, k}));
// std::cout << "starting" << std::endl;
// C.compile();
std::cout << C.getSource() << std::endl;
C.evaluate();
// E.compile();
// std::cout << E.getSource() << std::endl;
std::cout << C << std::endl;
}

}}