Skip to content

Commit

Permalink
replace OnceCell with better logic
Browse files Browse the repository at this point in the history
  • Loading branch information
rob-maron committed Jul 2, 2024
1 parent fdcf888 commit d2347eb
Showing 1 changed file with 104 additions and 93 deletions.
197 changes: 104 additions & 93 deletions cdn-client/src/retry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use cdn_proto::{
message::{Message, Topic},
};
use tokio::{
sync::{OnceCell, RwLock},
spawn,
sync::{RwLock, Semaphore},
time::sleep,
};
use tracing::{error, info};
Expand All @@ -46,7 +47,10 @@ pub struct Inner<C: ConnectionDef> {
use_local_authority: bool,

/// The underlying connection
connection: Arc<RwLock<OnceCell<Connection>>>,
connection: Arc<RwLock<Option<Connection>>>,

/// The semaphore to ensure only one reconnection is happening at a time
connecting_guard: Semaphore,

/// The keypair to use when authenticating
pub keypair: KeyPair<Scheme<C>>,
Expand Down Expand Up @@ -126,51 +130,93 @@ pub struct Config<C: ConnectionDef> {
pub subscribed_topics: Vec<Topic>,
}

/// This is a macro that helps with reconnections when sending
/// and receiving messages. You can specify the operation and it
/// will reconnect on the operation's failure, while handling all
/// reconnection logic and synchronization patterns.
macro_rules! try_with_reconnect {
($self: expr, $out: expr) => {{
// See if operation was an error
match $out {
Ok(res) => Ok(res),
Err(err) => {
// Acquire our "semaphore". If another task is doing this, just return an error
if let Ok(mut connection_guard) = $self.inner.connection.clone().try_write_owned() {
error!("connection failed: {err}, reconnecting");

// Clone `inner` so we can use it in the task
let inner = $self.inner.clone();
// We are the only ones reconnecting. Let's launch the task to reconnect
tokio::spawn(async move {
// Loop to connect and authenticate
let connection = loop {
// Sleep so we don't overload the server
sleep(Duration::from_secs(2)).await;

// Create a connection
match inner.connect().await {
Ok(connection) => break connection,
Err(err) => {
error!("failed connection: {err}");
}
}
};

// Update the connection and drop the guard
*connection_guard = OnceCell::from(connection);
});
// Disconnects the current connection if an error was passed in and we're
// not already reconnecting.
macro_rules! disconnect_on_error {
($self:expr, $res: expr) => {
match $res {
Ok(t) => Ok(t),
Err(e) => {
// If we are not currently reconnecting, take the current connection
if $self.inner.connecting_guard.available_permits() > 0 {
// If we ran into an error, take the current connection.
// This will only start reconnecting if we try to receive or send another message.
$self.inner.connection.write().await.take();
}

// We are trying to reconnect. Return an error.
return Err(Error::Connection("reconnection in progress".to_string()));
Err(e)
}
}
}};
};
}

impl<C: ConnectionDef> Retry<C> {
/// Get the underlying connection if it exists, otherwise try to reconnect.
///
/// # Errors
/// - If we are in the middle of reconnecting
fn reconnect_if_needed(&self, possible_connection: Option<Connection>) -> Result<Connection> {
let Some(connection) = possible_connection else {
// If the connection is not initialized for one reason or another, try to reconnect
// Acquire the semaphore to ensure only one reconnection is happening at a time
if let Ok(permit) = self.inner.connecting_guard.try_acquire() {
// We were the first to try reconnecting, spawn a reconnection task
let inner = self.inner.clone();
spawn(async move {
let mut connection = inner.connection.write().await;

// Forever,
loop {
// Try to reconnect
match inner.connect().await {
Ok(new_connection) => {
// We successfully reconnected
*connection = Some(new_connection);
break;
}
Err(e) => {
// We failed to reconnect
// Sleep for 2 seconds and then try again
error!(error = %e, "failed to reconnect");
sleep(Duration::from_secs(2)).await;
}
}
}
_ = permit;
});
}

// If we are in the middle of reconnecting, return an error
return Err(Error::Connection("connection in progress".to_string()));
};

Ok(connection)
}

/// Get the connection if it exists, wait for a potential reconnection if it does not.
///
/// # Errors
/// - If somebody else is already reconnecting
async fn get_connection(&self) -> Result<Connection> {
let possible_connection = self.inner.connection.read().await;

// TODO: figure out a potential way to remove this clone
self.reconnect_if_needed(possible_connection.clone())
}

/// Try to get the connection if it exists, otherwise try to reconnect. Does not block.
///
/// # Errors
/// - If we are in the middle of reconnecting
fn try_get_connection(&self) -> Result<Connection> {
let Ok(possible_connection) = self.inner.connection.try_read() else {
// Someone else is already reconnecting
return Err(Error::Connection("connection in progress".to_string()));
};

self.reconnect_if_needed(possible_connection.clone())
}

/// Creates a new `Retry` connection from a `Config`
/// Attempts to make an initial connection.
/// This allows us to create elastic clients that always try to maintain a connection.
Expand All @@ -197,6 +243,7 @@ impl<C: ConnectionDef> Retry<C> {
use_local_authority,
// TODO: parameterize batch params
connection: Arc::default(),
connecting_guard: Semaphore::const_new(1),
keypair,
subscribed_topics,
}),
Expand All @@ -205,18 +252,9 @@ impl<C: ConnectionDef> Retry<C> {

/// Returns only when the connection is fully initialized
pub async fn ensure_initialized(&self) {
// In a loop, attempt to initialize the connection (if not yet)
while let Err(err) = self
.inner
.connection
.read()
.await
.get_or_try_init(|| self.inner.connect())
.await
{
// Try to get the underlying connection
while let Err(err) = self.get_connection().await {
error!("failed to initialize connection: {err}");

// Wait a bit so we don't overload the server
sleep(Duration::from_secs(2)).await;
}
}
Expand All @@ -229,22 +267,11 @@ impl<C: ConnectionDef> Retry<C> {
/// - If we are reconnecting
/// - If we are disconnected
pub async fn send_message(&self, message: Message) -> Result<()> {
// Check if we're (probably) reconnecting or not
if let Ok(connection_guard) = self.inner.connection.try_read() {
// We're not reconnecting, try to send the message
// Initialize the connection if it does not yet exist
let out = connection_guard
.get_or_try_init(|| self.inner.connect())
.await?
.send_message(message)
.await;
drop(connection_guard);

try_with_reconnect!(self, out)
} else {
// We are reconnecting, return an error
Err(Error::Connection("reconnection in progress".to_string()))
}
// Try to get the underlying connection
let connection = self.try_get_connection()?;

// Soft close the connection
disconnect_on_error!(self, connection.send_message(message).await)
}

/// Receives a message from the underlying fallible connection. Reconnection logic is here,
Expand All @@ -254,19 +281,11 @@ impl<C: ConnectionDef> Retry<C> {
/// - If we are in the middle of reconnecting
/// - If the message receiving failed
pub async fn receive_message(&self) -> Result<Message> {
// Wait for a reconnection before trying to receive
let connection_guard = self.inner.connection.read().await;

// Initialize the connection if it does not yet exist
let out = connection_guard
.get_or_try_init(|| self.inner.connect())
.await?
.recv_message()
.await;
drop(connection_guard);

// If we failed to receive a message, kick off reconnection logic
try_with_reconnect!(self, out)
// Try to synchronously get the underlying connection
let connection = self.get_connection().await?;

// Receive a message
disconnect_on_error!(self, connection.recv_message().await)
}

/// Soft close the connection, ensuring that all messages are sent.
Expand All @@ -275,18 +294,10 @@ impl<C: ConnectionDef> Retry<C> {
/// - If we are in the middle of reconnecting
/// - If the connection is closed
pub async fn soft_close(&self) -> Result<()> {
// Check if we're (probably) reconnecting or not
if let Ok(connection_guard) = self.inner.connection.try_read() {
// We're not reconnecting, try to send the message
// Initialize the connection if it does not yet exist
connection_guard
.get_or_try_init(|| self.inner.connect())
.await?
.soft_close()
.await
} else {
// We are reconnecting, return an error
Err(Error::Connection("reconnection in progress".to_string()))
}
// Try to get the underlying connection
let connection = self.try_get_connection()?;

// Soft close the connection
disconnect_on_error!(self, connection.soft_close().await)
}
}

0 comments on commit d2347eb

Please sign in to comment.