Skip to content

Commit

Permalink
Merge pull request #1 from rohany/windowing-array-algebra-fix
Browse files Browse the repository at this point in the history
lower: fix a bug introduced by merging windowing and array algebra
  • Loading branch information
rawnhenry authored Feb 18, 2021
2 parents a76408a + 4e27678 commit b1f7f88
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 30 deletions.
5 changes: 5 additions & 0 deletions include/taco/index_notation/index_notation.h
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ class Access : public IndexExpr {
Assignment operator+=(const IndexExpr&);

typedef AccessNode Node;

// Equality and comparison are overridden on Access to perform a deep
// comparison of the access rather than a pointer check.
friend bool operator==(const Access& a, const Access& b);
friend bool operator<(const Access& a, const Access &b);
};


Expand Down
6 changes: 6 additions & 0 deletions include/taco/index_notation/index_notation_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ struct AccessWindow {
friend bool operator==(const AccessWindow& a, const AccessWindow& b) {
return a.lo == b.lo && a.hi == b.hi;
}
friend bool operator<(const AccessWindow& a, const AccessWindow& b) {
if (a.lo != b.lo) {
return a.lo < b.lo;
}
return a.hi < b.hi;
}
};

struct AccessNode : public IndexExprNode {
Expand Down
32 changes: 32 additions & 0 deletions src/index_notation/index_notation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -999,6 +999,38 @@ int Access::getWindowUpperBound(int mode) const {
return getNode(*this)->windowedModes.at(mode).hi;
}

bool operator==(const Access& a, const Access& b) {
// Short-circuit for when the Access pointers are the same.
if (getNode(a) == getNode(b)) {
return true;
}
if (a.getTensorVar() != b.getTensorVar()) {
return false;
}
if (a.getIndexVars() != b.getIndexVars()) {
return false;
}
if (getNode(a)->windowedModes != getNode(b)->windowedModes) {
return false;
}
return true;
}

bool operator<(const Access& a, const Access& b) {
// First branch on tensorVar.
if(a.getTensorVar() != b.getTensorVar()) {
return a.getTensorVar() < b.getTensorVar();
}

// Then branch on the indexVars used in the access.
if (a.getIndexVars() != b.getIndexVars()) {
return a.getIndexVars() < b.getIndexVars();
}

// Lastly, branch on the windows.
return getNode(a)->windowedModes < getNode(b)->windowedModes;
}

static void check(Assignment assignment) {
auto lhs = assignment.getLhs();
auto tensorVar = lhs.getTensorVar();
Expand Down
35 changes: 5 additions & 30 deletions src/lower/mode_access.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,43 +13,18 @@ size_t ModeAccess::getModePos() const {
return mode;
}

static bool accessEqual(const Access& a, const Access& b) {
return a == b ||
(a.getTensorVar() == b.getTensorVar() && a.getIndexVars() == b.getIndexVars());
}

bool operator==(const ModeAccess& a, const ModeAccess& b) {
return accessEqual(a.getAccess(), b.getAccess()) && a.getModePos() == b.getModePos();
return a.getAccess() == b.getAccess() && a.getModePos() == b.getModePos();
}

bool operator<(const ModeAccess& a, const ModeAccess& b) {

// fast path for when access pointers are equal
if(a.getAccess() == b.getAccess()) {
// First break on the mode position.
if (a.getModePos() != b.getModePos()) {
return a.getModePos() < b.getModePos();
}

// First break on tensorVars
if(a.getAccess().getTensorVar() != b.getAccess().getTensorVar()) {
return a.getAccess().getTensorVar() < b.getAccess().getTensorVar();
}

// Then break on the indexVars used in the access
std::vector<IndexVar> aVars = a.getAccess().getIndexVars();
std::vector<IndexVar> bVars = b.getAccess().getIndexVars();

if(aVars.size() != bVars.size()) {
return aVars.size() < bVars.size();
}

for(size_t i = 0; i < aVars.size(); ++i) {
if(aVars[i] != bVars[i]) {
return aVars[i] < bVars[i];
}
}

// Finally, break on the mode position
return a.getModePos() < b.getModePos();
// Then, return a deep comparison of the underlying access.
return a.getAccess() <b.getAccess();
}

std::ostream &operator<<(std::ostream &os, const ModeAccess & modeAccess) {
Expand Down

0 comments on commit b1f7f88

Please sign in to comment.