Skip to content

Commit

Permalink
Optimize Lazy<ArcSwap<ConnectionState>> into ArcSwapOption<Connection…
Browse files Browse the repository at this point in the history
…State>
  • Loading branch information
Bajix committed Mar 23, 2023
1 parent df70047 commit 461673e
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 51 deletions.
4 changes: 2 additions & 2 deletions crates/derive-redis-swapplex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ pub fn derive_manager_context(input: proc_macro::TokenStream) -> proc_macro::Tok
}

fn state_cache(
) -> &'static std::thread::LocalKey<std::cell::RefCell<redis_swapplex::arc_swap::Cache<&'static redis_swapplex::arc_swap::ArcSwap<redis_swapplex::ConnectionState>, std::sync::Arc<redis_swapplex::ConnectionState>>>>
) -> &'static std::thread::LocalKey<std::cell::RefCell<redis_swapplex::arc_swap::Cache<&'static redis_swapplex::arc_swap::ArcSwapOption<redis_swapplex::ConnectionState>, Option<std::sync::Arc<redis_swapplex::ConnectionState>>>>>
{
thread_local! {
static STATE_CACHE:
std::cell::RefCell<redis_swapplex::arc_swap::Cache<&'static redis_swapplex::arc_swap::ArcSwap<redis_swapplex::ConnectionState>, std::sync::Arc<redis_swapplex::ConnectionState>>> = std::cell::RefCell::new(redis_swapplex::arc_swap::Cache::new(<#ident>::connection_manager().deref()));
std::cell::RefCell<redis_swapplex::arc_swap::Cache<&'static redis_swapplex::arc_swap::ArcSwapOption<redis_swapplex::ConnectionState>, Option<std::sync::Arc<redis_swapplex::ConnectionState>>>> = std::cell::RefCell::new(redis_swapplex::arc_swap::Cache::new(<#ident>::connection_manager().deref()));
}

&STATE_CACHE
Expand Down
102 changes: 53 additions & 49 deletions crates/redis-swapplex/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ mod bytes;

pub use bytes::IntoBytes;

use arc_swap::{ArcSwap, ArcSwapAny, Cache};
use arc_swap::{ArcSwapAny, ArcSwapOption, Cache};
pub use derive_redis_swapplex::ConnectionManagerContext;
use env_url::*;
use futures_util::{future::FutureExt, stream::unfold, Stream};
Expand Down Expand Up @@ -150,7 +150,6 @@ where

#[doc(hidden)]
pub enum ConnectionState {
Idle,
Connecting,
ClientError(ErrorKind),
ConnectionError(ErrorKind, SystemTime),
Expand All @@ -159,7 +158,7 @@ pub enum ConnectionState {

#[doc(hidden)]
pub struct ConnectionManager<T: ConnectionInfo> {
state: Lazy<ArcSwap<ConnectionState>>,
state: ArcSwapOption<ConnectionState>,
notify: Notify,
connection_info: Lazy<T>,
}
Expand All @@ -170,14 +169,14 @@ where
{
pub const fn new(connection_info: fn() -> T) -> ConnectionManager<T> {
ConnectionManager {
state: Lazy::new(|| ArcSwap::from(Arc::new(ConnectionState::Idle))),
state: ArcSwapOption::const_empty(),
notify: Notify::const_new(),
connection_info: Lazy::new(connection_info),
}
}

fn store_and_notify<S: Into<Arc<ConnectionState>>>(&self, state: S) {
self.state.store(state.into());
fn store_and_notify(&self, state: Option<Arc<ConnectionState>>) {
self.state.store(state);
self.notify.notify_waiters();
}

Expand All @@ -194,10 +193,10 @@ impl<T> Deref for ConnectionManager<T>
where
T: ConnectionInfo,
{
type Target = ArcSwapAny<Arc<ConnectionState>>;
type Target = ArcSwapAny<Option<Arc<ConnectionState>>>;

fn deref(&self) -> &Self::Target {
self.state.deref()
&self.state
}
}

Expand Down Expand Up @@ -234,10 +233,11 @@ pub trait ConnectionManagerContext: Send + Sync + 'static + Sized {
Self::connection_manager().get_db()
}

fn state_cache(
) -> &'static LocalKey<RefCell<Cache<&'static ArcSwap<ConnectionState>, Arc<ConnectionState>>>>;
fn state_cache() -> &'static LocalKey<
RefCell<Cache<&'static ArcSwapOption<ConnectionState>, Option<Arc<ConnectionState>>>>,
>;

fn with_state<T>(with_fn: fn(&ConnectionState) -> T) -> T {
fn with_state<T>(with_fn: fn(&Option<Arc<ConnectionState>>) -> T) -> T {
<Self as ConnectionManagerContext>::state_cache()
.with(|cache| with_fn(cache.borrow_mut().load()))
}
Expand All @@ -248,32 +248,32 @@ where
T: ConnectionManagerContext,
{
async fn get_multiplexed_connection() -> RedisResult<(MultiplexedConnection, ConnectionAddr)> {
let connection = T::with_state(|connection_state| match connection_state {
ConnectionState::Idle => {
let connection = T::with_state(|connection_state| match connection_state.as_deref() {
None => {
Self::establish_connection(None);
None
}
ConnectionState::Connecting => None,
ConnectionState::ClientError(kind) => Some(Err(RedisError::from((
Some(ConnectionState::Connecting) => None,
Some(ConnectionState::ClientError(kind)) => Some(Err(RedisError::from((
kind.to_owned(),
"Invalid Redis connection URL",
)))),
ConnectionState::ConnectionError(
Some(ConnectionState::ConnectionError(
ErrorKind::IoError | ErrorKind::ClusterDown | ErrorKind::BusyLoadingError,
time,
) if SystemTime::now()
)) if SystemTime::now()
.duration_since(*time)
.unwrap()
.gt(&Duration::from_millis(1500)) =>
{
Self::establish_connection(None);
None
}
ConnectionState::ConnectionError(kind, _) => Some(Err(RedisError::from((
Some(ConnectionState::ConnectionError(kind, _)) => Some(Err(RedisError::from((
kind.to_owned(),
"Unable to establish Redis connection",
)))),
ConnectionState::Connected(connection) => {
Some(ConnectionState::Connected(connection)) => {
let conn_addr = ConnectionAddr(addr_of!(*connection));
Some(Ok((connection.clone(), conn_addr)))
}
Expand All @@ -284,18 +284,18 @@ where
None => {
T::connection_manager().notify.notified().await;

T::with_state(|connection_state| match connection_state {
ConnectionState::Idle => unreachable!(),
ConnectionState::Connecting => unreachable!(),
ConnectionState::ClientError(kind) => Err(RedisError::from((
T::with_state(|connection_state| match connection_state.as_deref() {
None => unreachable!(),
Some(ConnectionState::Connecting) => unreachable!(),
Some(ConnectionState::ClientError(kind)) => Err(RedisError::from((
kind.to_owned(),
"Invalid Redis connection URL",
))),
ConnectionState::ConnectionError(kind, _timestamp) => Err(RedisError::from((
Some(ConnectionState::ConnectionError(kind, _timestamp)) => Err(RedisError::from((
kind.to_owned(),
"Unable to establish Redis connection",
))),
ConnectionState::Connected(connection) => {
Some(ConnectionState::Connected(connection)) => {
let conn_addr = ConnectionAddr(addr_of!(*connection));
Ok((connection.clone(), conn_addr))
}
Expand All @@ -307,25 +307,25 @@ where
fn establish_connection(conn_addr: Option<ConnectionAddr>) {
let state = T::connection_manager().state.load();

let should_connect = match state.as_ref() {
ConnectionState::Idle => true,
ConnectionState::Connecting => false,
let should_connect = match state.as_deref() {
None => true,
Some(ConnectionState::Connecting) => false,
// Never reconnect if there's been a client error; treat as poisoned
ConnectionState::ClientError(_) => false,
ConnectionState::ConnectionError(
Some(ConnectionState::ClientError(_)) => false,
Some(ConnectionState::ConnectionError(
ErrorKind::AuthenticationFailed | ErrorKind::InvalidClientConfig,
_,
) => false,
ConnectionState::ConnectionError(_, time)
)) => false,
Some(ConnectionState::ConnectionError(_, time))
if SystemTime::now()
.duration_since(*time)
.unwrap()
.gt(&Duration::from_millis(1500)) =>
{
true
}
ConnectionState::ConnectionError(_, _) => false,
ConnectionState::Connected(connection) => {
Some(ConnectionState::ConnectionError(_, _)) => false,
Some(ConnectionState::Connected(connection)) => {
if let Some(conn_addr) = conn_addr {
let current_addr = ConnectionAddr(addr_of!(*connection));

Expand All @@ -340,22 +340,26 @@ where
if should_connect {
let prev = T::connection_manager()
.state
.compare_and_swap(&state, Arc::new(ConnectionState::Connecting));
.compare_and_swap(&state, Some(Arc::new(ConnectionState::Connecting)));

if Arc::ptr_eq(&prev, &state) {
if match (prev.as_ref(), state.as_ref()) {
(None, None) => true,
(Some(prev), Some(state)) => Arc::ptr_eq(prev, state),
_ => false,
} {
tokio::task::spawn(async move {
match T::client() {
Ok(client) => match client.get_multiplexed_tokio_connection().await {
Ok(conn) => {
T::connection_manager().store_and_notify(ConnectionState::Connected(conn));
T::connection_manager()
.store_and_notify(Some(Arc::new(ConnectionState::Connected(conn))));
}
Err(err) => T::connection_manager().store_and_notify(
Err(err) => T::connection_manager().store_and_notify(Some(Arc::new(
ConnectionState::ConnectionError(err.kind(), SystemTime::now()),
),
))),
},
Err(err) => {
T::connection_manager().store_and_notify(ConnectionState::ClientError(err.kind()))
}
Err(err) => T::connection_manager()
.store_and_notify(Some(Arc::new(ConnectionState::ClientError(err.kind())))),
}
});
}
Expand All @@ -366,20 +370,20 @@ where
loop {
T::connection_manager().notify.notified().await;

let poll = T::with_state(|connection_state| match connection_state {
ConnectionState::ClientError(kind) => Poll::Ready(Err(RedisError::from((
let poll = T::with_state(|connection_state| match connection_state.as_deref() {
Some(ConnectionState::ClientError(kind)) => Poll::Ready(Err(RedisError::from((
kind.to_owned(),
"Invalid Redis connection URL",
)))),
ConnectionState::ConnectionError(
Some(ConnectionState::ConnectionError(
ErrorKind::BusyLoadingError | ErrorKind::ClusterDown | ErrorKind::IoError,
_,
) => Poll::Pending,
ConnectionState::ConnectionError(kind, _) => Poll::Ready(Err(RedisError::from((
)) => Poll::Pending,
Some(ConnectionState::ConnectionError(kind, _)) => Poll::Ready(Err(RedisError::from((
kind.to_owned(),
"Unable to establish Redis connection",
)))),
ConnectionState::Connected(_) => Poll::Ready(Ok(())),
Some(ConnectionState::Connected(_)) => Poll::Ready(Ok(())),
_ => Poll::Pending,
});

Expand Down Expand Up @@ -490,7 +494,7 @@ where
T: ConnectionManagerContext,
{
T::with_state(|connect_state| {
if let ConnectionState::Connected(connection) = connect_state {
if let Some(ConnectionState::Connected(connection)) = connect_state.as_deref() {
let conn_addr = ConnectionAddr(addr_of!(*connection));

Some(conn_addr)
Expand Down

0 comments on commit 461673e

Please sign in to comment.