Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Consumer::receive_all #18

Merged
merged 1 commit into from
Oct 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 40 additions & 14 deletions omniqueue/src/backends/gcp_pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use google_cloud_pubsub::subscription::Subscription;
use serde::Serialize;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use std::{any::TypeId, collections::HashMap};

pub struct GcpPubSubBackend;
Expand Down Expand Up @@ -217,31 +218,56 @@ async fn subscription(client: &Client, subscription_id: &str) -> Result<Subscrip
Ok(subscription)
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let mut recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;
impl GcpPubSubConsumer {
fn wrap_recv_msg(&self, mut recv_msg: ReceivedMessage) -> Delivery {
// FIXME: would be nice to avoid having to move the data out here.
// While it's possible to ack via a subscription and an ack_id, nack is only
// possible via a `ReceiveMessage`. This means we either need to hold 2 copies of
// the payload, or move the bytes out so they can be returned _outside of the Acker_.
let payload = recv_msg.message.data.drain(..).collect();
Ok(Delivery {

Delivery {
decoders: self.registry.clone(),
acker: Box::new(GcpPubSubAcker {
recv_msg,
subscription_id: self.subscription_id.clone(),
}),
payload: Some(payload),
})
}
}
}

#[async_trait]
impl QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
let mut stream = subscription
.subscribe(None)
.await
.map_err(QueueError::generic)?;

let recv_msg = stream.next().await.ok_or_else(|| QueueError::NoData)?;

Ok(self.wrap_recv_msg(recv_msg))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let subscription = subscription(&self.client, &self.subscription_id).await?;
match tokio::time::timeout(deadline, subscription.pull(max_messages as _, None)).await {
Ok(messages) => Ok(messages
.map_err(QueueError::generic)?
.into_iter()
.map(|m| self.wrap_recv_msg(m))
.collect()),
// Timeout
Err(_) => Ok(vec![]),
}
}
}

Expand Down
180 changes: 171 additions & 9 deletions omniqueue/src/backends/memory_queue.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::time::{Duration, Instant};
use std::{any::TypeId, collections::HashMap};

use async_trait::async_trait;
Expand Down Expand Up @@ -90,22 +91,53 @@ pub struct MemoryQueueConsumer {
tx: broadcast::Sender<Vec<u8>>,
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;

Ok(Delivery {
impl MemoryQueueConsumer {
fn wrap_payload(&self, payload: Vec<u8>) -> Delivery {
Delivery {
payload: Some(payload.clone()),
decoders: self.registry.clone(),
acker: Box::new(MemoryQueueAcker {
tx: self.tx.clone(),
payload_copy: Some(payload),
alredy_acked_or_nacked: false,
}),
})
}
}
}

#[async_trait]
impl QueueConsumer for MemoryQueueConsumer {
type Payload = Vec<u8>;

async fn receive(&mut self) -> Result<Delivery, QueueError> {
let payload = self.rx.recv().await.map_err(QueueError::generic)?;
Ok(self.wrap_payload(payload))
}

async fn receive_all(
&mut self,
max_messages: usize,
deadline: Duration,
) -> Result<Vec<Delivery>, QueueError> {
let mut out = Vec::with_capacity(max_messages);
let start = Instant::now();
match tokio::time::timeout(deadline, self.rx.recv()).await {
Ok(Ok(x)) => out.push(self.wrap_payload(x)),
// Timeouts and stream termination
Err(_) | Ok(Err(_)) => return Ok(out),
}

if max_messages > 1 {
// `try_recv` will break the loop if no ready items are already buffered in the channel.
// This should allow us to opportunistically fill up the buffer in the remaining time.
while let Ok(x) = self.rx.try_recv() {
out.push(self.wrap_payload(x));
if out.len() >= max_messages || start.elapsed() >= deadline {
break;
}
}
svix-gabriel marked this conversation as resolved.
Show resolved Hide resolved
}
Ok(out)
}
}

Expand Down Expand Up @@ -146,6 +178,7 @@ impl Acker for MemoryQueueAcker {
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use std::time::{Duration, Instant};

use crate::{
queue::{consumer::QueueConsumer, producer::QueueProducer, QueueBuilder},
Expand Down Expand Up @@ -233,4 +266,133 @@ mod tests {
TypeA { a: 12 }
);
}

#[derive(Debug, Deserialize, Serialize, PartialEq)]
pub struct ExType {
a: u8,
}

/// Consumer will return immediately if there are fewer than max messages to start with.
#[tokio::test]
async fn test_send_recv_all_partial() {
let payload = ExType { a: 2 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload).await.unwrap();
let deadline = Duration::from_secs(1);

let now = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 1);
let d = xs.remove(0);
assert_eq!(d.payload_serde_json::<ExType>().unwrap().unwrap(), payload);
d.ack().await.unwrap();
assert!(now.elapsed() <= deadline);
}

/// Consumer should yield items immediately if there's a full batch ready on the first poll.
#[tokio::test]
async fn test_send_recv_all_full() {
let payload1 = ExType { a: 1 };
let payload2 = ExType { a: 2 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload1).await.unwrap();
p.send_serde_json(&payload2).await.unwrap();
let deadline = Duration::from_secs(1);

let now = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 2);
let d1 = xs.remove(0);
assert_eq!(
d1.payload_serde_json::<ExType>().unwrap().unwrap(),
payload1
);
d1.ack().await.unwrap();

let d2 = xs.remove(0);
assert_eq!(
d2.payload_serde_json::<ExType>().unwrap().unwrap(),
payload2
);
d2.ack().await.unwrap();
// N.b. it's still possible this could turn up false if the test runs too slow.
assert!(now.elapsed() < deadline);
}

/// Consumer will return the full batch immediately, but also return immediately if a partial batch is ready.
#[tokio::test]
async fn test_send_recv_all_full_then_partial() {
let payload1 = ExType { a: 1 };
let payload2 = ExType { a: 2 };
let payload3 = ExType { a: 3 };

let (p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

p.send_serde_json(&payload1).await.unwrap();
p.send_serde_json(&payload2).await.unwrap();
p.send_serde_json(&payload3).await.unwrap();

let deadline = Duration::from_secs(1);
let now1 = Instant::now();
let mut xs = c.receive_all(2, deadline).await.unwrap();
assert_eq!(xs.len(), 2);
let d1 = xs.remove(0);
assert_eq!(
d1.payload_serde_json::<ExType>().unwrap().unwrap(),
payload1
);
d1.ack().await.unwrap();

let d2 = xs.remove(0);
assert_eq!(
d2.payload_serde_json::<ExType>().unwrap().unwrap(),
payload2
);
d2.ack().await.unwrap();
assert!(now1.elapsed() < deadline);

// 2nd call
let now2 = Instant::now();
let mut ys = c.receive_all(2, deadline).await.unwrap();
assert_eq!(ys.len(), 1);
let d3 = ys.remove(0);
assert_eq!(
d3.payload_serde_json::<ExType>().unwrap().unwrap(),
payload3
);
d3.ack().await.unwrap();
assert!(now2.elapsed() <= deadline);
}

/// Consumer will NOT wait indefinitely for at least one item.
#[tokio::test]
async fn test_send_recv_all_late_arriving_items() {
let (_p, mut c) = QueueBuilder::<MemoryQueueBackend, _>::new(16)
.build_pair()
.await
.unwrap();

let deadline = Duration::from_secs(1);
let now = Instant::now();
let xs = c.receive_all(2, deadline).await.unwrap();
let elapsed = now.elapsed();

assert_eq!(xs.len(), 0);
// Elapsed should be around the deadline, ballpark
assert!(elapsed >= deadline);
assert!(elapsed <= deadline + Duration::from_millis(200));
}
}
Loading
Loading