From d2347eb88899d723631ec13407fa9f3c801e61fb Mon Sep 17 00:00:00 2001 From: Rob Date: Mon, 1 Jul 2024 20:09:51 -0400 Subject: [PATCH] replace `OnceCell` with better logic --- cdn-client/src/retry.rs | 197 +++++++++++++++++++++------------------- 1 file changed, 104 insertions(+), 93 deletions(-) diff --git a/cdn-client/src/retry.rs b/cdn-client/src/retry.rs index 3323020..d1862bd 100644 --- a/cdn-client/src/retry.rs +++ b/cdn-client/src/retry.rs @@ -19,7 +19,8 @@ use cdn_proto::{ message::{Message, Topic}, }; use tokio::{ - sync::{OnceCell, RwLock}, + spawn, + sync::{RwLock, Semaphore}, time::sleep, }; use tracing::{error, info}; @@ -46,7 +47,10 @@ pub struct Inner { use_local_authority: bool, /// The underlying connection - connection: Arc>>, + connection: Arc>>, + + /// The semaphore to ensure only one reconnection is happening at a time + connecting_guard: Semaphore, /// The keypair to use when authenticating pub keypair: KeyPair>, @@ -126,51 +130,93 @@ pub struct Config { pub subscribed_topics: Vec, } -/// This is a macro that helps with reconnections when sending -/// and receiving messages. You can specify the operation and it -/// will reconnect on the operation's failure, while handling all -/// reconnection logic and synchronization patterns. -macro_rules! try_with_reconnect { - ($self: expr, $out: expr) => {{ - // See if operation was an error - match $out { - Ok(res) => Ok(res), - Err(err) => { - // Acquire our "semaphore". If another task is doing this, just return an error - if let Ok(mut connection_guard) = $self.inner.connection.clone().try_write_owned() { - error!("connection failed: {err}, reconnecting"); - - // Clone `inner` so we can use it in the task - let inner = $self.inner.clone(); - // We are the only ones reconnecting. Let's launch the task to reconnect - tokio::spawn(async move { - // Loop to connect and authenticate - let connection = loop { - // Sleep so we don't overload the server - sleep(Duration::from_secs(2)).await; - - // Create a connection - match inner.connect().await { - Ok(connection) => break connection, - Err(err) => { - error!("failed connection: {err}"); - } - } - }; - - // Update the connection and drop the guard - *connection_guard = OnceCell::from(connection); - }); +// Disconnects the current connection if an error was passed in and we're +// not already reconnecting. +macro_rules! disconnect_on_error { + ($self:expr, $res: expr) => { + match $res { + Ok(t) => Ok(t), + Err(e) => { + // If we are not currently reconnecting, take the current connection + if $self.inner.connecting_guard.available_permits() > 0 { + // If we ran into an error, take the current connection. + // This will only start reconnecting if we try to receive or send another message. + $self.inner.connection.write().await.take(); } - // We are trying to reconnect. Return an error. - return Err(Error::Connection("reconnection in progress".to_string())); + Err(e) } } - }}; + }; } impl Retry { + /// Get the underlying connection if it exists, otherwise try to reconnect. + /// + /// # Errors + /// - If we are in the middle of reconnecting + fn reconnect_if_needed(&self, possible_connection: Option) -> Result { + let Some(connection) = possible_connection else { + // If the connection is not initialized for one reason or another, try to reconnect + // Acquire the semaphore to ensure only one reconnection is happening at a time + if let Ok(permit) = self.inner.connecting_guard.try_acquire() { + // We were the first to try reconnecting, spawn a reconnection task + let inner = self.inner.clone(); + spawn(async move { + let mut connection = inner.connection.write().await; + + // Forever, + loop { + // Try to reconnect + match inner.connect().await { + Ok(new_connection) => { + // We successfully reconnected + *connection = Some(new_connection); + break; + } + Err(e) => { + // We failed to reconnect + // Sleep for 2 seconds and then try again + error!(error = %e, "failed to reconnect"); + sleep(Duration::from_secs(2)).await; + } + } + } + _ = permit; + }); + } + + // If we are in the middle of reconnecting, return an error + return Err(Error::Connection("connection in progress".to_string())); + }; + + Ok(connection) + } + + /// Get the connection if it exists, wait for a potential reconnection if it does not. + /// + /// # Errors + /// - If somebody else is already reconnecting + async fn get_connection(&self) -> Result { + let possible_connection = self.inner.connection.read().await; + + // TODO: figure out a potential way to remove this clone + self.reconnect_if_needed(possible_connection.clone()) + } + + /// Try to get the connection if it exists, otherwise try to reconnect. Does not block. + /// + /// # Errors + /// - If we are in the middle of reconnecting + fn try_get_connection(&self) -> Result { + let Ok(possible_connection) = self.inner.connection.try_read() else { + // Someone else is already reconnecting + return Err(Error::Connection("connection in progress".to_string())); + }; + + self.reconnect_if_needed(possible_connection.clone()) + } + /// Creates a new `Retry` connection from a `Config` /// Attempts to make an initial connection. /// This allows us to create elastic clients that always try to maintain a connection. @@ -197,6 +243,7 @@ impl Retry { use_local_authority, // TODO: parameterize batch params connection: Arc::default(), + connecting_guard: Semaphore::const_new(1), keypair, subscribed_topics, }), @@ -205,18 +252,9 @@ impl Retry { /// Returns only when the connection is fully initialized pub async fn ensure_initialized(&self) { - // In a loop, attempt to initialize the connection (if not yet) - while let Err(err) = self - .inner - .connection - .read() - .await - .get_or_try_init(|| self.inner.connect()) - .await - { + // Try to get the underlying connection + while let Err(err) = self.get_connection().await { error!("failed to initialize connection: {err}"); - - // Wait a bit so we don't overload the server sleep(Duration::from_secs(2)).await; } } @@ -229,22 +267,11 @@ impl Retry { /// - If we are reconnecting /// - If we are disconnected pub async fn send_message(&self, message: Message) -> Result<()> { - // Check if we're (probably) reconnecting or not - if let Ok(connection_guard) = self.inner.connection.try_read() { - // We're not reconnecting, try to send the message - // Initialize the connection if it does not yet exist - let out = connection_guard - .get_or_try_init(|| self.inner.connect()) - .await? - .send_message(message) - .await; - drop(connection_guard); - - try_with_reconnect!(self, out) - } else { - // We are reconnecting, return an error - Err(Error::Connection("reconnection in progress".to_string())) - } + // Try to get the underlying connection + let connection = self.try_get_connection()?; + + // Soft close the connection + disconnect_on_error!(self, connection.send_message(message).await) } /// Receives a message from the underlying fallible connection. Reconnection logic is here, @@ -254,19 +281,11 @@ impl Retry { /// - If we are in the middle of reconnecting /// - If the message receiving failed pub async fn receive_message(&self) -> Result { - // Wait for a reconnection before trying to receive - let connection_guard = self.inner.connection.read().await; - - // Initialize the connection if it does not yet exist - let out = connection_guard - .get_or_try_init(|| self.inner.connect()) - .await? - .recv_message() - .await; - drop(connection_guard); - - // If we failed to receive a message, kick off reconnection logic - try_with_reconnect!(self, out) + // Try to synchronously get the underlying connection + let connection = self.get_connection().await?; + + // Receive a message + disconnect_on_error!(self, connection.recv_message().await) } /// Soft close the connection, ensuring that all messages are sent. @@ -275,18 +294,10 @@ impl Retry { /// - If we are in the middle of reconnecting /// - If the connection is closed pub async fn soft_close(&self) -> Result<()> { - // Check if we're (probably) reconnecting or not - if let Ok(connection_guard) = self.inner.connection.try_read() { - // We're not reconnecting, try to send the message - // Initialize the connection if it does not yet exist - connection_guard - .get_or_try_init(|| self.inner.connect()) - .await? - .soft_close() - .await - } else { - // We are reconnecting, return an error - Err(Error::Connection("reconnection in progress".to_string())) - } + // Try to get the underlying connection + let connection = self.try_get_connection()?; + + // Soft close the connection + disconnect_on_error!(self, connection.soft_close().await) } }