Skip to content

Commit

Permalink
Support redis max_receives config option
Browse files Browse the repository at this point in the history
This adds a `max_receives` option to both Redis queue implementations.
This is the first step in supporting deadletter queuing.

`InternalPayload` type aliases have been added to represent Omniqueue
items that track the current number of times a message has been
received (`num_receives`). `num_receives` is incremented whenever
an item is re-queued from the pending/processing queues until it
hits `max_receives`, at which point the message is abandoned. Later
we will support putting this in an optional deadletter queue.

I originally tried adding a proper struct for `InternalPayload`, but
I don't think that clarified anything and was less memory optimal,
so the relevant logic has been captured in simple functions/ a macro
instead.
  • Loading branch information
jaymell committed Aug 30, 2024
1 parent 6fbecdb commit a49ab47
Show file tree
Hide file tree
Showing 6 changed files with 468 additions and 85 deletions.
97 changes: 74 additions & 23 deletions omniqueue/src/backends/redis/fallback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@ use svix_ksuid::{KsuidLike as _, KsuidMs};
use time::OffsetDateTime;
use tracing::{error, trace};

use super::{from_key, to_key, RawPayload, RedisConnection, RedisConsumer, RedisProducer};
use super::{
internal_from_list, internal_to_list_payload, InternalPayloadOwned, RawPayload,
RedisConnection, RedisConsumer, RedisProducer,
};
use crate::{queue::Acker, Delivery, QueueError, Result};

pub(super) async fn send_raw<R: RedisConnection>(
Expand All @@ -21,7 +24,7 @@ pub(super) async fn send_raw<R: RedisConnection>(
.get()
.await
.map_err(QueueError::generic)?
.lpush(&producer.queue_key, to_key(payload))
.lpush(&producer.queue_key, internal_to_list_payload((payload, 0)))
.await
.map_err(QueueError::generic)
}
Expand All @@ -45,7 +48,7 @@ async fn receive_with_timeout<R: RedisConnection>(
consumer: &RedisConsumer<R>,
timeout: Duration,
) -> Result<Option<Delivery>> {
let key: Option<Vec<u8>> = consumer
let payload: Option<Vec<u8>> = consumer
.redis
.get()
.await
Expand All @@ -61,29 +64,52 @@ async fn receive_with_timeout<R: RedisConnection>(
.await
.map_err(QueueError::generic)?;

key.map(|key| make_delivery(consumer, &key)).transpose()
match payload {
Some(old_payload) => {
let (payload, num_receives) = internal_from_list(&old_payload)?;
Some(internal_to_delivery(
(payload.to_vec(), num_receives),
consumer,
old_payload,
))
.transpose()
}
None => Ok(None),
}
}

fn make_delivery<R: RedisConnection>(consumer: &RedisConsumer<R>, key: &[u8]) -> Result<Delivery> {
let (_, payload) = from_key(key)?;

fn internal_to_delivery<R: RedisConnection>(
internal: InternalPayloadOwned,
consumer: &RedisConsumer<R>,
old_payload: Vec<u8>,
) -> Result<Delivery> {
let (payload, num_receives) = internal;
Ok(Delivery::new(
payload.to_owned(),
payload,
RedisFallbackAcker {
redis: consumer.redis.clone(),
processing_queue_key: consumer.processing_queue_key.clone(),
key: key.to_owned(),
old_payload,
already_acked_or_nacked: false,
max_receives: consumer.max_receives,
num_receives,
},
))
}

struct RedisFallbackAcker<M: ManageConnection> {
redis: bb8::Pool<M>,
processing_queue_key: String,
key: RawPayload,
pub(super) struct RedisFallbackAcker<M: ManageConnection> {
pub(super) redis: bb8::Pool<M>,
pub(super) processing_queue_key: String,
// We delete based on the payload -- and since the
// `num_receives` changes after receiving it's the
// `old_payload`, since `num_receives` is part of the
// payload. Make sense?
pub(super) old_payload: RawPayload,

pub(super) already_acked_or_nacked: bool,

already_acked_or_nacked: bool,
pub(super) max_receives: usize,
pub(super) num_receives: usize,
}

impl<R: RedisConnection> Acker for RedisFallbackAcker<R> {
Expand All @@ -97,7 +123,7 @@ impl<R: RedisConnection> Acker for RedisFallbackAcker<R> {
.get()
.await
.map_err(QueueError::generic)?
.lrem(&self.processing_queue_key, 1, &self.key)
.lrem(&self.processing_queue_key, 1, &self.old_payload)
.await
.map_err(QueueError::generic)?;

Expand All @@ -107,6 +133,11 @@ impl<R: RedisConnection> Acker for RedisFallbackAcker<R> {
}

async fn nack(&mut self) -> Result<()> {
if self.num_receives >= self.max_receives {
trace!("Maximum attempts reached");
return self.ack().await;
}

if self.already_acked_or_nacked {
return Err(QueueError::CannotAckOrNackTwice);
}
Expand Down Expand Up @@ -144,13 +175,19 @@ pub(super) async fn background_task_processing<R: RedisConnection>(
queue_key: String,
processing_queue_key: String,
ack_deadline_ms: i64,
max_receives: usize,
) -> Result<()> {
// FIXME: ack_deadline_ms should be unsigned
let ack_deadline = Duration::from_millis(ack_deadline_ms as _);
loop {
if let Err(err) =
reenqueue_timed_out_messages(&pool, &queue_key, &processing_queue_key, ack_deadline)
.await
if let Err(err) = reenqueue_timed_out_messages(
&pool,
&queue_key,
&processing_queue_key,
ack_deadline,
max_receives,
)
.await
{
error!("{err}");
tokio::time::sleep(Duration::from_millis(500)).await;
Expand All @@ -164,6 +201,7 @@ async fn reenqueue_timed_out_messages<R: RedisConnection>(
queue_key: &str,
processing_queue_key: &str,
ack_deadline: Duration,
max_receives: usize,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
const BATCH_SIZE: isize = 50;

Expand All @@ -180,10 +218,24 @@ async fn reenqueue_timed_out_messages<R: RedisConnection>(
let keys: Vec<RawPayload> = conn.lrange(processing_queue_key, 0, BATCH_SIZE).await?;
for key in keys {
if key <= validity_limit {
let internal = internal_from_list(&key)?;
let num_receives = internal.1;
if num_receives >= max_receives {
trace!(
num_receives = num_receives,
"Maximum attempts reached for message, not reenqueuing",
);
} else {
trace!(
num_receives = num_receives,
"Pushing back overdue task to queue"
);
let _: () = conn
.rpush(queue_key, internal_to_list_payload(internal))
.await?;
}

// We use LREM to be sure we only delete the keys we should be deleting
trace!("Pushing back overdue task to queue");
let refreshed_key = regenerate_key(&key)?;
let _: () = conn.rpush(queue_key, &refreshed_key).await?;
let _: () = conn.lrem(processing_queue_key, 1, &key).await?;
}
}
Expand All @@ -196,6 +248,5 @@ async fn reenqueue_timed_out_messages<R: RedisConnection>(
}

fn regenerate_key(key: &[u8]) -> Result<RawPayload> {
let (_, payload) = from_key(key)?;
Ok(to_key(payload))
Ok(internal_to_list_payload(internal_from_list(key)?))
}
95 changes: 69 additions & 26 deletions omniqueue/src/backends/redis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,63 @@ impl RedisConnection for RedisClusterConnectionManager {
}
}

// First element is the raw payload slice, second
// is `num_receives`, the number of the times
// the message has previously been received.
type InternalPayload<'a> = (&'a [u8], usize);

// The same as `InternalPayload` but with an
// owned payload.
type InternalPayloadOwned = (Vec<u8>, usize);

fn internal_from_list(payload: &[u8]) -> Result<InternalPayload<'_>> {
// All information is stored in the key in which the ID and the [optional]
// number of prior receives are separated by a `#`, and the JSON
// formatted task is delimited by a `|` So, take the key, then take the
// optional receive count, then take the part after the `|` to get the
// payload.
let count_sep_pos = payload.iter().position(|&byte| byte == b'#');
let payload_sep_pos = payload
.iter()
.position(|&byte| byte == b'|')
.ok_or_else(|| QueueError::Generic("Improper key format".into()))?;

let id_end_pos = match count_sep_pos {
Some(count_sep_pos) if count_sep_pos < payload_sep_pos => count_sep_pos,
_ => payload_sep_pos,
};
let _id = str::from_utf8(&payload[..id_end_pos])
.map_err(|_| QueueError::Generic("Non-UTF8 key ID".into()))?;

// This should be backward-compatible with messages that don't include
// `num_receives`
let num_receives = if let Some(count_sep_pos) = count_sep_pos {
let num_receives = std::str::from_utf8(&payload[(count_sep_pos + 1)..payload_sep_pos])
.map_err(|_| QueueError::Generic("Improper key format".into()))?
.parse::<usize>()
.map_err(|_| QueueError::Generic("Improper key format".into()))?;
num_receives + 1
} else {
1
};

Ok((&payload[payload_sep_pos + 1..], num_receives))
}

fn internal_to_list_payload(internal: InternalPayload) -> Vec<u8> {
let id = delayed_key_id();
let (payload, num_receives) = internal;
let num_receives = num_receives.to_string();
let mut result =
Vec::with_capacity(id.len() + num_receives.as_bytes().len() + payload.len() + 3);
result.extend(id.as_bytes());
result.push(b'#');
result.extend(num_receives.as_bytes());
result.push(b'|');
result.extend(payload);
result
}

#[derive(Debug, Error)]
enum EvictionCheckError {
#[error("Unable to verify eviction policy. Ensure `maxmemory-policy` set to `noeviction` or `volatile-*`")]
Expand Down Expand Up @@ -139,6 +196,7 @@ pub struct RedisConfig {
pub consumer_name: String,
pub payload_key: String,
pub ack_deadline_ms: i64,
pub max_receives: Option<usize>,
}

pub struct RedisBackend<R = RedisConnectionManager>(PhantomData<R>);
Expand Down Expand Up @@ -290,6 +348,7 @@ impl<R: RedisConnection> RedisBackendBuilder<R> {
payload_key: self.config.payload_key,
use_redis_streams: self.use_redis_streams,
_background_tasks: background_tasks.clone(),
max_receives: self.config.max_receives.unwrap_or(usize::MAX),
},
))
}
Expand Down Expand Up @@ -332,6 +391,7 @@ impl<R: RedisConnection> RedisBackendBuilder<R> {
consumer_name: self.config.consumer_name,
payload_key: self.config.payload_key,
use_redis_streams: self.use_redis_streams,
max_receives: self.config.max_receives.unwrap_or(usize::MAX),
_background_tasks,
})
}
Expand Down Expand Up @@ -395,13 +455,16 @@ impl<R: RedisConnection> RedisBackendBuilder<R> {
self.config.consumer_group.to_owned(),
self.config.consumer_name.to_owned(),
self.config.ack_deadline_ms,
self.config.max_receives.unwrap_or(usize::MAX),
self.config.payload_key.to_owned(),
));
} else {
join_set.spawn(fallback::background_task_processing(
redis.clone(),
self.config.queue_key.to_owned(),
self.get_processing_queue_key(),
self.config.ack_deadline_ms,
self.config.max_receives.unwrap_or(usize::MAX),
));
}

Expand Down Expand Up @@ -555,7 +618,11 @@ impl<R: RedisConnection> RedisProducer<R> {
.get()
.await
.map_err(QueueError::generic)?
.zadd(&self.delayed_queue_key, to_key(payload), timestamp)
.zadd(
&self.delayed_queue_key,
internal_to_list_payload((payload, 0)),
timestamp,
)
.await
.map_err(QueueError::generic)?;

Expand Down Expand Up @@ -594,31 +661,6 @@ fn delayed_key_id() -> String {
svix_ksuid::Ksuid::new(None, None).to_base62()
}

/// Prefixes a payload with an id, separated by a pipe, e.g `ID|payload`.
fn to_key(payload: &[u8]) -> RawPayload {
let id = delayed_key_id();

let mut result = Vec::with_capacity(id.len() + payload.len() + 1);
result.extend(id.as_bytes());
result.push(b'|');
result.extend(payload);
result
}

/// Splits a key encoded with [`to_key`] into ID and payload.
fn from_key(key: &[u8]) -> Result<(&str, &[u8])> {
// All information is stored in the key in which the ID and JSON formatted task
// are separated by a `|`. So, take the key, then take the part after the `|`.
let sep_pos = key
.iter()
.position(|&byte| byte == b'|')
.ok_or_else(|| QueueError::Generic("Improper key format".into()))?;
let id = str::from_utf8(&key[..sep_pos])
.map_err(|_| QueueError::Generic("Non-UTF8 key ID".into()))?;

Ok((id, &key[sep_pos + 1..]))
}

pub struct RedisConsumer<M: ManageConnection> {
redis: bb8::Pool<M>,
queue_key: String,
Expand All @@ -627,6 +669,7 @@ pub struct RedisConsumer<M: ManageConnection> {
consumer_name: String,
payload_key: String,
use_redis_streams: bool,
max_receives: usize,
_background_tasks: Arc<JoinSet<Result<()>>>,
}

Expand Down
Loading

0 comments on commit a49ab47

Please sign in to comment.