From e0ed48ce495d9e08020d72278ad7b092defdbbb9 Mon Sep 17 00:00:00 2001 From: Owen Nelson Date: Thu, 26 Oct 2023 16:07:25 -0700 Subject: [PATCH] refactor: switch broadcast for mpsc Unsure if this is going to be needed. Can avoid redoing/adding mem backend by simply giving usize::MAX for capacity, or whatever is allowable for broadcast. The underlying channel is a private detail. --- omniqueue/src/backends/memory_queue.rs | 28 +++++++++++++++----------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/omniqueue/src/backends/memory_queue.rs b/omniqueue/src/backends/memory_queue.rs index 212972e..5522237 100644 --- a/omniqueue/src/backends/memory_queue.rs +++ b/omniqueue/src/backends/memory_queue.rs @@ -3,7 +3,7 @@ use std::{any::TypeId, collections::HashMap}; use async_trait::async_trait; use serde::Serialize; -use tokio::sync::broadcast; +use tokio::sync::mpsc; use crate::{ decoding::DecoderRegistry, @@ -30,7 +30,7 @@ impl QueueBackend for MemoryQueueBackend { custom_encoders: EncoderRegistry>, custom_decoders: DecoderRegistry>, ) -> Result<(MemoryQueueProducer, MemoryQueueConsumer), QueueError> { - let (tx, rx) = broadcast::channel(config); + let (tx, rx) = mpsc::channel(config); Ok(( MemoryQueueProducer { @@ -62,7 +62,7 @@ impl QueueBackend for MemoryQueueBackend { pub struct MemoryQueueProducer { registry: EncoderRegistry>, - tx: broadcast::Sender>, + tx: mpsc::Sender>, } #[async_trait] @@ -76,7 +76,7 @@ impl QueueProducer for MemoryQueueProducer { async fn send_raw(&self, payload: &Self::Payload) -> Result<(), QueueError> { self.tx .send(payload.clone()) - .map(|_| ()) + .await .map_err(QueueError::generic) } @@ -98,7 +98,7 @@ impl ScheduledProducer for MemoryQueueProducer { tokio::spawn(async move { tracing::trace!("MemoryQueue: event sent > (delay: {:?})", delay); tokio::time::sleep(delay).await; - if tx.send(payload).is_err() { + if tx.send(payload).await.is_err() { tracing::error!("Receiver dropped"); } }); @@ -108,8 +108,8 @@ impl ScheduledProducer for MemoryQueueProducer { pub struct MemoryQueueConsumer { registry: DecoderRegistry>, - rx: broadcast::Receiver>, - tx: broadcast::Sender>, + rx: mpsc::Receiver>, + tx: mpsc::Sender>, } impl MemoryQueueConsumer { @@ -131,7 +131,11 @@ impl QueueConsumer for MemoryQueueConsumer { type Payload = Vec; async fn receive(&mut self) -> Result { - let payload = self.rx.recv().await.map_err(QueueError::generic)?; + let payload = self + .rx + .recv() + .await + .ok_or_else(|| QueueError::Generic("recv failed".into()))?; Ok(self.wrap_payload(payload)) } @@ -143,9 +147,9 @@ impl QueueConsumer for MemoryQueueConsumer { let mut out = Vec::with_capacity(max_messages); let start = Instant::now(); match tokio::time::timeout(deadline, self.rx.recv()).await { - Ok(Ok(x)) => out.push(self.wrap_payload(x)), + Ok(Some(x)) => out.push(self.wrap_payload(x)), // Timeouts and stream termination - Err(_) | Ok(Err(_)) => return Ok(out), + Err(_) | Ok(None) => return Ok(out), } if max_messages > 1 { @@ -163,7 +167,7 @@ impl QueueConsumer for MemoryQueueConsumer { } pub struct MemoryQueueAcker { - tx: broadcast::Sender>, + tx: mpsc::Sender>, payload_copy: Option>, already_acked_or_nacked: bool, } @@ -190,7 +194,7 @@ impl Acker for MemoryQueueAcker { .take() .ok_or(QueueError::CannotAckOrNackTwice)?, ) - .map(|_| ()) + .await .map_err(QueueError::generic) } }