Skip to content

Commit

Permalink
refactor: switch broadcast for mpsc
Browse files Browse the repository at this point in the history
Unsure if this is going to be needed.
Can avoid redoing/adding mem backend by simply giving usize::MAX for
capacity, or whatever is allowable for broadcast. The underlying channel
is a private detail.
  • Loading branch information
svix-onelson committed Feb 2, 2024
1 parent a95ce9a commit e0ed48c
Showing 1 changed file with 16 additions and 12 deletions.
28 changes: 16 additions & 12 deletions omniqueue/src/backends/memory_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::{any::TypeId, collections::HashMap};

use async_trait::async_trait;
use serde::Serialize;
use tokio::sync::broadcast;
use tokio::sync::mpsc;

use crate::{
decoding::DecoderRegistry,
Expand All @@ -30,7 +30,7 @@ impl QueueBackend for MemoryQueueBackend {
custom_encoders: EncoderRegistry<Vec<u8>>,
custom_decoders: DecoderRegistry<Vec<u8>>,
) -> Result<(MemoryQueueProducer, MemoryQueueConsumer), QueueError> {
let (tx, rx) = broadcast::channel(config);
let (tx, rx) = mpsc::channel(config);

Ok((
MemoryQueueProducer {
Expand Down Expand Up @@ -62,7 +62,7 @@ impl QueueBackend for MemoryQueueBackend {

pub struct MemoryQueueProducer {
registry: EncoderRegistry<Vec<u8>>,
tx: broadcast::Sender<Vec<u8>>,
tx: mpsc::Sender<Vec<u8>>,
}

#[async_trait]
Expand All @@ -76,7 +76,7 @@ impl QueueProducer for MemoryQueueProducer {
async fn send_raw(&self, payload: &Self::Payload) -> Result<(), QueueError> {
self.tx
.send(payload.clone())
.map(|_| ())
.await
.map_err(QueueError::generic)
}

Expand All @@ -98,7 +98,7 @@ impl ScheduledProducer for MemoryQueueProducer {
tokio::spawn(async move {
tracing::trace!("MemoryQueue: event sent > (delay: {:?})", delay);
tokio::time::sleep(delay).await;
if tx.send(payload).is_err() {
if tx.send(payload).await.is_err() {
tracing::error!("Receiver dropped");
}
});
Expand All @@ -108,8 +108,8 @@ impl ScheduledProducer for MemoryQueueProducer {

pub struct MemoryQueueConsumer {
registry: DecoderRegistry<Vec<u8>>,
rx: broadcast::Receiver<Vec<u8>>,
tx: broadcast::Sender<Vec<u8>>,
rx: mpsc::Receiver<Vec<u8>>,
tx: mpsc::Sender<Vec<u8>>,
}

impl MemoryQueueConsumer {
Expand All @@ -131,7 +131,11 @@ impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;
let payload = self
.rx
.recv()
.await
.ok_or_else(|| QueueError::Generic("recv failed".into()))?;
Ok(self.wrap_payload(payload))
}

Expand All @@ -143,9 +147,9 @@ impl QueueConsumer for MemoryQueueConsumer {
let mut out = Vec::with_capacity(max_messages);
let start = Instant::now();
match tokio::time::timeout(deadline, self.rx.recv()).await {
Ok(Ok(x)) => out.push(self.wrap_payload(x)),
Ok(Some(x)) => out.push(self.wrap_payload(x)),
// Timeouts and stream termination
Err(_) | Ok(Err(_)) => return Ok(out),
Err(_) | Ok(None) => return Ok(out),
}

if max_messages > 1 {
Expand All @@ -163,7 +167,7 @@ impl QueueConsumer for MemoryQueueConsumer {
}

pub struct MemoryQueueAcker {
tx: broadcast::Sender<Vec<u8>>,
tx: mpsc::Sender<Vec<u8>>,
payload_copy: Option<Vec<u8>>,
already_acked_or_nacked: bool,
}
Expand All @@ -190,7 +194,7 @@ impl Acker for MemoryQueueAcker {
.take()
.ok_or(QueueError::CannotAckOrNackTwice)?,
)
.map(|_| ())
.await
.map_err(QueueError::generic)
}
}
Expand Down

0 comments on commit e0ed48c

Please sign in to comment.