diff --git a/server/Cargo.lock b/server/Cargo.lock index d879ea4b1..427abc427 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -2489,6 +2489,20 @@ dependencies = [ "memchr", ] +[[package]] +name = "omniqueue" +version = "0.1.0" +source = "git+https://github.com/svix/omniqueue-rs.git?rev=32bf5f17209b76ab33902ed149f1890a80dda32a#32bf5f17209b76ab33902ed149f1890a80dda32a" +dependencies = [ + "async-trait", + "futures", + "serde", + "serde_json", + "thiserror", + "tokio", + "tracing", +] + [[package]] name = "once_cell" version = "1.19.0" @@ -4263,6 +4277,7 @@ dependencies = [ "jwt-simple", "lapin", "num_enum", + "omniqueue", "once_cell", "openssl", "opentelemetry", diff --git a/server/svix-server/Cargo.toml b/server/svix-server/Cargo.toml index 42db2e007..95902dbc8 100644 --- a/server/svix-server/Cargo.toml +++ b/server/svix-server/Cargo.toml @@ -72,6 +72,7 @@ urlencoding = "2.1.2" form_urlencoded = "1.1.0" lapin = "2.1.1" sentry = { version = "0.32.2", features = ["tracing"] } +omniqueue = { git = "https://github.com/svix/omniqueue-rs.git", rev = "32bf5f17209b76ab33902ed149f1890a80dda32a", default-features = false, features = ["memory_queue"] } [target.'cfg(not(target_env = "msvc"))'.dependencies] tikv-jemallocator = { version = "0.5", optional = true } diff --git a/server/svix-server/src/core/cache/mod.rs b/server/svix-server/src/core/cache/mod.rs index 9427729f1..dec81d387 100644 --- a/server/svix-server/src/core/cache/mod.rs +++ b/server/svix-server/src/core/cache/mod.rs @@ -8,7 +8,7 @@ use axum::async_trait; use enum_dispatch::enum_dispatch; use serde::{de::DeserializeOwned, Serialize}; -use crate::core::run_with_retries::run_with_retries; +use crate::core::retry::run_with_retries; pub mod memory; pub mod none; diff --git a/server/svix-server/src/core/mod.rs b/server/svix-server/src/core/mod.rs index 3660b7d81..e36e0731f 100644 --- a/server/svix-server/src/core/mod.rs +++ b/server/svix-server/src/core/mod.rs @@ -8,7 +8,7 @@ pub mod message_app; pub mod operational_webhooks; pub mod otel_spans; pub mod permissions; -pub mod run_with_retries; +pub mod retry; pub mod security; pub mod types; pub mod webhook_http_client; diff --git a/server/svix-server/src/core/retry.rs b/server/svix-server/src/core/retry.rs new file mode 100644 index 000000000..2af0e58a9 --- /dev/null +++ b/server/svix-server/src/core/retry.rs @@ -0,0 +1,67 @@ +use std::{future::Future, time::Duration}; + +use tracing::warn; + +pub async fn run_with_retries< + T, + E: std::error::Error, + F: Future>, + FN: FnMut() -> F, +>( + mut fun: FN, + should_retry: impl Fn(&E) -> bool, + retry_schedule: &[Duration], +) -> Result { + let mut retry = Retry::new(should_retry, retry_schedule); + loop { + if let Some(result) = retry.run(&mut fun).await { + return result; + } + } +} + +/// A state machine for retrying an asynchronous operation. +/// +/// Unfortunately needed to get around Rust's lack of `AttachedFn*` traits. +/// For usage, check the implementation of `run_with_retries`.` +pub struct Retry<'a, Re> { + retry_schedule: &'a [Duration], + should_retry: Re, +} + +impl<'a, Re> Retry<'a, Re> { + pub fn new(should_retry: Re, retry_schedule: &'a [Duration]) -> Self { + Self { + retry_schedule, + should_retry, + } + } + + pub async fn run(&mut self, f: F) -> Option> + where + E: std::error::Error, + F: FnOnce() -> Fut, + Fut: Future>, + Re: Fn(&E) -> bool, + { + match f().await { + // If the function succeeded, we're done + Ok(t) => Some(Ok(t)), + Err(e) => { + let should_retry = &self.should_retry; + if self.retry_schedule.is_empty() || !should_retry(&e) { + // If we already used up all the retries or should_retry returns false, + // return the latest error and stop retrying. + self.retry_schedule = &[]; + Some(Err(e)) + } else { + // Otherwise, wait and let the caller call retry.run() again. + warn!("Retrying after error: {e}"); + tokio::time::sleep(self.retry_schedule[0]).await; + self.retry_schedule = &self.retry_schedule[1..]; + None + } + } + } + } +} diff --git a/server/svix-server/src/core/run_with_retries.rs b/server/svix-server/src/core/run_with_retries.rs deleted file mode 100644 index 6477504ef..000000000 --- a/server/svix-server/src/core/run_with_retries.rs +++ /dev/null @@ -1,30 +0,0 @@ -use std::{future::Future, time::Duration}; - -pub async fn run_with_retries< - T, - E: std::error::Error, - F: Future>, - FN: FnMut() -> F, ->( - mut fun: FN, - should_retry: impl Fn(&E) -> bool, - retry_schedule: &[Duration], -) -> Result { - for duration in retry_schedule { - match fun().await { - Ok(ret) => return Ok(ret), - Err(e) => { - if should_retry(&e) { - tracing::warn!("Retrying after error {}", e); - tokio::time::sleep(*duration).await; - } else { - return Err(e); - } - } - } - } - - // Loop sleeps after a failed attempt so you need this last fun call to avoid a fencepost error - // with durations between tries. - fun().await -} diff --git a/server/svix-server/src/error.rs b/server/svix-server/src/error.rs index c49f8c55f..8468b59b6 100644 --- a/server/svix-server/src/error.rs +++ b/server/svix-server/src/error.rs @@ -158,6 +158,13 @@ impl From for Error { } } +impl From for Error { + #[track_caller] + fn from(value: omniqueue::QueueError) -> Self { + Error::queue(value) + } +} + impl From> for Error { #[track_caller] fn from(value: bb8::RunError) -> Self { diff --git a/server/svix-server/src/queue/memory.rs b/server/svix-server/src/queue/memory.rs deleted file mode 100644 index 2df813bd4..000000000 --- a/server/svix-server/src/queue/memory.rs +++ /dev/null @@ -1,80 +0,0 @@ -use std::{sync::Arc, time::Duration}; - -use axum::async_trait; -use chrono::Utc; -use tokio::{sync::mpsc, time::sleep}; - -use crate::error::Error; -use crate::error::Result; - -use super::{ - Acker, QueueTask, TaskQueueConsumer, TaskQueueDelivery, TaskQueueProducer, TaskQueueReceive, - TaskQueueSend, -}; - -pub async fn new_pair() -> (TaskQueueProducer, TaskQueueConsumer) { - let (tx, rx) = mpsc::unbounded_channel::(); - ( - TaskQueueProducer::Memory(MemoryQueueProducer { tx }), - TaskQueueConsumer::Memory(MemoryQueueConsumer { rx }), - ) -} - -#[derive(Clone, Debug)] -pub struct MemoryQueueProducer { - tx: mpsc::UnboundedSender, -} - -#[async_trait] -impl TaskQueueSend for MemoryQueueProducer { - async fn send(&self, msg: Arc, delay: Option) -> Result<()> { - let timestamp = delay.map(|delay| Utc::now() + chrono::Duration::from_std(delay).unwrap()); - let delivery = TaskQueueDelivery::from_arc(msg, timestamp, Acker::Memory(self.clone())); - - if let Some(delay) = delay { - let tx = self.tx.clone(); - tokio::spawn(async move { - // We just assume memory queue always works, so we can defer the error handling - tracing::trace!("MemoryQueue: event sent > (delay: {:?})", delay); - sleep(delay).await; - if tx.send(delivery).is_err() { - tracing::error!("Receiver dropped"); - } - }); - } else if self.tx.send(delivery).is_err() { - tracing::error!("Receiver dropped"); - } - - Ok(()) - } -} - -pub struct MemoryQueueConsumer { - rx: mpsc::UnboundedReceiver, -} - -#[async_trait] -impl TaskQueueReceive for MemoryQueueConsumer { - async fn receive_all(&mut self) -> Result> { - let mut deliveries = tokio::select! { - _ = tokio::time::sleep(Duration::from_secs(30)) => return Ok(Vec::new()), - recv = self.rx.recv() => { - if let Some(delivery) = recv { - tracing::trace!("MemoryQueue: event recv <"); - vec![delivery] - } else { - return Err(Error::queue("Failed to fetch from queue")) - } - } - }; - - // possible errors are `Empty` or `Disconnected`. Either way, - // we want to return the deliveries that could be received. - // If it was Disconnected, the next call to receive_all will fail - while let Ok(delivery) = self.rx.try_recv() { - deliveries.push(delivery); - } - - Ok(deliveries) - } -} diff --git a/server/svix-server/src/queue/mod.rs b/server/svix-server/src/queue/mod.rs index db6d00124..7a056d213 100644 --- a/server/svix-server/src/queue/mod.rs +++ b/server/svix-server/src/queue/mod.rs @@ -3,6 +3,15 @@ use std::{sync::Arc, time::Duration}; use axum::async_trait; use chrono::{DateTime, Utc}; use lapin::options::{BasicAckOptions, BasicNackOptions}; +use omniqueue::{ + backends::memory_queue::MemoryQueueBackend, + queue::{ + consumer::{DynConsumer, QueueConsumer}, + producer::QueueProducer, + Delivery, QueueBackend as _, + }, + scheduled::ScheduledProducer, +}; use serde::{Deserialize, Serialize}; use svix_ksuid::*; @@ -10,18 +19,14 @@ use crate::error::Traceable; use crate::{ cfg::{Configuration, QueueBackend}, core::{ - run_with_retries::run_with_retries, + retry::{run_with_retries, Retry}, types::{ApplicationId, EndpointId, MessageAttemptTriggerType, MessageId}, }, error::{Error, ErrorType, Result}, }; -use self::{ - memory::{MemoryQueueConsumer, MemoryQueueProducer}, - redis::{RedisQueueConsumer, RedisQueueInner, RedisQueueProducer}, -}; +use self::redis::{RedisQueueConsumer, RedisQueueInner, RedisQueueProducer}; -pub mod memory; pub mod rabbitmq; pub mod redis; @@ -48,7 +53,17 @@ pub async fn new_pair( let pool = crate::redis::new_redis_pool_clustered(dsn, cfg).await; redis::new_pair(pool, prefix).await } - QueueBackend::Memory => memory::new_pair().await, + QueueBackend::Memory => { + let (producer, consumer) = MemoryQueueBackend::builder(()) + .build_pair() + .await + .expect("building in-memory queue can't fail"); + + ( + TaskQueueProducer::Omni(Arc::new(producer.into_dyn_scheduled(Default::default()))), + TaskQueueConsumer::Omni(consumer.into_dyn(Default::default())), + ) + } QueueBackend::RabbitMq(dsn) => { let prefix = prefix.unwrap_or(""); let queue = format!("{prefix}-message-queue"); @@ -135,9 +150,9 @@ impl QueueTask { #[derive(Clone)] pub enum TaskQueueProducer { - Memory(MemoryQueueProducer), Redis(RedisQueueProducer), RabbitMq(rabbitmq::Producer), + Omni(Arc), } impl TaskQueueProducer { @@ -146,9 +161,14 @@ impl TaskQueueProducer { run_with_retries( || async { match self { - TaskQueueProducer::Memory(q) => q.send(task.clone(), delay).await, TaskQueueProducer::Redis(q) => q.send(task.clone(), delay).await, TaskQueueProducer::RabbitMq(q) => q.send(task.clone(), delay).await, + TaskQueueProducer::Omni(q) => if let Some(delay) = delay { + q.send_serde_json_scheduled(task.as_ref(), delay).await + } else { + q.send_serde_json(task.as_ref()).await + } + .map_err(Into::into), } }, should_retry, @@ -160,16 +180,26 @@ impl TaskQueueProducer { pub enum TaskQueueConsumer { Redis(RedisQueueConsumer), - Memory(MemoryQueueConsumer), RabbitMq(rabbitmq::Consumer), + Omni(DynConsumer), } impl TaskQueueConsumer { pub async fn receive_all(&mut self) -> Result> { match self { TaskQueueConsumer::Redis(q) => q.receive_all().await.trace(), - TaskQueueConsumer::Memory(q) => q.receive_all().await.trace(), TaskQueueConsumer::RabbitMq(q) => q.receive_all().await.trace(), + TaskQueueConsumer::Omni(q) => { + const MAX_MESSAGES: usize = 128; + // FIXME(onelson): need to figure out what deadline/duration to use here + q.receive_all(MAX_MESSAGES, Duration::from_secs(30)) + .await + .map_err(Into::into) + .trace()? + .into_iter() + .map(TryInto::try_into) + .collect() + } } } } @@ -177,9 +207,9 @@ impl TaskQueueConsumer { /// Used by TaskQueueDeliveries to Ack/Nack itself #[derive(Debug)] enum Acker { - Memory(MemoryQueueProducer), Redis(Arc), RabbitMQ(lapin::message::Delivery), + Omni(Delivery), } #[derive(Debug)] @@ -202,54 +232,110 @@ impl TaskQueueDelivery { pub async fn ack(self) -> Result<()> { tracing::trace!("ack {}", self.id); - run_with_retries( - || async { - match &self.acker { - Acker::Memory(_) => Ok(()), // nothing to do - Acker::Redis(q) => q.ack(&self.id, &self.task).await.trace(), - Acker::RabbitMQ(delivery) => { - delivery - .ack(BasicAckOptions { - multiple: false, // Only ack this message, not others - }) - .await - .map_err(Into::into) + + let mut retry = Retry::new(should_retry, RETRY_SCHEDULE); + let mut acker = Some(self.acker); + loop { + if let Some(result) = retry + .run(|| async { + let acker_ref = acker + .as_ref() + .expect("acker is always Some when trying to ack"); + match acker_ref { + Acker::Redis(q) => q.ack(&self.id, &self.task).await.trace(), + Acker::RabbitMQ(delivery) => { + delivery + .ack(BasicAckOptions { + multiple: false, // Only ack this message, not others + }) + .await + .map_err(Into::into) + } + Acker::Omni(_) => match acker.take() { + Some(Acker::Omni(delivery)) => { + delivery.ack().await.map_err(|(e, delivery)| { + // Put the delivery back in acker beforr retrying, to + // satisfy the expect above. + acker = Some(Acker::Omni(delivery)); + e.into() + }) + } + _ => unreachable!(), + }, } - } - }, - should_retry, - RETRY_SCHEDULE, - ) - .await + }) + .await + { + return result; + } + } } pub async fn nack(self) -> Result<()> { tracing::trace!("nack {}", self.id); - run_with_retries( - || async { - match &self.acker { - Acker::Memory(q) => { - tracing::debug!("nack {}", self.id); - q.send(self.task.clone(), None).await.trace() - } - Acker::Redis(q) => q.nack(&self.id, &self.task).await.trace(), - Acker::RabbitMQ(delivery) => { - // See https://www.rabbitmq.com/confirms.html#consumer-nacks-requeue - delivery - .nack(BasicNackOptions { - requeue: true, - multiple: false, // Only nack this message, not others - }) - .await - .map_err(Into::into) + let mut retry = Retry::new(should_retry, RETRY_SCHEDULE); + let mut acker = Some(self.acker); + loop { + if let Some(result) = retry + .run(|| async { + let acker_ref = acker + .as_ref() + .expect("acker is always Some when trying to ack"); + match acker_ref { + Acker::Redis(q) => q.nack(&self.id, &self.task).await.trace(), + Acker::RabbitMQ(delivery) => { + // See https://www.rabbitmq.com/confirms.html#consumer-nacks-requeue + + delivery + .nack(BasicNackOptions { + requeue: true, + multiple: false, // Only nack this message, not others + }) + .await + .map_err(Into::into) + } + Acker::Omni(_) => match acker.take() { + Some(Acker::Omni(delivery)) => { + delivery + .nack() + .await + .map_err(|(e, delivery)| { + // Put the delivery back in acker beforr retrying, to + // satisfy the expect above. + acker = Some(Acker::Omni(delivery)); + e.into() + }) + .trace() + } + _ => unreachable!(), + }, } - } - }, - should_retry, - RETRY_SCHEDULE, - ) - .await + }) + .await + { + return result; + } + } + } +} + +impl TryFrom for TaskQueueDelivery { + type Error = Error; + fn try_from(value: Delivery) -> Result { + Ok(TaskQueueDelivery { + // FIXME(onelson): ksuid for the id? + // Since ack/nack is all handled internally by the omniqueue delivery, maybe it + // doesn't matter. + id: "".to_string(), + task: Arc::new( + value + .payload_serde_json() + .map_err(|_| Error::queue("Failed to decode queue task"))? + .ok_or_else(|| Error::queue("Unexpected empty delivery"))?, + ), + acker: Acker::Omni(value), + }) } } @@ -262,148 +348,3 @@ trait TaskQueueSend: Sync + Send { trait TaskQueueReceive { async fn receive_all(&mut self) -> Result>; } - -#[cfg(test)] -mod tests { - use super::*; - - // TODO: Test Redis impl too - - /// Creates a [`MessageTask`] with filler information and the given MessageId inner String - fn mock_message(message_id: String) -> QueueTask { - MessageTask::new_task( - MessageId(message_id), - ApplicationId("TestEndpointID".to_owned()), - EndpointId("TestEndpointID".to_owned()), - MessageAttemptTriggerType::Scheduled, - ) - } - - /// Sends a message with the given TaskQueueProducer reference and asserts that the result is OK - async fn assert_send(tx: &TaskQueueProducer, message_id: &str) { - assert!(tx - .send(mock_message(message_id.to_owned()), None) - .await - .is_ok()); - } - - /// Receives a message with the given TaskQueueConsumer mutable reference and asserts that it is - /// equal to the mock message with the given message_id. - async fn assert_recv(rx: &mut TaskQueueConsumer, message_id: &str) { - assert_eq!( - *rx.receive_all().await.unwrap().first().unwrap().task, - mock_message(message_id.to_owned()) - ) - } - - #[tokio::test] - async fn test_single_producer_single_consumer() { - let (tx_mem, mut rx_mem) = memory::new_pair().await; - - let msg_id = "TestMessageID1"; - - assert_send(&tx_mem, msg_id).await; - assert_recv(&mut rx_mem, msg_id).await; - } - - #[tokio::test] - async fn test_multiple_producer_single_consumer() { - let (tx_mem, mut rx_mem) = memory::new_pair().await; - - let msg_1 = "TestMessageID1"; - let msg_2 = "TestMessageID2"; - - tokio::spawn({ - let tx_mem = tx_mem.clone(); - async move { - assert_send(&tx_mem, msg_1).await; - } - }); - tokio::spawn(async move { - assert_send(&tx_mem, msg_2).await; - }); - - let tasks = rx_mem.receive_all().await.unwrap(); - assert_eq!(*tasks[0].task, mock_message(msg_1.to_owned())); - assert_eq!(*tasks[1].task, mock_message(msg_2.to_owned())); - } - - #[tokio::test] - async fn test_delay() { - let (tx_mem, mut rx_mem) = memory::new_pair().await; - - let msg_1 = "TestMessageID1"; - let msg_2 = "TestMessageID2"; - - assert!(tx_mem - .send( - mock_message(msg_1.to_owned()), - Some(Duration::from_millis(200)) - ) - .await - .is_ok()); - assert_send(&tx_mem, msg_2).await; - - assert_recv(&mut rx_mem, msg_2).await; - assert_recv(&mut rx_mem, msg_1).await; - } - - #[tokio::test] - async fn test_ack() { - let (tx_mem, mut rx_mem) = memory::new_pair().await; - assert!(tx_mem - .send(mock_message("test".to_owned()), None) - .await - .is_ok()); - - let recv = rx_mem - .receive_all() - .await - .unwrap() - .into_iter() - .next() - .unwrap(); - - assert_eq!(*recv.task, mock_message("test".to_owned())); - - assert!(recv.ack().await.is_ok()); - - tokio::select! { - _ = rx_mem.receive_all() => { - panic!("`rx_mem` received second message"); - } - - // FIXME: Find out correct timeout duration - _ = tokio::time::sleep(Duration::from_millis(500)) => {} - } - } - - #[tokio::test] - async fn test_nack() { - let (tx_mem, mut rx_mem) = memory::new_pair().await; - assert!(tx_mem - .send(mock_message("test".to_owned()), None) - .await - .is_ok()); - - let recv = rx_mem - .receive_all() - .await - .unwrap() - .into_iter() - .next() - .unwrap(); - assert_eq!(*recv.task, mock_message("test".to_owned())); - - assert!(recv.nack().await.is_ok()); - - tokio::select! { - _ = rx_mem.receive_all() => {} - - // FIXME: Find out correct timeout duration - _ = tokio::time::sleep(Duration::from_millis(500)) => { - panic!("`rx_mem` did not receive second message"); - } - } - } -}