From 6cb483bf1e14ac9de1a28d33cb311b936f73136b Mon Sep 17 00:00:00 2001 From: fruffy Date: Thu, 6 Jun 2024 09:35:20 -0400 Subject: [PATCH] Replace boost::container::flat_map with a custom flat_map implementation in P4Tools. --- backends/p4tools/BUILD.bazel | 1 - backends/p4tools/common/lib/model.cpp | 2 - backends/p4tools/common/lib/model.h | 7 +- backends/p4tools/common/lib/symbolic_env.cpp | 2 - backends/p4tools/common/options.h | 1 - .../modules/testgen/lib/execution_state.cpp | 1 - .../modules/testgen/lib/final_state.cpp | 2 - ir/solver.h | 11 +- lib/flat_map.h | 314 ++++++++++++++++++ 9 files changed, 322 insertions(+), 19 deletions(-) create mode 100644 lib/flat_map.h diff --git a/backends/p4tools/BUILD.bazel b/backends/p4tools/BUILD.bazel index 383e198a1e1..33cbc4b582a 100644 --- a/backends/p4tools/BUILD.bazel +++ b/backends/p4tools/BUILD.bazel @@ -181,7 +181,6 @@ cc_binary( deps = [ ":testgen_lib", "//:lib", - "@boost//:filesystem", "@boost//:multiprecision", ], ) diff --git a/backends/p4tools/common/lib/model.cpp b/backends/p4tools/common/lib/model.cpp index 382be531640..f5cacf24d31 100644 --- a/backends/p4tools/common/lib/model.cpp +++ b/backends/p4tools/common/lib/model.cpp @@ -4,8 +4,6 @@ #include #include -#include - #include "frontends/p4/optimizeExpressions.h" #include "ir/indexed_vector.h" #include "ir/irutils.h" diff --git a/backends/p4tools/common/lib/model.h b/backends/p4tools/common/lib/model.h index ffc63aef3ea..b6fa7c0da0d 100644 --- a/backends/p4tools/common/lib/model.h +++ b/backends/p4tools/common/lib/model.h @@ -5,16 +5,17 @@ #include #include -#include - #include "ir/ir.h" #include "ir/solver.h" #include "ir/visitor.h" +#include "lib/map.h" namespace P4Tools { /// Symbolic maps map a state variable to a IR::Expression. -using SymbolicMapType = boost::container::flat_map; +using SymbolicMapType = + P4C::flat_map, + std::vector>>; /// Represents a solution found by the solver. A model is a concretized form of a symbolic /// environment. All the expressions in a Model must be of type IR::Literal. diff --git a/backends/p4tools/common/lib/symbolic_env.cpp b/backends/p4tools/common/lib/symbolic_env.cpp index e66ca624fc6..26bf02f063c 100644 --- a/backends/p4tools/common/lib/symbolic_env.cpp +++ b/backends/p4tools/common/lib/symbolic_env.cpp @@ -3,8 +3,6 @@ #include #include -#include - #include "backends/p4tools/common/lib/model.h" #include "ir/indexed_vector.h" #include "ir/vector.h" diff --git a/backends/p4tools/common/options.h b/backends/p4tools/common/options.h index a1b145a7433..66664c7ee22 100644 --- a/backends/p4tools/common/options.h +++ b/backends/p4tools/common/options.h @@ -1,7 +1,6 @@ #ifndef BACKENDS_P4TOOLS_COMMON_OPTIONS_H_ #define BACKENDS_P4TOOLS_COMMON_OPTIONS_H_ -// Boost #include #include #include diff --git a/backends/p4tools/modules/testgen/lib/execution_state.cpp b/backends/p4tools/modules/testgen/lib/execution_state.cpp index 6fb6c5db013..925525d3bfe 100644 --- a/backends/p4tools/modules/testgen/lib/execution_state.cpp +++ b/backends/p4tools/modules/testgen/lib/execution_state.cpp @@ -11,7 +11,6 @@ #include #include -#include #include #include "backends/p4tools/common/compiler/convert_hs_index.h" diff --git a/backends/p4tools/modules/testgen/lib/final_state.cpp b/backends/p4tools/modules/testgen/lib/final_state.cpp index a4eb7436ee1..c21069db95a 100644 --- a/backends/p4tools/modules/testgen/lib/final_state.cpp +++ b/backends/p4tools/modules/testgen/lib/final_state.cpp @@ -5,8 +5,6 @@ #include #include -#include - #include "backends/p4tools/common/lib/model.h" #include "backends/p4tools/common/lib/symbolic_env.h" #include "backends/p4tools/common/lib/trace_event.h" diff --git a/ir/solver.h b/ir/solver.h index 459be62fa28..6136104b399 100644 --- a/ir/solver.h +++ b/ir/solver.h @@ -4,12 +4,10 @@ #include #include -#include -#include - #include "ir/ir.h" #include "lib/castable.h" #include "lib/cstring.h" +#include "lib/flat_map.h" /// Represents a constraint that can be shipped to and asserted within a solver. // TODO: This should implement AbstractRepCheckedNode. @@ -23,10 +21,9 @@ struct SymbolicVarComp { }; /// This type maps symbolic variables to their value assigned by the solver. -using SymbolicMapping = boost::container::flat_map; - -using SymbolicSet = boost::container::flat_set; +using SymbolicMapping = + P4C::flat_map>>; /// Provides a higher-level interface for an SMT solver. class AbstractSolver : public ICastable { diff --git a/lib/flat_map.h b/lib/flat_map.h new file mode 100644 index 00000000000..9d0606511ab --- /dev/null +++ b/lib/flat_map.h @@ -0,0 +1,314 @@ +/* +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef LIB_FLAT_MAP_H_ +#define LIB_FLAT_MAP_H_ + +#include +#include +#include + +namespace P4C { + +/// A header-only implementation of a memory-efficient flat_map. +/// TODO: Replace this map with std::flat_map once available in C++23: +/// https://en.cppreference.com/w/cpp/container/flat_map +template , + typename Container = std::vector>> +struct flat_map { + using key_type = K; + using mapped_type = V; + using value_type = typename Container::value_type; + using key_compare = Compare; + + struct value_compare { + bool operator()(const value_type &lhs, const value_type &rhs) const { + return key_compare()(lhs.first, rhs.first); + } + }; + + using allocator_type = typename Container::allocator_type; + using reference = typename Container::reference; + using const_reference = typename Container::const_reference; + using pointer = typename Container::pointer; + using const_pointer = typename Container::const_pointer; + using iterator = typename Container::iterator; + using const_iterator = typename Container::const_iterator; + using reverse_iterator = typename Container::reverse_iterator; + using const_reverse_iterator = typename Container::const_reverse_iterator; + using difference_type = typename Container::difference_type; + using size_type = typename Container::size_type; + + flat_map() = default; + + template + flat_map(It begin, It end) { + insert(begin, end); + } + + flat_map(std::initializer_list il) : flat_map(il.begin(), il.end()) {} + + iterator begin() { return data_.begin(); } + iterator end() { return data_.end(); } + const_iterator begin() const { return data_.begin(); } + const_iterator end() const { return data_.end(); } + const_iterator cbegin() const { return data_.cbegin(); } + const_iterator cend() const { return data_.cend(); } + reverse_iterator rbegin() { return data_.rbegin(); } + reverse_iterator rend() { return data_.rend(); } + const_reverse_iterator rbegin() const { return data_.rbegin(); } + const_reverse_iterator rend() const { return data_.rend(); } + const_reverse_iterator crbegin() const { return data_.crbegin(); } + const_reverse_iterator crend() const { return data_.crend(); } + + bool empty() const { return data_.empty(); } + size_type size() const { return data_.size(); } + size_type max_size() const { return data_.max_size(); } + size_type capacity() const { return data_.capacity(); } + void reserve(size_type size) { data_.reserve(size); } + void shrink_to_fit() { data_.shrink_to_fit(); } + size_type bytes_used() const { return capacity() * sizeof(value_type) + sizeof(data_); } + + mapped_type &operator[](const key_type &key) { + KeyOrValueCompare comp; + auto lower = lower_bound(key); + if (lower == end() || comp(key, *lower)) + return data_.emplace(lower, key, mapped_type())->second; + + return lower->second; + } + + mapped_type &operator[](key_type &&key) { + KeyOrValueCompare comp; + auto lower = lower_bound(key); + if (lower == end() || comp(key, *lower)) + return data_.emplace(lower, std::move(key), mapped_type())->second; + + return lower->second; + } + + std::pair insert(value_type &&value) { return emplace(std::move(value)); } + + std::pair insert(const value_type &value) { return emplace(value); } + + iterator insert(const_iterator hint, value_type &&value) { + return emplace_hint(hint, std::move(value)); + } + + iterator insert(const_iterator hint, const value_type &value) { + return emplace_hint(hint, value); + } + + template + void insert(It begin, It end) { + // If we need to increase the capacity, utilize this fact and emplace + // the stuff. + for (; begin != end && size() == capacity(); ++begin) { + emplace(*begin); + } + if (begin == end) return; + + // If we don't need to increase capacity, then we can use a more efficient + // insert method where everything is just put in the same vector + // and then merge in place. + size_type size_before = data_.size(); + try { + for (size_t i = capacity(); i > size_before && begin != end; --i, ++begin) { + data_.emplace_back(*begin); + } + } catch (...) { + // If emplace_back throws an exception, the easiest way to make sure + // that our invariants are still in place is to resize to the state + // we were in before + for (size_t i = data_.size(); i > size_before; --i) { + data_.pop_back(); + } + throw; + } + + value_compare comp; + auto mid = data_.begin() + size_before; + std::stable_sort(mid, data_.end(), comp); + std::inplace_merge(data_.begin(), mid, data_.end(), comp); + data_.erase(std::unique(data_.begin(), data_.end(), std::not_fn(comp)), data_.end()); + + // Make sure that we inserted at least one element before + // recursing. Otherwise we'd recurse too often if we were to insert the + // same element many times + if (data_.size() == size_before) { + for (; begin != end; ++begin) { + if (emplace(*begin).second) { + ++begin; + break; + } + } + } + + // Insert the remaining elements that didn't fit by calling this function recursively. + return insert(begin, end); + } + + void insert(std::initializer_list il) { insert(il.begin(), il.end()); } + + iterator erase(iterator it) { return data_.erase(it); } + + iterator erase(const_iterator it) { return erase(iterator_const_cast(it)); } + + size_type erase(const key_type &key) { + auto found = find(key); + if (found == end()) return 0; + erase(found); + return 1; + } + + iterator erase(const_iterator first, const_iterator last) { + return data_.erase(iterator_const_cast(first), iterator_const_cast(last)); + } + + void swap(flat_map &other) { data_.swap(other.data_); } + + void clear() { data_.clear(); } + + template + std::pair emplace(First &&first, Args &&...args) { + KeyOrValueCompare comp; + auto lower_bound = std::lower_bound(data_.begin(), data_.end(), first, comp); + if (lower_bound == data_.end() || comp(first, *lower_bound)) + return { + data_.emplace(lower_bound, std::forward(first), std::forward(args)...), + true}; + + return {lower_bound, false}; + } + + std::pair emplace() { return emplace(value_type()); } + + template + iterator emplace_hint(const_iterator hint, First &&first, Args &&...args) { + KeyOrValueCompare comp; + if (hint == cend() || comp(first, *hint)) { + if (hint == cbegin() || comp(*(hint - 1), first)) + return data_.emplace(iterator_const_cast(hint), std::forward(first), + std::forward(args)...); + + return emplace(std::forward(first), std::forward(args)...).first; + } else if (!comp(*hint, first)) { + return begin() + (hint - cbegin()); + } + + return emplace(std::forward(first), std::forward(args)...).first; + } + + iterator emplace_hint(const_iterator hint) { return emplace_hint(hint, value_type()); } + + key_compare key_comp() const { return key_compare(); } + value_compare value_comp() const { return value_compare(); } + + template + iterator find(const T &key) { + return binary_find(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator find(const T &key) const { + return binary_find(begin(), end(), key, KeyOrValueCompare()); + } + template + size_type count(const T &key) const { + return std::binary_search(begin(), end(), key, KeyOrValueCompare()) ? 1 : 0; + } + template + iterator lower_bound(const T &key) { + return std::lower_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator lower_bound(const T &key) const { + return std::lower_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + iterator upper_bound(const T &key) { + return std::upper_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + const_iterator upper_bound(const T &key) const { + return std::upper_bound(begin(), end(), key, KeyOrValueCompare()); + } + template + std::pair equal_range(const T &key) { + return std::equal_range(begin(), end(), key, KeyOrValueCompare()); + } + template + std::pair equal_range(const T &key) const { + return std::equal_range(begin(), end(), key, KeyOrValueCompare()); + } + allocator_type get_allocator() const { return data_.get_allocator(); } + + bool operator==(const flat_map &other) const { return data_ == other.data_; } + bool operator!=(const flat_map &other) const { return !(*this == other); } + bool operator<(const flat_map &other) const { return data_ < other.data_; } + bool operator>(const flat_map &other) const { return other < *this; } + bool operator<=(const flat_map &other) const { return !(other < *this); } + bool operator>=(const flat_map &other) const { return !(*this < other); } + + private: + Container data_; + + iterator iterator_const_cast(const_iterator it) { return begin() + (it - cbegin()); } + + struct KeyOrValueCompare { + bool operator()(const key_type &lhs, const key_type &rhs) const { + return key_compare()(lhs, rhs); + } + bool operator()(const key_type &lhs, const value_type &rhs) const { + return key_compare()(lhs, rhs.first); + } + template + bool operator()(const key_type &lhs, const T &rhs) const { + return key_compare()(lhs, rhs); + } + template + bool operator()(const T &lhs, const key_type &rhs) const { + return key_compare()(lhs, rhs); + } + bool operator()(const value_type &lhs, const key_type &rhs) const { + return key_compare()(lhs.first, rhs); + } + bool operator()(const value_type &lhs, const value_type &rhs) const { + return key_compare()(lhs.first, rhs.first); + } + template + bool operator()(const value_type &lhs, const T &rhs) const { + return key_compare()(lhs.first, rhs); + } + template + bool operator()(const T &lhs, const value_type &rhs) const { + return key_compare()(lhs, rhs.first); + } + }; + + template + static It binary_find(It begin, It end, const T &value, const Comp &cmp) { + auto lower_bound = std::lower_bound(begin, end, value, cmp); + if (lower_bound == end || cmp(value, *lower_bound)) return end; + + return lower_bound; + } +}; + +template +void swap(flat_map &lhs, flat_map &rhs) { + lhs.swap(rhs); +} + +} // namespace P4C + +#endif // LIB_FLAT_MAP_H_