Skip to content

Commit

Permalink
Merge pull request #98 from svix/jplatte/batch-send
Browse files Browse the repository at this point in the history
Add support for sending multiple messages as a batch
  • Loading branch information
svix-jplatte authored Aug 16, 2024
2 parents 2efb242 + 3ffc2bc commit 6fbecdb
Show file tree
Hide file tree
Showing 10 changed files with 310 additions and 115 deletions.
14 changes: 10 additions & 4 deletions omniqueue/src/backends/azure_queue_storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,13 @@ impl AqsProducer {
}
}

impl_queue_producer!(AqsProducer, String);
impl_scheduled_queue_producer!(AqsProducer, String);
impl crate::QueueProducer for AqsProducer {
type Payload = String;
omni_delegate!(send_raw, send_serde_json);
}
impl crate::ScheduledQueueProducer for AqsProducer {
omni_delegate!(send_raw_scheduled, send_serde_json_scheduled);
}

/// Note that blocking receives are not supported by Azure Queue Storage and
/// that message order is not guaranteed.
Expand Down Expand Up @@ -238,11 +243,12 @@ impl AqsConsumer {
}
}

impl_queue_consumer!(for AqsConsumer {
impl crate::QueueConsumer for AqsConsumer {
type Payload = String;
omni_delegate!(receive, receive_all);

fn max_messages(&self) -> Option<NonZeroUsize> {
// https://learn.microsoft.com/en-us/rest/api/storageservices/get-messages#uri-parameters
NonZeroUsize::new(32)
}
});
}
90 changes: 74 additions & 16 deletions omniqueue/src/backends/gcp_pubsub.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::{
time::Duration,
};

use futures_util::StreamExt;
use futures_util::{future::try_join_all, StreamExt};
use google_cloud_googleapis::pubsub::v1::PubsubMessage;
use google_cloud_pubsub::{
client::{google_cloud_auth::credentials::CredentialsFile, Client, ClientConfig},
publisher::Publisher,
subscriber::ReceivedMessage,
subscription::Subscription,
};
Expand Down Expand Up @@ -125,17 +126,7 @@ impl GcpPubSubProducer {
})
}

#[tracing::instrument(
name = "send",
skip_all,
fields(payload_size = payload.len())
)]
pub async fn send_raw(&self, payload: &[u8]) -> Result<()> {
let msg = PubsubMessage {
data: payload.to_vec(),
..Default::default()
};

async fn publisher(&self) -> Result<Publisher> {
// N.b. defer the creation of a publisher/topic until needed. Helps recover when
// the topic does not yet exist, but will soon.
// Might be more expensive to recreate each time, but overall more reliable.
Expand All @@ -150,8 +141,23 @@ impl GcpPubSubProducer {
format!("topic {} does not exist", &self.topic_id).into(),
));
}

// FIXME: may need to expose `PublisherConfig` to caller so they can tweak this
let publisher = topic.new_publisher(None);
Ok(topic.new_publisher(None))
}

#[tracing::instrument(
name = "send",
skip_all,
fields(payload_size = payload.len())
)]
pub async fn send_raw(&self, payload: &[u8]) -> Result<()> {
let msg = PubsubMessage {
data: payload.to_vec(),
..Default::default()
};

let publisher = self.publisher().await?;
let awaiter = publisher.publish(msg).await;
awaiter.get().await.map_err(QueueError::generic)?;
Ok(())
Expand All @@ -170,7 +176,58 @@ impl std::fmt::Debug for GcpPubSubProducer {
}
}

impl_queue_producer!(GcpPubSubProducer, Payload);
impl crate::QueueProducer for GcpPubSubProducer {
type Payload = Payload;
omni_delegate!(send_raw, send_serde_json);

/// This method is overwritten for the Google Cloud Pub/Sub backend to be
/// more efficient than the default of sequentially publishing `payloads`.
#[tracing::instrument(name = "send_batch", skip_all)]
async fn send_raw_batch(
&self,
payloads: impl IntoIterator<Item: AsRef<Self::Payload> + Send, IntoIter: Send> + Send,
) -> Result<()> {
let msgs = payloads
.into_iter()
.map(|payload| PubsubMessage {
data: payload.as_ref().to_vec(),
..Default::default()
})
.collect();

let publisher = self.publisher().await?;
let awaiters = publisher.publish_bulk(msgs).await;
try_join_all(awaiters.into_iter().map(|a| a.get()))
.await
.map_err(QueueError::generic)?;
Ok(())
}

/// This method is overwritten for the Google Cloud Pub/Sub backend to be
/// more efficient than the default of sequentially publishing `payloads`.
#[tracing::instrument(name = "send_batch", skip_all)]
async fn send_serde_json_batch(
&self,
payloads: impl IntoIterator<Item: Serialize + Send, IntoIter: Send> + Send,
) -> Result<()> {
let msgs = payloads
.into_iter()
.map(|payload| {
Ok(PubsubMessage {
data: serde_json::to_vec(&payload)?,
..Default::default()
})
})
.collect::<Result<_>>()?;

let publisher = self.publisher().await?;
let awaiters = publisher.publish_bulk(msgs).await;
try_join_all(awaiters.into_iter().map(|a| a.get()))
.await
.map_err(QueueError::generic)?;
Ok(())
}
}

pub struct GcpPubSubConsumer {
client: Client,
Expand Down Expand Up @@ -254,9 +311,10 @@ async fn subscription(client: &Client, subscription_id: &str) -> Result<Subscrip
Ok(subscription)
}

impl_queue_consumer!(for GcpPubSubConsumer {
impl crate::QueueConsumer for GcpPubSubConsumer {
type Payload = Payload;
});
omni_delegate!(receive, receive_all);
}

struct GcpPubSubAcker {
recv_msg: ReceivedMessage,
Expand Down
14 changes: 10 additions & 4 deletions omniqueue/src/backends/in_memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,13 @@ impl InMemoryProducer {
}
}

impl_queue_producer!(InMemoryProducer, Vec<u8>);
impl_scheduled_queue_producer!(InMemoryProducer, Vec<u8>);
impl crate::QueueProducer for InMemoryProducer {
type Payload = Vec<u8>;
omni_delegate!(send_raw, send_serde_json);
}
impl crate::ScheduledQueueProducer for InMemoryProducer {
omni_delegate!(send_raw_scheduled, send_serde_json_scheduled);
}

pub struct InMemoryConsumer {
rx: mpsc::UnboundedReceiver<Vec<u8>>,
Expand Down Expand Up @@ -142,9 +147,10 @@ impl InMemoryConsumer {
}
}

impl_queue_consumer!(for InMemoryConsumer {
impl crate::QueueConsumer for InMemoryConsumer {
type Payload = Vec<u8>;
});
omni_delegate!(receive, receive_all);
}

struct InMemoryAcker {
tx: mpsc::UnboundedSender<Vec<u8>>,
Expand Down
14 changes: 10 additions & 4 deletions omniqueue/src/backends/rabbitmq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -204,8 +204,13 @@ impl RabbitMqProducer {
}
}

impl_queue_producer!(RabbitMqProducer, Vec<u8>);
impl_scheduled_queue_producer!(RabbitMqProducer, Vec<u8>);
impl crate::QueueProducer for RabbitMqProducer {
type Payload = Vec<u8>;
omni_delegate!(send_raw, send_serde_json);
}
impl crate::ScheduledQueueProducer for RabbitMqProducer {
omni_delegate!(send_raw_scheduled, send_serde_json_scheduled);
}

pub struct RabbitMqConsumer {
consumer: Consumer,
Expand Down Expand Up @@ -269,9 +274,10 @@ impl RabbitMqConsumer {
}
}

impl_queue_consumer!(for RabbitMqConsumer {
impl crate::QueueConsumer for RabbitMqConsumer {
type Payload = Vec<u8>;
});
omni_delegate!(receive, receive_all);
}

struct RabbitMqAcker {
acker: Option<LapinAcker>,
Expand Down
14 changes: 10 additions & 4 deletions omniqueue/src/backends/redis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -573,8 +573,13 @@ impl<R: RedisConnection> RedisProducer<R> {
}
}

impl_queue_producer!(RedisProducer<R: RedisConnection>, Vec<u8>);
impl_scheduled_queue_producer!(RedisProducer<R: RedisConnection>, Vec<u8>);
impl<R: RedisConnection> crate::QueueProducer for RedisProducer<R> {
type Payload = Vec<u8>;
omni_delegate!(send_raw, send_serde_json);
}
impl<R: RedisConnection> crate::ScheduledQueueProducer for RedisProducer<R> {
omni_delegate!(send_raw_scheduled, send_serde_json_scheduled);
}

fn unix_timestamp(time: SystemTime) -> Result<u64, SystemTimeError> {
Ok(time.duration_since(UNIX_EPOCH)?.as_secs())
Expand Down Expand Up @@ -647,6 +652,7 @@ impl<R: RedisConnection> RedisConsumer<R> {
}
}

impl_queue_consumer!(for RedisConsumer<R: RedisConnection> {
impl<R: RedisConnection> crate::QueueConsumer for RedisConsumer<R> {
type Payload = Vec<u8>;
});
omni_delegate!(receive, receive_all);
}
88 changes: 81 additions & 7 deletions omniqueue/src/backends/sqs.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use std::{
fmt::{self, Write},
future::Future,
num::NonZeroUsize,
time::Duration,
};

use aws_sdk_sqs::{
operation::delete_message::DeleteMessageError,
types::{error::ReceiptHandleIsInvalid, Message},
types::{error::ReceiptHandleIsInvalid, Message, SendMessageBatchRequestEntry},
Client,
};
use serde::Serialize;
Expand All @@ -20,6 +21,8 @@ use crate::{

/// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-messages.html
const MAX_PAYLOAD_SIZE: usize = 262_144;
/// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/APIReference/API_SendMessageBatch.html
const MAX_BATCH_SIZE: usize = 10;

#[derive(Clone, Debug, Eq, PartialEq)]
pub struct SqsConfig {
Expand Down Expand Up @@ -305,10 +308,80 @@ impl SqsProducer {
let payload = serde_json::to_string(payload)?;
self.send_raw_scheduled(&payload, delay).await
}

#[tracing::instrument(name = "send_batch", skip_all)]
async fn send_batch_inner<I>(
&self,
payloads: impl IntoIterator<Item = I, IntoIter: Send> + Send,
convert_payload: impl Fn(I) -> Result<String>,
) -> Result<()> {
// Convert payloads up front and collect to Vec to run the payload size
// check on everything before submitting the first batch.
let payloads: Vec<_> = payloads
.into_iter()
.map(convert_payload)
.collect::<Result<_>>()?;

for payload in &payloads {
if payload.len() > MAX_PAYLOAD_SIZE {
return Err(QueueError::PayloadTooLarge {
limit: MAX_PAYLOAD_SIZE,
actual: payload.len(),
});
}
}

for payloads in payloads.chunks(MAX_BATCH_SIZE) {
let entries = payloads
.iter()
.enumerate()
.map(|(i, payload)| {
SendMessageBatchRequestEntry::builder()
.message_body(payload)
.id(i.to_string())
.build()
.map_err(QueueError::generic)
})
.collect::<Result<_>>()?;

self.client
.send_message_batch()
.queue_url(&self.queue_dsn)
.set_entries(Some(entries))
.send()
.await
.map_err(aws_to_queue_error)?;
}

Ok(())
}
}

impl_queue_producer!(SqsProducer, String);
impl_scheduled_queue_producer!(SqsProducer, String);
impl crate::QueueProducer for SqsProducer {
type Payload = String;
omni_delegate!(send_raw, send_serde_json);

/// This method is overwritten for the SQS backend to be more efficient
/// than the default of sequentially publishing `payloads`.
fn send_raw_batch(
&self,
payloads: impl IntoIterator<Item: AsRef<Self::Payload> + Send, IntoIter: Send> + Send,
) -> impl Future<Output = Result<()>> {
self.send_batch_inner(payloads, |p| Ok(p.as_ref().into()))
}

/// This method is overwritten for the SQS backend to be more efficient
/// than the default of sequentially publishing `payloads`.
fn send_serde_json_batch(
&self,
payloads: impl IntoIterator<Item: Serialize + Send, IntoIter: Send> + Send,
) -> impl Future<Output = Result<()>> {
self.send_batch_inner(payloads, |p| Ok(serde_json::to_string(&p)?))
}
}
impl crate::ScheduledQueueProducer for SqsProducer {
omni_delegate!(send_raw_scheduled, send_serde_json_scheduled);
}

pub struct SqsConsumer {
client: Client,
Expand Down Expand Up @@ -369,15 +442,16 @@ impl SqsConsumer {
}
}

impl_queue_consumer!(for SqsConsumer {
impl crate::QueueConsumer for SqsConsumer {
type Payload = String;
omni_delegate!(receive, receive_all);

fn max_messages(&self) -> Option<NonZeroUsize> {
// Not very clearly documented, but this doc mentions "batch of 10 messages" a few times:
// https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-messages.html
// Not very clearly documented, but this doc mentions "batch of 10 messages" a
// few times: https://docs.aws.amazon.com/AWSSimpleQueueService/latest/SQSDeveloperGuide/quotas-messages.html
NonZeroUsize::new(10)
}
});
}

fn aws_to_queue_error<E>(err: aws_sdk_sqs::error::SdkError<E>) -> QueueError
where
Expand Down
Loading

0 comments on commit 6fbecdb

Please sign in to comment.