From 1107dd76e7a704700e735414eb6a57be9117dc12 Mon Sep 17 00:00:00 2001 From: Frank Hamand Date: Fri, 6 Sep 2024 09:22:43 +0100 Subject: [PATCH] chore: refactor RedisLimiter to update async out of the hot path (#24818) --- rust/capture/src/limiters/redis.rs | 187 +++++++++++----------------- rust/capture/src/server.rs | 7 +- rust/capture/src/v0_endpoint.rs | 31 +---- rust/capture/tests/django_compat.rs | 10 +- 4 files changed, 90 insertions(+), 145 deletions(-) diff --git a/rust/capture/src/limiters/redis.rs b/rust/capture/src/limiters/redis.rs index 132f09df77513..cc7e7d119d89b 100644 --- a/rust/capture/src/limiters/redis.rs +++ b/rust/capture/src/limiters/redis.rs @@ -1,5 +1,11 @@ use metrics::gauge; -use std::{collections::HashSet, ops::Sub, sync::Arc}; +use std::time::Duration as StdDuration; +use std::{collections::HashSet, sync::Arc}; +use time::{Duration, OffsetDateTime}; +use tokio::sync::RwLock; +use tokio::task; +use tokio::time::interval; +use tracing::instrument; use crate::redis::Client; @@ -17,18 +23,12 @@ use crate::redis::Client; /// 2. Capture should cope with redis being _totally down_, and fail open /// 3. We should not hit redis for every single request /// -/// The solution here is to read from the cache until a time interval is hit, and then fetch new -/// data. The write requires taking a lock that stalls all readers, though so long as redis reads -/// stay fast we're ok. +/// The solution here is to read from the cache and update the set in a background thread. +/// We have to lock all readers briefly while we update the set, but we don't hold the lock +/// until we already have the response from redis so it should be very short. /// /// Some small delay between an account being limited and the limit taking effect is acceptable. /// However, ideally we should not allow requests from some pods but 429 from others. -use thiserror::Error; -use time::{Duration, OffsetDateTime}; -use tokio::sync::RwLock; -use tracing::instrument; - -// todo: fetch from env const QUOTA_LIMITER_CACHE_KEY: &str = "@posthog/quota-limits/"; #[derive(Debug)] @@ -46,19 +46,12 @@ impl QuotaResource { } } -#[derive(Error, Debug)] -pub enum LimiterError { - #[error("updater already running - there can only be one")] - UpdaterRunning, -} - #[derive(Clone)] pub struct RedisLimiter { limited: Arc>>, redis: Arc, - redis_key_prefix: String, + key: String, interval: Duration, - updated: Arc>, } impl RedisLimiter { @@ -74,98 +67,67 @@ impl RedisLimiter { interval: Duration, redis: Arc, redis_key_prefix: Option, + resource: QuotaResource, ) -> anyhow::Result { let limited = Arc::new(RwLock::new(HashSet::new())); + let key_prefix = redis_key_prefix.unwrap_or_default(); - // Force an update immediately if we have any reasonable interval - let updated = OffsetDateTime::from_unix_timestamp(0)?; - let updated = Arc::new(RwLock::new(updated)); - - Ok(RedisLimiter { + let limiter = RedisLimiter { interval, limited, - updated, - redis, - redis_key_prefix: redis_key_prefix.unwrap_or_default(), - }) + redis: redis.clone(), + key: format!("{key_prefix}{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()), + }; + + // Spawn a background task to periodically fetch data from Redis + limiter.spawn_background_update(); + + Ok(limiter) + } + + fn spawn_background_update(&self) { + let limited = Arc::clone(&self.limited); + let redis = Arc::clone(&self.redis); + let interval_duration = StdDuration::from_nanos(self.interval.whole_nanoseconds() as u64); + let key = self.key.clone(); + + // Spawn a task to periodically update the cache from Redis + task::spawn(async move { + let mut interval = interval(interval_duration); + loop { + match RedisLimiter::fetch_limited(&redis, &key).await { + Ok(set) => { + let set = HashSet::from_iter(set.iter().cloned()); + gauge!("capture_billing_limits_loaded_tokens",).set(set.len() as f64); + + let mut limited_lock = limited.write().await; + *limited_lock = set; + } + Err(e) => { + tracing::error!("Failed to update cache from Redis: {:?}", e); + } + } + + interval.tick().await; + } + }); } #[instrument(skip_all)] async fn fetch_limited( client: &Arc, - key_prefix: &str, - resource: &QuotaResource, + key: &String, ) -> anyhow::Result> { let now = OffsetDateTime::now_utc().unix_timestamp(); - let key = format!("{key_prefix}{QUOTA_LIMITER_CACHE_KEY}{}", resource.as_str()); client - .zrangebyscore(key, now.to_string(), String::from("+Inf")) + .zrangebyscore(key.to_string(), now.to_string(), String::from("+Inf")) .await } - #[instrument(skip_all, fields(key = key))] - pub async fn is_limited(&self, key: &str, resource: QuotaResource) -> bool { - // hold the read lock to clone it, very briefly. clone is ok because it's very small 🤏 - // rwlock can have many readers, but one writer. the writer will wait in a queue with all - // the readers, so we want to hold read locks for the smallest time possible to avoid - // writers waiting for too long. and vice versa. - let updated = { - let updated = self.updated.read().await; - *updated - }; - - let now = OffsetDateTime::now_utc(); - let since_update = now.sub(updated); - - // If an update is due, fetch the set from redis + cache it until the next update is due. - // Otherwise, return a value from the cache - // - // This update will block readers! Keep it fast. - if since_update > self.interval { - // open the update lock to change the update, and prevent anyone else from doing so - let mut updated = self.updated.write().await; - *updated = OffsetDateTime::now_utc(); - - let span = tracing::debug_span!("updating billing cache from redis"); - let _span = span.enter(); - - // a few requests might end up in here concurrently, but I don't think a few extra will - // be a big problem. If it is, we can rework the concurrency a bit. - // On prod atm we call this around 15 times per second at peak times, and it usually - // completes in <1ms. - - let set = Self::fetch_limited(&self.redis, &self.redis_key_prefix, &resource).await; - - tracing::debug!("fetched set from redis, caching"); - - if let Ok(set) = set { - let set = HashSet::from_iter(set.iter().cloned()); - gauge!( - "capture_billing_limits_loaded_tokens", - "resource" => resource.as_str(), - ) - .set(set.len() as f64); - - let mut limited = self.limited.write().await; - *limited = set; - - tracing::debug!("updated cache from redis"); - - limited.contains(key) - } else { - tracing::error!("failed to fetch from redis in time, failing open"); - // If we fail to fetch the set, something really wrong is happening. To avoid - // dropping events that we don't mean to drop, fail open and accept data. Better - // than angry customers :) - // - // TODO: Consider backing off our redis checks - false - } - } else { - let l = self.limited.read().await; - - l.contains(key) - } + #[instrument(skip_all, fields(value = value))] + pub async fn is_limited(&self, value: &str) -> bool { + let limited = self.limited.read().await; + limited.contains(value) } } @@ -185,15 +147,12 @@ mod tests { .zrangebyscore_ret("@posthog/quota-limits/events", vec![String::from("banana")]); let client = Arc::new(client); - let limiter = RedisLimiter::new(Duration::microseconds(1), client, None) + let limiter = RedisLimiter::new(Duration::seconds(1), client, None, QuotaResource::Events) .expect("Failed to create billing limiter"); + tokio::time::sleep(std::time::Duration::from_millis(30)).await; - assert!( - !limiter - .is_limited("not_limited", QuotaResource::Events) - .await, - ); - assert!(limiter.is_limited("banana", QuotaResource::Events).await); + assert!(!limiter.is_limited("not_limited").await); + assert!(limiter.is_limited("banana").await); } #[tokio::test] @@ -205,27 +164,27 @@ mod tests { let client = Arc::new(client); // Default lookup without prefix fails - let limiter = RedisLimiter::new(Duration::microseconds(1), client.clone(), None) - .expect("Failed to create billing limiter"); - assert!(!limiter.is_limited("banana", QuotaResource::Events).await); + let limiter = RedisLimiter::new( + Duration::seconds(1), + client.clone(), + None, + QuotaResource::Events, + ) + .expect("Failed to create billing limiter"); + tokio::time::sleep(std::time::Duration::from_millis(30)).await; + assert!(!limiter.is_limited("banana").await); // Limiter using the correct prefix let prefixed_limiter = RedisLimiter::new( Duration::microseconds(1), client, Some("prefix//".to_string()), + QuotaResource::Events, ) .expect("Failed to create billing limiter"); + tokio::time::sleep(std::time::Duration::from_millis(30)).await; - assert!( - !prefixed_limiter - .is_limited("not_limited", QuotaResource::Events) - .await, - ); - assert!( - prefixed_limiter - .is_limited("banana", QuotaResource::Events) - .await - ); + assert!(!prefixed_limiter.is_limited("not_limited").await); + assert!(prefixed_limiter.is_limited("banana").await); } } diff --git a/rust/capture/src/server.rs b/rust/capture/src/server.rs index 93ff3f646c3bc..bb6f7aaf5dd5b 100644 --- a/rust/capture/src/server.rs +++ b/rust/capture/src/server.rs @@ -6,10 +6,11 @@ use health::{ComponentStatus, HealthRegistry}; use time::Duration; use tokio::net::TcpListener; +use crate::config::CaptureMode; use crate::config::Config; use crate::limiters::overflow::OverflowLimiter; -use crate::limiters::redis::RedisLimiter; +use crate::limiters::redis::{QuotaResource, RedisLimiter}; use crate::redis::RedisClient; use crate::router; use crate::sinks::kafka::KafkaSink; @@ -28,6 +29,10 @@ where Duration::seconds(5), redis_client.clone(), config.redis_key_prefix, + match config.capture_mode { + CaptureMode::Events => QuotaResource::Events, + CaptureMode::Recordings => QuotaResource::Recordings, + }, ) .expect("failed to create billing limiter"); diff --git a/rust/capture/src/v0_endpoint.rs b/rust/capture/src/v0_endpoint.rs index 03b550cd9cdaf..e01b9cbf5bc22 100644 --- a/rust/capture/src/v0_endpoint.rs +++ b/rust/capture/src/v0_endpoint.rs @@ -13,7 +13,6 @@ use serde_json::json; use serde_json::Value; use tracing::instrument; -use crate::limiters::redis::QuotaResource; use crate::prometheus::report_dropped_events; use crate::v0_request::{Compression, ProcessingContext, RawRequest}; use crate::{ @@ -29,7 +28,6 @@ use crate::{ /// /// Because it must accommodate several shapes, it is inefficient in places. A v1 /// endpoint should be created, that only accepts the BatchedRequest payload shape. -#[allow(clippy::too_many_arguments)] async fn handle_common( state: &State, InsecureClientIp(ip): &InsecureClientIp, @@ -37,7 +35,6 @@ async fn handle_common( headers: &HeaderMap, method: &Method, path: &MatchedPath, - quota_resource: QuotaResource, body: Bytes, ) -> Result<(ProcessingContext, Vec), CaptureError> { let user_agent = headers @@ -119,7 +116,7 @@ async fn handle_common( let billing_limited = state .billing_limiter - .is_limited(context.token.as_str(), quota_resource) + .is_limited(context.token.as_str()) .await; if billing_limited { @@ -157,18 +154,7 @@ pub async fn event( path: MatchedPath, body: Bytes, ) -> Result, CaptureError> { - match handle_common( - &state, - &ip, - &meta, - &headers, - &method, - &path, - QuotaResource::Events, - body, - ) - .await - { + match handle_common(&state, &ip, &meta, &headers, &method, &path, body).await { Err(CaptureError::BillingLimit) => { // for v0 we want to just return ok 🙃 // this is because the clients are pretty dumb and will just retry over and over and @@ -227,18 +213,7 @@ pub async fn recording( path: MatchedPath, body: Bytes, ) -> Result, CaptureError> { - match handle_common( - &state, - &ip, - &meta, - &headers, - &method, - &path, - QuotaResource::Recordings, - body, - ) - .await - { + match handle_common(&state, &ip, &meta, &headers, &method, &path, body).await { Err(CaptureError::BillingLimit) => Ok(Json(CaptureResponse { status: CaptureResponseCode::Ok, quota_limited: Some(vec!["recordings".to_string()]), diff --git a/rust/capture/tests/django_compat.rs b/rust/capture/tests/django_compat.rs index d08be11c7506c..a5f81aa589c51 100644 --- a/rust/capture/tests/django_compat.rs +++ b/rust/capture/tests/django_compat.rs @@ -6,6 +6,7 @@ use base64::engine::general_purpose; use base64::Engine; use capture::api::{CaptureError, CaptureResponse, CaptureResponseCode, DataType, ProcessedEvent}; use capture::config::CaptureMode; +use capture::limiters::redis::QuotaResource; use capture::limiters::redis::RedisLimiter; use capture::redis::MockRedisClient; use capture::router::router; @@ -101,8 +102,13 @@ async fn it_matches_django_capture_behaviour() -> anyhow::Result<()> { let timesource = FixedTime { time: case.now }; let redis = Arc::new(MockRedisClient::new()); - let billing_limiter = RedisLimiter::new(Duration::weeks(1), redis.clone(), None) - .expect("failed to create billing limiter"); + let billing_limiter = RedisLimiter::new( + Duration::weeks(1), + redis.clone(), + None, + QuotaResource::Events, + ) + .expect("failed to create billing limiter"); let app = router( timesource,