diff --git a/crates/curp/src/client/fetch.rs b/crates/curp/src/client/fetch.rs index 1b7f4e187..e9af6d884 100644 --- a/crates/curp/src/client/fetch.rs +++ b/crates/curp/src/client/fetch.rs @@ -16,7 +16,8 @@ use super::cluster_state::ClusterState; use super::config::Config; /// Fetch cluster implementation -struct Fetch { +#[derive(Debug, Default, Clone)] +pub(crate) struct Fetch { /// The fetch config config: Config, } diff --git a/crates/curp/src/client/keep_alive.rs b/crates/curp/src/client/keep_alive.rs new file mode 100644 index 000000000..3bcc77a04 --- /dev/null +++ b/crates/curp/src/client/keep_alive.rs @@ -0,0 +1,104 @@ +use std::{ + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, + }, + time::Duration, +}; + +use event_listener::Event; +use futures::Future; +use parking_lot::RwLock; +use tokio::{sync::broadcast, task::JoinHandle}; +use tracing::{debug, info, warn}; + +use super::{cluster_state::ClusterState, state::State}; +use crate::rpc::{connect::ConnectApi, CurpError, Redirect}; + +/// Keep alive +#[derive(Clone, Debug)] +pub(crate) struct KeepAlive { + /// Heartbeat interval + heartbeat_interval: Duration, +} + +/// Handle of the keep alive task +#[derive(Debug)] +pub(crate) struct KeepAliveHandle { + /// Client id + client_id: Arc, + /// Update event of client id + update_event: Arc, + /// Task join handle + handle: JoinHandle<()>, +} + +impl KeepAliveHandle { + /// Wait for the client id + pub(crate) async fn wait_id_update(&self, current_id: u64) -> u64 { + loop { + let id = self.client_id.load(Ordering::Relaxed); + if current_id != id { + return id; + } + self.update_event.listen().await; + } + } +} + +impl KeepAlive { + /// Creates a new `KeepAlive` + pub(crate) fn new(heartbeat_interval: Duration) -> Self { + Self { heartbeat_interval } + } + + /// Streaming keep alive + pub(crate) fn spawn_keep_alive( + self, + cluster_state: Arc>, + ) -> KeepAliveHandle { + /// Sleep duration when keep alive failed + const FAIL_SLEEP_DURATION: Duration = Duration::from_secs(1); + let client_id = Arc::new(AtomicU64::new(0)); + let client_id_c = Arc::clone(&client_id); + let update_event = Arc::new(Event::new()); + let update_event_c = Arc::clone(&update_event); + let handle = tokio::spawn(async move { + loop { + let current_state = cluster_state.read().clone(); + let current_id = client_id.load(Ordering::Relaxed); + match self.keep_alive_with(current_id, current_state).await { + Ok(new_id) => { + client_id.store(new_id, Ordering::Relaxed); + let _ignore = update_event.notify(usize::MAX); + } + Err(e) => { + warn!("keep alive failed: {e:?}"); + // Sleep for some time, the cluster state should be updated in a while + tokio::time::sleep(FAIL_SLEEP_DURATION).await; + } + } + } + }); + + KeepAliveHandle { + client_id: client_id_c, + update_event: update_event_c, + handle, + } + } + + /// Keep alive with the given state and config + pub(crate) async fn keep_alive_with( + &self, + client_id: u64, + cluster_state: ClusterState, + ) -> Result { + cluster_state + .map_leader(|conn| async move { + conn.lease_keep_alive(client_id, self.heartbeat_interval) + .await + }) + .await + } +} diff --git a/crates/curp/src/client/mod.rs b/crates/curp/src/client/mod.rs index f759f18b9..6f054967a 100644 --- a/crates/curp/src/client/mod.rs +++ b/crates/curp/src/client/mod.rs @@ -8,12 +8,15 @@ mod metrics; /// Unary rpc client mod unary; +#[cfg(ignore)] /// Stream rpc client mod stream; +#[allow(unused)] /// Retry layer mod retry; +#[allow(unused)] /// State for clients mod state; @@ -29,6 +32,10 @@ mod fetch; /// Config of the client mod config; +#[allow(unused)] +/// Lease keep alive implementation +mod keep_alive; + /// Tests for client #[cfg(test)] mod tests; @@ -41,7 +48,6 @@ use async_trait::async_trait; use curp_external_api::cmd::Command; use futures::{stream::FuturesUnordered, StreamExt}; use parking_lot::RwLock; -use tokio::task::JoinHandle; #[cfg(not(madsim))] use tonic::transport::ClientTlsConfig; use tracing::{debug, warn}; @@ -50,6 +56,8 @@ use utils::ClientTlsConfig; use utils::{build_endpoint, config::ClientConfig}; use self::{ + fetch::Fetch, + keep_alive::KeepAlive, retry::{Retry, RetryConfig}, state::StateBuilder, unary::{Unary, UnaryConfig}, @@ -424,16 +432,6 @@ impl ClientBuilder { ) } - /// Spawn background tasks for the client - fn spawn_bg_tasks(&self, state: Arc) -> JoinHandle<()> { - let interval = *self.config.keep_alive_interval(); - tokio::spawn(async move { - let stream = stream::Streaming::new(state, stream::StreamingConfig::new(interval)); - stream.keep_heartbeat().await; - debug!("keep heartbeat task shutdown"); - }) - } - /// Build the client /// /// # Errors @@ -445,10 +443,14 @@ impl ClientBuilder { ) -> Result + Send + Sync + 'static, tonic::Status> { let state = Arc::new(self.init_state_builder().build()); + let keep_alive = KeepAlive::new(*self.config.keep_alive_interval()); + // TODO: build the fetch object + let fetch = Fetch::default(); let client = Retry::new( Unary::new(Arc::clone(&state), self.init_unary_config()), self.init_retry_config(), - Some(self.spawn_bg_tasks(Arc::clone(&state))), + keep_alive, + fetch, ); Ok(client) @@ -496,10 +498,14 @@ impl ClientBuilderWithBypass

{ .init_state_builder() .build_bypassed::

(self.local_server_id, self.local_server); let state = Arc::new(state); + let keep_alive = KeepAlive::new(*self.inner.config.keep_alive_interval()); + // TODO: build the fetch object + let fetch = Fetch::default(); let client = Retry::new( Unary::new(Arc::clone(&state), self.inner.init_unary_config()), self.inner.init_retry_config(), - Some(self.inner.spawn_bg_tasks(Arc::clone(&state))), + keep_alive, + fetch, ); Ok(client) diff --git a/crates/curp/src/client/retry.rs b/crates/curp/src/client/retry.rs index c67db6019..387e4dbdb 100644 --- a/crates/curp/src/client/retry.rs +++ b/crates/curp/src/client/retry.rs @@ -1,14 +1,19 @@ -use std::{ops::SubAssign, time::Duration}; +use std::{ops::SubAssign, sync::Arc, time::Duration}; use async_trait::async_trait; use futures::Future; -use tokio::task::JoinHandle; -use tracing::{info, warn}; - -use super::{ClientApi, LeaderStateUpdate, ProposeResponse, RepeatableClientApi}; +use parking_lot::RwLock; +use tracing::warn; + +use super::{ + cluster_state::ClusterState, + fetch::Fetch, + keep_alive::{KeepAlive, KeepAliveHandle}, + ClientApi, LeaderStateUpdate, ProposeResponse, RepeatableClientApi, +}; use crate::{ members::ServerId, - rpc::{ConfChange, CurpError, FetchClusterResponse, Member, ReadState, Redirect}, + rpc::{ConfChange, CurpError, FetchClusterResponse, Member, ReadState}, }; /// Backoff config @@ -95,6 +100,35 @@ impl Backoff { } } +/// The context of a retry +#[derive(Debug)] +pub(crate) struct Context { + /// The current client id + client_id: u64, + /// The current cluster state + cluster_state: ClusterState, +} + +impl Context { + /// Creates a new `Context` + pub(crate) fn new(client_id: u64, cluster_state: ClusterState) -> Self { + Self { + client_id, + cluster_state, + } + } + + /// Returns the current client id + pub(crate) fn client_id(&self) -> u64 { + self.client_id + } + + /// Returns the current client id + pub(crate) fn cluster_state(&self) -> ClusterState { + self.cluster_state.clone() + } +} + /// The retry client automatically retry the requests of the inner client api /// which raises the [`tonic::Status`] error #[derive(Debug)] @@ -102,18 +136,13 @@ pub(super) struct Retry { /// Inner client inner: Api, /// Retry config - config: RetryConfig, - /// Background task handle - bg_handle: Option>, -} - -impl Drop for Retry { - fn drop(&mut self) { - if let Some(handle) = self.bg_handle.as_ref() { - info!("stopping background task"); - handle.abort(); - } - } + retry_config: RetryConfig, + /// Cluster state + cluster_state: Arc>, + /// Keep alive client + keep_alive: KeepAliveHandle, + /// Fetch cluster object + fetch: Fetch, } impl Retry @@ -121,77 +150,49 @@ where Api: RepeatableClientApi + LeaderStateUpdate + Send + Sync + 'static, { /// Create a retry client - pub(super) fn new(inner: Api, config: RetryConfig, bg_handle: Option>) -> Self { + pub(super) fn new( + inner: Api, + retry_config: RetryConfig, + keep_alive: KeepAlive, + fetch: Fetch, + ) -> Self { + // TODO: build state from parameters + let cluster_state = Arc::new(RwLock::default()); + let keep_alive_handle = keep_alive.spawn_keep_alive(Arc::clone(&cluster_state)); Self { inner, - config, - bg_handle, + retry_config, + cluster_state, + keep_alive: keep_alive_handle, + fetch, } } /// Takes a function f and run retry. - async fn retry<'a, R, F>(&'a self, f: impl Fn(&'a Api) -> F) -> Result + async fn retry<'a, R, F>( + &'a self, + f: impl Fn(&'a Api, Context) -> F, + ) -> Result where F: Future>, { - let mut backoff = self.config.init_backoff(); + let mut backoff = self.retry_config.init_backoff(); let mut last_err = None; + let client_id = self.keep_alive.wait_id_update(0).await; while let Some(delay) = backoff.next_delay() { - let err = match f(&self.inner).await { + let cluster_state = self.cluster_state.read().clone(); + let context = Context::new(client_id, cluster_state.clone()); + let result = tokio::select! { + result = f(&self.inner, context) => result, + _ = self.keep_alive.wait_id_update(client_id) => { + return Err(CurpError::expired_client_id().into()); + }, + }; + let err = match result { Ok(res) => return Ok(res), Err(err) => err, }; - - match err { - // some errors that should not retry - CurpError::Duplicated(()) - | CurpError::ShuttingDown(()) - | CurpError::InvalidConfig(()) - | CurpError::NodeNotExists(()) - | CurpError::NodeAlreadyExists(()) - | CurpError::LearnerNotCatchUp(()) => { - return Err(tonic::Status::from(err)); - } - - // some errors that could have a retry - CurpError::ExpiredClientId(()) - | CurpError::KeyConflict(()) - | CurpError::Internal(_) - | CurpError::LeaderTransfer(_) => {} - - // update leader state if we got a rpc transport error - CurpError::RpcTransport(()) => { - if let Err(e) = self.inner.fetch_leader_id(true).await { - warn!("fetch leader failed, error {e:?}"); - } - } - - // update the cluster state if got WrongClusterVersion - CurpError::WrongClusterVersion(()) => { - // the inner client should automatically update cluster state when fetch_cluster - if let Err(e) = self.inner.fetch_cluster(true).await { - warn!("fetch cluster failed, error {e:?}"); - } - } - - // update the leader state if got Redirect - CurpError::Redirect(Redirect { - ref leader_id, - term, - }) => { - let _ig = self - .inner - .update_leader(leader_id.as_ref().map(Into::into), term) - .await; - } - - // update the cluster state if got Zombie - CurpError::Zombie(()) => { - if let Err(e) = self.inner.fetch_cluster(true).await { - warn!("fetch cluster failed, error {e:?}"); - } - } - } + self.handle_err(&err, cluster_state).await?; #[cfg(feature = "client-metrics")] super::metrics::get().client_retry_count.add(1, &[]); @@ -209,6 +210,43 @@ where last_err.unwrap_or_else(|| unreachable!("last error must be set")) ))) } + + /// Handles errors before another retry + async fn handle_err( + &self, + err: &CurpError, + cluster_state: ClusterState, + ) -> Result<(), tonic::Status> { + match *err { + // some errors that should not retry + CurpError::Duplicated(()) + | CurpError::ShuttingDown(()) + | CurpError::InvalidConfig(()) + | CurpError::NodeNotExists(()) + | CurpError::NodeAlreadyExists(()) + | CurpError::LearnerNotCatchUp(()) => { + return Err(tonic::Status::from(err.clone())); + } + + // some errors that could have a retry + CurpError::ExpiredClientId(()) + | CurpError::KeyConflict(()) + | CurpError::Internal(_) + | CurpError::LeaderTransfer(_) => {} + + // Some error that needs to update cluster state + CurpError::RpcTransport(()) + | CurpError::WrongClusterVersion(()) + | CurpError::Redirect(_) // FIXME: The redirect error needs to include full cluster state + | CurpError::Zombie(()) => { + let new_cluster_state = self.fetch.fetch_cluster(cluster_state).await?; + // TODO: Prevent concurrent updating cluster state + *self.cluster_state.write() = new_cluster_state; + } + } + + Ok(()) + } } #[async_trait] @@ -230,7 +268,7 @@ where token: Option<&String>, use_fast_path: bool, ) -> Result, tonic::Status> { - self.retry::<_, _>(|client| async move { + self.retry::<_, _>(|client, _ctx| async move { let propose_id = self.inner.gen_propose_id().await?; RepeatableClientApi::propose(client, *propose_id, cmd, token, use_fast_path).await }) @@ -242,7 +280,7 @@ where &self, changes: Vec, ) -> Result, tonic::Status> { - self.retry::<_, _>(|client| { + self.retry::<_, _>(|client, _ctx| { let changes_c = changes.clone(); async move { let propose_id = self.inner.gen_propose_id().await?; @@ -254,7 +292,7 @@ where /// Send propose to shutdown cluster async fn propose_shutdown(&self) -> Result<(), tonic::Status> { - self.retry::<_, _>(|client| async move { + self.retry::<_, _>(|client, _ctx| async move { let propose_id = self.inner.gen_propose_id().await?; RepeatableClientApi::propose_shutdown(client, *propose_id).await }) @@ -268,7 +306,7 @@ where node_name: String, node_client_urls: Vec, ) -> Result<(), Self::Error> { - self.retry::<_, _>(|client| { + self.retry::<_, _>(|client, _ctx| { let name_c = node_name.clone(); let node_client_urls_c = node_client_urls.clone(); async move { @@ -288,13 +326,13 @@ where /// Send move leader request async fn move_leader(&self, node_id: u64) -> Result<(), Self::Error> { - self.retry::<_, _>(|client| client.move_leader(node_id)) + self.retry::<_, _>(|client, _ctx| client.move_leader(node_id)) .await } /// Send fetch read state from leader async fn fetch_read_state(&self, cmd: &Self::Cmd) -> Result { - self.retry::<_, _>(|client| client.fetch_read_state(cmd)) + self.retry::<_, _>(|client, _ctx| client.fetch_read_state(cmd)) .await } @@ -306,7 +344,7 @@ where &self, linearizable: bool, ) -> Result { - self.retry::<_, _>(|client| client.fetch_cluster(linearizable)) + self.retry::<_, _>(|client, _ctx| client.fetch_cluster(linearizable)) .await } } diff --git a/crates/curp/src/client/tests.rs b/crates/curp/src/client/tests.rs index 39c8b88bc..1ec9b7971 100644 --- a/crates/curp/src/client/tests.rs +++ b/crates/curp/src/client/tests.rs @@ -1,36 +1,32 @@ use std::{ collections::HashMap, - sync::{atomic::AtomicU64, Arc, Mutex}, + sync::{Arc, Mutex}, time::{Duration, Instant}, }; use curp_test_utils::test_cmd::{LogIndexResult, TestCommand, TestCommandResult}; -use futures::{future::BoxFuture, Stream}; #[cfg(not(madsim))] use tonic::transport::ClientTlsConfig; -use tonic::Status; use tracing_test::traced_test; #[cfg(madsim)] use utils::ClientTlsConfig; use super::{ state::State, - stream::{Streaming, StreamingConfig}, unary::{Unary, UnaryConfig}, }; use crate::{ client::{ + fetch::Fetch, + keep_alive::KeepAlive, retry::{Retry, RetryConfig}, ClientApi, }, members::ServerId, rpc::{ connect::{ConnectApi, MockConnectApi}, - CurpError, FetchClusterRequest, FetchClusterResponse, FetchReadStateRequest, - FetchReadStateResponse, Member, MoveLeaderRequest, MoveLeaderResponse, OpResponse, - ProposeConfChangeRequest, ProposeConfChangeResponse, ProposeRequest, ProposeResponse, - PublishRequest, PublishResponse, ReadIndexResponse, RecordRequest, RecordResponse, - ResponseOp, ShutdownRequest, ShutdownResponse, SyncedResponse, + CurpError, FetchClusterResponse, Member, OpResponse, ProposeResponse, ReadIndexResponse, + RecordResponse, ResponseOp, SyncedResponse, }, }; @@ -472,7 +468,8 @@ async fn test_retry_propose_return_no_retry_error() { let retry = Retry::new( unary, RetryConfig::new_fixed(Duration::from_millis(100), 5), - None, + KeepAlive::new(Duration::from_secs(1)), + Fetch::default(), ); let err = retry .propose(&TestCommand::new_put(vec![1], 1), None, false) @@ -522,7 +519,8 @@ async fn test_retry_propose_return_retry_error() { let retry = Retry::new( unary, RetryConfig::new_fixed(Duration::from_millis(10), 5), - None, + KeepAlive::new(Duration::from_secs(1)), + Fetch::default(), ); let err = retry .propose(&TestCommand::new_put(vec![1], 1), None, false) @@ -595,242 +593,263 @@ async fn test_read_index_fail() { assert!(res.is_err()); } -// Tests for stream client - -struct MockedStreamConnectApi { - id: ServerId, - lease_keep_alive_handle: - Box) -> BoxFuture<'static, CurpError> + Send + Sync + 'static>, -} - -#[async_trait::async_trait] -impl ConnectApi for MockedStreamConnectApi { - /// Get server id - fn id(&self) -> ServerId { - self.id - } - - /// Update server addresses, the new addresses will override the old ones - async fn update_addrs(&self, _addrs: Vec) -> Result<(), tonic::transport::Error> { - Ok(()) - } - - /// Send `ProposeRequest` - async fn propose_stream( - &self, - _request: ProposeRequest, - _token: Option, - _timeout: Duration, - ) -> Result> + Send>>, CurpError> - { - unreachable!("please use MockedConnectApi") - } - - /// Send `RecordRequest` - async fn record( - &self, - _request: RecordRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") - } +// TODO: rewrite these tests +#[cfg(ignore)] +mod test_stream { + use super::*; - /// Send `ReadIndexRequest` - async fn read_index( - &self, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") - } + // Tests for stream client - /// Send `ProposeConfChange` - async fn propose_conf_change( - &self, - _request: ProposeConfChangeRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + struct MockedStreamConnectApi { + id: ServerId, + lease_keep_alive_handle: + Box BoxFuture<'static, Result> + Send + Sync + 'static>, } - /// Send `PublishRequest` - async fn publish( - &self, - _request: PublishRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + #[async_trait::async_trait] + impl ConnectApi for MockedStreamConnectApi { + /// Get server id + fn id(&self) -> ServerId { + self.id + } + + /// Update server addresses, the new addresses will override the old ones + async fn update_addrs(&self, _addrs: Vec) -> Result<(), tonic::transport::Error> { + Ok(()) + } + + /// Send `ProposeRequest` + async fn propose_stream( + &self, + _request: ProposeRequest, + _token: Option, + _timeout: Duration, + ) -> Result< + tonic::Response> + Send>>, + CurpError, + > { + unreachable!("please use MockedConnectApi") + } + + /// Send `RecordRequest` + async fn record( + &self, + _request: RecordRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `ReadIndexRequest` + async fn read_index( + &self, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `ProposeConfChange` + async fn propose_conf_change( + &self, + _request: ProposeConfChangeRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `PublishRequest` + async fn publish( + &self, + _request: PublishRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `ShutdownRequest` + async fn shutdown( + &self, + _request: ShutdownRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `FetchClusterRequest` + async fn fetch_cluster( + &self, + _request: FetchClusterRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `FetchReadStateRequest` + async fn fetch_read_state( + &self, + _request: FetchReadStateRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Send `MoveLeaderRequest` + async fn move_leader( + &self, + _request: MoveLeaderRequest, + _timeout: Duration, + ) -> Result, CurpError> { + unreachable!("please use MockedConnectApi") + } + + /// Keep send lease keep alive to server and mutate the client id + async fn lease_keep_alive( + &self, + client_id: u64, + _interval: Duration, + ) -> Result { + (self.lease_keep_alive_handle)(client_id).await + } } - /// Send `ShutdownRequest` - async fn shutdown( - &self, - _request: ShutdownRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + /// Create mocked stream connects + /// + /// The leader is S0 + #[allow(trivial_casts)] // cannot be inferred + fn init_mocked_stream_connects( + size: usize, + leader_idx: usize, + leader_term: u64, + keep_alive_handle: impl Fn(u64) -> BoxFuture<'static, Result> + + Send + + Sync + + 'static, + ) -> HashMap> { + let mut keep_alive_handle = Some(keep_alive_handle); + let redirect_handle = move |_id| { + Box::pin(async move { + Err(CurpError::redirect( + Some(leader_idx as ServerId), + leader_term, + )) + }) as BoxFuture<'static, Result> + }; + (0..size) + .map(|id| MockedStreamConnectApi { + id: id as ServerId, + lease_keep_alive_handle: if id == leader_idx { + Box::new(keep_alive_handle.take().unwrap()) + } else { + Box::new(redirect_handle) + }, + }) + .enumerate() + .map(|(id, api)| (id as ServerId, Arc::new(api) as Arc)) + .collect() } - /// Send `FetchClusterRequest` - async fn fetch_cluster( - &self, - _request: FetchClusterRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + /// Create stream client for test + fn init_stream_client( + connects: HashMap>, + local_server: Option, + leader: Option, + term: u64, + cluster_version: u64, + ) -> Streaming { + let state = State::new_arc(connects, local_server, leader, term, cluster_version, None); + Streaming::new(state, StreamingConfig::new(Duration::from_secs(1))) } - /// Send `FetchReadStateRequest` - async fn fetch_read_state( - &self, - _request: FetchReadStateRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + #[traced_test] + #[tokio::test] + async fn test_stream_client_keep_alive_works() { + let connects = init_mocked_stream_connects(5, 0, 1, move |client_id| { + Box::pin(async move { + client_id + .compare_exchange( + 1, + 10, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) + .unwrap(); + tokio::time::sleep(Duration::from_secs(30)).await; + unreachable!("test timeout") + }) + }); + let stream = init_stream_client(connects, None, Some(0), 1, 1); + tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) + .await + .unwrap_err(); + assert_eq!(stream.state.client_id(), 10); } - /// Send `MoveLeaderRequest` - async fn move_leader( - &self, - _request: MoveLeaderRequest, - _timeout: Duration, - ) -> Result, CurpError> { - unreachable!("please use MockedConnectApi") + #[traced_test] + #[tokio::test] + async fn test_stream_client_keep_alive_on_redirect() { + let connects = init_mocked_stream_connects(5, 0, 2, move |client_id| { + Box::pin(async move { + client_id + .compare_exchange( + 1, + 10, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ) + .unwrap(); + tokio::time::sleep(Duration::from_secs(30)).await; + unreachable!("test timeout") + }) + }); + let stream = init_stream_client(connects, None, Some(1), 1, 1); + tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) + .await + .unwrap_err(); + assert_eq!(stream.state.client_id(), 10); } - /// Keep send lease keep alive to server and mutate the client id - async fn lease_keep_alive(&self, client_id: Arc, _interval: Duration) -> CurpError { - (self.lease_keep_alive_handle)(Arc::clone(&client_id)).await + #[traced_test] + #[tokio::test] + async fn test_stream_client_keep_alive_hang_up_on_bypassed() { + let connects = init_mocked_stream_connects(5, 0, 1, |_client_id| { + Box::pin( + async move { panic!("should not invoke lease_keep_alive in bypassed connection") }, + ) + }); + let stream = init_stream_client(connects, Some(0), Some(0), 1, 1); + tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) + .await + .unwrap_err(); + assert_ne!(stream.state.client_id(), 0); } -} - -/// Create mocked stream connects -/// -/// The leader is S0 -#[allow(trivial_casts)] // cannot be inferred -fn init_mocked_stream_connects( - size: usize, - leader_idx: usize, - leader_term: u64, - keep_alive_handle: impl Fn(Arc) -> BoxFuture<'static, CurpError> + Send + Sync + 'static, -) -> HashMap> { - let mut keep_alive_handle = Some(keep_alive_handle); - let redirect_handle = move |_id| { - Box::pin(async move { CurpError::redirect(Some(leader_idx as ServerId), leader_term) }) - as BoxFuture<'static, CurpError> - }; - (0..size) - .map(|id| MockedStreamConnectApi { - id: id as ServerId, - lease_keep_alive_handle: if id == leader_idx { - Box::new(keep_alive_handle.take().unwrap()) - } else { - Box::new(redirect_handle) - }, - }) - .enumerate() - .map(|(id, api)| (id as ServerId, Arc::new(api) as Arc)) - .collect() -} - -/// Create stream client for test -fn init_stream_client( - connects: HashMap>, - local_server: Option, - leader: Option, - term: u64, - cluster_version: u64, -) -> Streaming { - let state = State::new_arc(connects, local_server, leader, term, cluster_version, None); - Streaming::new(state, StreamingConfig::new(Duration::from_secs(1))) -} - -#[traced_test] -#[tokio::test] -async fn test_stream_client_keep_alive_works() { - let connects = init_mocked_stream_connects(5, 0, 1, move |client_id| { - Box::pin(async move { - client_id - .compare_exchange( - 1, - 10, - std::sync::atomic::Ordering::Relaxed, - std::sync::atomic::Ordering::Relaxed, - ) - .unwrap(); - tokio::time::sleep(Duration::from_secs(30)).await; - unreachable!("test timeout") - }) - }); - let stream = init_stream_client(connects, None, Some(0), 1, 1); - tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) - .await - .unwrap_err(); - assert_eq!(stream.state.client_id(), 10); -} - -#[traced_test] -#[tokio::test] -async fn test_stream_client_keep_alive_on_redirect() { - let connects = init_mocked_stream_connects(5, 0, 2, move |client_id| { - Box::pin(async move { - client_id - .compare_exchange( - 1, - 10, - std::sync::atomic::Ordering::Relaxed, - std::sync::atomic::Ordering::Relaxed, - ) - .unwrap(); - tokio::time::sleep(Duration::from_secs(30)).await; - unreachable!("test timeout") - }) - }); - let stream = init_stream_client(connects, None, Some(1), 1, 1); - tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) - .await - .unwrap_err(); - assert_eq!(stream.state.client_id(), 10); -} - -#[traced_test] -#[tokio::test] -async fn test_stream_client_keep_alive_hang_up_on_bypassed() { - let connects = init_mocked_stream_connects(5, 0, 1, |_client_id| { - Box::pin(async move { panic!("should not invoke lease_keep_alive in bypassed connection") }) - }); - let stream = init_stream_client(connects, Some(0), Some(0), 1, 1); - tokio::time::timeout(Duration::from_millis(100), stream.keep_heartbeat()) - .await - .unwrap_err(); - assert_ne!(stream.state.client_id(), 0); -} -#[traced_test] -#[tokio::test] -#[allow(clippy::ignored_unit_patterns)] // tokio select internal triggered -async fn test_stream_client_keep_alive_resume_on_leadership_changed() { - let connects = init_mocked_stream_connects(5, 1, 2, move |client_id| { - Box::pin(async move { - // generated a client id for bypassed client - assert_ne!(client_id.load(std::sync::atomic::Ordering::Relaxed), 0); - client_id.store(10, std::sync::atomic::Ordering::Relaxed); - tokio::time::sleep(Duration::from_secs(30)).await; - unreachable!("test timeout") - }) - }); - let stream = init_stream_client(connects, Some(0), Some(0), 1, 1); - let update_leader = async { - // wait for stream to hang up - tokio::time::sleep(Duration::from_millis(100)).await; - // check the local id - assert_ne!(stream.state.client_id(), 0); - stream.state.check_and_update_leader(Some(1), 2).await; - // wait for stream to resume - tokio::time::sleep(Duration::from_millis(100)).await; - }; - tokio::select! { - _ = stream.keep_heartbeat() => {}, - _ = update_leader => {} + #[traced_test] + #[tokio::test] + #[allow(clippy::ignored_unit_patterns)] // tokio select internal triggered + async fn test_stream_client_keep_alive_resume_on_leadership_changed() { + let connects = init_mocked_stream_connects(5, 1, 2, move |client_id| { + Box::pin(async move { + // generated a client id for bypassed client + assert_ne!(client_id.load(std::sync::atomic::Ordering::Relaxed), 0); + client_id.store(10, std::sync::atomic::Ordering::Relaxed); + tokio::time::sleep(Duration::from_secs(30)).await; + unreachable!("test timeout") + }) + }); + let stream = init_stream_client(connects, Some(0), Some(0), 1, 1); + let update_leader = async { + // wait for stream to hang up + tokio::time::sleep(Duration::from_millis(100)).await; + // check the local id + assert_ne!(stream.state.client_id(), 0); + stream.state.check_and_update_leader(Some(1), 2).await; + // wait for stream to resume + tokio::time::sleep(Duration::from_millis(100)).await; + }; + tokio::select! { + _ = stream.keep_heartbeat() => {}, + _ = update_leader => {} + } + assert_eq!(stream.state.client_id(), 10); } - assert_eq!(stream.state.client_id(), 10); } diff --git a/crates/curp/src/rpc/connect.rs b/crates/curp/src/rpc/connect.rs index c62b37d31..68c8ae18d 100644 --- a/crates/curp/src/rpc/connect.rs +++ b/crates/curp/src/rpc/connect.rs @@ -2,7 +2,7 @@ use std::{ collections::{HashMap, HashSet}, fmt::{Debug, Formatter}, ops::Deref, - sync::{atomic::AtomicU64, Arc}, + sync::Arc, time::Duration, }; @@ -223,7 +223,7 @@ pub(crate) trait ConnectApi: Send + Sync + 'static { ) -> Result, CurpError>; /// Keep send lease keep alive to server and mutate the client id - async fn lease_keep_alive(&self, client_id: Arc, interval: Duration) -> CurpError; + async fn lease_keep_alive(&self, client_id: u64, interval: Duration) -> Result; } /// Inner Connect interface among different servers @@ -513,22 +513,18 @@ impl ConnectApi for Connect> { with_timeout!(timeout, client.move_leader(req)).map_err(Into::into) } - /// Keep send lease keep alive to server and mutate the client id - async fn lease_keep_alive(&self, client_id: Arc, interval: Duration) -> CurpError { + /// Keep send lease keep alive to server with the current client id + async fn lease_keep_alive(&self, client_id: u64, interval: Duration) -> Result { let mut client = self.rpc_connect.clone(); - loop { - let stream = heartbeat_stream( - client_id.load(std::sync::atomic::Ordering::Relaxed), - interval, - ); - let new_id = match client.lease_keep_alive(stream).await { - Err(err) => return err.into(), - Ok(res) => res.into_inner().client_id, - }; - // The only place to update the client id for follower - info!("client_id update to {new_id}"); - client_id.store(new_id, std::sync::atomic::Ordering::Relaxed); - } + let stream = heartbeat_stream(client_id, interval); + let new_id = client + .lease_keep_alive(stream) + .await? + .into_inner() + .client_id; + // The only place to update the client id for follower + info!("client_id update to {new_id}"); + Ok(new_id) } } @@ -812,7 +808,11 @@ where } /// Keep send lease keep alive to server and mutate the client id - async fn lease_keep_alive(&self, _client_id: Arc, _interval: Duration) -> CurpError { + async fn lease_keep_alive( + &self, + _client_id: u64, + _interval: Duration, + ) -> Result { unreachable!("cannot invoke lease_keep_alive in bypassed connect") } } diff --git a/crates/curp/src/rpc/reconnect.rs b/crates/curp/src/rpc/reconnect.rs index e392db38a..f92844234 100644 --- a/crates/curp/src/rpc/reconnect.rs +++ b/crates/curp/src/rpc/reconnect.rs @@ -1,7 +1,4 @@ -use std::{ - sync::{atomic::AtomicU64, Arc}, - time::Duration, -}; +use std::time::Duration; use async_trait::async_trait; use event_listener::Event; @@ -48,7 +45,6 @@ impl Reconnect { // Cancel the leader keep alive loop task because it hold a read lock let _cancel = self.event.notify(1); let _ignore = self.connect.write().await.replace(new_connect); - // After connection is updated, notify to start the keep alive loop let _continue = self.event.notify(1); } @@ -178,21 +174,15 @@ impl ConnectApi for Reconnect { } /// Keep send lease keep alive to server and mutate the client id - async fn lease_keep_alive(&self, client_id: Arc, interval: Duration) -> CurpError { - loop { - let connect = self.connect.read().await; - let connect_ref = connect.as_ref().unwrap(); - tokio::select! { - err = connect_ref.lease_keep_alive(Arc::clone(&client_id), interval) => { - return err; - } - _empty = self.event.listen() => {}, - } - // Creates the listener before dropping the read lock. - // This prevents us from losting the event. - let listener = self.event.listen(); - drop(connect); - let _connection_updated = listener.await; - } + async fn lease_keep_alive(&self, client_id: u64, interval: Duration) -> Result { + let connect = self.connect.read().await; + let connect_ref = connect.as_ref().unwrap(); + let result = tokio::select! { + result = connect_ref.lease_keep_alive(client_id, interval) => result, + _empty = self.event.listen() => Err(CurpError::RpcTransport(())), + }; + // Wait for connection update + self.event.listen().await; + result } }