diff --git a/bazel/foreign_cc/0004-local-rate-limit-bucket-backport.patch b/bazel/foreign_cc/0004-local-rate-limit-bucket-backport.patch new file mode 100755 index 00000000..9e8e66d0 --- /dev/null +++ b/bazel/foreign_cc/0004-local-rate-limit-bucket-backport.patch @@ -0,0 +1,3746 @@ +diff --git a/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto b/api/envoy/extensions/common/ratelimit/v3/ratelimit.proto +index 9b5d9a7b91..73d729adc2 100644 +--- api/envoy/extensions/common/ratelimit/v3/ratelimit.proto ++++ api/envoy/extensions/common/ratelimit/v3/ratelimit.proto +@@ -130,3 +130,15 @@ message LocalRateLimitDescriptor { + // Token Bucket algorithm for local ratelimiting. + type.v3.TokenBucket token_bucket = 2 [(validate.rules).message = {required: true}]; + } ++ ++// Configuration used to enable local cluster level rate limiting where the token buckets ++// will be shared across all the Envoy instances in the local cluster. ++// A share will be calculated based on the membership of the local cluster dynamically ++// and the configuration. When the limiter refilling the token bucket, the share will be ++// applied. By default, the token bucket will be shared evenly. ++// ++// See :ref:`local cluster name ++// ` for more context ++// about local cluster. ++message LocalClusterRateLimit { ++} +diff --git a/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto b/api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +index c253d04973..a32475f352 100644 +--- api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto ++++ api/envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.proto +@@ -22,7 +22,7 @@ option (udpa.annotations.file_status).package_version_status = ACTIVE; + // Local Rate limit :ref:`configuration overview `. + // [#extension: envoy.filters.http.local_ratelimit] + +-// [#next-free-field: 16] ++// [#next-free-field: 17] + message LocalRateLimit { + // The human readable prefix to use when emitting stats. + string stat_prefix = 1 [(validate.rules).string = {min_len: 1}]; +@@ -110,6 +110,23 @@ message LocalRateLimit { + // If unspecified, the default value is false. + bool local_rate_limit_per_downstream_connection = 11; + ++ // Enables the local cluster level rate limiting, iff this is set explicitly. For example, ++ // given an Envoy gateway that contains N Envoy instances and a rate limit rule X tokens ++ // per second. If this is set, the total rate limit of whole gateway will always be X tokens ++ // per second regardless of how N changes. If this is not set, the total rate limit of whole ++ // gateway will be N * X tokens per second. ++ // ++ // .. note:: ++ // This should never be set if the ``local_rate_limit_per_downstream_connection`` is set to ++ // true. Because if per connection rate limiting is enabled, we assume that the token buckets ++ // should never be shared across Envoy instances. ++ // ++ // .. note:: ++ // This only works when the :ref:`local cluster name ++ // ` is set and ++ // the related cluster is defined in the bootstrap configuration. ++ common.ratelimit.v3.LocalClusterRateLimit local_cluster_rate_limit = 16; ++ + // Defines the standard version to use for X-RateLimit headers emitted by the filter. + // + // Disabled by default. +diff --git a/envoy/ratelimit/ratelimit.h b/envoy/ratelimit/ratelimit.h +index 2839cfa86c..5854372136 100644 +--- envoy/ratelimit/ratelimit.h ++++ envoy/ratelimit/ratelimit.h +@@ -49,23 +49,38 @@ struct Descriptor { + absl::optional limit_ = absl::nullopt; + }; + +-/** +- * A single token bucket. See token_bucket.proto. +- */ +-struct TokenBucket { +- uint32_t max_tokens_; +- uint32_t tokens_per_fill_; +- absl::Duration fill_interval_; +-}; +- + /** + * A single rate limit request descriptor. See ratelimit.proto. + */ + struct LocalDescriptor { + std::vector entries_; +- friend bool operator==(const LocalDescriptor& lhs, const LocalDescriptor& rhs) { +- return lhs.entries_ == rhs.entries_; ++ ++ friend bool operator==(const LocalDescriptor& a, const LocalDescriptor& b) { ++ return a.entries_ == b.entries_; + } ++ struct Hash { ++ using is_transparent = void; // NOLINT(readability-identifier-naming) ++ size_t operator()(const LocalDescriptor& d) const { ++ return absl::Hash>()(d.entries_); ++ } ++ }; ++ struct Equal { ++ using is_transparent = void; // NOLINT(readability-identifier-naming) ++ size_t operator()(const LocalDescriptor& a, const LocalDescriptor& b) const { ++ return a.entries_ == b.entries_; ++ } ++ }; ++ ++ std::string toString() const { ++ return absl::StrJoin(entries_, ", ", [](std::string* out, const auto& e) { ++ absl::StrAppend(out, e.key_, "=", e.value_); ++ }); ++ } ++ ++ /** ++ * Local descriptor map. ++ */ ++ template using Map = absl::flat_hash_map; + }; + + /* +diff --git a/source/common/common/token_bucket_impl.cc b/source/common/common/token_bucket_impl.cc +index f813d426d0..b3d4f10e78 100644 +--- source/common/common/token_bucket_impl.cc ++++ source/common/common/token_bucket_impl.cc +@@ -1,5 +1,6 @@ + #include "source/common/common/token_bucket_impl.h" + ++#include + #include + + namespace Envoy { +@@ -7,6 +8,7 @@ namespace Envoy { + namespace { + // The minimal fill rate will be one second every year. + constexpr double kMinFillRate = 1.0 / (365 * 24 * 60 * 60); ++ + } // namespace + + TokenBucketImpl::TokenBucketImpl(uint64_t max_tokens, TimeSource& time_source, double fill_rate) +@@ -56,4 +58,44 @@ void TokenBucketImpl::maybeReset(uint64_t num_tokens) { + last_fill_ = time_source_.monotonicTime(); + } + ++AtomicTokenBucketImpl::AtomicTokenBucketImpl(uint64_t max_tokens, TimeSource& time_source, ++ double fill_rate, bool init_fill) ++ : max_tokens_(max_tokens), fill_rate_(std::max(std::abs(fill_rate), kMinFillRate)), ++ time_source_(time_source) { ++ auto time_in_seconds = timeNowInSeconds(); ++ if (init_fill) { ++ time_in_seconds -= max_tokens_ / fill_rate_; ++ } ++ time_in_seconds_.store(time_in_seconds, std::memory_order_relaxed); ++} ++ ++bool AtomicTokenBucketImpl::consume() { ++ constexpr auto consumed_cb = [](double total_tokens) -> double { ++ return total_tokens >= 1 ? 1 : 0; ++ }; ++ return consume(consumed_cb) == 1; ++} ++ ++uint64_t AtomicTokenBucketImpl::consume(uint64_t tokens, bool allow_partial) { ++ const auto consumed_cb = [tokens, allow_partial](double total_tokens) { ++ const auto consumed = static_cast(tokens); ++ if (total_tokens >= consumed) { ++ return consumed; // There are enough tokens to consume. ++ } ++ // If allow_partial is true, consume all available tokens. ++ return allow_partial ? std::max(0, std::floor(total_tokens)) : 0; ++ }; ++ return static_cast(consume(consumed_cb)); ++} ++ ++double AtomicTokenBucketImpl::remainingTokens() const { ++ const double time_now = timeNowInSeconds(); ++ const double time_old = time_in_seconds_.load(std::memory_order_relaxed); ++ return std::min(max_tokens_, (time_now - time_old) * fill_rate_); ++} ++ ++double AtomicTokenBucketImpl::timeNowInSeconds() const { ++ return std::chrono::duration(time_source_.monotonicTime().time_since_epoch()).count(); ++} ++ + } // namespace Envoy +diff --git a/source/common/common/token_bucket_impl.h b/source/common/common/token_bucket_impl.h +index 96ac238e37..673f0a74e2 100644 +--- source/common/common/token_bucket_impl.h ++++ source/common/common/token_bucket_impl.h +@@ -35,4 +35,92 @@ private: + TimeSource& time_source_; + }; + ++/** ++ * Atomic token bucket. This class is thread-safe. ++ */ ++class AtomicTokenBucketImpl { ++public: ++ /** ++ * @param max_tokens supplies the maximum number of tokens in the bucket. ++ * @param time_source supplies the time source. ++ * @param fill_rate supplies the number of tokens that will return to the bucket on each second. ++ * The default is 1. ++ * @param init_fill supplies whether the bucket should be initialized with max_tokens. ++ */ ++ explicit AtomicTokenBucketImpl(uint64_t max_tokens, TimeSource& time_source, ++ double fill_rate = 1.0, bool init_fill = true); ++ ++ // This reference https://github.com/facebook/folly/blob/main/folly/TokenBucket.h. ++ template double consume(const GetConsumedTokens& cb) { ++ const double time_now = timeNowInSeconds(); ++ ++ double time_old = time_in_seconds_.load(std::memory_order_relaxed); ++ double time_new{}; ++ double consumed{}; ++ do { ++ const double total_tokens = std::min(max_tokens_, (time_now - time_old) * fill_rate_); ++ if (consumed = cb(total_tokens); consumed == 0) { ++ return 0; ++ } ++ ++ // There are two special cases that should rarely happen in practice but we will not ++ // prevent them in this common template method: ++ // The consumed is negative. It means the token is added back to the bucket. ++ // The consumed is larger than total_tokens. It means the bucket is overflowed and future ++ // tokens are consumed. ++ ++ // Move the time_in_seconds_ forward by the number of tokens consumed. ++ const double total_tokens_new = total_tokens - consumed; ++ time_new = time_now - (total_tokens_new / fill_rate_); ++ } while ( ++ !time_in_seconds_.compare_exchange_weak(time_old, time_new, std::memory_order_relaxed)); ++ ++ return consumed; ++ } ++ ++ /** ++ * Consumes one tokens from the bucket. ++ * @return true if the token is consumed, false otherwise. ++ */ ++ bool consume(); ++ ++ /** ++ * Consumes multiple tokens from the bucket. ++ * @param tokens the number of tokens to consume. ++ * @param allow_partial whether to allow partial consumption. ++ * @return the number of tokens consumed. ++ */ ++ uint64_t consume(uint64_t tokens, bool allow_partial); ++ ++ /** ++ * Get the maximum number of tokens in the bucket. The actual maximum number of tokens in the ++ * bucket may be changed with the factor. ++ * @return the maximum number of tokens in the bucket. ++ */ ++ double maxTokens() const { return max_tokens_; } ++ ++ /** ++ * Get the fill rate of the bucket. This is a constant for the lifetime of the bucket. But note ++ * the actual used fill rate will multiply the dynamic factor. ++ * @return the fill rate of the bucket. ++ */ ++ double fillRate() const { return fill_rate_; } ++ ++ /** ++ * Get the remaining number of tokens in the bucket. This is a snapshot and may change after the ++ * call. ++ * @return the remaining number of tokens in the bucket. ++ */ ++ double remainingTokens() const; ++ ++private: ++ double timeNowInSeconds() const; ++ ++ const double max_tokens_; ++ const double fill_rate_; ++ ++ std::atomic time_in_seconds_{}; ++ TimeSource& time_source_; ++}; ++ + } // namespace Envoy +diff --git a/source/common/runtime/runtime_features.cc b/source/common/runtime/runtime_features.cc +index 260a08820c..f22d9cf58b 100644 +--- source/common/runtime/runtime_features.cc ++++ source/common/runtime/runtime_features.cc +@@ -73,6 +73,7 @@ RUNTIME_GUARD(envoy_reloadable_features_lowercase_scheme); + RUNTIME_GUARD(envoy_reloadable_features_no_downgrade_to_canonical_name); + RUNTIME_GUARD(envoy_reloadable_features_no_extension_lookup_by_name); + RUNTIME_GUARD(envoy_reloadable_features_no_full_scan_certs_on_sni_mismatch); ++// RUNTIME_GUARD(envoy_reloadable_features_no_timer_based_rate_limit_token_bucket); + RUNTIME_GUARD(envoy_reloadable_features_normalize_host_for_preresolve_dfp_dns); + RUNTIME_GUARD(envoy_reloadable_features_oauth_make_token_cookie_httponly); + RUNTIME_GUARD(envoy_reloadable_features_oauth_use_standard_max_age_value); +@@ -113,6 +114,9 @@ RUNTIME_GUARD(envoy_restart_features_use_fast_protobuf_hash); + + // Begin false flags. Most of them should come with a TODO to flip true. + ++//backporting this so TODO: Remove all this forked code before adding to 1.31 ++FALSE_RUNTIME_GUARD(envoy_reloadable_features_no_timer_based_rate_limit_token_bucket); ++ + // TODO(birenroy) Flip this to true after resolving issues. + // Ignore the automated "remove this flag" issue: we should keep this for 1 year. + FALSE_RUNTIME_GUARD(envoy_reloadable_features_http2_use_oghttp2); +diff --git a/source/common/singleton/manager_impl.cc b/source/common/singleton/manager_impl.cc +index a71fd7f3a9..925b1c6050 100644 +--- source/common/singleton/manager_impl.cc ++++ source/common/singleton/manager_impl.cc +@@ -9,7 +9,7 @@ namespace Envoy { + namespace Singleton { + + InstanceSharedPtr ManagerImpl::get(const std::string& name, SingletonFactoryCb cb, bool pin) { +- ASSERT(run_tid_ == thread_factory_.currentThreadId()); ++ ASSERT_IS_MAIN_OR_TEST_THREAD(); + + ENVOY_BUG(Registry::FactoryRegistry::getFactory(name) != nullptr, + "invalid singleton name '" + name + "'. Make sure it is registered."); +diff --git a/source/common/singleton/manager_impl.h b/source/common/singleton/manager_impl.h +index cc715f55c1..a541f90695 100644 +--- source/common/singleton/manager_impl.h ++++ source/common/singleton/manager_impl.h +@@ -17,8 +17,7 @@ namespace Singleton { + */ + class ManagerImpl : public Manager, NonCopyable { + public: +- explicit ManagerImpl(Thread::ThreadFactory& thread_factory) +- : thread_factory_(thread_factory), run_tid_(thread_factory.currentThreadId()) {} ++ ManagerImpl() = default; + + // Singleton::Manager + InstanceSharedPtr get(const std::string& name, SingletonFactoryCb cb, bool pin) override; +@@ -26,9 +25,6 @@ public: + private: + absl::node_hash_map> singletons_; + std::vector pinned_singletons_; +- +- Thread::ThreadFactory& thread_factory_; +- const Thread::ThreadId run_tid_; + }; + + } // namespace Singleton +diff --git a/source/extensions/filters/common/local_ratelimit/BUILD b/source/extensions/filters/common/local_ratelimit/BUILD +index 2e6af5b6da..5cea645b04 100644 +--- source/extensions/filters/common/local_ratelimit/BUILD ++++ source/extensions/filters/common/local_ratelimit/BUILD +@@ -17,6 +17,7 @@ envoy_cc_library( + "//envoy/event:timer_interface", + "//envoy/ratelimit:ratelimit_interface", + "//source/common/common:thread_synchronizer_lib", ++ "//source/common/common:token_bucket_impl_lib", + "//source/common/protobuf:utility_lib", + "@envoy_api//envoy/extensions/common/ratelimit/v3:pkg_cc_proto", + ], +diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc +index 603a61eca0..dad17d9961 100644 +--- source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc ++++ source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.cc +@@ -1,6 +1,7 @@ + #include "source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" + + #include ++#include + + #include "envoy/runtime/runtime.h" + +@@ -13,75 +14,226 @@ namespace Filters { + namespace Common { + namespace LocalRateLimit { + ++SINGLETON_MANAGER_REGISTRATION(local_ratelimit_share_provider_manager); ++ ++class DefaultEvenShareMonitor : public ShareProviderManager::ShareMonitor { ++public: ++ double getTokensShareFactor() const override { return share_factor_.load(); } ++ double onLocalClusterUpdate(const Upstream::Cluster& cluster) override { ++ ASSERT_IS_MAIN_OR_TEST_THREAD(); ++ const auto num = cluster.info()->endpointStats().membership_total_.value(); ++ const double new_share_factor = num == 0 ? 1.0 : 1.0 / num; ++ share_factor_.store(new_share_factor); ++ return new_share_factor; ++ } ++ ++private: ++ std::atomic share_factor_{1.0}; ++}; ++ ++ShareProviderManager::ShareProviderManager(Event::Dispatcher& main_dispatcher, ++ const Upstream::Cluster& cluster) ++ : main_dispatcher_(main_dispatcher), cluster_(cluster) { ++ // It's safe to capture the local cluster reference here because the local cluster is ++ // guaranteed to be static cluster and should never be removed. ++ handle_ = cluster_.prioritySet().addMemberUpdateCb([this](const auto&, const auto&) { ++ share_monitor_->onLocalClusterUpdate(cluster_); ++ return absl::OkStatus(); ++ }); ++ share_monitor_ = std::make_shared(); ++ share_monitor_->onLocalClusterUpdate(cluster_); ++} ++ ++ShareProviderManager::~ShareProviderManager() { ++ // Ensure the callback is unregistered on the main dispatcher thread. ++ main_dispatcher_.post([h = std::move(handle_)]() {}); ++} ++ ++ShareProviderSharedPtr ++ShareProviderManager::getShareProvider(const ProtoLocalClusterRateLimit&) const { ++ // TODO(wbpcode): we may want to support custom share provider in the future based on the ++ // configuration. ++ return share_monitor_; ++} ++ ++ShareProviderManagerSharedPtr ShareProviderManager::singleton(Event::Dispatcher& dispatcher, ++ Upstream::ClusterManager& cm, ++ Singleton::Manager& manager) { ++ return manager.getTyped( ++ SINGLETON_MANAGER_REGISTERED_NAME(local_ratelimit_share_provider_manager), ++ [&dispatcher, &cm]() -> Singleton::InstanceSharedPtr { ++ const auto& local_cluster_name = cm.localClusterName(); ++ if (!local_cluster_name.has_value()) { ++ return nullptr; ++ } ++ auto cluster = cm.clusters().getCluster(local_cluster_name.value()); ++ if (!cluster.has_value()) { ++ return nullptr; ++ } ++ return ShareProviderManagerSharedPtr{ ++ new ShareProviderManager(dispatcher, cluster.value().get())}; ++ }); ++} ++ ++TimerTokenBucket::TimerTokenBucket(uint32_t max_tokens, uint32_t tokens_per_fill, ++ std::chrono::milliseconds fill_interval, uint64_t multiplier, ++ LocalRateLimiterImpl& parent) ++ : multiplier_(multiplier), parent_(parent), max_tokens_(max_tokens), ++ tokens_per_fill_(tokens_per_fill), fill_interval_(fill_interval), ++ // Calculate the fill rate in tokens per second. ++ fill_rate_(tokens_per_fill / ++ std::chrono::duration_cast>(fill_interval).count()) { ++ tokens_ = max_tokens; ++ fill_time_ = parent_.time_source_.monotonicTime(); ++} ++ ++absl::optional TimerTokenBucket::remainingFillInterval() const { ++ using namespace std::literals; ++ ++ const auto time_after_last_fill = std::chrono::duration_cast( ++ parent_.time_source_.monotonicTime() - fill_time_.load()); ++ ++ // Note that the fill timer may be delayed because other tasks are running on the main thread. ++ // So it's possible that the time_after_last_fill is greater than fill_interval_. ++ if (time_after_last_fill >= fill_interval_) { ++ return {}; ++ } ++ ++ return absl::ToInt64Seconds(absl::FromChrono(fill_interval_) - ++ absl::Seconds((time_after_last_fill) / 1s)); ++} ++ ++bool TimerTokenBucket::consume(double) { ++ // Relaxed consistency is used for all operations because we don't care about ordering, just the ++ // final atomic correctness. ++ uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); ++ do { ++ // expected_tokens is either initialized above or reloaded during the CAS failure below. ++ if (expected_tokens == 0) { ++ return false; ++ } ++ ++ // Testing hook. ++ parent_.synchronizer_.syncPoint("allowed_pre_cas"); ++ ++ // Loop while the weak CAS fails trying to subtract 1 from expected. ++ } while (!tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1, ++ std::memory_order_relaxed)); ++ ++ // We successfully decremented the counter by 1. ++ return true; ++} ++ ++void TimerTokenBucket::onFillTimer(uint64_t refill_counter, double factor) { ++ // Descriptors are refilled every Nth timer hit where N is the ratio of the ++ // descriptor refill interval over the global refill interval. For example, ++ // if the descriptor refill interval is 150ms and the global refill ++ // interval is 50ms, this descriptor is refilled every 3rd call. ++ if (refill_counter % multiplier_ != 0) { ++ return; ++ } ++ ++ const uint32_t tokens_per_fill = std::ceil(tokens_per_fill_ * factor); ++ ++ // Relaxed consistency is used for all operations because we don't care about ordering, just the ++ // final atomic correctness. ++ uint32_t expected_tokens = tokens_.load(std::memory_order_relaxed); ++ uint32_t new_tokens_value{}; ++ do { ++ // expected_tokens is either initialized above or reloaded during the CAS failure below. ++ new_tokens_value = std::min(max_tokens_, expected_tokens + tokens_per_fill); ++ ++ // Testing hook. ++ parent_.synchronizer_.syncPoint("on_fill_timer_pre_cas"); ++ ++ // Loop while the weak CAS fails trying to update the tokens value. ++ } while ( ++ !tokens_.compare_exchange_weak(expected_tokens, new_tokens_value, std::memory_order_relaxed)); ++ ++ // Update fill time at last. ++ fill_time_ = parent_.time_source_.monotonicTime(); ++} ++ ++AtomicTokenBucket::AtomicTokenBucket(uint32_t max_tokens, uint32_t tokens_per_fill, ++ std::chrono::milliseconds fill_interval, ++ TimeSource& time_source) ++ : token_bucket_(max_tokens, time_source, ++ // Calculate the fill rate in tokens per second. ++ tokens_per_fill / std::chrono::duration(fill_interval).count()) {} ++ ++bool AtomicTokenBucket::consume(double factor) { ++ ASSERT(!(factor <= 0.0 || factor > 1.0)); ++ auto cb = [tokens = 1.0 / factor](double total) { return total < tokens ? 0.0 : tokens; }; ++ return token_bucket_.consume(cb) != 0.0; ++} ++ + LocalRateLimiterImpl::LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors, +- bool always_consume_default_token_bucket) ++ bool always_consume_default_token_bucket, ShareProviderSharedPtr shared_provider) + : fill_timer_(fill_interval > std::chrono::milliseconds(0) + ? dispatcher.createTimer([this] { onFillTimer(); }) + : nullptr), +- time_source_(dispatcher.timeSource()), +- always_consume_default_token_bucket_(always_consume_default_token_bucket) { ++ time_source_(dispatcher.timeSource()), share_provider_(std::move(shared_provider)), ++ always_consume_default_token_bucket_(always_consume_default_token_bucket), ++ no_timer_based_rate_limit_token_bucket_(Runtime::runtimeFeatureEnabled( ++ "envoy.reloadable_features.no_timer_based_rate_limit_token_bucket")) { + if (fill_timer_ && fill_interval < std::chrono::milliseconds(50)) { + throw EnvoyException("local rate limit token bucket fill timer must be >= 50ms"); + } + +- token_bucket_.max_tokens_ = max_tokens; +- token_bucket_.tokens_per_fill_ = tokens_per_fill; +- token_bucket_.fill_interval_ = absl::FromChrono(fill_interval); +- tokens_.tokens_ = max_tokens; +- tokens_.fill_time_ = time_source_.monotonicTime(); ++ if (no_timer_based_rate_limit_token_bucket_) { ++ default_token_bucket_ = std::make_shared(max_tokens, tokens_per_fill, ++ fill_interval, time_source_); ++ } else { ++ default_token_bucket_ = ++ std::make_shared(max_tokens, tokens_per_fill, fill_interval, 1, *this); ++ } + +- if (fill_timer_) { +- fill_timer_->enableTimer(fill_interval); ++ if (fill_timer_ && default_token_bucket_->fillInterval().count() > 0 && ++ !no_timer_based_rate_limit_token_bucket_) { ++ fill_timer_->enableTimer(default_token_bucket_->fillInterval()); + } + + for (const auto& descriptor : descriptors) { +- LocalDescriptorImpl new_descriptor; ++ RateLimit::LocalDescriptor new_descriptor; ++ new_descriptor.entries_.reserve(descriptor.entries_size()); + for (const auto& entry : descriptor.entries()) { + new_descriptor.entries_.push_back({entry.key(), entry.value()}); + } +- RateLimit::TokenBucket per_descriptor_token_bucket; +- per_descriptor_token_bucket.fill_interval_ = +- absl::Milliseconds(PROTOBUF_GET_MS_OR_DEFAULT(descriptor.token_bucket(), fill_interval, 0)); +- if (per_descriptor_token_bucket.fill_interval_ % token_bucket_.fill_interval_ != +- absl::ZeroDuration()) { ++ ++ const auto per_descriptor_max_tokens = descriptor.token_bucket().max_tokens(); ++ const auto per_descriptor_tokens_per_fill = ++ PROTOBUF_GET_WRAPPED_OR_DEFAULT(descriptor.token_bucket(), tokens_per_fill, 1); ++ const auto per_descriptor_fill_interval = std::chrono::milliseconds( ++ PROTOBUF_GET_MS_OR_DEFAULT(descriptor.token_bucket(), fill_interval, 0)); ++ ++ if (per_descriptor_fill_interval.count() % fill_interval.count() != 0) { + throw EnvoyException( + "local rate descriptor limit is not a multiple of token bucket fill timer"); + } + // Save the multiplicative factor to control the descriptor refill frequency. +- new_descriptor.multiplier_ = +- per_descriptor_token_bucket.fill_interval_ / token_bucket_.fill_interval_; +- per_descriptor_token_bucket.max_tokens_ = descriptor.token_bucket().max_tokens(); +- per_descriptor_token_bucket.tokens_per_fill_ = +- PROTOBUF_GET_WRAPPED_OR_DEFAULT(descriptor.token_bucket(), tokens_per_fill, 1); +- new_descriptor.token_bucket_ = per_descriptor_token_bucket; +- +- auto token_state = std::make_shared(); +- token_state->tokens_ = per_descriptor_token_bucket.max_tokens_; +- token_state->fill_time_ = time_source_.monotonicTime(); +- new_descriptor.token_state_ = token_state; ++ const auto per_descriptor_multiplier = per_descriptor_fill_interval / fill_interval; ++ ++ RateLimitTokenBucketSharedPtr per_descriptor_token_bucket; ++ if (no_timer_based_rate_limit_token_bucket_) { ++ per_descriptor_token_bucket = std::make_shared( ++ per_descriptor_max_tokens, per_descriptor_tokens_per_fill, per_descriptor_fill_interval, ++ time_source_); ++ } else { ++ per_descriptor_token_bucket = std::make_shared( ++ per_descriptor_max_tokens, per_descriptor_tokens_per_fill, per_descriptor_fill_interval, ++ per_descriptor_multiplier, *this); ++ } + +- auto result = descriptors_.emplace(new_descriptor); ++ auto result = ++ descriptors_.emplace(std::move(new_descriptor), std::move(per_descriptor_token_bucket)); + if (!result.second) { + throw EnvoyException(absl::StrCat("duplicate descriptor in the local rate descriptor: ", +- result.first->toString())); ++ result.first->first.toString())); + } +- sorted_descriptors_.push_back(new_descriptor); +- } +- // If a request is limited by a descriptor, it should not consume tokens from the remaining +- // matched descriptors, so we sort the descriptors by tokens per second, as a result, in most +- // cases the strictest descriptor will be consumed first. However, it can not solve the +- // problem perfectly. +- if (!sorted_descriptors_.empty()) { +- std::sort(sorted_descriptors_.begin(), sorted_descriptors_.end(), +- [this](LocalDescriptorImpl a, LocalDescriptorImpl b) -> bool { +- const int a_token_fill_per_second = tokensFillPerSecond(a); +- const int b_token_fill_per_second = tokensFillPerSecond(b); +- return a_token_fill_per_second < b_token_fill_per_second; +- }); + } + } + +@@ -98,153 +250,74 @@ void LocalRateLimiterImpl::onFillTimer() { + // descriptors tokens from being refilled at the first time hit, regardless of its fill + // interval configuration. + refill_counter_++; +- onFillTimerHelper(tokens_, token_bucket_); +- onFillTimerDescriptorHelper(); +- fill_timer_->enableTimer(absl::ToChronoMilliseconds(token_bucket_.fill_interval_)); +-} ++ const double share_factor = ++ share_provider_ != nullptr ? share_provider_->getTokensShareFactor() : 1.0; + +-void LocalRateLimiterImpl::onFillTimerHelper(TokenState& tokens, +- const RateLimit::TokenBucket& bucket) { +- // Relaxed consistency is used for all operations because we don't care about ordering, just the +- // final atomic correctness. +- uint32_t expected_tokens = tokens.tokens_.load(std::memory_order_relaxed); +- uint32_t new_tokens_value; +- do { +- // expected_tokens is either initialized above or reloaded during the CAS failure below. +- new_tokens_value = std::min(bucket.max_tokens_, expected_tokens + bucket.tokens_per_fill_); +- +- // Testing hook. +- synchronizer_.syncPoint("on_fill_timer_pre_cas"); +- +- // Loop while the weak CAS fails trying to update the tokens value. +- } while (!tokens.tokens_.compare_exchange_weak(expected_tokens, new_tokens_value, +- std::memory_order_relaxed)); +- +- // Update fill time at last. +- tokens.fill_time_ = time_source_.monotonicTime(); +-} +- +-void LocalRateLimiterImpl::onFillTimerDescriptorHelper() { ++ default_token_bucket_->onFillTimer(refill_counter_, share_factor); + for (const auto& descriptor : descriptors_) { +- // Descriptors are refilled every Nth timer hit where N is the ratio of the +- // descriptor refill interval over the global refill interval. For example, +- // if the descriptor refill interval is 150ms and the global refill +- // interval is 50ms, this descriptor is refilled every 3rd call. +- if (refill_counter_ % descriptor.multiplier_ == 0) { +- onFillTimerHelper(*descriptor.token_state_, descriptor.token_bucket_); +- } ++ descriptor.second->onFillTimer(refill_counter_, share_factor); + } +-} +- +-bool LocalRateLimiterImpl::requestAllowedHelper(const TokenState& tokens) const { +- // Relaxed consistency is used for all operations because we don't care about ordering, just the +- // final atomic correctness. +- uint32_t expected_tokens = tokens.tokens_.load(std::memory_order_relaxed); +- do { +- // expected_tokens is either initialized above or reloaded during the CAS failure below. +- if (expected_tokens == 0) { +- return false; +- } + +- // Testing hook. +- synchronizer_.syncPoint("allowed_pre_cas"); +- +- // Loop while the weak CAS fails trying to subtract 1 from expected. +- } while (!tokens.tokens_.compare_exchange_weak(expected_tokens, expected_tokens - 1, +- std::memory_order_relaxed)); +- +- // We successfully decremented the counter by 1. +- return true; ++ fill_timer_->enableTimer(default_token_bucket_->fillInterval()); + } + +-OptRef LocalRateLimiterImpl::descriptorHelper( ++LocalRateLimiterImpl::Result LocalRateLimiterImpl::requestAllowed( + absl::Span request_descriptors) const { +- if (!descriptors_.empty() && !request_descriptors.empty()) { +- // The override rate limit descriptor is selected by the first full match from the request +- // descriptors. +- for (const auto& request_descriptor : request_descriptors) { +- auto it = descriptors_.find(request_descriptor); +- if (it != descriptors_.end()) { +- return *it; +- } +- } +- } +- return {}; +-} + +-bool LocalRateLimiterImpl::requestAllowed( +- absl::Span request_descriptors) const { +- // Matched descriptors will be sorted by tokens per second and tokens consumed in order. +- // In most cases, if one of them is limited the remaining descriptors will not consume +- // their tokens. +- bool matched_descriptor = false; +- if (!descriptors_.empty() && !request_descriptors.empty()) { +- for (const auto& descriptor : sorted_descriptors_) { +- for (const auto& request_descriptor : request_descriptors) { +- if (descriptor == request_descriptor) { +- matched_descriptor = true; +- // Descriptor token is not enough. +- if (!requestAllowedHelper(*descriptor.token_state_)) { +- return false; +- } +- break; +- } +- } ++ // In most cases the request descriptors has only few elements. We use a inlined vector to ++ // avoid heap allocation. ++ absl::InlinedVector matched_descriptors; ++ ++ // Find all matched descriptors. ++ for (const auto& request_descriptor : request_descriptors) { ++ auto iter = descriptors_.find(request_descriptor); ++ if (iter != descriptors_.end()) { ++ matched_descriptors.push_back(iter->second.get()); + } + } + +- if (!matched_descriptor || always_consume_default_token_bucket_) { +- // Since global tokens are not sorted, it should be larger than other descriptors. +- return requestAllowedHelper(tokens_); ++ if (matched_descriptors.size() > 1) { ++ // Sort the matched descriptors by token bucket fill rate to ensure the descriptor with the ++ // smallest fill rate is consumed first. ++ std::sort(matched_descriptors.begin(), matched_descriptors.end(), ++ [](const RateLimitTokenBucket* lhs, const RateLimitTokenBucket* rhs) { ++ return lhs->fillRate() < rhs->fillRate(); ++ }); + } +- return true; +-} +- +-int LocalRateLimiterImpl::tokensFillPerSecond(LocalDescriptorImpl& descriptor) { +- return descriptor.token_bucket_.tokens_per_fill_ / +- (absl::ToInt64Seconds(descriptor.token_bucket_.fill_interval_) +- ? absl::ToInt64Seconds(descriptor.token_bucket_.fill_interval_) +- : 1); +-} + +-uint32_t LocalRateLimiterImpl::maxTokens( +- absl::Span request_descriptors) const { +- auto descriptor = descriptorHelper(request_descriptors); ++ const double share_factor = ++ share_provider_ != nullptr ? share_provider_->getTokensShareFactor() : 1.0; + +- return descriptor.has_value() ? descriptor.value().get().token_bucket_.max_tokens_ +- : token_bucket_.max_tokens_; +-} +- +-uint32_t LocalRateLimiterImpl::remainingTokens( +- absl::Span request_descriptors) const { +- auto descriptor = descriptorHelper(request_descriptors); ++ // See if the request is forbidden by any of the matched descriptors. ++ for (auto descriptor : matched_descriptors) { ++ if (!descriptor->consume(share_factor)) { ++ // If the request is forbidden by a descriptor, return the result and the descriptor ++ // token bucket. ++ return {false, makeOptRefFromPtr(descriptor)}; ++ } ++ } + +- return descriptor.has_value() +- ? descriptor.value().get().token_state_->tokens_.load(std::memory_order_relaxed) +- : tokens_.tokens_.load(std::memory_order_relaxed); +-} ++ // See if the request is forbidden by the default token bucket. ++ if (matched_descriptors.empty() || always_consume_default_token_bucket_) { ++ if (const bool result = default_token_bucket_->consume(share_factor); !result) { ++ // If the request is forbidden by the default token bucket, return the result and the ++ // default token bucket. ++ return {false, makeOptRefFromPtr(default_token_bucket_.get())}; ++ } + +-int64_t LocalRateLimiterImpl::remainingFillInterval( +- absl::Span request_descriptors) const { +- using namespace std::literals; ++ // If the request is allowed then return the result the token bucket. The descriptor ++ // token bucket will be selected as priority if it exists. ++ return {true, makeOptRefFromPtr(matched_descriptors.empty() ++ ? default_token_bucket_.get() ++ : matched_descriptors[0])}; ++ }; + +- auto current_time = time_source_.monotonicTime(); +- auto descriptor = descriptorHelper(request_descriptors); +- // Remaining time to next fill = fill interval - (current time - last fill time). +- if (descriptor.has_value()) { +- ASSERT(std::chrono::duration_cast( +- current_time - descriptor.value().get().token_state_->fill_time_) <= +- absl::ToChronoMilliseconds(descriptor.value().get().token_bucket_.fill_interval_)); +- return absl::ToInt64Seconds( +- descriptor.value().get().token_bucket_.fill_interval_ - +- absl::Seconds((current_time - descriptor.value().get().token_state_->fill_time_) / 1s)); +- } +- return absl::ToInt64Seconds(token_bucket_.fill_interval_ - +- absl::Seconds((current_time - tokens_.fill_time_) / 1s)); ++ ASSERT(!matched_descriptors.empty()); ++ return {true, makeOptRefFromPtr(matched_descriptors[0])}; + } + + } // namespace LocalRateLimit + } // namespace Common + } // namespace Filters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h b/source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h +index c0cc182a49..ad4810b102 100644 +--- source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h ++++ source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h +@@ -1,13 +1,17 @@ + #pragma once + + #include ++#include + + #include "envoy/event/dispatcher.h" + #include "envoy/event/timer.h" + #include "envoy/extensions/common/ratelimit/v3/ratelimit.pb.h" + #include "envoy/ratelimit/ratelimit.h" ++#include "envoy/singleton/instance.h" ++#include "envoy/upstream/cluster_manager.h" + + #include "source/common/common/thread_synchronizer.h" ++#include "source/common/common/token_bucket_impl.h" + #include "source/common/protobuf/protobuf.h" + + namespace Envoy { +@@ -16,82 +20,156 @@ namespace Filters { + namespace Common { + namespace LocalRateLimit { + ++using ProtoLocalClusterRateLimit = envoy::extensions::common::ratelimit::v3::LocalClusterRateLimit; ++ ++class ShareProvider { ++public: ++ virtual ~ShareProvider() = default; ++ // The share of the tokens. This method should be thread-safe. ++ virtual double getTokensShareFactor() const PURE; ++}; ++using ShareProviderSharedPtr = std::shared_ptr; ++ ++class ShareProviderManager; ++using ShareProviderManagerSharedPtr = std::shared_ptr; ++ ++class ShareProviderManager : public Singleton::Instance { ++public: ++ ShareProviderSharedPtr getShareProvider(const ProtoLocalClusterRateLimit& config) const; ++ ~ShareProviderManager() override; ++ ++ static ShareProviderManagerSharedPtr singleton(Event::Dispatcher& dispatcher, ++ Upstream::ClusterManager& cm, ++ Singleton::Manager& manager); ++ ++ class ShareMonitor : public ShareProvider { ++ public: ++ virtual double onLocalClusterUpdate(const Upstream::Cluster& cluster) PURE; ++ }; ++ using ShareMonitorSharedPtr = std::shared_ptr; ++ ++private: ++ ShareProviderManager(Event::Dispatcher& main_dispatcher, const Upstream::Cluster& cluster); ++ ++ Event::Dispatcher& main_dispatcher_; ++ const Upstream::Cluster& cluster_; ++ Envoy::Common::CallbackHandlePtr handle_; ++ ShareMonitorSharedPtr share_monitor_; ++}; ++using ShareProviderManagerSharedPtr = std::shared_ptr; ++ ++class TokenBucketContext { ++public: ++ virtual ~TokenBucketContext() = default; ++ ++ virtual uint32_t maxTokens() const PURE; ++ virtual uint32_t remainingTokens() const PURE; ++ virtual absl::optional remainingFillInterval() const PURE; ++}; ++ ++class RateLimitTokenBucket : public TokenBucketContext { ++public: ++ virtual bool consume(double factor = 1.0) PURE; ++ virtual void onFillTimer(uint64_t refill_counter, double factor = 1.0) PURE; ++ virtual std::chrono::milliseconds fillInterval() const PURE; ++ virtual double fillRate() const PURE; ++}; ++using RateLimitTokenBucketSharedPtr = std::shared_ptr; ++ ++class LocalRateLimiterImpl; ++ ++// Token bucket that implements based on the periodic timer. ++class TimerTokenBucket : public RateLimitTokenBucket { ++public: ++ TimerTokenBucket(uint32_t max_tokens, uint32_t tokens_per_fill, ++ std::chrono::milliseconds fill_interval, uint64_t multiplier, ++ LocalRateLimiterImpl& parent); ++ ++ // RateLimitTokenBucket ++ bool consume(double factor) override; ++ void onFillTimer(uint64_t refill_counter, double factor) override; ++ std::chrono::milliseconds fillInterval() const override { return fill_interval_; } ++ double fillRate() const override { return fill_rate_; } ++ uint32_t maxTokens() const override { return max_tokens_; } ++ uint32_t remainingTokens() const override { return tokens_.load(); } ++ absl::optional remainingFillInterval() const override; ++ ++ // Descriptor refill interval is a multiple of the timer refill interval. ++ // For example, if the descriptor refill interval is 150ms and the global ++ // refill interval is 50ms, the value is 3. Every 3rd invocation of ++ // the global timer, the descriptor is refilled. ++ const uint64_t multiplier_{}; ++ LocalRateLimiterImpl& parent_; ++ std::atomic tokens_{}; ++ std::atomic fill_time_{}; ++ ++ const uint32_t max_tokens_{}; ++ const uint32_t tokens_per_fill_{}; ++ const std::chrono::milliseconds fill_interval_{}; ++ const double fill_rate_{}; ++}; ++ ++class AtomicTokenBucket : public RateLimitTokenBucket { ++public: ++ AtomicTokenBucket(uint32_t max_tokens, uint32_t tokens_per_fill, ++ std::chrono::milliseconds fill_interval, TimeSource& time_source); ++ ++ // RateLimitTokenBucket ++ bool consume(double factor) override; ++ void onFillTimer(uint64_t, double) override {} ++ std::chrono::milliseconds fillInterval() const override { return {}; } ++ double fillRate() const override { return token_bucket_.fillRate(); } ++ uint32_t maxTokens() const override { return static_cast(token_bucket_.maxTokens()); } ++ uint32_t remainingTokens() const override { ++ return static_cast(token_bucket_.remainingTokens()); ++ } ++ absl::optional remainingFillInterval() const override { return {}; } ++ ++private: ++ AtomicTokenBucketImpl token_bucket_; ++}; ++ + class LocalRateLimiterImpl { + public: ++ struct Result { ++ bool allowed{}; ++ OptRef token_bucket_context{}; ++ }; ++ + LocalRateLimiterImpl( + const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, + const uint32_t tokens_per_fill, Event::Dispatcher& dispatcher, + const Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>& descriptors, +- bool always_consume_default_token_bucket = true); ++ bool always_consume_default_token_bucket = true, ++ ShareProviderSharedPtr shared_provider = nullptr); + ~LocalRateLimiterImpl(); + +- bool requestAllowed(absl::Span request_descriptors) const; +- uint32_t maxTokens(absl::Span request_descriptors) const; +- uint32_t remainingTokens(absl::Span request_descriptors) const; +- int64_t +- remainingFillInterval(absl::Span request_descriptors) const; ++ Result requestAllowed(absl::Span request_descriptors) const; + + private: +- struct TokenState { +- mutable std::atomic tokens_; +- MonotonicTime fill_time_; +- }; +- // Refill counter is incremented per each refill timer hit. +- uint64_t refill_counter_{0}; +- struct LocalDescriptorImpl : public RateLimit::LocalDescriptor { +- std::shared_ptr token_state_; +- RateLimit::TokenBucket token_bucket_; +- // Descriptor refill interval is a multiple of the timer refill interval. +- // For example, if the descriptor refill interval is 150ms and the global +- // refill interval is 50ms, the value is 3. Every 3rd invocation of +- // the global timer, the descriptor is refilled. +- uint64_t multiplier_; +- std::string toString() const { +- std::vector entries; +- entries.reserve(entries_.size()); +- for (const auto& entry : entries_) { +- entries.push_back(absl::StrCat(entry.key_, "=", entry.value_)); +- } +- return absl::StrJoin(entries, ", "); +- } +- }; +- struct LocalDescriptorHash { +- using is_transparent = void; // NOLINT(readability-identifier-naming) +- size_t operator()(const RateLimit::LocalDescriptor& d) const { +- return absl::Hash>()(d.entries_); +- } +- }; +- struct LocalDescriptorEqual { +- using is_transparent = void; // NOLINT(readability-identifier-naming) +- size_t operator()(const RateLimit::LocalDescriptor& a, +- const RateLimit::LocalDescriptor& b) const { +- return a.entries_ == b.entries_; +- } +- }; +- + void onFillTimer(); +- void onFillTimerHelper(TokenState& state, const RateLimit::TokenBucket& bucket); +- void onFillTimerDescriptorHelper(); +- OptRef +- descriptorHelper(absl::Span request_descriptors) const; +- bool requestAllowedHelper(const TokenState& tokens) const; +- int tokensFillPerSecond(LocalDescriptorImpl& descriptor); +- +- RateLimit::TokenBucket token_bucket_; ++ ++ RateLimitTokenBucketSharedPtr default_token_bucket_; ++ + const Event::TimerPtr fill_timer_; + TimeSource& time_source_; +- TokenState tokens_; +- absl::flat_hash_set descriptors_; +- std::vector sorted_descriptors_; ++ RateLimit::LocalDescriptor::Map descriptors_; ++ // Refill counter is incremented per each refill timer hit. ++ uint64_t refill_counter_{0}; ++ ++ ShareProviderSharedPtr share_provider_; ++ + mutable Thread::ThreadSynchronizer synchronizer_; // Used for testing only. + const bool always_consume_default_token_bucket_{}; ++ const bool no_timer_based_rate_limit_token_bucket_{}; + + friend class LocalRateLimiterImplTest; ++ friend class TimerTokenBucket; + }; + + } // namespace LocalRateLimit + } // namespace Common + } // namespace Filters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/http/local_ratelimit/BUILD b/source/extensions/filters/http/local_ratelimit/BUILD +index 88d4042e98..d4ed3daa7e 100644 +--- source/extensions/filters/http/local_ratelimit/BUILD ++++ source/extensions/filters/http/local_ratelimit/BUILD +@@ -45,4 +45,4 @@ envoy_cc_extension( + "//source/extensions/filters/http/common:factory_base_lib", + "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", + ], +-) ++) +\ No newline at end of file +diff --git a/source/extensions/filters/http/local_ratelimit/config.cc b/source/extensions/filters/http/local_ratelimit/config.cc +index cbf719cae4..f1e3951c4f 100644 +--- source/extensions/filters/http/local_ratelimit/config.cc ++++ source/extensions/filters/http/local_ratelimit/config.cc +@@ -19,7 +19,8 @@ Http::FilterFactoryCb LocalRateLimitFilterConfig::createFilterFactoryFromProtoTy + + FilterConfigSharedPtr filter_config = std::make_shared( + proto_config, server_context.localInfo(), server_context.mainThreadDispatcher(), +- context.scope(), server_context.runtime()); ++ server_context.clusterManager(), server_context.singletonManager(), context.scope(), ++ server_context.runtime()); + return [filter_config](Http::FilterChainFactoryCallbacks& callbacks) -> void { + callbacks.addStreamFilter(std::make_shared(filter_config)); + }; +@@ -29,9 +30,9 @@ Router::RouteSpecificFilterConfigConstSharedPtr + LocalRateLimitFilterConfig::createRouteSpecificFilterConfigTyped( + const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& proto_config, + Server::Configuration::ServerFactoryContext& context, ProtobufMessage::ValidationVisitor&) { +- return std::make_shared(proto_config, context.localInfo(), +- context.mainThreadDispatcher(), context.scope(), +- context.runtime(), true); ++ return std::make_shared( ++ proto_config, context.localInfo(), context.mainThreadDispatcher(), context.clusterManager(), ++ context.singletonManager(), context.scope(), context.runtime(), true); + } + + /** +@@ -44,4 +45,4 @@ LEGACY_REGISTER_FACTORY(LocalRateLimitFilterConfig, + } // namespace LocalRateLimitFilter + } // namespace HttpFilters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/http/local_ratelimit/config.h b/source/extensions/filters/http/local_ratelimit/config.h +index 96ae7caa5b..b009c9bcdb 100644 +--- source/extensions/filters/http/local_ratelimit/config.h ++++ source/extensions/filters/http/local_ratelimit/config.h +@@ -31,4 +31,4 @@ private: + } // namespace LocalRateLimitFilter + } // namespace HttpFilters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/http/local_ratelimit/local_ratelimit.cc +index 6938c8b1d6..ad9ba13720 100644 +--- source/extensions/filters/http/local_ratelimit/local_ratelimit.cc ++++ source/extensions/filters/http/local_ratelimit/local_ratelimit.cc +@@ -23,7 +23,8 @@ const std::string& PerConnectionRateLimiter::key() { + + FilterConfig::FilterConfig( + const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, +- const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, Stats::Scope& scope, ++ const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, ++ Upstream::ClusterManager& cm, Singleton::Manager& singleton_manager, Stats::Scope& scope, + Runtime::Loader& runtime, const bool per_route) + : dispatcher_(dispatcher), status_(toErrorCode(config.status().code())), + stats_(generateStats(config.stat_prefix(), scope)), +@@ -37,9 +38,6 @@ FilterConfig::FilterConfig( + config.has_always_consume_default_token_bucket() + ? config.always_consume_default_token_bucket().value() + : true), +- rate_limiter_(new Filters::Common::LocalRateLimit::LocalRateLimiterImpl( +- fill_interval_, max_tokens_, tokens_per_fill_, dispatcher, descriptors_, +- always_consume_default_token_bucket_)), + local_info_(local_info), runtime_(runtime), + filter_enabled_( + config.has_filter_enabled() +@@ -51,10 +49,12 @@ FilterConfig::FilterConfig( + ? absl::optional( + Envoy::Runtime::FractionalPercent(config.filter_enforced(), runtime_)) + : absl::nullopt), +- response_headers_parser_( +- Envoy::Router::HeaderParser::configure(config.response_headers_to_add())), +- request_headers_parser_(Envoy::Router::HeaderParser::configure( +- config.request_headers_to_add_when_not_enforced())), ++ response_headers_parser_(THROW_OR_RETURN_VALUE( ++ Envoy::Router::HeaderParser::configure(config.response_headers_to_add()), ++ Router::HeaderParserPtr)), ++ request_headers_parser_(THROW_OR_RETURN_VALUE( ++ Envoy::Router::HeaderParser::configure(config.request_headers_to_add_when_not_enforced()), ++ Router::HeaderParserPtr)), + stage_(static_cast(config.stage())), + has_descriptors_(!config.descriptors().empty()), + enable_x_rate_limit_headers_(config.enable_x_ratelimit_headers() == +@@ -72,26 +72,36 @@ FilterConfig::FilterConfig( + if (per_route && !config.has_token_bucket()) { + throw EnvoyException("local rate limit token bucket must be set for per filter configs"); + } +-} + +-bool FilterConfig::requestAllowed( +- absl::Span request_descriptors) const { +- return rate_limiter_->requestAllowed(request_descriptors); +-} ++ Filters::Common::LocalRateLimit::ShareProviderSharedPtr share_provider; ++ if (config.has_local_cluster_rate_limit()) { ++ if (rate_limit_per_connection_) { ++ throw EnvoyException("local_cluster_rate_limit is set and " ++ "local_rate_limit_per_downstream_connection is set to true"); ++ } ++ if (!cm.localClusterName().has_value()) { ++ throw EnvoyException("local_cluster_rate_limit is set but no local cluster name is present"); ++ } + +-uint32_t +-FilterConfig::maxTokens(absl::Span request_descriptors) const { +- return rate_limiter_->maxTokens(request_descriptors); +-} ++ // If the local cluster name is set then the relevant cluster must exist or the cluster ++ // manager will fail to initialize. ++ share_provider_manager_ = Filters::Common::LocalRateLimit::ShareProviderManager::singleton( ++ dispatcher, cm, singleton_manager); ++ if (!share_provider_manager_) { ++ throw EnvoyException("local_cluster_rate_limit is set but no local cluster is present"); ++ } + +-uint32_t FilterConfig::remainingTokens( +- absl::Span request_descriptors) const { +- return rate_limiter_->remainingTokens(request_descriptors); ++ share_provider = share_provider_manager_->getShareProvider(config.local_cluster_rate_limit()); ++ } ++ ++ rate_limiter_ = std::make_unique( ++ fill_interval_, max_tokens_, tokens_per_fill_, dispatcher, descriptors_, ++ always_consume_default_token_bucket_, std::move(share_provider)); + } + +-int64_t FilterConfig::remainingFillInterval( ++Filters::Common::LocalRateLimit::LocalRateLimiterImpl::Result FilterConfig::requestAllowed( + absl::Span request_descriptors) const { +- return rate_limiter_->remainingFillInterval(request_descriptors); ++ return rate_limiter_->requestAllowed(request_descriptors); + } + + LocalRateLimitStats FilterConfig::generateStats(const std::string& prefix, Stats::Scope& scope) { +@@ -108,107 +118,93 @@ bool FilterConfig::enforced() const { + } + + Http::FilterHeadersStatus Filter::decodeHeaders(Http::RequestHeaderMap& headers, bool) { +- const auto* config = getConfig(); ++ const auto* route_config = ++ Http::Utility::resolveMostSpecificPerFilterConfig(decoder_callbacks_); ++ ++ // We can never assume that the configuration/route will not change between ++ // decodeHeaders() and encodeHeaders(). ++ // Store the configuration because we will use it later in encodeHeaders(). ++ if (route_config != nullptr) { ++ ASSERT(used_config_ == config_.get()); ++ used_config_ = route_config; // Overwrite the used configuration. ++ } + +- if (!config->enabled()) { ++ if (!used_config_->enabled()) { + return Http::FilterHeadersStatus::Continue; + } + +- config->stats().enabled_.inc(); ++ used_config_->stats().enabled_.inc(); + + std::vector descriptors; +- if (config->hasDescriptors()) { ++ if (used_config_->hasDescriptors()) { + populateDescriptors(descriptors, headers); + } + +- // Store descriptors which is used to generate x-ratelimit-* headers in encoding response headers. +- stored_descriptors_ = descriptors; +- + if (ENVOY_LOG_CHECK_LEVEL(debug)) { + for (const auto& request_descriptor : descriptors) { +- for (const Envoy::RateLimit::DescriptorEntry& entry : request_descriptor.entries_) { +- ENVOY_LOG(debug, "populate descriptors: key={} value={}", entry.key_, entry.value_); +- } ++ ENVOY_LOG(debug, "populate descriptor: {}", request_descriptor.toString()); + } + } + +- if (requestAllowed(descriptors)) { +- config->stats().ok_.inc(); ++ auto result = requestAllowed(descriptors); ++ // The global limiter, route limiter, or connection level limiter are all have longer life ++ // than the request, so we can safely store the token bucket context reference. ++ token_bucket_context_ = result.token_bucket_context; ++ ++ if (result.allowed) { ++ used_config_->stats().ok_.inc(); + return Http::FilterHeadersStatus::Continue; + } + +- config->stats().rate_limited_.inc(); ++ used_config_->stats().rate_limited_.inc(); + +- if (!config->enforced()) { +- config->requestHeadersParser().evaluateHeaders(headers, decoder_callbacks_->streamInfo()); ++ if (!used_config_->enforced()) { ++ used_config_->requestHeadersParser().evaluateHeaders(headers, decoder_callbacks_->streamInfo()); + return Http::FilterHeadersStatus::Continue; + } + +- config->stats().enforced_.inc(); ++ used_config_->stats().enforced_.inc(); + + decoder_callbacks_->sendLocalReply( +- config->status(), "local_rate_limited", +- [this, config](Http::HeaderMap& headers) { +- config->responseHeadersParser().evaluateHeaders(headers, decoder_callbacks_->streamInfo()); ++ used_config_->status(), "local_rate_limited", ++ [this](Http::HeaderMap& headers) { ++ used_config_->responseHeadersParser().evaluateHeaders(headers, ++ decoder_callbacks_->streamInfo()); + }, +- config->rateLimitedGrpcStatus(), "local_rate_limited"); ++ used_config_->rateLimitedGrpcStatus(), "local_rate_limited"); + decoder_callbacks_->streamInfo().setResponseFlag(StreamInfo::CoreResponseFlag::RateLimited); + + return Http::FilterHeadersStatus::StopIteration; + } + + Http::FilterHeadersStatus Filter::encodeHeaders(Http::ResponseHeaderMap& headers, bool) { +- const auto* config = getConfig(); +- +- if (config->enabled() && config->enableXRateLimitHeaders()) { +- ASSERT(stored_descriptors_.has_value()); +- auto limit = maxTokens(stored_descriptors_.value()); +- auto remaining = remainingTokens(stored_descriptors_.value()); +- auto reset = remainingFillInterval(stored_descriptors_.value()); +- +- headers.addReferenceKey( +- HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitLimit, limit); ++ // We can never assume the decodeHeaders() was called before encodeHeaders(). ++ if (used_config_->enableXRateLimitHeaders() && token_bucket_context_.has_value()) { + headers.addReferenceKey( +- HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitRemaining, remaining); ++ HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitLimit, ++ token_bucket_context_->maxTokens()); + headers.addReferenceKey( +- HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitReset, reset); ++ HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitRemaining, ++ token_bucket_context_->remainingTokens()); ++ const auto reset = token_bucket_context_->remainingFillInterval(); ++ if (reset.has_value()) { ++ headers.addReferenceKey( ++ HttpFilters::Common::RateLimit::XRateLimitHeaders::get().XRateLimitReset, reset.value()); ++ } + } + + return Http::FilterHeadersStatus::Continue; + } + +-bool Filter::requestAllowed(absl::Span request_descriptors) { +- const auto* config = getConfig(); +- return config->rateLimitPerConnection() ++Filters::Common::LocalRateLimit::LocalRateLimiterImpl::Result ++Filter::requestAllowed(absl::Span request_descriptors) { ++ return used_config_->rateLimitPerConnection() + ? getPerConnectionRateLimiter().requestAllowed(request_descriptors) +- : config->requestAllowed(request_descriptors); +-} +- +-uint32_t Filter::maxTokens(absl::Span request_descriptors) { +- const auto* config = getConfig(); +- return config->rateLimitPerConnection() +- ? getPerConnectionRateLimiter().maxTokens(request_descriptors) +- : config->maxTokens(request_descriptors); +-} +- +-uint32_t Filter::remainingTokens(absl::Span request_descriptors) { +- const auto* config = getConfig(); +- return config->rateLimitPerConnection() +- ? getPerConnectionRateLimiter().remainingTokens(request_descriptors) +- : config->remainingTokens(request_descriptors); +-} +- +-int64_t +-Filter::remainingFillInterval(absl::Span request_descriptors) { +- const auto* config = getConfig(); +- return config->rateLimitPerConnection() +- ? getPerConnectionRateLimiter().remainingFillInterval(request_descriptors) +- : config->remainingFillInterval(request_descriptors); ++ : used_config_->requestAllowed(request_descriptors); + } + + const Filters::Common::LocalRateLimit::LocalRateLimiterImpl& Filter::getPerConnectionRateLimiter() { +- const auto* config = getConfig(); +- ASSERT(config->rateLimitPerConnection()); ++ ASSERT(used_config_->rateLimitPerConnection()); + + auto typed_state = + decoder_callbacks_->streamInfo().filterState()->getDataReadOnly( +@@ -216,9 +212,9 @@ const Filters::Common::LocalRateLimit::LocalRateLimiterImpl& Filter::getPerConne + + if (typed_state == nullptr) { + auto limiter = std::make_shared( +- config->fillInterval(), config->maxTokens(), config->tokensPerFill(), +- decoder_callbacks_->dispatcher(), config->descriptors(), +- config->consumeDefaultTokenBucket()); ++ used_config_->fillInterval(), used_config_->maxTokens(), used_config_->tokensPerFill(), ++ decoder_callbacks_->dispatcher(), used_config_->descriptors(), ++ used_config_->consumeDefaultTokenBucket()); + + decoder_callbacks_->streamInfo().filterState()->setData( + PerConnectionRateLimiter::key(), limiter, StreamInfo::FilterState::StateType::ReadOnly, +@@ -259,35 +255,23 @@ void Filter::populateDescriptors(std::vector& descri + void Filter::populateDescriptors(const Router::RateLimitPolicy& rate_limit_policy, + std::vector& descriptors, + Http::RequestHeaderMap& headers) { +- const auto* config = getConfig(); + for (const Router::RateLimitPolicyEntry& rate_limit : +- rate_limit_policy.getApplicableRateLimit(config->stage())) { ++ rate_limit_policy.getApplicableRateLimit(used_config_->stage())) { + const std::string& disable_key = rate_limit.disableKey(); + + if (!disable_key.empty()) { + continue; + } +- rate_limit.populateLocalDescriptors(descriptors, config->localInfo().clusterName(), headers, +- decoder_callbacks_->streamInfo()); +- } +-} +- +-const FilterConfig* Filter::getConfig() const { +- const auto* config = +- Http::Utility::resolveMostSpecificPerFilterConfig(decoder_callbacks_); +- if (config) { +- return config; ++ rate_limit.populateLocalDescriptors(descriptors, used_config_->localInfo().clusterName(), ++ headers, decoder_callbacks_->streamInfo()); + } +- +- return config_.get(); + } + + VhRateLimitOptions Filter::getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route) { + if (route->routeEntry()->includeVirtualHostRateLimits()) { + vh_rate_limits_ = VhRateLimitOptions::Include; + } else { +- const auto* config = getConfig(); +- switch (config->virtualHostRateLimits()) { ++ switch (used_config_->virtualHostRateLimits()) { + PANIC_ON_PROTO_ENUM_SENTINEL_VALUES; + case envoy::extensions::common::ratelimit::v3::INCLUDE: + vh_rate_limits_ = VhRateLimitOptions::Include; +@@ -306,4 +290,4 @@ VhRateLimitOptions Filter::getVirtualHostRateLimitOption(const Router::RouteCons + } // namespace LocalRateLimitFilter + } // namespace HttpFilters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/http/local_ratelimit/local_ratelimit.h b/source/extensions/filters/http/local_ratelimit/local_ratelimit.h +index e816da64e3..9788eba17c 100644 +--- source/extensions/filters/http/local_ratelimit/local_ratelimit.h ++++ source/extensions/filters/http/local_ratelimit/local_ratelimit.h +@@ -73,6 +73,7 @@ class FilterConfig : public Router::RouteSpecificFilterConfig { + public: + FilterConfig(const envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit& config, + const LocalInfo::LocalInfo& local_info, Event::Dispatcher& dispatcher, ++ Upstream::ClusterManager& cm, Singleton::Manager& singleton_manager, + Stats::Scope& scope, Runtime::Loader& runtime, bool per_route = false); + ~FilterConfig() override { + // Ensure that the LocalRateLimiterImpl instance will be destroyed on the thread where its inner +@@ -84,11 +85,8 @@ public: + } + const LocalInfo::LocalInfo& localInfo() const { return local_info_; } + Runtime::Loader& runtime() { return runtime_; } +- bool requestAllowed(absl::Span request_descriptors) const; +- uint32_t maxTokens(absl::Span request_descriptors) const; +- uint32_t remainingTokens(absl::Span request_descriptors) const; +- int64_t +- remainingFillInterval(absl::Span request_descriptors) const; ++ Filters::Common::LocalRateLimit::LocalRateLimiterImpl::Result ++ requestAllowed(absl::Span request_descriptors) const; + bool enabled() const; + bool enforced() const; + LocalRateLimitStats& stats() const { return stats_; } +@@ -139,6 +137,7 @@ private: + descriptors_; + const bool rate_limit_per_connection_; + const bool always_consume_default_token_bucket_{}; ++ Filters::Common::LocalRateLimit::ShareProviderManagerSharedPtr share_provider_manager_; + std::unique_ptr rate_limiter_; + const LocalInfo::LocalInfo& local_info_; + Runtime::Loader& runtime_; +@@ -161,7 +160,7 @@ using FilterConfigSharedPtr = std::shared_ptr; + */ + class Filter : public Http::PassThroughFilter, Logger::Loggable { + public: +- Filter(FilterConfigSharedPtr config) : config_(config) {} ++ Filter(FilterConfigSharedPtr config) : config_(config), used_config_(config_.get()) {} + + // Http::StreamDecoderFilter + Http::FilterHeadersStatus decodeHeaders(Http::RequestHeaderMap& headers, +@@ -181,19 +180,19 @@ private: + Http::RequestHeaderMap& headers); + VhRateLimitOptions getVirtualHostRateLimitOption(const Router::RouteConstSharedPtr& route); + const Filters::Common::LocalRateLimit::LocalRateLimiterImpl& getPerConnectionRateLimiter(); +- bool requestAllowed(absl::Span request_descriptors); +- uint32_t maxTokens(absl::Span request_descriptors); +- uint32_t remainingTokens(absl::Span request_descriptors); +- int64_t remainingFillInterval(absl::Span request_descriptors); ++ Filters::Common::LocalRateLimit::LocalRateLimiterImpl::Result ++ requestAllowed(absl::Span request_descriptors); + +- const FilterConfig* getConfig() const; + FilterConfigSharedPtr config_; ++ // Actual config used for the current request. Is config_ by default, but can be overridden by ++ // per-route config. ++ const FilterConfig* used_config_{}; ++ OptRef token_bucket_context_; + +- absl::optional> stored_descriptors_; + VhRateLimitOptions vh_rate_limits_; + }; + + } // namespace LocalRateLimitFilter + } // namespace HttpFilters + } // namespace Extensions +-} // namespace Envoy ++} // namespace Envoy +\ No newline at end of file +diff --git a/source/extensions/filters/listener/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/listener/local_ratelimit/local_ratelimit.cc +index 7aa4561b80..85f810ad40 100644 +--- source/extensions/filters/listener/local_ratelimit/local_ratelimit.cc ++++ source/extensions/filters/listener/local_ratelimit/local_ratelimit.cc +@@ -21,7 +21,7 @@ FilterConfig::FilterConfig( + Protobuf::RepeatedPtrField< + envoy::extensions::common::ratelimit::v3::LocalRateLimitDescriptor>()) {} + +-bool FilterConfig::canCreateConnection() { return rate_limiter_.requestAllowed({}); } ++bool FilterConfig::canCreateConnection() { return rate_limiter_.requestAllowed({}).allowed; } + + LocalRateLimitStats FilterConfig::generateStats(const std::string& prefix, Stats::Scope& scope) { + const std::string final_prefix = "listener_local_ratelimit." + prefix; +diff --git a/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc b/source/extensions/filters/network/local_ratelimit/local_ratelimit.cc +index 8c6af4c820..09594692dd 100644 +--- source/extensions/filters/network/local_ratelimit/local_ratelimit.cc ++++ source/extensions/filters/network/local_ratelimit/local_ratelimit.cc +@@ -94,7 +94,7 @@ LocalRateLimitStats Config::generateStats(const std::string& prefix, Stats::Scop + return {ALL_LOCAL_RATE_LIMIT_STATS(POOL_COUNTER_PREFIX(scope, final_prefix))}; + } + +-bool Config::canCreateConnection() { return rate_limiter_->requestAllowed({}); } ++bool Config::canCreateConnection() { return rate_limiter_->requestAllowed({}).allowed; } + + Network::FilterStatus Filter::onNewConnection() { + if (!config_->enabled()) { +diff --git a/source/server/config_validation/server.cc b/source/server/config_validation/server.cc +index 03bbd16fae..569f677341 100644 +--- source/server/config_validation/server.cc ++++ source/server/config_validation/server.cc +@@ -11,7 +11,6 @@ + #include "source/common/listener_manager/listener_info_impl.h" + #include "source/common/local_info/local_info_impl.h" + #include "source/common/protobuf/utility.h" +-#include "source/common/singleton/manager_impl.h" + #include "source/common/stats/tag_producer_impl.h" + #include "source/common/tls/context_manager_impl.h" + #include "source/common/version/version.h" +@@ -59,7 +58,6 @@ ValidationInstance::ValidationInstance( + api_(new Api::ValidationImpl(thread_factory, store, time_system, file_system, + random_generator_, bootstrap_, process_context)), + dispatcher_(api_->allocateDispatcher("main_thread")), +- singleton_manager_(new Singleton::ManagerImpl(api_->threadFactory())), + access_log_manager_(options.fileFlushIntervalMsec(), *api_, *dispatcher_, access_log_lock, + store), + grpc_context_(stats_store_.symbolTable()), http_context_(stats_store_.symbolTable()), +diff --git a/source/server/config_validation/server.h b/source/server/config_validation/server.h +index 777b5a1ff6..af367feeb6 100644 +--- source/server/config_validation/server.h ++++ source/server/config_validation/server.h +@@ -23,6 +23,7 @@ + #include "source/common/router/rds_impl.h" + #include "source/common/runtime/runtime_impl.h" + #include "source/common/secret/secret_manager_impl.h" ++#include "source/common/singleton/manager_impl.h" + #include "source/common/thread_local/thread_local_impl.h" + #include "source/server/config_validation/admin.h" + #include "source/server/config_validation/api.h" +@@ -100,7 +101,7 @@ public: + void shutdown() override; + bool isShutdown() override { return false; } + void shutdownAdmin() override {} +- Singleton::Manager& singletonManager() override { return *singleton_manager_; } ++ Singleton::Manager& singletonManager() override { return singleton_manager_; } + OverloadManager& overloadManager() override { return *overload_manager_; } + bool healthCheckFailed() override { return false; } + const Options& options() override { return options_; } +@@ -175,7 +176,7 @@ private: + std::unique_ptr ssl_context_manager_; + Event::DispatcherPtr dispatcher_; + std::unique_ptr admin_; +- Singleton::ManagerPtr singleton_manager_; ++ Singleton::ManagerImpl singleton_manager_; + std::unique_ptr runtime_; + Random::RandomGeneratorImpl random_generator_; + Configuration::MainImpl config_; +diff --git a/source/server/server.cc b/source/server/server.cc +index 254cbc0860..6b14538680 100644 +--- source/server/server.cc ++++ source/server/server.cc +@@ -45,7 +45,6 @@ + #include "source/common/runtime/runtime_impl.h" + #include "source/common/runtime/runtime_keys.h" + #include "source/common/signal/fatal_error_handler.h" +-#include "source/common/singleton/manager_impl.h" + #include "source/common/stats/stats_matcher_impl.h" + #include "source/common/stats/tag_producer_impl.h" + #include "source/common/stats/thread_local_store.h" +@@ -99,7 +98,6 @@ InstanceBase::InstanceBase(Init::Manager& init_manager, const Options& options, + dispatcher_(api_->allocateDispatcher("main_thread")), + access_log_manager_(options.fileFlushIntervalMsec(), *api_, *dispatcher_, access_log_lock, + store), +- singleton_manager_(new Singleton::ManagerImpl(api_->threadFactory())), + handler_(getHandler(*dispatcher_)), worker_factory_(thread_local_, *api_, hooks), + mutex_tracer_(options.mutexTracingEnabled() ? &Envoy::MutexTracerImpl::getOrCreateTracer() + : nullptr), +diff --git a/source/server/server.h b/source/server/server.h +index ae9b1ea95b..3060531879 100644 +--- source/server/server.h ++++ source/server/server.h +@@ -37,6 +37,7 @@ + #include "source/common/router/context_impl.h" + #include "source/common/runtime/runtime_impl.h" + #include "source/common/secret/secret_manager_impl.h" ++#include "source/common/singleton/manager_impl.h" + #include "source/common/upstream/health_discovery_service.h" + + #ifdef ENVOY_ADMIN_FUNCTIONALITY +@@ -278,7 +279,7 @@ public: + void shutdown() override; + bool isShutdown() final { return shutdown_; } + void shutdownAdmin() override; +- Singleton::Manager& singletonManager() override { return *singleton_manager_; } ++ Singleton::Manager& singletonManager() override { return singleton_manager_; } + bool healthCheckFailed() override; + const Options& options() override { return options_; } + time_t startTimeCurrentEpoch() override { return start_time_; } +@@ -379,8 +380,8 @@ private: + std::unique_ptr ssl_context_manager_; + Event::DispatcherPtr dispatcher_; + AccessLog::AccessLogManagerImpl access_log_manager_; +- std::unique_ptr admin_; +- Singleton::ManagerPtr singleton_manager_; ++ std::shared_ptr admin_; ++ Singleton::ManagerImpl singleton_manager_; + Network::ConnectionHandlerPtr handler_; + std::unique_ptr runtime_; + ProdWorkerFactory worker_factory_; +diff --git a/test/common/common/BUILD b/test/common/common/BUILD +index ede00f4bc1..5a2be7a86e 100644 +--- test/common/common/BUILD ++++ test/common/common/BUILD +@@ -335,6 +335,7 @@ envoy_cc_test( + deps = [ + "//source/common/common:token_bucket_impl_lib", + "//test/test_common:simulated_time_system_lib", ++ "//test/test_common:test_time_lib", + "//test/test_common:utility_lib", + ], + ) +diff --git a/test/common/common/token_bucket_impl_test.cc b/test/common/common/token_bucket_impl_test.cc +index 66308cfded..ce20de2cfb 100644 +--- test/common/common/token_bucket_impl_test.cc ++++ test/common/common/token_bucket_impl_test.cc +@@ -3,6 +3,7 @@ + #include "source/common/common/token_bucket_impl.h" + + #include "test/test_common/simulated_time_system.h" ++#include "test/test_common/test_time.h" + + #include "gtest/gtest.h" + +@@ -126,4 +127,172 @@ TEST_F(TokenBucketImplTest, YearlyMinRefillRate) { + EXPECT_EQ(1, token_bucket.consume(1, false)); + } + ++class AtomicTokenBucketImplTest : public testing::Test { ++protected: ++ Event::SimulatedTimeSystem time_system_; ++}; ++ ++// Verifies TokenBucket initialization. ++TEST_F(AtomicTokenBucketImplTest, Initialization) { ++ AtomicTokenBucketImpl token_bucket{1, time_system_, -1.0}; ++ ++ EXPECT_EQ(1, token_bucket.fillRate()); ++ EXPECT_EQ(1, token_bucket.maxTokens()); ++ EXPECT_EQ(1, token_bucket.remainingTokens()); ++ ++ EXPECT_EQ(1, token_bucket.consume(1, false)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ EXPECT_EQ(false, token_bucket.consume()); ++} ++ ++// Verifies TokenBucket's maximum capacity. ++TEST_F(AtomicTokenBucketImplTest, MaxBucketSize) { ++ AtomicTokenBucketImpl token_bucket{3, time_system_, 1}; ++ ++ EXPECT_EQ(1, token_bucket.fillRate()); ++ EXPECT_EQ(3, token_bucket.maxTokens()); ++ EXPECT_EQ(3, token_bucket.remainingTokens()); ++ ++ EXPECT_EQ(3, token_bucket.consume(3, false)); ++ time_system_.setMonotonicTime(std::chrono::seconds(10)); ++ EXPECT_EQ(0, token_bucket.consume(4, false)); ++ EXPECT_EQ(3, token_bucket.consume(3, false)); ++} ++ ++// Verifies that TokenBucket can consume tokens. ++TEST_F(AtomicTokenBucketImplTest, Consume) { ++ AtomicTokenBucketImpl token_bucket{10, time_system_, 1}; ++ ++ EXPECT_EQ(0, token_bucket.consume(20, false)); ++ EXPECT_EQ(9, token_bucket.consume(9, false)); ++ ++ // consume() == consume(1, false) ++ EXPECT_EQ(true, token_bucket.consume()); ++ ++ time_system_.setMonotonicTime(std::chrono::milliseconds(999)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ ++ time_system_.setMonotonicTime(std::chrono::milliseconds(5999)); ++ EXPECT_EQ(0, token_bucket.consume(6, false)); ++ ++ time_system_.setMonotonicTime(std::chrono::milliseconds(6000)); ++ EXPECT_EQ(6, token_bucket.consume(6, false)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++} ++ ++// Verifies that TokenBucket can refill tokens. ++TEST_F(AtomicTokenBucketImplTest, Refill) { ++ AtomicTokenBucketImpl token_bucket{1, time_system_, 0.5}; ++ EXPECT_EQ(1, token_bucket.consume(1, false)); ++ ++ time_system_.setMonotonicTime(std::chrono::milliseconds(500)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ time_system_.setMonotonicTime(std::chrono::milliseconds(1500)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ time_system_.setMonotonicTime(std::chrono::milliseconds(2000)); ++ EXPECT_EQ(1, token_bucket.consume(1, false)); ++} ++ ++// Test partial consumption of tokens. ++TEST_F(AtomicTokenBucketImplTest, PartialConsumption) { ++ AtomicTokenBucketImpl token_bucket{16, time_system_, 16}; ++ EXPECT_EQ(16, token_bucket.consume(18, true)); ++ time_system_.advanceTimeWait(std::chrono::milliseconds(62)); ++ EXPECT_EQ(0, token_bucket.consume(1, true)); ++ time_system_.advanceTimeWait(std::chrono::milliseconds(1)); ++ EXPECT_EQ(1, token_bucket.consume(2, true)); ++} ++ ++// Validate that a minimal refresh time is 1 year. ++TEST_F(AtomicTokenBucketImplTest, YearlyMinRefillRate) { ++ constexpr uint64_t seconds_per_year = 365 * 24 * 60 * 60; ++ // Set the fill rate to be 2 years. ++ AtomicTokenBucketImpl token_bucket{1, time_system_, 1.0 / (seconds_per_year * 2)}; ++ ++ // Consume first token. ++ EXPECT_EQ(1, token_bucket.consume(1, false)); ++ ++ // Less than a year should still have no tokens. ++ time_system_.setMonotonicTime(std::chrono::seconds(seconds_per_year - 1)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ time_system_.setMonotonicTime(std::chrono::seconds(seconds_per_year)); ++ EXPECT_EQ(1, token_bucket.consume(1, false)); ++} ++ ++TEST_F(AtomicTokenBucketImplTest, ConsumeNegativeTokens) { ++ AtomicTokenBucketImpl token_bucket{10, time_system_, 1}; ++ ++ EXPECT_EQ(3, token_bucket.consume([](double) { return 3; })); ++ EXPECT_EQ(7, token_bucket.remainingTokens()); ++ EXPECT_EQ(-3, token_bucket.consume([](double) { return -3; })); ++ EXPECT_EQ(10, token_bucket.remainingTokens()); ++} ++ ++TEST_F(AtomicTokenBucketImplTest, ConsumeSuperLargeTokens) { ++ AtomicTokenBucketImpl token_bucket{10, time_system_, 1}; ++ ++ EXPECT_EQ(100, token_bucket.consume([](double) { return 100; })); ++ EXPECT_EQ(-90, token_bucket.remainingTokens()); ++} ++ ++TEST_F(AtomicTokenBucketImplTest, MultipleThreadsConsume) { ++ // Real time source to ensure we will not fall into endless loop. ++ Event::TestRealTimeSystem real_time_source; ++ ++ AtomicTokenBucketImpl token_bucket{1200, time_system_, 1.0}; ++ ++ // Exhaust all tokens. ++ EXPECT_EQ(1200, token_bucket.consume(1200, false)); ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++ ++ std::vector threads; ++ auto timeout_point = real_time_source.monotonicTime() + std::chrono::seconds(30); ++ ++ size_t thread_1_token = 0; ++ threads.push_back(std::thread([&] { ++ while (thread_1_token < 300 && real_time_source.monotonicTime() < timeout_point) { ++ thread_1_token += token_bucket.consume(1, false); ++ } ++ })); ++ ++ size_t thread_2_token = 0; ++ threads.push_back(std::thread([&] { ++ while (thread_2_token < 300 && real_time_source.monotonicTime() < timeout_point) { ++ thread_2_token += token_bucket.consume(1, false); ++ } ++ })); ++ ++ size_t thread_3_token = 0; ++ threads.push_back(std::thread([&] { ++ while (thread_3_token < 300 && real_time_source.monotonicTime() < timeout_point) { ++ const size_t left = 300 - thread_3_token; ++ thread_3_token += token_bucket.consume(std::min(left, 2), true); ++ } ++ })); ++ ++ size_t thread_4_token = 0; ++ threads.push_back(std::thread([&] { ++ while (thread_4_token < 300 && real_time_source.monotonicTime() < timeout_point) { ++ const size_t left = 300 - thread_4_token; ++ thread_4_token += token_bucket.consume(std::min(left, 3), true); ++ } ++ })); ++ ++ // Fill the buckets by changing the time. ++ for (size_t i = 0; i < 200; i++) { ++ time_system_.advanceTimeWait(std::chrono::seconds(1)); ++ } ++ for (size_t i = 0; i < 100; i++) { ++ time_system_.advanceTimeWait(std::chrono::seconds(10)); ++ } ++ ++ for (auto& thread : threads) { ++ thread.join(); ++ } ++ ++ EXPECT_EQ(1200, thread_1_token + thread_2_token + thread_3_token + thread_4_token); ++ ++ EXPECT_EQ(0, token_bucket.consume(1, false)); ++} ++ + } // namespace Envoy +diff --git a/test/common/http/http_server_properties_cache_manager_test.cc b/test/common/http/http_server_properties_cache_manager_test.cc +index f6b8fbdc31..0d2fe3d44f 100644 +--- test/common/http/http_server_properties_cache_manager_test.cc ++++ test/common/http/http_server_properties_cache_manager_test.cc +@@ -30,7 +30,7 @@ public: + manager_ = factory_->get(); + } + +- Singleton::ManagerImpl singleton_manager_{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager_; + NiceMock context_; + testing::NiceMock tls_; + std::unique_ptr factory_; +diff --git a/test/common/singleton/manager_impl_test.cc b/test/common/singleton/manager_impl_test.cc +index d1a8513fbd..0a943cfb6a 100644 +--- test/common/singleton/manager_impl_test.cc ++++ test/common/singleton/manager_impl_test.cc +@@ -12,7 +12,7 @@ namespace { + + // Must be a dedicated function so that TID is within the death test. + static void deathTestWorker() { +- ManagerImpl manager(Thread::threadFactoryForTest()); ++ ManagerImpl manager; + + manager.get( + "foo", [] { return nullptr; }, false); +@@ -39,7 +39,7 @@ TEST(SingletonRegistration, category) { + } + + TEST(SingletonManagerImplTest, Basic) { +- ManagerImpl manager(Thread::threadFactoryForTest()); ++ ManagerImpl manager; + + std::shared_ptr singleton = std::make_shared(); + EXPECT_EQ(singleton, manager.get( +@@ -53,7 +53,7 @@ TEST(SingletonManagerImplTest, Basic) { + } + + TEST(SingletonManagerImplTest, NonConstructingGetTyped) { +- ManagerImpl manager(Thread::threadFactoryForTest()); ++ ManagerImpl manager; + + // Access without first constructing should be null. + EXPECT_EQ(nullptr, manager.getTyped("test_singleton")); +@@ -73,7 +73,7 @@ TEST(SingletonManagerImplTest, NonConstructingGetTyped) { + TEST(SingletonManagerImplTest, PinnedSingleton) { + + { +- ManagerImpl manager(Thread::threadFactoryForTest()); ++ ManagerImpl manager; + TestSingleton* singleton_ptr{}; + + // Register a singleton and get it. +@@ -94,7 +94,7 @@ TEST(SingletonManagerImplTest, PinnedSingleton) { + } + + { +- ManagerImpl manager(Thread::threadFactoryForTest()); ++ ManagerImpl manager; + TestSingleton* singleton_ptr{}; + + // Register a pinned singleton and get it. +diff --git a/test/common/upstream/test_cluster_manager.h b/test/common/upstream/test_cluster_manager.h +index e9836c75db..e3c6662135 100644 +--- test/common/upstream/test_cluster_manager.h ++++ test/common/upstream/test_cluster_manager.h +@@ -151,7 +151,7 @@ public: + NiceMock& admin_ = server_context_.admin_; + NiceMock secret_manager_; + NiceMock& log_manager_ = server_context_.access_log_manager_; +- Singleton::ManagerImpl singleton_manager_{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager_; + NiceMock validation_visitor_; + NiceMock random_; + Api::ApiPtr api_; +diff --git a/test/extensions/common/async_files/async_file_handle_thread_pool_test.cc b/test/extensions/common/async_files/async_file_handle_thread_pool_test.cc +index 538462fdfa..b4e63ed44d 100644 +--- test/extensions/common/async_files/async_file_handle_thread_pool_test.cc ++++ test/extensions/common/async_files/async_file_handle_thread_pool_test.cc +@@ -64,7 +64,7 @@ public: + class AsyncFileHandleTest : public testing::Test, public AsyncFileHandleHelpers { + public: + void SetUp() override { +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + factory_ = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + envoy::extensions::common::async_files::v3::AsyncFileManagerConfig config; + config.mutable_thread_pool()->set_thread_count(1); +@@ -77,7 +77,7 @@ public: + void SetUp() override { + EXPECT_CALL(mock_posix_file_operations_, supportsAllPosixFileOperations()) + .WillRepeatedly(Return(true)); +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + factory_ = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + envoy::extensions::common::async_files::v3::AsyncFileManagerConfig config; + config.mutable_thread_pool()->set_thread_count(1); +diff --git a/test/extensions/common/async_files/async_file_manager_factory_test.cc b/test/extensions/common/async_files/async_file_manager_factory_test.cc +index 2914d157c3..f2afd82435 100644 +--- test/extensions/common/async_files/async_file_manager_factory_test.cc ++++ test/extensions/common/async_files/async_file_manager_factory_test.cc +@@ -30,7 +30,7 @@ using ::testing::StrictMock; + class AsyncFileManagerFactoryTest : public ::testing::Test { + public: + void SetUp() override { +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + factory_ = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + EXPECT_CALL(mock_posix_file_operations_, supportsAllPosixFileOperations()) + .WillRepeatedly(Return(true)); +diff --git a/test/extensions/common/async_files/async_file_manager_thread_pool_test.cc b/test/extensions/common/async_files/async_file_manager_thread_pool_test.cc +index 3ece365190..d0c262ddf9 100644 +--- test/extensions/common/async_files/async_file_manager_thread_pool_test.cc ++++ test/extensions/common/async_files/async_file_manager_thread_pool_test.cc +@@ -119,7 +119,7 @@ private: + class AsyncFileManagerTest : public testing::Test { + public: + void SetUp() override { +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + factory_ = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + } + +@@ -179,7 +179,7 @@ public: + void SetUp() override { + envoy::extensions::common::async_files::v3::AsyncFileManagerConfig config; + config.mutable_thread_pool()->set_thread_count(1); +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + auto factory = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + manager_ = factory->getAsyncFileManager(config); + } +diff --git a/test/extensions/common/async_files/async_file_manager_thread_pool_with_mocks_test.cc b/test/extensions/common/async_files/async_file_manager_thread_pool_with_mocks_test.cc +index feeba117a1..03a854da6b 100644 +--- test/extensions/common/async_files/async_file_manager_thread_pool_with_mocks_test.cc ++++ test/extensions/common/async_files/async_file_manager_thread_pool_with_mocks_test.cc +@@ -41,7 +41,7 @@ public: + .WillRepeatedly(Return(true)); + envoy::extensions::common::async_files::v3::AsyncFileManagerConfig config; + config.mutable_thread_pool()->set_thread_count(1); +- singleton_manager_ = std::make_unique(Thread::threadFactoryForTest()); ++ singleton_manager_ = std::make_unique(); + factory_ = AsyncFileManagerFactory::singleton(singleton_manager_.get()); + manager_ = factory_->getAsyncFileManager(config, &mock_posix_file_operations_); + } +diff --git a/test/extensions/filters/common/local_ratelimit/BUILD b/test/extensions/filters/common/local_ratelimit/BUILD +index 96bd5d38a4..85f2f74d5c 100644 +--- test/extensions/filters/common/local_ratelimit/BUILD ++++ test/extensions/filters/common/local_ratelimit/BUILD +@@ -12,7 +12,12 @@ envoy_cc_test( + name = "local_ratelimit_test", + srcs = ["local_ratelimit_test.cc"], + deps = [ ++ "//source/common/singleton:manager_impl_lib", + "//source/extensions/filters/common/local_ratelimit:local_ratelimit_lib", + "//test/mocks/event:event_mocks", ++ "//test/mocks/upstream:cluster_manager_mocks", ++ "//test/mocks/upstream:cluster_priority_set_mocks", ++ "//test/test_common:test_runtime_lib", ++ "//test/test_common:utility_lib", + ], + ) +diff --git a/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc b/test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc +index 9c1ead992a..7228c16516 100644 +--- test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc ++++ test/extensions/filters/common/local_ratelimit/local_ratelimit_test.cc +@@ -1,6 +1,11 @@ ++#include "source/common/singleton/manager_impl.h" + #include "source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" + + #include "test/mocks/event/mocks.h" ++#include "test/mocks/upstream/cluster_manager.h" ++#include "test/mocks/upstream/cluster_priority_set.h" ++#include "test/test_common/test_runtime.h" ++#include "test/test_common/thread_factory_for_test.h" + #include "test/test_common/utility.h" + + #include "gmock/gmock.h" +@@ -15,6 +20,97 @@ namespace Filters { + namespace Common { + namespace LocalRateLimit { + ++class WrapperedProvider { ++public: ++ WrapperedProvider(ShareProviderSharedPtr provider) : provider_(provider) {} ++ ++ uint32_t tokensPerFill(uint32_t origin_tokens_per_fill) const { ++ return std::ceil(origin_tokens_per_fill * provider_->getTokensShareFactor()); ++ } ++ ++ ShareProviderSharedPtr provider_; ++}; ++ ++TEST(ShareProviderManagerTest, ShareProviderManagerTest) { ++ NiceMock cm; ++ NiceMock dispatcher; ++ Singleton::ManagerImpl manager; ++ ++ NiceMock priority_set; ++ cm.local_cluster_name_ = "local_cluster"; ++ cm.initializeClusters({"local_cluster"}, {}); ++ ++ const auto* mock_local_cluster = cm.active_clusters_.at("local_cluster").get(); ++ ++ EXPECT_CALL(*mock_local_cluster, prioritySet()).WillOnce(ReturnRef(priority_set)); ++ EXPECT_CALL(priority_set, addMemberUpdateCb(_)); ++ ++ // Set the membership total to 2. ++ mock_local_cluster->info_->endpoint_stats_.membership_total_.set(2); ++ ++ ShareProviderManagerSharedPtr share_provider_manager = ++ ShareProviderManager::singleton(dispatcher, cm, manager); ++ EXPECT_NE(share_provider_manager, nullptr); ++ ++ auto provider = std::make_shared( ++ share_provider_manager->getShareProvider(ProtoLocalClusterRateLimit())); ++ ++ EXPECT_EQ(1, provider->tokensPerFill(1)); // At least 1 token per fill. ++ EXPECT_EQ(1, provider->tokensPerFill(2)); ++ EXPECT_EQ(2, provider->tokensPerFill(4)); ++ EXPECT_EQ(4, provider->tokensPerFill(8)); ++ ++ // Set the membership total to 4. ++ mock_local_cluster->info_->endpoint_stats_.membership_total_.set(4); ++ priority_set.runUpdateCallbacks(0, {}, {}); ++ ++ EXPECT_EQ(1, provider->tokensPerFill(1)); // At least 1 token per fill. ++ EXPECT_EQ(1, provider->tokensPerFill(4)); ++ EXPECT_EQ(2, provider->tokensPerFill(8)); ++ EXPECT_EQ(4, provider->tokensPerFill(16)); ++ ++ // Set the membership total to 0. ++ mock_local_cluster->info_->endpoint_stats_.membership_total_.set(0); ++ priority_set.runUpdateCallbacks(0, {}, {}); ++ ++ EXPECT_EQ(1, provider->tokensPerFill(1)); // At least 1 token per fill. ++ EXPECT_EQ(2, provider->tokensPerFill(2)); ++ EXPECT_EQ(4, provider->tokensPerFill(4)); ++ EXPECT_EQ(8, provider->tokensPerFill(8)); ++ ++ // Set the membership total to 1. ++ mock_local_cluster->info_->endpoint_stats_.membership_total_.set(1); ++ priority_set.runUpdateCallbacks(0, {}, {}); ++ ++ EXPECT_EQ(1, provider->tokensPerFill(1)); // At least 1 token per fill. ++ EXPECT_EQ(2, provider->tokensPerFill(2)); ++ EXPECT_EQ(4, provider->tokensPerFill(4)); ++ EXPECT_EQ(8, provider->tokensPerFill(8)); ++ ++ // Destroy the share provider manager. ++ // This is used to ensure the share provider is still safe to use even ++ // the share provider manager is destroyed. But note this should never ++ // happen in real production because the share provider manager should ++ // have longer life cycle than the limiter. ++ share_provider_manager.reset(); ++ ++ // Set the membership total to 4 again. ++ mock_local_cluster->info_->endpoint_stats_.membership_total_.set(4); ++ priority_set.runUpdateCallbacks(0, {}, {}); ++ ++ // The provider should still work but the value should not change. ++ EXPECT_EQ(1, provider->tokensPerFill(1)); // At least 1 token per fill. ++ EXPECT_EQ(2, provider->tokensPerFill(2)); ++ EXPECT_EQ(4, provider->tokensPerFill(4)); ++ EXPECT_EQ(8, provider->tokensPerFill(8)); ++} ++ ++class MockShareProvider : public ShareProvider { ++public: ++ MockShareProvider() = default; ++ MOCK_METHOD(double, getTokensShareFactor, (), (const)); ++}; ++ + class LocalRateLimiterImplTest : public testing::Test { + public: + void initializeTimer() { +@@ -24,12 +120,27 @@ public: + } + + void initialize(const std::chrono::milliseconds fill_interval, const uint32_t max_tokens, +- const uint32_t tokens_per_fill) { ++ const uint32_t tokens_per_fill, ShareProviderSharedPtr share_provider = nullptr) { + + initializeTimer(); + +- rate_limiter_ = std::make_shared( +- fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); ++ rate_limiter_ = ++ std::make_shared(fill_interval, max_tokens, tokens_per_fill, ++ dispatcher_, descriptors_, true, share_provider); ++ } ++ ++ void initializeWithAtomicTokenBucket(const std::chrono::milliseconds fill_interval, ++ const uint32_t max_tokens, const uint32_t tokens_per_fill, ++ ShareProviderSharedPtr share_provider = nullptr) { ++ ++ ++ TestScopedRuntime runtime; ++ runtime.mergeValues( ++ {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "true"}}); ++ ++ rate_limiter_ = ++ std::make_shared(fill_interval, max_tokens, tokens_per_fill, ++ dispatcher_, descriptors_, true, share_provider); + } + + Thread::ThreadSynchronizer& synchronizer() { return rate_limiter_->synchronizer_; } +@@ -68,15 +179,15 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { + synchronizer().barrierOn("on_fill_timer_pre_cas"); + + // This should succeed. +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. + synchronizer().signal("on_fill_timer_pre_cas"); + t1.join(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + } + + // This tests the case in which two allowed checks race. +@@ -87,12 +198,13 @@ TEST_F(LocalRateLimiterImplTest, CasEdgeCases) { + + // Start a thread and see if we are under limit. This will wait pre-CAS. + synchronizer().waitOn("allowed_pre_cas"); +- std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); }); ++ std::thread t1( ++ [&] { EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("allowed_pre_cas"); + + // Consume a token on this thread, which should cause the CAS to fail on the other thread. +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + synchronizer().signal("allowed_pre_cas"); + t1.join(); + } +@@ -103,17 +215,17 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { + initialize(std::chrono::milliseconds(200), 1, 1); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // 0 -> 1 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // 0 -> 1 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); +@@ -124,8 +236,8 @@ TEST_F(LocalRateLimiterImplTest, TokenBucket) { + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + } + + // Verify token bucket functionality with max tokens and tokens per fill > 1. +@@ -133,25 +245,59 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMultipleTokensPerFill) { + initialize(std::chrono::milliseconds(200), 2, 2); + + // 2 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // 0 -> 2 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 1 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // 1 -> 2 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); + fill_timer_->invokeCallback(); + + // 2 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++} ++ ++// Verify token bucket functionality with max tokens and tokens per fill > 1 and ++// share provider is used. ++TEST_F(LocalRateLimiterImplTest, TokenBucketMultipleTokensPerFillWithShareProvider) { ++ auto share_provider = std::make_shared(); ++ EXPECT_CALL(*share_provider, getTokensShareFactor()) ++ .WillRepeatedly(testing::Invoke([]() -> double { return 0.5; })); ++ ++ // Final tokens per fill is 2/2 = 1. ++ initialize(std::chrono::milliseconds(200), 2, 2, share_provider); ++ ++ // The limiter will be initialized with max tokens and it will not be shared. ++ // So, the initial tokens is 2. ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // The tokens per fill will be handled by the share provider and it will be 1. ++ // 0 -> 1 tokens ++ EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); ++ fill_timer_->invokeCallback(); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 1 tokens ++ EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); ++ fill_timer_->invokeCallback(); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + } + + // Verify token bucket functionality with max tokens > tokens per fill. +@@ -159,17 +305,17 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketMaxTokensGreaterThanTokensPerFill) { + initialize(std::chrono::milliseconds(200), 2, 1); + + // 2 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + + // 0 -> 1 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); + } + + // Verify token bucket status of max tokens, remaining tokens and remaining fill interval. +@@ -178,34 +324,42 @@ TEST_F(LocalRateLimiterImplTest, TokenBucketStatus) { + + // 2 -> 1 tokens + EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(3000), nullptr)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_EQ(rate_limiter_->maxTokens(route_descriptors_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(route_descriptors_), 1); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(route_descriptors_), 3); ++ auto rate_limit_result = rate_limiter_->requestAllowed(route_descriptors_); ++ EXPECT_TRUE(rate_limit_result.allowed); ++ ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3.0); + + // 1 -> 0 tokens + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +- EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_EQ(rate_limiter_->maxTokens(route_descriptors_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(route_descriptors_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(route_descriptors_), 2); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 2.0); + + // 0 -> 0 tokens + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +- EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_)); +- EXPECT_EQ(rate_limiter_->maxTokens(route_descriptors_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(route_descriptors_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(route_descriptors_), 1); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 1.0); + + // 0 -> 2 tokens + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + fill_timer_->invokeCallback(); +- EXPECT_EQ(rate_limiter_->maxTokens(route_descriptors_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(route_descriptors_), 2); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(route_descriptors_), 3); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3.0); + } + + class LocalRateLimiterDescriptorImplTest : public LocalRateLimiterImplTest { +@@ -213,11 +367,26 @@ public: + void initializeWithDescriptor(const std::chrono::milliseconds fill_interval, + const uint32_t max_tokens, const uint32_t tokens_per_fill) { + ++ // TestScopedRuntime runtime; ++ // runtime.mergeValues( ++ // {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "false"}}); ++ + initializeTimer(); + + rate_limiter_ = std::make_shared( + fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); + } ++ ++ void initializeWithAtomicTokenBucketDescriptor(const std::chrono::milliseconds fill_interval, ++ const uint32_t max_tokens, ++ const uint32_t tokens_per_fill) { ++ TestScopedRuntime runtime; ++ runtime.mergeValues( ++ {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "true"}}); ++ rate_limiter_ = std::make_shared( ++ fill_interval, max_tokens, tokens_per_fill, dispatcher_, descriptors_); ++ } ++ + static constexpr absl::string_view single_descriptor_config_yaml = R"( + entries: + - key: foo2 +@@ -237,7 +406,7 @@ public: + token_bucket: + max_tokens: 1 + tokens_per_fill: 1 +- fill_interval: 0.05s ++ fill_interval: 1s + )"; + + // Default token bucket +@@ -299,15 +468,15 @@ TEST_F(LocalRateLimiterDescriptorImplTest, CasEdgeCasesDescriptor) { + synchronizer().barrierOn("on_fill_timer_pre_cas"); + + // This should succeed. +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // Now signal the thread to continue which should cause a CAS failure and the loop to repeat. + synchronizer().signal("on_fill_timer_pre_cas"); + t1.join(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + } + + // This tests the case in which two allowed checks race. +@@ -318,12 +487,12 @@ TEST_F(LocalRateLimiterDescriptorImplTest, CasEdgeCasesDescriptor) { + + // Start a thread and see if we are under limit. This will wait pre-CAS. + synchronizer().waitOn("allowed_pre_cas"); +- std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); }); ++ std::thread t1([&] { EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); }); + // Wait until the thread is actually waiting. + synchronizer().barrierOn("allowed_pre_cas"); + + // Consume a token on this thread, which should cause the CAS to fail on the other thread. +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); + synchronizer().signal("allowed_pre_cas"); + t1.join(); + } +@@ -334,9 +503,9 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor2) { + *descriptors_.Add()); + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(100), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); + } +@@ -348,8 +517,8 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor) { + initializeWithDescriptor(std::chrono::milliseconds(50), 1, 1); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 0 -> 1 tokens + for (int i = 0; i < 2; i++) { +@@ -360,8 +529,8 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor) { + } + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 0 -> 1 tokens + for (int i = 0; i < 2; i++) { +@@ -380,8 +549,8 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptor) { + } + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + } + + // Verify token bucket functionality with request per unit > 1. +@@ -391,9 +560,9 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketMultipleTokensPerFillDescr + initializeWithDescriptor(std::chrono::milliseconds(50), 2, 2); + + // 2 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 0 -> 2 tokens + for (int i = 0; i < 2; i++) { +@@ -404,7 +573,7 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketMultipleTokensPerFillDescr + } + + // 2 -> 1 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 1 -> 2 tokens + for (int i = 0; i < 2; i++) { +@@ -415,34 +584,34 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketMultipleTokensPerFillDescr + } + + // 2 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + } + + // Verify token bucket functionality with multiple descriptors. + TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDifferentDescriptorDifferentRateLimits) { + TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); +- TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "1000s"), ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "2s"), + *descriptors_.Add()); +- initializeWithDescriptor(std::chrono::milliseconds(50), 3, 1); ++ initializeWithDescriptor(std::chrono::milliseconds(1000), 3, 1); + + // 1 -> 0 tokens for descriptor_ and descriptor2_ +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 0 -> 1 tokens for descriptor2_ +- dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(50), dispatcher_, ++ dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +- EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); ++ EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(1000), nullptr)); + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens for descriptor2_ and 0 only for descriptor_ +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + } + + // Verify token bucket functionality with multiple descriptors sorted. +@@ -456,10 +625,10 @@ TEST_F(LocalRateLimiterDescriptorImplTest, + {{{"foo2", "bar2"}}}}; + + // Descriptors are sorted as descriptor2 < descriptor < global +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptors)); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptors)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptors).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptors).allowed); + // Request limited by descriptor2 will not consume tokens from descriptor. +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); + } + + // Verify token bucket status of max tokens, remaining tokens and remaining fill interval. +@@ -469,10 +638,14 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptorStatus) { + initializeWithDescriptor(std::chrono::milliseconds(1000), 2, 2); + + // 2 -> 1 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 1); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 3); ++ auto rate_limit_result = rate_limiter_->requestAllowed(descriptor_); ++ ++ EXPECT_TRUE(rate_limit_result.allowed); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3.0); + + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +@@ -480,10 +653,11 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptorStatus) { + fill_timer_->invokeCallback(); + + // 1 -> 0 tokens +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 2); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 2.0); + + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +@@ -491,10 +665,11 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptorStatus) { + fill_timer_->invokeCallback(); + + // 0 -> 0 tokens +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 1); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 1.0); + + dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +@@ -502,9 +677,10 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDescriptorStatus) { + fill_timer_->invokeCallback(); + + // 0 -> 2 tokens +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 3); ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3.0); + } + + // Verify token bucket status of max tokens, remaining tokens and remaining fill interval with +@@ -513,47 +689,343 @@ TEST_F(LocalRateLimiterDescriptorImplTest, TokenBucketDifferentDescriptorStatus) + TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); + TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "3s"), + *descriptors_.Add()); +- initializeWithDescriptor(std::chrono::milliseconds(50), 2, 1); ++ initializeWithDescriptor(std::chrono::milliseconds(1000), 20, 20); + + // 2 -> 1 tokens for descriptor_ +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 1); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 3); ++ auto rate_limit_result = rate_limiter_->requestAllowed(descriptor_); + +- // 1 -> 0 tokens for descriptor_ and descriptor2_ +- EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor2_), 1); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor2_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor2_), 0); ++ EXPECT_TRUE(rate_limit_result.allowed); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3); ++ ++ // 1 -> 0 tokens for descriptor_ ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3); ++ ++ // 1 -> 0 tokens for descriptor2_ ++ auto rate_limit_result2 = rate_limiter_->requestAllowed(descriptor2_); ++ EXPECT_TRUE(rate_limit_result2.allowed); ++ ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->maxTokens(), 1); ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->remainingTokens(), 0); ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->remainingFillInterval().value(), 1); + + // 0 -> 0 tokens for descriptor_ and descriptor2_ +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor2_), 1); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor2_), 0); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor2_), 0); +- EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_)); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); + + // 0 -> 1 tokens for descriptor2_ +- dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(50), dispatcher_, ++ dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +- EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); ++ EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(1000), nullptr)); + fill_timer_->invokeCallback(); +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor2_), 1); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor2_), 1); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor2_), 0); ++ ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->maxTokens(), 1); ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->remainingTokens(), 1); ++ EXPECT_EQ(rate_limit_result2.token_bucket_context->remainingFillInterval().value(), 1); + + // 0 -> 2 tokens for descriptor_ +- for (int i = 0; i < 60; i++) { +- dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(50), dispatcher_, ++ for (int i = 0; i < 2; i++) { ++ dispatcher_.globalTimeSystem().advanceTimeAndRun(std::chrono::milliseconds(1000), dispatcher_, + Envoy::Event::Dispatcher::RunType::NonBlock); +- EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(50), nullptr)); ++ EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(1000), nullptr)); + fill_timer_->invokeCallback(); + } +- EXPECT_EQ(rate_limiter_->maxTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingTokens(descriptor_), 2); +- EXPECT_EQ(rate_limiter_->remainingFillInterval(descriptor_), 3); ++ ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingFillInterval().value(), 3.0); ++} ++ ++// Verify token bucket functionality with a single token. ++TEST_F(LocalRateLimiterImplTest, AtomicTokenBucket) { ++ initializeWithAtomicTokenBucket(std::chrono::milliseconds(200), 1, 1); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 1 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++} ++ ++// Verify token bucket functionality with max tokens and tokens per fill > 1. ++TEST_F(LocalRateLimiterImplTest, AtomicTokenBucketMultipleTokensPerFill) { ++ initializeWithAtomicTokenBucket(std::chrono::milliseconds(200), 2, 2); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 2 -> 1 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 1 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++} ++ ++// Verify token bucket functionality with max tokens and tokens per fill > 1 and ++// share provider is used. ++TEST_F(LocalRateLimiterImplTest, AtomicTokenBucketMultipleTokensPerFillWithShareProvider) { ++ auto share_provider = std::make_shared(); ++ EXPECT_CALL(*share_provider, getTokensShareFactor()) ++ .WillRepeatedly(testing::Invoke([]() -> double { return 0.5; })); ++ ++ initializeWithAtomicTokenBucket(std::chrono::milliseconds(200), 2, 2, share_provider); ++ ++ // Every request will consume 1 / factor = 2 tokens. ++ ++ // The limiter will be initialized with max tokens and will be consumed at once. ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++} ++ ++// Verify token bucket functionality with max tokens > tokens per fill. ++TEST_F(LocalRateLimiterImplTest, AtomicTokenBucketMaxTokensGreaterThanTokensPerFill) { ++ initializeWithAtomicTokenBucket(std::chrono::milliseconds(200), 2, 1); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++} ++ ++// Verify token bucket status of max tokens, remaining tokens and remaining fill interval. ++TEST_F(LocalRateLimiterImplTest, AtomicTokenBucketStatus) { ++ initializeWithAtomicTokenBucket(std::chrono::milliseconds(3000), 2, 2); ++ ++ // 2 -> 1 tokens ++ auto rate_limit_result = rate_limiter_->requestAllowed(route_descriptors_); ++ EXPECT_TRUE(rate_limit_result.allowed); ++ ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(1500)); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // 0 -> 0 tokens ++ EXPECT_FALSE(rate_limiter_->requestAllowed(route_descriptors_).allowed); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ ++ // 0 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(3000)); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 2); ++} ++ ++TEST_F(LocalRateLimiterDescriptorImplTest, AtomicTokenBucketDescriptorBase) { ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(50), 1, 1); ++ ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++} ++ ++TEST_F(LocalRateLimiterDescriptorImplTest, AtomicTokenBucketDescriptor) { ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "0.1s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(50), 1, 1); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(100)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ ++ // 0 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(100)); ++ ++ // 1 -> 1 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(100)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++} ++ ++// Verify token bucket functionality with request per unit > 1. ++TEST_F(LocalRateLimiterDescriptorImplTest, AtomicTokenBucketMultipleTokensPerFillDescriptor) { ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "0.1s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(50), 2, 2); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ ++ // 0 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(100)); ++ ++ // 2 -> 1 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ ++ // 1 -> 2 tokens ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(50)); ++ ++ // 2 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++} ++ ++// Verify token bucket functionality with multiple descriptors. ++TEST_F(LocalRateLimiterDescriptorImplTest, ++ AtomicTokenBucketDifferentDescriptorDifferentRateLimits) { ++ TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 1, 1, "2s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(1000), 3, 3); ++ ++ // 1 -> 0 tokens for descriptor_ and descriptor2_ ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ ++ // 0 -> 1 tokens for descriptor2_ ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(1000)); ++ ++ // 1 -> 0 tokens for descriptor2_ and 0 only for descriptor_ ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor2_).allowed); ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++} ++ ++// Verify token bucket functionality with multiple descriptors sorted. ++TEST_F(LocalRateLimiterDescriptorImplTest, ++ AtomicTokenBucketDifferentDescriptorDifferentRateLimitsSorted) { ++ TestUtility::loadFromYaml(multiple_descriptor_config_yaml, *descriptors_.Add()); ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "1s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(50), 3, 3); ++ ++ std::vector descriptors{{{{"hello", "world"}, {"foo", "bar"}}}, ++ {{{"foo2", "bar2"}}}}; ++ ++ // Descriptors are sorted as descriptor2 < descriptor < global ++ // Descriptor2 from 1 -> 0 tokens ++ // Descriptor from 2 -> 1 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptors).allowed); ++ // Request limited by descriptor2 and won't consume tokens from descriptor. ++ // Descriptor2 from 0 -> 0 tokens ++ // Descriptor from 1 -> 1 tokens ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptors).allowed); ++ // Descriptor from 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ // Descriptor from 0 -> 0 tokens ++ EXPECT_FALSE(rate_limiter_->requestAllowed(descriptor_).allowed); ++} ++ ++// Verify token bucket status of max tokens, remaining tokens and remaining fill interval. ++TEST_F(LocalRateLimiterDescriptorImplTest, AtomicTokenBucketDescriptorStatus) { ++ TestUtility::loadFromYaml(fmt::format(single_descriptor_config_yaml, 2, 2, "3s"), ++ *descriptors_.Add()); ++ initializeWithAtomicTokenBucketDescriptor(std::chrono::milliseconds(1000), 2, 2); ++ ++ // 2 -> 1 tokens ++ auto rate_limit_result = rate_limiter_->requestAllowed(descriptor_); ++ ++ EXPECT_TRUE(rate_limit_result.allowed); ++ ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 1); ++ ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(500)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ ++ // 0 -> 1 tokens. 1500ms passed and 1 token will be added. ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(1000)); ++ ++ // 1 -> 0 tokens ++ EXPECT_TRUE(rate_limiter_->requestAllowed(descriptor_).allowed); ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 0); ++ ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(3000)); ++ ++ // 0 -> 2 tokens ++ // Note that the route descriptors are not changed so we can reuse the same token bucket context. ++ EXPECT_EQ(rate_limit_result.token_bucket_context->maxTokens(), 2); ++ EXPECT_EQ(rate_limit_result.token_bucket_context->remainingTokens(), 2); + } + + } // Namespace LocalRateLimit +diff --git a/test/extensions/filters/http/local_ratelimit/BUILD b/test/extensions/filters/http/local_ratelimit/BUILD +index 5d997f8345..3175a577e0 100644 +--- test/extensions/filters/http/local_ratelimit/BUILD ++++ test/extensions/filters/http/local_ratelimit/BUILD +@@ -16,10 +16,14 @@ envoy_extension_cc_test( + srcs = ["filter_test.cc"], + extension_names = ["envoy.filters.http.local_ratelimit"], + deps = [ ++ "//source/common/singleton:manager_impl_lib", + "//source/extensions/filters/http/local_ratelimit:local_ratelimit_lib", + "//test/common/stream_info:test_util", + "//test/mocks/http:http_mocks", + "//test/mocks/local_info:local_info_mocks", ++ "//test/mocks/upstream:cluster_manager_mocks", ++ "//test/test_common:test_runtime_lib", ++ "//test/test_common:utility_lib", + "@envoy_api//envoy/extensions/filters/http/local_ratelimit/v3:pkg_cc_proto", + ], + ) +diff --git a/test/extensions/filters/http/local_ratelimit/config_test.cc b/test/extensions/filters/http/local_ratelimit/config_test.cc +index 37c3a991e2..ba02999bee 100644 +--- test/extensions/filters/http/local_ratelimit/config_test.cc ++++ test/extensions/filters/http/local_ratelimit/config_test.cc +@@ -2,6 +2,7 @@ + #include "source/extensions/filters/http/local_ratelimit/local_ratelimit.h" + + #include "test/mocks/server/mocks.h" ++#include "test/mocks/upstream/priority_set.h" + + #include "gmock/gmock.h" + #include "gtest/gtest.h" +@@ -63,7 +64,7 @@ response_headers_to_add: + const auto route_config = factory.createRouteSpecificFilterConfig( + *proto_config, context, ProtobufMessage::getNullValidationVisitor()); + const auto* config = dynamic_cast(route_config.get()); +- EXPECT_TRUE(config->requestAllowed({})); ++ EXPECT_TRUE(config->requestAllowed({}).allowed); + } + + TEST(Factory, EnabledEnforcedDisabledByDefault) { +@@ -221,7 +222,7 @@ descriptors: + const auto route_config = factory.createRouteSpecificFilterConfig( + *proto_config, context, ProtobufMessage::getNullValidationVisitor()); + const auto* config = dynamic_cast(route_config.get()); +- EXPECT_TRUE(config->requestAllowed({})); ++ EXPECT_TRUE(config->requestAllowed({}).allowed); + } + + TEST(Factory, RouteSpecificFilterConfigWithDescriptorsTimerNotDivisible) { +@@ -306,12 +307,148 @@ response_headers_to_add: + + NiceMock context; + +- EXPECT_CALL(context.dispatcher_, createTimer_(_)); + EXPECT_THROW(factory.createRouteSpecificFilterConfig(*proto_config, context, + ProtobufMessage::getNullValidationVisitor()), + EnvoyException); + } + ++TEST(Factory, LocalClusterRateLimitAndLocalRateLimitPerDownstreamConnection) { ++ const std::string config_yaml = R"( ++stat_prefix: test ++token_bucket: ++ max_tokens: 1 ++ tokens_per_fill: 1 ++ fill_interval: 1000s ++filter_enabled: ++ runtime_key: test_enabled ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++filter_enforced: ++ runtime_key: test_enforced ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++local_cluster_rate_limit: {} ++local_rate_limit_per_downstream_connection: true ++)"; ++ ++ LocalRateLimitFilterConfig factory; ++ ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); ++ TestUtility::loadFromYaml(config_yaml, *proto_config); ++ ++ NiceMock context; ++ ++ EXPECT_THROW_WITH_MESSAGE( ++ factory.createRouteSpecificFilterConfig(*proto_config, context, ++ ProtobufMessage::getNullValidationVisitor()), ++ EnvoyException, ++ "local_cluster_rate_limit is set and local_rate_limit_per_downstream_connection is set to " ++ "true"); ++} ++ ++TEST(Factory, LocalClusterRateLimitAndWithoutLocalClusterName) { ++ const std::string config_yaml = R"( ++stat_prefix: test ++token_bucket: ++ max_tokens: 1 ++ tokens_per_fill: 1 ++ fill_interval: 1000s ++filter_enabled: ++ runtime_key: test_enabled ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++filter_enforced: ++ runtime_key: test_enforced ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++local_cluster_rate_limit: {} ++)"; ++ ++ LocalRateLimitFilterConfig factory; ++ ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); ++ TestUtility::loadFromYaml(config_yaml, *proto_config); ++ ++ NiceMock context; ++ ++ EXPECT_THROW_WITH_MESSAGE( ++ factory.createRouteSpecificFilterConfig(*proto_config, context, ++ ProtobufMessage::getNullValidationVisitor()), ++ EnvoyException, "local_cluster_rate_limit is set but no local cluster name is present"); ++} ++ ++TEST(Factory, LocalClusterRateLimitAndWithoutLocalCluster) { ++ const std::string config_yaml = R"( ++stat_prefix: test ++token_bucket: ++ max_tokens: 1 ++ tokens_per_fill: 1 ++ fill_interval: 1000s ++filter_enabled: ++ runtime_key: test_enabled ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++filter_enforced: ++ runtime_key: test_enforced ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++local_cluster_rate_limit: {} ++)"; ++ ++ LocalRateLimitFilterConfig factory; ++ ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); ++ TestUtility::loadFromYaml(config_yaml, *proto_config); ++ ++ NiceMock context; ++ context.cluster_manager_.local_cluster_name_ = "local_cluster"; ++ ++ EXPECT_THROW_WITH_MESSAGE( ++ factory.createRouteSpecificFilterConfig(*proto_config, context, ++ ProtobufMessage::getNullValidationVisitor()), ++ EnvoyException, "local_cluster_rate_limit is set but no local cluster is present"); ++} ++ ++TEST(Factory, LocalClusterRateLimit) { ++ const std::string config_yaml = R"( ++stat_prefix: test ++token_bucket: ++ max_tokens: 1 ++ tokens_per_fill: 1 ++ fill_interval: 1000s ++filter_enabled: ++ runtime_key: test_enabled ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++filter_enforced: ++ runtime_key: test_enforced ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++local_cluster_rate_limit: {} ++)"; ++ ++ LocalRateLimitFilterConfig factory; ++ ProtobufTypes::MessagePtr proto_config = factory.createEmptyRouteConfigProto(); ++ TestUtility::loadFromYaml(config_yaml, *proto_config); ++ ++ NiceMock context; ++ context.cluster_manager_.local_cluster_name_ = "local_cluster"; ++ context.cluster_manager_.initializeClusters({"local_cluster"}, {}); ++ ++ NiceMock priority_set; ++ const auto* local_cluster = context.cluster_manager_.active_clusters_.at("local_cluster").get(); ++ EXPECT_CALL(*local_cluster, prioritySet()).WillOnce(ReturnRef(priority_set)); ++ ++ EXPECT_CALL(context.dispatcher_, createTimer_(_)); ++ EXPECT_NO_THROW(factory.createRouteSpecificFilterConfig( ++ *proto_config, context, ProtobufMessage::getNullValidationVisitor())); ++} ++ + } // namespace LocalRateLimitFilter + } // namespace HttpFilters + } // namespace Extensions +diff --git a/test/extensions/filters/http/local_ratelimit/filter_test.cc b/test/extensions/filters/http/local_ratelimit/filter_test.cc +index f5b7e121a1..5a40a97752 100644 +--- test/extensions/filters/http/local_ratelimit/filter_test.cc ++++ test/extensions/filters/http/local_ratelimit/filter_test.cc +@@ -1,9 +1,13 @@ + #include "envoy/extensions/filters/http/local_ratelimit/v3/local_rate_limit.pb.h" + ++#include "source/common/singleton/manager_impl.h" + #include "source/extensions/filters/http/local_ratelimit/local_ratelimit.h" + + #include "test/mocks/http/mocks.h" + #include "test/mocks/local_info/mocks.h" ++#include "test/mocks/upstream/cluster_manager.h" ++#include "test/test_common/test_runtime.h" ++#include "test/test_common/thread_factory_for_test.h" + + #include "gmock/gmock.h" + #include "gtest/gtest.h" +@@ -75,8 +79,9 @@ public: + + envoy::extensions::filters::http::local_ratelimit::v3::LocalRateLimit config; + TestUtility::loadFromYaml(yaml, config); +- config_ = std::make_shared(config, local_info_, dispatcher_, *stats_.rootScope(), +- runtime_, per_route); ++ config_ = ++ std::make_shared(config, local_info_, dispatcher_, cm_, singleton_manager_, ++ *stats_.rootScope(), runtime_, per_route); + filter_ = std::make_shared(config_); + filter_->setDecoderFilterCallbacks(decoder_callbacks_); + +@@ -100,6 +105,9 @@ public: + NiceMock dispatcher_; + NiceMock runtime_; + NiceMock local_info_; ++ NiceMock cm_; ++ Singleton::ManagerImpl singleton_manager_; ++ + std::shared_ptr config_; + std::shared_ptr filter_; + std::shared_ptr filter_2_; +@@ -293,6 +301,31 @@ TEST_F(FilterTest, RequestRateLimitedXRateLimitHeaders) { + auto request_headers = Http::TestRequestHeaderMapImpl(); + auto response_headers = Http::TestResponseHeaderMapImpl(); + ++ EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); ++ EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); ++ EXPECT_EQ("1", response_headers.get_("x-ratelimit-limit")); ++ EXPECT_EQ("0", response_headers.get_("x-ratelimit-remaining")); ++ EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, ++ filter_2_->decodeHeaders(request_headers, false)); ++ EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_2_->encodeHeaders(response_headers, false)); ++ EXPECT_EQ("1", response_headers.get_("x-ratelimit-limit")); ++ EXPECT_EQ("0", response_headers.get_("x-ratelimit-remaining")); ++ EXPECT_EQ(2U, findCounter("test.http_local_rate_limit.enabled")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enforced")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.ok")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); ++} ++ ++TEST_F(FilterTest, RequestRateLimitedXRateLimitHeadersWithTimerBasedTokenBucket) { ++ TestScopedRuntime runtime; ++ runtime.mergeValues( ++ {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "false"}}); ++ ++ setup(fmt::format(config_yaml, "false", "1", "false", "DRAFT_VERSION_03")); ++ ++ auto request_headers = Http::TestRequestHeaderMapImpl(); ++ auto response_headers = Http::TestResponseHeaderMapImpl(); ++ + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->decodeHeaders(request_headers, false)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); + EXPECT_EQ("1", response_headers.get_("x-ratelimit-limit")); +@@ -310,6 +343,7 @@ TEST_F(FilterTest, RequestRateLimitedXRateLimitHeaders) { + EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); + } + ++ + static constexpr absl::string_view descriptor_config_yaml = R"( + stat_prefix: test + token_bucket: +@@ -539,7 +573,6 @@ TEST_F(DescriptorFilterTest, RouteDescriptorRequestRatelimited) { + + TEST_F(DescriptorFilterTest, RouteDescriptorNotFound) { + setUpTest(fmt::format(descriptor_config_yaml, "1", "\"OFF\"", "1", "0")); +- + EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, + getApplicableRateLimit(0)); + +@@ -662,6 +695,33 @@ TEST_F(DescriptorFilterTest, RouteDescriptorRequestRatelimitedXRateLimitHeaders) + auto request_headers = Http::TestRequestHeaderMapImpl(); + auto response_headers = Http::TestResponseHeaderMapImpl(); + ++ EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, ++ filter_->decodeHeaders(request_headers, false)); ++ EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); ++ EXPECT_EQ("0", response_headers.get_("x-ratelimit-limit")); ++ EXPECT_EQ("0", response_headers.get_("x-ratelimit-remaining")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enabled")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.enforced")); ++ EXPECT_EQ(1U, findCounter("test.http_local_rate_limit.rate_limited")); ++} ++ ++TEST_F(DescriptorFilterTest, ++ RouteDescriptorRequestRatelimitedXRateLimitHeadersWithTimerTokenBucket) { ++ TestScopedRuntime runtime; ++ runtime.mergeValues( ++ {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "false"}}); ++ ++ setUpTest(fmt::format(descriptor_config_yaml, "0", "DRAFT_VERSION_03", "0", "0")); ++ ++ EXPECT_CALL(decoder_callbacks_.route_->route_entry_.rate_limit_policy_, ++ getApplicableRateLimit(0)); ++ ++ EXPECT_CALL(route_rate_limit_, populateLocalDescriptors(_, _, _, _)) ++ .WillOnce(testing::SetArgReferee<0>(descriptor_)); ++ ++ auto request_headers = Http::TestRequestHeaderMapImpl(); ++ auto response_headers = Http::TestResponseHeaderMapImpl(); ++ + EXPECT_EQ(Http::FilterHeadersStatus::StopIteration, + filter_->decodeHeaders(request_headers, false)); + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter_->encodeHeaders(response_headers, false)); +diff --git a/test/extensions/filters/http/local_ratelimit/local_ratelimit_integration_test.cc b/test/extensions/filters/http/local_ratelimit/local_ratelimit_integration_test.cc +index 7c9834d9d3..2171ea59eb 100644 +--- test/extensions/filters/http/local_ratelimit/local_ratelimit_integration_test.cc ++++ test/extensions/filters/http/local_ratelimit/local_ratelimit_integration_test.cc +@@ -1,3 +1,5 @@ ++#include "source/extensions/filters/common/local_ratelimit/local_ratelimit_impl.h" ++ + #include "test/integration/http_protocol_integration.h" + + #include "gtest/gtest.h" +@@ -18,8 +20,8 @@ protected: + const std::string& initial_route_config) { + // Set this flag to true to create fake upstream for xds_cluster. + create_xds_upstream_ = true; +- // Create static clusters. +- createClusters(); ++ // Create static XDS cluster. ++ createXdsCluster(); + + config_helper_.prependFilter(filter_config); + +@@ -60,7 +62,38 @@ protected: + registerTestServerPorts({"http"}); + } + +- void createClusters() { ++ void initializeFilterWithLocalCluster(const std::string& filter_config, ++ const std::string& initial_local_cluster_endpoints) { ++ config_helper_.prependFilter(filter_config); ++ ++ // Set this flag to true to create fake upstream for xds_cluster. ++ create_xds_upstream_ = true; ++ // Create static XDS cluster. ++ createXdsCluster(); ++ ++ // Create local cluster. ++ createLocalCluster(); ++ ++ on_server_init_function_ = [&]() { ++ AssertionResult result = ++ fake_upstreams_[1]->waitForHttpConnection(*dispatcher_, xds_connection_); ++ RELEASE_ASSERT(result, result.message()); ++ result = xds_connection_->waitForNewStream(*dispatcher_, xds_stream_); ++ RELEASE_ASSERT(result, result.message()); ++ xds_stream_->startGrpcStream(); ++ ++ EXPECT_TRUE(compareSotwDiscoveryRequest(Config::TypeUrl::get().ClusterLoadAssignment, "", ++ {"local_cluster"}, true)); ++ sendSotwDiscoveryResponse( ++ Config::TypeUrl::get().ClusterLoadAssignment, ++ {TestUtility::parseYaml( ++ initial_local_cluster_endpoints)}, ++ "1"); ++ }; ++ initialize(); ++ } ++ ++ void createXdsCluster() { + config_helper_.addConfigModifier([](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { + auto* xds_cluster = bootstrap.mutable_static_resources()->add_clusters(); + xds_cluster->MergeFrom(bootstrap.static_resources().clusters()[0]); +@@ -69,6 +102,34 @@ protected: + }); + } + ++ void createLocalCluster() { ++ config_helper_.addConfigModifier([](envoy::config::bootstrap::v3::Bootstrap& bootstrap) { ++ // Set local cluster name to "local_cluster". ++ bootstrap.mutable_cluster_manager()->set_local_cluster_name("local_cluster"); ++ ++ // Create local cluster. ++ auto* local_cluster = bootstrap.mutable_static_resources()->add_clusters(); ++ local_cluster->MergeFrom(bootstrap.static_resources().clusters()[0]); ++ local_cluster->set_name("local_cluster"); ++ local_cluster->clear_load_assignment(); ++ ++ // This should be EDS cluster to load endpoints dynamically. ++ local_cluster->set_type(::envoy::config::cluster::v3::Cluster::EDS); ++ local_cluster->mutable_eds_cluster_config()->set_service_name("local_cluster"); ++ local_cluster->mutable_eds_cluster_config()->mutable_eds_config()->set_resource_api_version( ++ envoy::config::core::v3::ApiVersion::V3); ++ envoy::config::core::v3::ApiConfigSource* eds_api_config_source = ++ local_cluster->mutable_eds_cluster_config() ++ ->mutable_eds_config() ++ ->mutable_api_config_source(); ++ eds_api_config_source->set_api_type(envoy::config::core::v3::ApiConfigSource::GRPC); ++ eds_api_config_source->set_transport_api_version(envoy::config::core::v3::V3); ++ envoy::config::core::v3::GrpcService* grpc_service = ++ eds_api_config_source->add_grpc_services(); ++ grpc_service->mutable_envoy_grpc()->set_cluster_name("xds_cluster"); ++ }); ++ } ++ + void cleanUpXdsConnection() { + if (xds_connection_ != nullptr) { + AssertionResult result = xds_connection_->close(); +@@ -107,6 +168,61 @@ typed_config: + local_rate_limit_per_downstream_connection: {} + )EOF"; + ++ const std::string filter_config_with_local_cluster_rate_limit_ = ++ R"EOF( ++name: envoy.filters.http.local_ratelimit ++typed_config: ++ "@type": type.googleapis.com/envoy.extensions.filters.http.local_ratelimit.v3.LocalRateLimit ++ stat_prefix: http_local_rate_limiter ++ token_bucket: ++ max_tokens: 1 ++ tokens_per_fill: 1 ++ fill_interval: 1000s ++ filter_enabled: ++ runtime_key: local_rate_limit_enabled ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++ filter_enforced: ++ runtime_key: local_rate_limit_enforced ++ default_value: ++ numerator: 100 ++ denominator: HUNDRED ++ response_headers_to_add: ++ - append_action: OVERWRITE_IF_EXISTS_OR_ADD ++ header: ++ key: x-local-rate-limit ++ value: 'true' ++ local_cluster_rate_limit: {} ++)EOF"; ++ ++ const std::string initial_local_cluster_endpoints_ = R"EOF( ++cluster_name: local_cluster ++endpoints: ++- lb_endpoints: ++ - endpoint: ++ address: ++ socket_address: ++ address: 127.0.0.1 ++ port_value: 80 ++)EOF"; ++ ++ const std::string update_local_cluster_endpoints_ = R"EOF( ++cluster_name: local_cluster ++endpoints: ++- lb_endpoints: ++ - endpoint: ++ address: ++ socket_address: ++ address: 127.0.0.1 ++ port_value: 80 ++ - endpoint: ++ address: ++ socket_address: ++ address: 127.0.0.1 ++ port_value: 81 ++)EOF"; ++ + const std::string initial_route_config_ = R"EOF( + name: basic_routes + virtual_hosts: +@@ -315,5 +431,36 @@ TEST_P(LocalRateLimitFilterIntegrationTest, BasicTestPerRouteAndRds) { + cleanUpXdsConnection(); + } + ++TEST_P(LocalRateLimitFilterIntegrationTest, TestLocalClusterRateLimit) { ++ initializeFilterWithLocalCluster(filter_config_with_local_cluster_rate_limit_, ++ initial_local_cluster_endpoints_); ++ ++ auto share_provider_manager = ++ test_server_->server() ++ .singletonManager() ++ .getTyped( ++ "local_ratelimit_share_provider_manager_singleton"); ++ ASSERT(share_provider_manager != nullptr); ++ auto share_provider = share_provider_manager->getShareProvider({}); ++ ++ test_server_->waitForGaugeEq("cluster.local_cluster.membership_total", 1); ++ simTime().advanceTimeWait(std::chrono::milliseconds(1)); ++ ++ EXPECT_EQ(1.0, share_provider->getTokensShareFactor()); ++ ++ sendSotwDiscoveryResponse( ++ Config::TypeUrl::get().ClusterLoadAssignment, ++ {TestUtility::parseYaml( ++ update_local_cluster_endpoints_)}, ++ "2"); ++ ++ test_server_->waitForGaugeEq("cluster.local_cluster.membership_total", 2); ++ simTime().advanceTimeWait(std::chrono::milliseconds(1)); ++ ++ EXPECT_EQ(0.5, share_provider->getTokensShareFactor()); ++ ++ cleanUpXdsConnection(); ++} ++ + } // namespace + } // namespace Envoy +diff --git a/test/extensions/filters/listener/local_ratelimit/local_ratelimit_test.cc b/test/extensions/filters/listener/local_ratelimit/local_ratelimit_test.cc +index 2a6999077f..c8c446f3c9 100644 +--- test/extensions/filters/listener/local_ratelimit/local_ratelimit_test.cc ++++ test/extensions/filters/listener/local_ratelimit/local_ratelimit_test.cc +@@ -56,14 +56,9 @@ public: + NiceMock io_handle_; + }; + +- uint64_t initialize(const std::string& filter_yaml, bool expect_timer_create = true) { ++ uint64_t initialize(const std::string& filter_yaml) { + envoy::extensions::filters::listener::local_ratelimit::v3::LocalRateLimit proto_config; + TestUtility::loadFromYaml(filter_yaml, proto_config); +- fill_timer_ = new Event::MockTimer(&dispatcher_); +- if (expect_timer_create) { +- EXPECT_CALL(*fill_timer_, enableTimer(_, nullptr)); +- EXPECT_CALL(*fill_timer_, disableTimer()); +- } + config_ = std::make_shared(proto_config, dispatcher_, *stats_store_.rootScope(), + runtime_); + return proto_config.token_bucket().max_tokens(); +@@ -72,7 +67,6 @@ public: + NiceMock dispatcher_; + Stats::IsolatedStoreImpl stats_store_; + NiceMock runtime_; +- Event::MockTimer* fill_timer_{}; + FilterConfigSharedPtr config_; + }; + +@@ -96,6 +90,9 @@ token_bucket: + + // Basic rate limit case. + TEST_F(LocalRateLimitTest, RateLimit) { ++ TestScopedRuntime runtime; ++ runtime.mergeValues( ++ {{"envoy.reloadable_features.no_timer_based_rate_limit_token_bucket", "true"}}); + initialize(R"EOF( + stat_prefix: local_rate_limit_stats + token_bucket: +@@ -116,8 +113,7 @@ token_bucket: + ->value()); + + // Refill the bucket. +- EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(3000), nullptr)); +- fill_timer_->invokeCallback(); ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(3000)); + + // Third socket is allowed after refill. + ActiveFilter active_filter3(config_); +diff --git a/test/extensions/filters/network/local_ratelimit/local_ratelimit_fuzz_test.cc b/test/extensions/filters/network/local_ratelimit/local_ratelimit_fuzz_test.cc +index ee00dbb61b..45bcd52280 100644 +--- test/extensions/filters/network/local_ratelimit/local_ratelimit_fuzz_test.cc ++++ test/extensions/filters/network/local_ratelimit/local_ratelimit_fuzz_test.cc +@@ -61,7 +61,7 @@ DEFINE_PROTO_FUZZER( + // default time system in GlobalTimeSystem. + dispatcher.time_system_ = std::make_unique(); + Stats::IsolatedStoreImpl stats_store; +- Singleton::ManagerImpl singleton_manager(Thread::threadFactoryForTest()); ++ Singleton::ManagerImpl singleton_manager; + static NiceMock runtime; + Event::MockTimer* fill_timer = new Event::MockTimer(&dispatcher); + envoy::extensions::filters::network::local_ratelimit::v3::LocalRateLimit proto_config = +diff --git a/test/extensions/filters/network/local_ratelimit/local_ratelimit_integration_test.cc b/test/extensions/filters/network/local_ratelimit/local_ratelimit_integration_test.cc +index ffa3c6e348..56132bcdfb 100644 +--- test/extensions/filters/network/local_ratelimit/local_ratelimit_integration_test.cc ++++ test/extensions/filters/network/local_ratelimit/local_ratelimit_integration_test.cc +@@ -57,7 +57,7 @@ typed_config: + token_bucket: + max_tokens: 1 + # Set fill_interval to effectively infinite so we only get max_tokens to start and never re-fill. +- fill_interval: 100000s ++ fill_interval: 1000s + )EOF"); + + IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0")); +@@ -88,7 +88,7 @@ typed_config: + token_bucket: + max_tokens: 2 + # Set fill_interval to effectively infinite so we only get max_tokens to start and never re-fill. +- fill_interval: 100000s ++ fill_interval: 1000s + )EOF"); + + // Clone the whole listener, which includes the `share_key`. +diff --git a/test/extensions/filters/network/local_ratelimit/local_ratelimit_test.cc b/test/extensions/filters/network/local_ratelimit/local_ratelimit_test.cc +index b4e43d7065..94af0da7d1 100644 +--- test/extensions/filters/network/local_ratelimit/local_ratelimit_test.cc ++++ test/extensions/filters/network/local_ratelimit/local_ratelimit_test.cc +@@ -12,7 +12,6 @@ + #include "gmock/gmock.h" + #include "gtest/gtest.h" + +-using testing::_; + using testing::InSequence; + using testing::NiceMock; + using testing::Return; +@@ -24,16 +23,11 @@ namespace LocalRateLimitFilter { + + class LocalRateLimitTestBase : public testing::Test, public Event::TestUsingSimulatedTime { + public: +- LocalRateLimitTestBase() : singleton_manager_(Thread::threadFactoryForTest()) {} ++ LocalRateLimitTestBase() = default; + +- uint64_t initialize(const std::string& filter_yaml, bool expect_timer_create = true) { ++ uint64_t initialize(const std::string& filter_yaml) { + envoy::extensions::filters::network::local_ratelimit::v3::LocalRateLimit proto_config; + TestUtility::loadFromYamlAndValidate(filter_yaml, proto_config); +- fill_timer_ = new Event::MockTimer(&dispatcher_); +- if (expect_timer_create) { +- EXPECT_CALL(*fill_timer_, enableTimer(_, nullptr)); +- EXPECT_CALL(*fill_timer_, disableTimer()); +- } + config_ = std::make_shared(proto_config, dispatcher_, *stats_store_.rootScope(), + runtime_, singleton_manager_); + return proto_config.token_bucket().max_tokens(); +@@ -43,7 +37,6 @@ public: + Stats::IsolatedStoreImpl stats_store_; + NiceMock runtime_; + Singleton::ManagerImpl singleton_manager_; +- Event::MockTimer* fill_timer_{}; + ConfigSharedPtr config_; + }; + +@@ -102,8 +95,7 @@ token_bucket: + ->value()); + + // Refill the bucket. +- EXPECT_CALL(*fill_timer_, enableTimer(std::chrono::milliseconds(200), nullptr)); +- fill_timer_->invokeCallback(); ++ dispatcher_.globalTimeSystem().advanceTimeWait(std::chrono::milliseconds(200)); + + // Third connection is OK. + ActiveFilter active_filter3(config_); +@@ -145,11 +137,6 @@ public: + envoy::extensions::filters::network::local_ratelimit::v3::LocalRateLimit proto_config; + TestUtility::loadFromYamlAndValidate(filter_yaml2, proto_config); + const uint64_t config2_tokens = proto_config.token_bucket().max_tokens(); +- if (!expect_sharing) { +- auto timer = new Event::MockTimer(&dispatcher_); +- EXPECT_CALL(*timer, enableTimer(_, nullptr)); +- EXPECT_CALL(*timer, disableTimer()); +- } + config2_ = std::make_shared(proto_config, dispatcher_, *stats_store_.rootScope(), + runtime_, singleton_manager_); + +diff --git a/test/extensions/http/cache/file_system_http_cache/file_system_http_cache_test.cc b/test/extensions/http/cache/file_system_http_cache/file_system_http_cache_test.cc +index 0ee731596b..52aeddd352 100644 +--- test/extensions/http/cache/file_system_http_cache/file_system_http_cache_test.cc ++++ test/extensions/http/cache/file_system_http_cache/file_system_http_cache_test.cc +@@ -262,7 +262,7 @@ CacheConfig varyAllowListConfig() { + + class MockSingletonManager : public Singleton::ManagerImpl { + public: +- MockSingletonManager() : Singleton::ManagerImpl(Thread::threadFactoryForTest()) { ++ MockSingletonManager() { + // By default just act like a real SingletonManager, but allow overrides. + ON_CALL(*this, get(_, _, _)) + .WillByDefault(std::bind(&MockSingletonManager::realGet, this, std::placeholders::_1, +diff --git a/test/extensions/key_value/file_based/alternate_protocols_cache_impl_test.cc b/test/extensions/key_value/file_based/alternate_protocols_cache_impl_test.cc +index 9363f64872..0787fe7c4e 100644 +--- test/extensions/key_value/file_based/alternate_protocols_cache_impl_test.cc ++++ test/extensions/key_value/file_based/alternate_protocols_cache_impl_test.cc +@@ -26,7 +26,7 @@ public: + singleton_manager_, tls_, data); + manager_ = factory_->get(); + } +- Singleton::ManagerImpl singleton_manager_{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager_; + NiceMock context_; + testing::NiceMock tls_; + std::unique_ptr factory_; +diff --git a/test/extensions/transport_sockets/alts/config_test.cc b/test/extensions/transport_sockets/alts/config_test.cc +index a848115fac..c237f6ff27 100644 +--- test/extensions/transport_sockets/alts/config_test.cc ++++ test/extensions/transport_sockets/alts/config_test.cc +@@ -18,7 +18,7 @@ namespace { + + TEST(UpstreamAltsConfigTest, CreateSocketFactory) { + NiceMock factory_context; +- Singleton::ManagerImpl singleton_manager{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager; + EXPECT_CALL(factory_context.server_context_, singletonManager()) + .WillRepeatedly(ReturnRef(singleton_manager)); + UpstreamAltsTransportSocketConfigFactory factory; +@@ -39,7 +39,7 @@ TEST(UpstreamAltsConfigTest, CreateSocketFactory) { + + TEST(DownstreamAltsConfigTest, CreateSocketFactory) { + NiceMock factory_context; +- Singleton::ManagerImpl singleton_manager{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager; + EXPECT_CALL(factory_context.server_context_, singletonManager()) + .WillRepeatedly(ReturnRef(singleton_manager)); + DownstreamAltsTransportSocketConfigFactory factory; +diff --git a/test/mocks/server/instance.cc b/test/mocks/server/instance.cc +index ebee702e5d..036f31fa48 100644 +--- test/mocks/server/instance.cc ++++ test/mocks/server/instance.cc +@@ -13,8 +13,7 @@ using ::testing::ReturnRef; + + MockInstance::MockInstance() + : secret_manager_(std::make_unique(admin_.getConfigTracker())), +- cluster_manager_(timeSource()), +- singleton_manager_(new Singleton::ManagerImpl(Thread::threadFactoryForTest())), ++ cluster_manager_(timeSource()), singleton_manager_(new Singleton::ManagerImpl()), + grpc_context_(stats_store_.symbolTable()), http_context_(stats_store_.symbolTable()), + router_context_(stats_store_.symbolTable()), quic_stat_names_(stats_store_.symbolTable()), + stats_config_(std::make_shared>()), +diff --git a/test/mocks/server/server_factory_context.cc b/test/mocks/server/server_factory_context.cc +index fb12fbe4f8..e1377b5b42 100644 +--- test/mocks/server/server_factory_context.cc ++++ test/mocks/server/server_factory_context.cc +@@ -8,9 +8,8 @@ using ::testing::Return; + using ::testing::ReturnRef; + + MockServerFactoryContext::MockServerFactoryContext() +- : singleton_manager_(new Singleton::ManagerImpl(Thread::threadFactoryForTest())), +- http_context_(store_.symbolTable()), grpc_context_(store_.symbolTable()), +- router_context_(store_.symbolTable()) { ++ : singleton_manager_(new Singleton::ManagerImpl()), http_context_(store_.symbolTable()), ++ grpc_context_(store_.symbolTable()), router_context_(store_.symbolTable()) { + ON_CALL(*this, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); + ON_CALL(*this, mainThreadDispatcher()).WillByDefault(ReturnRef(dispatcher_)); + ON_CALL(*this, drainDecision()).WillByDefault(ReturnRef(drain_manager_)); +diff --git a/test/mocks/server/transport_socket_factory_context.cc b/test/mocks/server/transport_socket_factory_context.cc +index d06c2711f7..09859ef971 100644 +--- test/mocks/server/transport_socket_factory_context.cc ++++ test/mocks/server/transport_socket_factory_context.cc +@@ -12,8 +12,7 @@ namespace Configuration { + using ::testing::ReturnRef; + + MockTransportSocketFactoryContext::MockTransportSocketFactoryContext() +- : secret_manager_(std::make_unique(config_tracker_)), +- singleton_manager_(Thread::threadFactoryForTest()) { ++ : secret_manager_(std::make_unique(config_tracker_)) { + ON_CALL(*this, serverFactoryContext()).WillByDefault(ReturnRef(server_context_)); + ON_CALL(*this, clusterManager()).WillByDefault(ReturnRef(cluster_manager_)); + ON_CALL(*this, messageValidationVisitor()) +diff --git a/test/mocks/upstream/cluster_manager_factory.h b/test/mocks/upstream/cluster_manager_factory.h +index 9cf678b4f2..c5ce18aaec 100644 +--- test/mocks/upstream/cluster_manager_factory.h ++++ test/mocks/upstream/cluster_manager_factory.h +@@ -51,7 +51,7 @@ public: + + private: + NiceMock secret_manager_; +- Singleton::ManagerImpl singleton_manager_{Thread::threadFactoryForTest()}; ++ Singleton::ManagerImpl singleton_manager_; + }; + } // namespace Upstream + } // namespace Envoy diff --git a/bazel/repositories.bzl b/bazel/repositories.bzl index 2e0eb491..f5def418 100644 --- a/bazel/repositories.bzl +++ b/bazel/repositories.bzl @@ -70,6 +70,7 @@ def envoy_gloo_dependencies(): "@envoy_gloo//bazel/foreign_cc:0001-otel-exporter-status-code-fix.patch", "@envoy_gloo//bazel/foreign_cc:0002-ratelimit-filter-state-hits-addend.patch", "@envoy_gloo//bazel/foreign_cc:0003-deallocate-slots-on-worker-threads.patch", # https://github.com/envoyproxy/envoy/pull/33395 + "@envoy_gloo//bazel/foreign_cc:0004-local-rate-limit-bucket-backport.patch", ]) _repository_impl("json", build_file = "@envoy_gloo//bazel/external:json.BUILD") _repository_impl("inja", build_file = "@envoy_gloo//bazel/external:inja.BUILD") diff --git a/changelog/v1.30.4-patch5/local-rate-limit-token-backports.yaml b/changelog/v1.30.4-patch5/local-rate-limit-token-backports.yaml new file mode 100644 index 00000000..6c47a94e --- /dev/null +++ b/changelog/v1.30.4-patch5/local-rate-limit-token-backports.yaml @@ -0,0 +1,6 @@ +changelog: +- type: FIX + issueLink: https://github.com/solo-io/gloo/issues/9564 + resolvesIssue: false + description: >- + Backport local rate limiting but make it disabled by default. Requires envoy_reloadable_features_no_timer_based_rate_limit_token_bucket to be set to true. \ No newline at end of file diff --git a/ci/cloudbuild.yaml b/ci/cloudbuild.yaml index 88ae2cf1..1b56274f 100644 --- a/ci/cloudbuild.yaml +++ b/ci/cloudbuild.yaml @@ -12,10 +12,11 @@ steps: secretEnv: - 'GCP_SERVICE_ACCOUNT_KEY' -- name: 'envoyproxy/envoy-build-ubuntu:41c5a05d708972d703661b702a63ef5060125c33' - id: 'static_analysis' - allowFailure: true - args: ['ci/static_analysis.sh'] +- name: 'envoyproxy/envoy-build-ubuntu:f94a38f62220a2b017878b790b6ea98a0f6c5f9c' + id: 'do_upstream_ci' # validate local rate limit tests for our backported functionality for http local rate limiter. + args: ['ci/do_ci.sh', 'dev','@envoy//test/extensions/filters/common/local_ratelimit/...', +'@envoy//test/extensions/filters/http/local_ratelimit/...','@envoy//test/extensions/filters/listener/local_ratelimit/...', + '@envoy//test/extensions/filters/http/local_ratelimit/...'] volumes: - name: 'vol-build' path: '/build' diff --git a/source/extensions/filters/http/transformation/inja_transformer.cc b/source/extensions/filters/http/transformation/inja_transformer.cc index a3788d5d..d1e0dca5 100644 --- a/source/extensions/filters/http/transformation/inja_transformer.cc +++ b/source/extensions/filters/http/transformation/inja_transformer.cc @@ -797,7 +797,7 @@ InjaTransformer::InjaTransformer(const TransformationTemplate &transformation, } try { merge_templates_.emplace_back(std::make_tuple(name, tmpl.override_empty(), instance_->parse(tmpl.tmpl().text()))); - } catch (const std::exception) { + } catch (std::exception &) { throw EnvoyException( fmt::format("Failed to parse merge_body_key template for key: ({})", name)); }