From b9ef6d9a52a977a8d04580b3d4d4e84b4ca0b55d Mon Sep 17 00:00:00 2001 From: Kyle Cripps Date: Fri, 12 Jul 2024 10:02:13 -0700 Subject: [PATCH 1/5] Map written LocationSets to program locations (loc_t) instead of IR::Expression*s Signed-off-by: Kyle Cripps --- frontends/p4/def_use.cpp | 45 +++++++++-- frontends/p4/def_use.h | 169 ++++++++++++++++++++++++--------------- 2 files changed, 142 insertions(+), 72 deletions(-) diff --git a/frontends/p4/def_use.cpp b/frontends/p4/def_use.cpp index f4d1842c182..06b4e52b59b 100644 --- a/frontends/p4/def_use.cpp +++ b/frontends/p4/def_use.cpp @@ -638,8 +638,10 @@ bool ComputeWriteSet::preorder(const IR::SelectExpression *expression) { visit(expression->select); visit(&expression->selectCases); auto l = getWrites(expression->select); - for (auto c : expression->selectCases) { - auto s = getWrites(c->keyset); + const loc_t *selectCasesLoc = getLoc(&expression->selectCases, getChildContext()); + for (auto *c : expression->selectCases) { + const loc_t *selectCaseLoc = getLoc(c, selectCasesLoc); + auto s = getWrites(c->keyset, selectCaseLoc); l = l->join(s); } expressionWrites(expression, l); @@ -673,7 +675,8 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) { lhs = save; auto mi = MethodInstance::resolve(expression, storageMap->refMap, storageMap->typeMap); if (auto bim = mi->to()) { - auto base = getWrites(bim->appliedTo); + const loc_t *methodLoc = getLoc(expression->method, getChildContext()); + auto base = getWrites(bim->appliedTo, methodLoc); cstring name = bim->name.name; if (name == IR::Type_Header::setInvalid || name == IR::Type_Header::setValid) { // modifies only the valid field. @@ -712,7 +715,7 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) { LOG3("Analyzing callees of " << expression << DBPrint::Brief << callees << DBPrint::Reset << indent); ProgramPoint pt(callingContext, expression); - ComputeWriteSet cw(this, pt, currentDefinitions); + ComputeWriteSet cw(this, pt, currentDefinitions, cached_locs); cw.setCalledBy(this); for (auto c : callees) (void)c->getNode()->apply(cw); currentDefinitions = cw.currentDefinitions; @@ -735,7 +738,8 @@ bool ComputeWriteSet::preorder(const IR::MethodCallExpression *expression) { visit(arg); lhs = save; if (p->direction == IR::Direction::Out || p->direction == IR::Direction::InOut) { - auto val = getWrites(arg->expression); + const loc_t *argLoc = getLoc(arg, getChildContext()); + auto val = getWrites(arg->expression, argLoc); result = result->join(val); } } @@ -759,6 +763,35 @@ void ComputeWriteSet::visitVirtualMethods(const IR::IndexedVectornode, getLoc(ctxt->parent)}; + return &*cached_locs.insert(tmp).first; +} + +// Returns program location of a child node n, given the context of the +// currently being visited node. +// Use to get loc if n is direct child of currently being visited node. +const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const IR::Node *n, + const Visitor::Context *ctxt) { + for (auto *p = ctxt; p; p = p->parent) + if (p->node == n) return getLoc(p); + auto rv = getLoc(ctxt); + loc_t tmp{n, rv}; + return &*cached_locs.insert(tmp).first; +} + // Symbolic execution of the parser bool ComputeWriteSet::preorder(const IR::P4Parser *parser) { LOG3("CWS Visiting " << dbp(parser)); @@ -784,7 +817,7 @@ bool ComputeWriteSet::preorder(const IR::P4Parser *parser) { // but we use the same data structures ProgramPoint pt(state); currentDefinitions = allDefinitions->getDefinitions(pt); - ComputeWriteSet cws(this, pt, currentDefinitions); + ComputeWriteSet cws(this, pt, currentDefinitions, cached_locs); cws.setCalledBy(this); (void)state->apply(cws); diff --git a/frontends/p4/def_use.h b/frontends/p4/def_use.h index 4cc8056ffa8..fb80923a932 100644 --- a/frontends/p4/def_use.h +++ b/frontends/p4/def_use.h @@ -24,6 +24,7 @@ limitations under the License. #include "lib/alloc_trace.h" #include "lib/hash.h" #include "lib/hvec_map.h" +#include "lib/ordered_map.h" #include "lib/ordered_set.h" #include "typeMap.h" @@ -476,6 +477,78 @@ class AllDefinitions : public IHasDbPrint { */ class ComputeWriteSet : public Inspector, public IHasDbPrint { + public: + // A location in the program. Includes the context from the visitor, which needs to + // be copied out of the Visitor::Context objects, as they are allocated on the stack and + // will become invalid as the IR traversal continues + struct loc_t { + const IR::Node *node; + const loc_t *parent; + bool operator<(const loc_t &a) const { + if (node != a.node) return node->id < a.node->id; + if (!parent || !a.parent) return parent != nullptr; + return *parent < *a.parent; + } + }; + + explicit ComputeWriteSet(AllDefinitions *allDefinitions) + : allDefinitions(allDefinitions), + currentDefinitions(nullptr), + returnedDefinitions(nullptr), + exitDefinitions(new Definitions()), + storageMap(allDefinitions->storageMap), + lhs(false), + virtualMethod(false), + cached_locs(*new std::set) { + CHECK_NULL(allDefinitions); + visitDagOnce = false; + } + + // expressions + bool preorder(const IR::Literal *expression) override; + bool preorder(const IR::Slice *expression) override; + bool preorder(const IR::TypeNameExpression *expression) override; + bool preorder(const IR::PathExpression *expression) override; + bool preorder(const IR::Member *expression) override; + bool preorder(const IR::ArrayIndex *expression) override; + bool preorder(const IR::Operation_Binary *expression) override; + bool preorder(const IR::Mux *expression) override; + bool preorder(const IR::SelectExpression *expression) override; + bool preorder(const IR::ListExpression *expression) override; + bool preorder(const IR::Operation_Unary *expression) override; + bool preorder(const IR::MethodCallExpression *expression) override; + bool preorder(const IR::DefaultExpression *expression) override; + bool preorder(const IR::Expression *expression) override; + bool preorder(const IR::InvalidHeader *expression) override; + bool preorder(const IR::InvalidHeaderUnion *expression) override; + bool preorder(const IR::P4ListExpression *expression) override; + bool preorder(const IR::HeaderStackExpression *expression) override; + bool preorder(const IR::StructExpression *expression) override; + // statements + bool preorder(const IR::P4Parser *parser) override; + bool preorder(const IR::P4Control *control) override; + bool preorder(const IR::P4Action *action) override; + bool preorder(const IR::P4Table *table) override; + bool preorder(const IR::Function *function) override; + bool preorder(const IR::AssignmentStatement *statement) override; + bool preorder(const IR::ReturnStatement *statement) override; + bool preorder(const IR::ExitStatement *statement) override; + bool preorder(const IR::BreakStatement *statement) override; + bool handleJump(const char *tok, Definitions *&defs); + bool preorder(const IR::ContinueStatement *statement) override; + bool preorder(const IR::IfStatement *statement) override; + bool preorder(const IR::ForStatement *statement) override; + bool preorder(const IR::ForInStatement *statement) override; + bool preorder(const IR::BlockStatement *statement) override; + bool preorder(const IR::SwitchStatement *statement) override; + bool preorder(const IR::EmptyStatement *statement) override; + bool preorder(const IR::MethodCallStatement *statement) override; + + const LocationSet *writtenLocations(const IR::Expression *expression) { + expression->apply(*this); + return getWrites(expression); + } + protected: AllDefinitions *allDefinitions; /// Result computed by this pass. Definitions *currentDefinitions; /// Before statement currently processed. @@ -487,8 +560,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { const StorageMap *storageMap; /// if true we are processing an expression on the lhs of an assignment bool lhs; - /// For each expression the location set it writes - hvec_map writes; + /// For each program location the location set it writes + ordered_map writes; bool virtualMethod; /// True if we are analyzing a virtual method AllocTrace memuse; alloc_trace_cb_t nested_trace; @@ -496,7 +569,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { /// Creates new visitor, but with same underlying data structures. /// Needed to visit some program fragments repeatedly. - ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions) + ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions, + std::set &cached_locs) : allDefinitions(source->allDefinitions), currentDefinitions(definitions), returnedDefinitions(nullptr), @@ -506,10 +580,14 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { callingContext(context), storageMap(source->storageMap), lhs(false), - virtualMethod(false) { + virtualMethod(false), + cached_locs(cached_locs) { visitDagOnce = false; } void visitVirtualMethods(const IR::IndexedVector &locals); + const loc_t *getLoc(const IR::Node *n, const loc_t *parentLoc); + const loc_t *getLoc(const Visitor::Context *ctxt); + const loc_t *getLoc(const IR::Node *n, const Visitor::Context *ctxt); void enterScope(const IR::ParameterList *parameters, const IR::IndexedVector *locals, ProgramPoint startPoint, bool clear = true); @@ -518,25 +596,39 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { Definitions *getDefinitionsAfter(const IR::ParserState *state); bool setDefinitions(Definitions *defs, const IR::Node *who = nullptr, bool overwrite = false); ProgramPoint getProgramPoint(const IR::Node *node = nullptr) const; - const LocationSet *getWrites(const IR::Expression *expression) const { - auto result = ::get(writes, expression); + // Get writes of a node that is a direct child of the currently being visited node. + const LocationSet *getWrites(const IR::Expression *expression) { + const loc_t &exprLoc = *getLoc(expression, getChildContext()); + auto result = ::get(writes, exprLoc); BUG_CHECK(result != nullptr, "No location set known for %1%", expression); return result; } + // Get writes of a node that is not a direct child of the currently being visited node. + // In this case, parentLoc is the loc of expression's direct parent node. + const LocationSet *getWrites(const IR::Expression *expression, const loc_t *parentLoc) { + const loc_t &exprLoc = *getLoc(expression, parentLoc); + auto result = ::get(writes, exprLoc); + BUG_CHECK(result != nullptr, "No location set known for %1%", expression); + return result; + } + // Register writes of expression, which is expected to be the currently visited node. void expressionWrites(const IR::Expression *expression, const LocationSet *loc) { CHECK_NULL(expression); CHECK_NULL(loc); LOG3(expression << dbp(expression) << " writes " << loc); - if (auto it = writes.find(expression); it != writes.end()) { + const Context *ctx = getChildContext(); + BUG_CHECK(ctx->node == expression, "Expected ctx->node == expression."); + const loc_t &exprLoc = *getLoc(ctx); + if (auto it = writes.find(exprLoc); it != writes.end()) { BUG_CHECK(*it->second == *loc || expression->is(), "Expression %1% write set already set", expression); } else { - writes.emplace(expression, loc); + writes.emplace(exprLoc, loc); } } void dbprint(std::ostream &out) const override { if (writes.empty()) out << "No writes"; - for (auto &it : writes) out << it.first << " writes " << it.second << Log::endl; + for (auto &it : writes) out << it.first.node << " writes " << it.second << Log::endl; } profile_t init_apply(const IR::Node *root) override { auto rv = Inspector::init_apply(root); @@ -555,63 +647,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { } } - public: - explicit ComputeWriteSet(AllDefinitions *allDefinitions) - : allDefinitions(allDefinitions), - currentDefinitions(nullptr), - returnedDefinitions(nullptr), - exitDefinitions(new Definitions()), - storageMap(allDefinitions->storageMap), - lhs(false), - virtualMethod(false) { - CHECK_NULL(allDefinitions); - visitDagOnce = false; - } - - // expressions - bool preorder(const IR::Literal *expression) override; - bool preorder(const IR::Slice *expression) override; - bool preorder(const IR::TypeNameExpression *expression) override; - bool preorder(const IR::PathExpression *expression) override; - bool preorder(const IR::Member *expression) override; - bool preorder(const IR::ArrayIndex *expression) override; - bool preorder(const IR::Operation_Binary *expression) override; - bool preorder(const IR::Mux *expression) override; - bool preorder(const IR::SelectExpression *expression) override; - bool preorder(const IR::ListExpression *expression) override; - bool preorder(const IR::Operation_Unary *expression) override; - bool preorder(const IR::MethodCallExpression *expression) override; - bool preorder(const IR::DefaultExpression *expression) override; - bool preorder(const IR::Expression *expression) override; - bool preorder(const IR::InvalidHeader *expression) override; - bool preorder(const IR::InvalidHeaderUnion *expression) override; - bool preorder(const IR::P4ListExpression *expression) override; - bool preorder(const IR::HeaderStackExpression *expression) override; - bool preorder(const IR::StructExpression *expression) override; - // statements - bool preorder(const IR::P4Parser *parser) override; - bool preorder(const IR::P4Control *control) override; - bool preorder(const IR::P4Action *action) override; - bool preorder(const IR::P4Table *table) override; - bool preorder(const IR::Function *function) override; - bool preorder(const IR::AssignmentStatement *statement) override; - bool preorder(const IR::ReturnStatement *statement) override; - bool preorder(const IR::ExitStatement *statement) override; - bool preorder(const IR::BreakStatement *statement) override; - bool handleJump(const char *tok, Definitions *&defs); - bool preorder(const IR::ContinueStatement *statement) override; - bool preorder(const IR::IfStatement *statement) override; - bool preorder(const IR::ForStatement *statement) override; - bool preorder(const IR::ForInStatement *statement) override; - bool preorder(const IR::BlockStatement *statement) override; - bool preorder(const IR::SwitchStatement *statement) override; - bool preorder(const IR::EmptyStatement *statement) override; - bool preorder(const IR::MethodCallStatement *statement) override; - - const LocationSet *writtenLocations(const IR::Expression *expression) { - expression->apply(*this); - return getWrites(expression); - } + private: + std::set &cached_locs; }; } // namespace P4 From 655b9c2a292f47dcf7b236e01a5c6f0dd3f84423 Mon Sep 17 00:00:00 2001 From: Kyle Cripps Date: Fri, 12 Jul 2024 14:47:25 -0700 Subject: [PATCH 2/5] Use hvec_map instead of ordered_map for writes and use std::unordered_set instead of std::set for cached_locs Signed-off-by: Kyle Cripps --- frontends/p4/def_use.cpp | 14 +++++++----- frontends/p4/def_use.h | 46 ++++++++++++++++++++++++---------------- 2 files changed, 37 insertions(+), 23 deletions(-) diff --git a/frontends/p4/def_use.cpp b/frontends/p4/def_use.cpp index 06b4e52b59b..70b2b203a11 100644 --- a/frontends/p4/def_use.cpp +++ b/frontends/p4/def_use.cpp @@ -763,18 +763,23 @@ void ComputeWriteSet::visitVirtualMethods(const IR::IndexedVectorid); + + return Util::Hash{}(node->id, parent->hash()); +} + // Returns program location of n, given the program location of n's direct parent. // Use to get loc if n is indirect child (e.g. grandchild) of currently being visited node. // In this case parentLoc is the loc of n's direct parent. -const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const IR::Node *n, - const loc_t *parentLoc) { +const P4::loc_t *ComputeWriteSet::getLoc(const IR::Node *n, const loc_t *parentLoc) { loc_t tmp{n, parentLoc}; return &*cached_locs.insert(tmp).first; } // Returns program location given the context of the currently being visited node. // Use to get loc of currently being visited node. -const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const Visitor::Context *ctxt) { +const P4::loc_t *ComputeWriteSet::getLoc(const Visitor::Context *ctxt) { if (!ctxt) return nullptr; loc_t tmp{ctxt->node, getLoc(ctxt->parent)}; return &*cached_locs.insert(tmp).first; @@ -783,8 +788,7 @@ const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const Visitor::Context // Returns program location of a child node n, given the context of the // currently being visited node. // Use to get loc if n is direct child of currently being visited node. -const P4::ComputeWriteSet::loc_t *ComputeWriteSet::getLoc(const IR::Node *n, - const Visitor::Context *ctxt) { +const P4::loc_t *ComputeWriteSet::getLoc(const IR::Node *n, const Visitor::Context *ctxt) { for (auto *p = ctxt; p; p = p->parent) if (p->node == n) return getLoc(p); auto rv = getLoc(ctxt); diff --git a/frontends/p4/def_use.h b/frontends/p4/def_use.h index fb80923a932..0800e091851 100644 --- a/frontends/p4/def_use.h +++ b/frontends/p4/def_use.h @@ -24,15 +24,29 @@ limitations under the License. #include "lib/alloc_trace.h" #include "lib/hash.h" #include "lib/hvec_map.h" -#include "lib/ordered_map.h" #include "lib/ordered_set.h" #include "typeMap.h" namespace P4 { +class ComputeWriteSet; class StorageFactory; class LocationSet; +// A location in the program. Includes the context from the visitor, which needs to +// be copied out of the Visitor::Context objects, as they are allocated on the stack and +// will become invalid as the IR traversal continues +struct loc_t { + const IR::Node *node; + const loc_t *parent; + bool operator==(const loc_t &a) const { + if (node != a.node) return false; + if (parent && a.parent) return *parent == *a.parent; + return parent == a.parent; + } + std::size_t hash() const; +}; + /// Abstraction for something that is has a left value (variable, parameter) class StorageLocation : public IHasDbPrint, public ICastable { static unsigned crtid; @@ -336,6 +350,14 @@ struct hash { typedef std::size_t result_type; result_type operator()(argument_type const &s) const { return s.hash(); } }; + +template <> +struct hash { + typedef P4::loc_t argument_type; + typedef std::size_t result_type; + result_type operator()(argument_type const &loc) const { return loc.hash(); } +}; + } // namespace std namespace Util { @@ -478,19 +500,6 @@ class AllDefinitions : public IHasDbPrint { class ComputeWriteSet : public Inspector, public IHasDbPrint { public: - // A location in the program. Includes the context from the visitor, which needs to - // be copied out of the Visitor::Context objects, as they are allocated on the stack and - // will become invalid as the IR traversal continues - struct loc_t { - const IR::Node *node; - const loc_t *parent; - bool operator<(const loc_t &a) const { - if (node != a.node) return node->id < a.node->id; - if (!parent || !a.parent) return parent != nullptr; - return *parent < *a.parent; - } - }; - explicit ComputeWriteSet(AllDefinitions *allDefinitions) : allDefinitions(allDefinitions), currentDefinitions(nullptr), @@ -499,7 +508,7 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { storageMap(allDefinitions->storageMap), lhs(false), virtualMethod(false), - cached_locs(*new std::set) { + cached_locs(*new std::unordered_set) { CHECK_NULL(allDefinitions); visitDagOnce = false; } @@ -561,7 +570,7 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { /// if true we are processing an expression on the lhs of an assignment bool lhs; /// For each program location the location set it writes - ordered_map writes; + hvec_map writes; bool virtualMethod; /// True if we are analyzing a virtual method AllocTrace memuse; alloc_trace_cb_t nested_trace; @@ -570,7 +579,7 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { /// Creates new visitor, but with same underlying data structures. /// Needed to visit some program fragments repeatedly. ComputeWriteSet(const ComputeWriteSet *source, ProgramPoint context, Definitions *definitions, - std::set &cached_locs) + std::unordered_set &cached_locs) : allDefinitions(source->allDefinitions), currentDefinitions(definitions), returnedDefinitions(nullptr), @@ -648,7 +657,8 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { } private: - std::set &cached_locs; + // TODO: Make absl::flat_hash_set instead? + std::unordered_set &cached_locs; }; } // namespace P4 From 867382a36e81f90b42903c100dcb06e051736d55 Mon Sep 17 00:00:00 2001 From: Kyle Cripps Date: Fri, 12 Jul 2024 22:21:37 -0700 Subject: [PATCH 3/5] Short-circuit when parents are equal Signed-off-by: Kyle Cripps --- frontends/p4/def_use.h | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/frontends/p4/def_use.h b/frontends/p4/def_use.h index 0800e091851..4343c2a153a 100644 --- a/frontends/p4/def_use.h +++ b/frontends/p4/def_use.h @@ -41,8 +41,9 @@ struct loc_t { const loc_t *parent; bool operator==(const loc_t &a) const { if (node != a.node) return false; - if (parent && a.parent) return *parent == *a.parent; - return parent == a.parent; + if (parent == a.parent) return true; + if (!parent || !a.parent) return false; + return *parent == *a.parent; } std::size_t hash() const; }; From 8e55157d1159a5aa169ec12e1b90ca39b8a85b74 Mon Sep 17 00:00:00 2001 From: Kyle Cripps Date: Wed, 17 Jul 2024 14:22:29 -0700 Subject: [PATCH 4/5] code cleanup Signed-off-by: Kyle Cripps --- frontends/p4/def_use.h | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/frontends/p4/def_use.h b/frontends/p4/def_use.h index 4343c2a153a..fcbf5de1107 100644 --- a/frontends/p4/def_use.h +++ b/frontends/p4/def_use.h @@ -347,16 +347,12 @@ class ProgramPoint : public IHasDbPrint { namespace std { template <> struct hash { - typedef P4::ProgramPoint argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const &s) const { return s.hash(); } + std::size_t operator()(const P4::ProgramPoint &s) const { return s.hash(); } }; template <> struct hash { - typedef P4::loc_t argument_type; - typedef std::size_t result_type; - result_type operator()(argument_type const &loc) const { return loc.hash(); } + std::size_t operator()(const P4::loc_t &loc) const { return loc.hash(); } }; } // namespace std @@ -658,7 +654,7 @@ class ComputeWriteSet : public Inspector, public IHasDbPrint { } private: - // TODO: Make absl::flat_hash_set instead? + // TODO: Make absl::node_hash_set instead? std::unordered_set &cached_locs; }; From 6858f16cc9157540ee9b3ef3dd0c532b6c6afcef Mon Sep 17 00:00:00 2001 From: Kyle Cripps Date: Wed, 17 Jul 2024 15:43:27 -0700 Subject: [PATCH 5/5] Memoize loc_t hashes Signed-off-by: Kyle Cripps --- frontends/p4/def_use.cpp | 10 +++++++--- frontends/p4/def_use.h | 1 + 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/frontends/p4/def_use.cpp b/frontends/p4/def_use.cpp index 70b2b203a11..2abff7cebe3 100644 --- a/frontends/p4/def_use.cpp +++ b/frontends/p4/def_use.cpp @@ -764,9 +764,13 @@ void ComputeWriteSet::visitVirtualMethods(const IR::IndexedVectorid); - - return Util::Hash{}(node->id, parent->hash()); + if (!computedHash) { + if (!parent) + computedHash = Util::Hash{}(node->id); + else + computedHash = Util::Hash{}(node->id, parent->hash()); + } + return computedHash; } // Returns program location of n, given the program location of n's direct parent. diff --git a/frontends/p4/def_use.h b/frontends/p4/def_use.h index fcbf5de1107..aa46ddfb1d4 100644 --- a/frontends/p4/def_use.h +++ b/frontends/p4/def_use.h @@ -39,6 +39,7 @@ class LocationSet; struct loc_t { const IR::Node *node; const loc_t *parent; + mutable std::size_t computedHash = 0; bool operator==(const loc_t &a) const { if (node != a.node) return false; if (parent == a.parent) return true;