diff --git a/curp/src/client.rs b/curp/src/client.rs index 6f5640ab0..9b5080b1b 100644 --- a/curp/src/client.rs +++ b/curp/src/client.rs @@ -72,7 +72,7 @@ impl Builder { let Some(config) = self.config else { return Err(ClientBuildError::invalid_arguments("timeout is required")); }; - let connects = rpc::connect(all_members).await?.collect(); + let connects = rpc::connects(all_members).await?.collect(); let client = Client:: { local_server_id: self.local_server_id, state: RwLock::new(State::new(leader_id, 0)), @@ -136,7 +136,7 @@ impl Builder { state: RwLock::new(State::new(res.leader_id, res.term)), config, cluster_version: AtomicU64::new(res.cluster_version), - connects: rpc::connect(res.into_members_addrs()).await?.collect(), + connects: rpc::connects(res.into_members_addrs()).await?.collect(), phantom: PhantomData, }; Ok(client) @@ -448,7 +448,7 @@ where .map(|m| (m.id, m.addrs)) .collect::>>(); self.connects.clear(); - for (id, connect) in rpc::connect(member_addrs) + for (id, connect) in rpc::connects(member_addrs) .await .map_err(|e| ClientError::InternalError(format!("connect to cluster failed: {e}")))? { diff --git a/curp/src/members.rs b/curp/src/members.rs index 7c87cf61c..00dd1ad77 100644 --- a/curp/src/members.rs +++ b/curp/src/members.rs @@ -376,7 +376,7 @@ pub async fn get_cluster_info_from_remote( timeout: Duration, ) -> Option { let peers = init_cluster_info.peers_addrs(); - let connects = rpc::connect(peers) + let connects = rpc::connects(peers) .await .ok()? .map(|pair| pair.1) diff --git a/curp/src/rpc/connect.rs b/curp/src/rpc/connect.rs index 625928901..4e676a04d 100644 --- a/curp/src/rpc/connect.rs +++ b/curp/src/rpc/connect.rs @@ -11,7 +11,7 @@ use async_trait::async_trait; use bytes::BytesMut; use clippy_utilities::NumericCast; use engine::SnapshotApi; -use futures::Stream; +use futures::{stream::FuturesUnordered, Stream}; #[cfg(test)] use mockall::automock; use tokio::sync::Mutex; @@ -43,63 +43,100 @@ const SNAPSHOT_CHUNK_SIZE: u64 = 64 * 1024; /// The default buffer size for rpc connection const DEFAULT_BUFFER_SIZE: usize = 1024; -/// Connect implementation -macro_rules! connect_impl { - ($client:ty, $api:path, $members:ident) => { - futures::future::join_all( - $members - .into_iter() - .map(|(id, mut addrs)| single_connect_impl!($client, $api, id, addrs)), - ) - .await - .into_iter() - .collect::, tonic::transport::Error>>() - .map(IntoIterator::into_iter) - }; +/// For protocol client +trait FromTonicChannel { + /// New from channel + fn from_channel(channel: Channel) -> Self; } -/// Single Connect implementation -macro_rules! single_connect_impl { - ($client:ty, $api:path, $id:ident, $addrs:ident) => { - async move { - let (channel, change_tx) = Channel::balance_channel(DEFAULT_BUFFER_SIZE); - // Addrs must start with "http" to communicate with the server - for addr in &mut $addrs { - if !addr.starts_with("http://") { - addr.insert_str(0, "http://"); - } - let endpoint = Endpoint::from_shared(addr.clone())?; - let _ig = change_tx - .send(tower::discover::Change::Insert(addr.clone(), endpoint)) - .await; - } - let client = <$client>::new(channel); - let connect: Arc = Arc::new(Connect { - $id, - rpc_connect: client, - change_tx, - addrs: Mutex::new($addrs), - }); - Ok(($id, connect)) +impl FromTonicChannel for ProtocolClient { + fn from_channel(channel: Channel) -> Self { + ProtocolClient::new(channel) + } +} + +impl FromTonicChannel for InnerProtocolClient { + fn from_channel(channel: Channel) -> Self { + InnerProtocolClient::new(channel) + } +} + +/// Connect to a server +async fn connect_to( + id: ServerId, + mut addrs: Vec, +) -> Result>, tonic::transport::Error> { + let (channel, change_tx) = Channel::balance_channel(DEFAULT_BUFFER_SIZE); + for addr in &mut addrs { + // TODO: support TLS + if !addr.starts_with("http://") { + addr.insert_str(0, "http://"); } - }; + let endpoint = Endpoint::from_shared(addr.clone())?; + let _ig = change_tx + .send(tower::discover::Change::Insert(addr.clone(), endpoint)) + .await; + } + let client = Client::from_channel(channel); + let connect = Arc::new(Connect { + id, + rpc_connect: client, + change_tx, + addrs: Mutex::new(addrs), + }); + Ok(connect) } -/// Convert a vec of addr string to a vec of `Connect` -/// # Errors -/// Return error if any of the address format is invalid +/// Connect to a map of members +async fn connect_all( + members: HashMap>, +) -> Result>)>, tonic::transport::Error> { + let conns_to: FuturesUnordered<_> = + members + .into_iter() + .map(|(id, addrs)| async move { + connect_to::(id, addrs).await.map(|conn| (id, conn)) + }) + .collect(); + futures::StreamExt::collect::>(conns_to) + .await + .into_iter() + .collect::, _>>() +} + +/// A wrapper of [`connect_to`], hide the detailed [`Connect`] +#[allow(unused)] // TODO: will be used in curp client refactor pub(crate) async fn connect( + id: ServerId, + addrs: Vec, +) -> Result, tonic::transport::Error> { + let conn = connect_to::>(id, addrs).await?; + Ok(conn) +} + +/// Wrapper of [`connect_all`], hide the details of [`Connect`] +pub(crate) async fn connects( members: HashMap>, ) -> Result)>, tonic::transport::Error> { - connect_impl!(ProtocolClient, ConnectApi, members) + // It seems that casting high-rank types cannot be inferred, so we allow trivial_casts to cast manually + #[allow(trivial_casts)] + #[allow(clippy::as_conversions)] + let conns = connect_all(members) + .await? + .into_iter() + .map(|(id, conn)| (id, conn as Arc)); + Ok(conns) } -/// Convert a vec of addr string to a vec of `InnerConnect` -pub(crate) async fn inner_connect( +/// Wrapper of [`connect_all`], hide the details of [`Connect`] +pub(crate) async fn inner_connects( members: HashMap>, ) -> Result, tonic::transport::Error> { - connect_impl!(InnerProtocolClient, InnerConnectApi, members) - .map(|iter| iter.map(|(id, connect)| (id, InnerConnectApiWrapper::new_from_arc(connect)))) + let conns = connect_all(members) + .await? + .into_iter() + .map(|(id, conn)| (id, InnerConnectApiWrapper::new_from_arc(conn))); + Ok(conns) } /// Connect interface between server and clients @@ -209,11 +246,10 @@ impl InnerConnectApiWrapper { /// Create a new `InnerConnectApiWrapper` from id and addrs pub(crate) async fn connect( id: ServerId, - mut addrs: Vec, + addrs: Vec, ) -> Result { - let (_id, connect) = - single_connect_impl!(InnerProtocolClient, InnerConnectApi, id, addrs).await?; - Ok(InnerConnectApiWrapper::new_from_arc(connect)) + let conn = connect_to::>(id, addrs).await?; + Ok(InnerConnectApiWrapper::new_from_arc(conn)) } } diff --git a/curp/src/rpc/mod.rs b/curp/src/rpc/mod.rs index 041a91a07..e1730e3a4 100644 --- a/curp/src/rpc/mod.rs +++ b/curp/src/rpc/mod.rs @@ -39,7 +39,7 @@ use crate::{ /// Rpc connect pub(crate) mod connect; -pub(crate) use connect::{connect, inner_connect}; +pub(crate) use connect::{connects, inner_connects}; // Skip for generated code #[allow( diff --git a/curp/src/server/curp_node.rs b/curp/src/server/curp_node.rs index 29f9878af..a35dda6f8 100644 --- a/curp/src/server/curp_node.rs +++ b/curp/src/server/curp_node.rs @@ -785,7 +785,7 @@ impl CurpNode { .into_iter() .map(|server_id| (server_id, Arc::new(Event::new()))) .collect(); - let connects = rpc::inner_connect(cluster_info.peers_addrs()) + let connects = rpc::inner_connects(cluster_info.peers_addrs()) .await .map_err(|e| CurpError::Internal(format!("parse peers addresses failed, err {e:?}")))? .collect();