diff --git a/include/taco/index_notation/index_notation.h b/include/taco/index_notation/index_notation.h index 76822bc09..640d87dd1 100644 --- a/include/taco/index_notation/index_notation.h +++ b/include/taco/index_notation/index_notation.h @@ -269,6 +269,11 @@ class Access : public IndexExpr { Assignment operator+=(const IndexExpr&); typedef AccessNode Node; + + // Equality and comparison are overridden on Access to perform a deep + // comparison of the access rather than a pointer check. + friend bool operator==(const Access& a, const Access& b); + friend bool operator<(const Access& a, const Access &b); }; diff --git a/include/taco/index_notation/index_notation_nodes.h b/include/taco/index_notation/index_notation_nodes.h index 3df5a1d32..d3284f0ab 100644 --- a/include/taco/index_notation/index_notation_nodes.h +++ b/include/taco/index_notation/index_notation_nodes.h @@ -28,6 +28,12 @@ struct AccessWindow { friend bool operator==(const AccessWindow& a, const AccessWindow& b) { return a.lo == b.lo && a.hi == b.hi; } + friend bool operator<(const AccessWindow& a, const AccessWindow& b) { + if (a.lo != b.lo) { + return a.lo < b.lo; + } + return a.hi < b.hi; + } }; struct AccessNode : public IndexExprNode { diff --git a/src/index_notation/index_notation.cpp b/src/index_notation/index_notation.cpp index efb330668..6b3fd71b4 100644 --- a/src/index_notation/index_notation.cpp +++ b/src/index_notation/index_notation.cpp @@ -999,6 +999,38 @@ int Access::getWindowUpperBound(int mode) const { return getNode(*this)->windowedModes.at(mode).hi; } +bool operator==(const Access& a, const Access& b) { + // Short-circuit for when the Access pointers are the same. + if (getNode(a) == getNode(b)) { + return true; + } + if (a.getTensorVar() != b.getTensorVar()) { + return false; + } + if (a.getIndexVars() != b.getIndexVars()) { + return false; + } + if (getNode(a)->windowedModes != getNode(b)->windowedModes) { + return false; + } + return true; +} + +bool operator<(const Access& a, const Access& b) { + // First branch on tensorVar. + if(a.getTensorVar() != b.getTensorVar()) { + return a.getTensorVar() < b.getTensorVar(); + } + + // Then branch on the indexVars used in the access. + if (a.getIndexVars() != b.getIndexVars()) { + return a.getIndexVars() < b.getIndexVars(); + } + + // Lastly, branch on the windows. + return getNode(a)->windowedModes < getNode(b)->windowedModes; +} + static void check(Assignment assignment) { auto lhs = assignment.getLhs(); auto tensorVar = lhs.getTensorVar(); diff --git a/src/lower/mode_access.cpp b/src/lower/mode_access.cpp index fcb0d19c4..f682e915e 100644 --- a/src/lower/mode_access.cpp +++ b/src/lower/mode_access.cpp @@ -13,43 +13,18 @@ size_t ModeAccess::getModePos() const { return mode; } -static bool accessEqual(const Access& a, const Access& b) { - return a == b || - (a.getTensorVar() == b.getTensorVar() && a.getIndexVars() == b.getIndexVars()); -} - bool operator==(const ModeAccess& a, const ModeAccess& b) { - return accessEqual(a.getAccess(), b.getAccess()) && a.getModePos() == b.getModePos(); + return a.getAccess() == b.getAccess() && a.getModePos() == b.getModePos(); } bool operator<(const ModeAccess& a, const ModeAccess& b) { - - // fast path for when access pointers are equal - if(a.getAccess() == b.getAccess()) { + // First break on the mode position. + if (a.getModePos() != b.getModePos()) { return a.getModePos() < b.getModePos(); } - // First break on tensorVars - if(a.getAccess().getTensorVar() != b.getAccess().getTensorVar()) { - return a.getAccess().getTensorVar() < b.getAccess().getTensorVar(); - } - - // Then break on the indexVars used in the access - std::vector aVars = a.getAccess().getIndexVars(); - std::vector bVars = b.getAccess().getIndexVars(); - - if(aVars.size() != bVars.size()) { - return aVars.size() < bVars.size(); - } - - for(size_t i = 0; i < aVars.size(); ++i) { - if(aVars[i] != bVars[i]) { - return aVars[i] < bVars[i]; - } - } - - // Finally, break on the mode position - return a.getModePos() < b.getModePos(); + // Then, return a deep comparison of the underlying access. + return a.getAccess()