Skip to content

Commit

Permalink
refactor: refactor connect from macros to function
Browse files Browse the repository at this point in the history
Signed-off-by: iGxnon <[email protected]>
  • Loading branch information
iGxnon committed Nov 24, 2023
1 parent c8aa843 commit 2e3869f
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 56 deletions.
6 changes: 3 additions & 3 deletions curp/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl<C: Command> Builder<C> {
let Some(config) = self.config else {
return Err(ClientBuildError::invalid_arguments("timeout is required"));
};
let connects = rpc::connect(all_members).await?.collect();
let connects = rpc::connects(all_members).await?.collect();
let client = Client::<C> {
local_server_id: self.local_server_id,
state: RwLock::new(State::new(leader_id, 0)),
Expand Down Expand Up @@ -136,7 +136,7 @@ impl<C: Command> Builder<C> {
state: RwLock::new(State::new(res.leader_id, res.term)),
config,
cluster_version: AtomicU64::new(res.cluster_version),
connects: rpc::connect(res.into_members_addrs()).await?.collect(),
connects: rpc::connects(res.into_members_addrs()).await?.collect(),
phantom: PhantomData,
};
Ok(client)
Expand Down Expand Up @@ -448,7 +448,7 @@ where
.map(|m| (m.id, m.addrs))
.collect::<HashMap<ServerId, Vec<String>>>();
self.connects.clear();
for (id, connect) in rpc::connect(member_addrs)
for (id, connect) in rpc::connects(member_addrs)
.await
.map_err(|e| ClientError::InternalError(format!("connect to cluster failed: {e}")))?
{
Expand Down
2 changes: 1 addition & 1 deletion curp/src/members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ pub async fn get_cluster_info_from_remote(
timeout: Duration,
) -> Option<ClusterInfo> {
let peers = init_cluster_info.peers_addrs();
let connects = rpc::connect(peers)
let connects = rpc::connects(peers)
.await
.ok()?
.map(|pair| pair.1)
Expand Down
136 changes: 86 additions & 50 deletions curp/src/rpc/connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use async_trait::async_trait;
use bytes::BytesMut;
use clippy_utilities::NumericCast;
use engine::SnapshotApi;
use futures::Stream;
use futures::{stream::FuturesUnordered, Stream};
#[cfg(test)]
use mockall::automock;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -43,63 +43,100 @@ const SNAPSHOT_CHUNK_SIZE: u64 = 64 * 1024;
/// The default buffer size for rpc connection
const DEFAULT_BUFFER_SIZE: usize = 1024;

/// Connect implementation
macro_rules! connect_impl {
($client:ty, $api:path, $members:ident) => {
futures::future::join_all(
$members
.into_iter()
.map(|(id, mut addrs)| single_connect_impl!($client, $api, id, addrs)),
)
.await
.into_iter()
.collect::<Result<Vec<_>, tonic::transport::Error>>()
.map(IntoIterator::into_iter)
};
/// For protocol client
trait FromTonicChannel {
/// New from channel
fn from_channel(channel: Channel) -> Self;
}

/// Single Connect implementation
macro_rules! single_connect_impl {
($client:ty, $api:path, $id:ident, $addrs:ident) => {
async move {
let (channel, change_tx) = Channel::balance_channel(DEFAULT_BUFFER_SIZE);
// Addrs must start with "http" to communicate with the server
for addr in &mut $addrs {
if !addr.starts_with("http://") {
addr.insert_str(0, "http://");
}
let endpoint = Endpoint::from_shared(addr.clone())?;
let _ig = change_tx
.send(tower::discover::Change::Insert(addr.clone(), endpoint))
.await;
}
let client = <$client>::new(channel);
let connect: Arc<dyn $api> = Arc::new(Connect {
$id,
rpc_connect: client,
change_tx,
addrs: Mutex::new($addrs),
});
Ok(($id, connect))
impl FromTonicChannel for ProtocolClient<Channel> {
fn from_channel(channel: Channel) -> Self {
ProtocolClient::new(channel)
}
}

impl FromTonicChannel for InnerProtocolClient<Channel> {
fn from_channel(channel: Channel) -> Self {
InnerProtocolClient::new(channel)
}
}

/// Connect to a server
async fn connect_to<Client: FromTonicChannel>(
id: ServerId,
mut addrs: Vec<String>,
) -> Result<Arc<Connect<Client>>, tonic::transport::Error> {
let (channel, change_tx) = Channel::balance_channel(DEFAULT_BUFFER_SIZE);
for addr in &mut addrs {
// TODO: support TLS
if !addr.starts_with("http://") {
addr.insert_str(0, "http://");
}
};
let endpoint = Endpoint::from_shared(addr.clone())?;
let _ig = change_tx
.send(tower::discover::Change::Insert(addr.clone(), endpoint))
.await;
}
let client = Client::from_channel(channel);
let connect = Arc::new(Connect {
id,
rpc_connect: client,
change_tx,
addrs: Mutex::new(addrs),
});
Ok(connect)
}

/// Convert a vec of addr string to a vec of `Connect`
/// # Errors
/// Return error if any of the address format is invalid
/// Connect to a map of members
async fn connect_all<Client: FromTonicChannel>(
members: HashMap<ServerId, Vec<String>>,
) -> Result<Vec<(u64, Arc<Connect<Client>>)>, tonic::transport::Error> {
let conns_to: FuturesUnordered<_> =
members
.into_iter()
.map(|(id, addrs)| async move {
connect_to::<Client>(id, addrs).await.map(|conn| (id, conn))
})
.collect();
futures::StreamExt::collect::<Vec<_>>(conns_to)
.await
.into_iter()
.collect::<Result<Vec<_>, _>>()
}

/// A wrapper of [`connect_to`], hide the detailed [`Connect<ProtocolClient>`]
#[allow(unused)] // TODO: will be used in curp client refactor
pub(crate) async fn connect(
id: ServerId,
addrs: Vec<String>,
) -> Result<Arc<dyn ConnectApi>, tonic::transport::Error> {
let conn = connect_to::<ProtocolClient<Channel>>(id, addrs).await?;
Ok(conn)
}

/// Wrapper of [`connect_all`], hide the details of [`Connect<ProtocolClient>`]
pub(crate) async fn connects(
members: HashMap<ServerId, Vec<String>>,
) -> Result<impl Iterator<Item = (ServerId, Arc<dyn ConnectApi>)>, tonic::transport::Error> {
connect_impl!(ProtocolClient<Channel>, ConnectApi, members)
// It seems that casting high-rank types cannot be inferred, so we allow trivial_casts to cast manually
#[allow(trivial_casts)]
#[allow(clippy::as_conversions)]
let conns = connect_all(members)
.await?
.into_iter()
.map(|(id, conn)| (id, conn as Arc<dyn ConnectApi>));
Ok(conns)
}

/// Convert a vec of addr string to a vec of `InnerConnect`
pub(crate) async fn inner_connect(
/// Wrapper of [`connect_all`], hide the details of [`Connect<InnerProtocolClient>`]
pub(crate) async fn inner_connects(
members: HashMap<ServerId, Vec<String>>,
) -> Result<impl Iterator<Item = (ServerId, InnerConnectApiWrapper)>, tonic::transport::Error> {
connect_impl!(InnerProtocolClient<Channel>, InnerConnectApi, members)
.map(|iter| iter.map(|(id, connect)| (id, InnerConnectApiWrapper::new_from_arc(connect))))
let conns = connect_all(members)
.await?
.into_iter()
.map(|(id, conn)| (id, InnerConnectApiWrapper::new_from_arc(conn)));
Ok(conns)
}

/// Connect interface between server and clients
Expand Down Expand Up @@ -209,11 +246,10 @@ impl InnerConnectApiWrapper {
/// Create a new `InnerConnectApiWrapper` from id and addrs
pub(crate) async fn connect(
id: ServerId,
mut addrs: Vec<String>,
addrs: Vec<String>,
) -> Result<Self, tonic::transport::Error> {
let (_id, connect) =
single_connect_impl!(InnerProtocolClient<Channel>, InnerConnectApi, id, addrs).await?;
Ok(InnerConnectApiWrapper::new_from_arc(connect))
let conn = connect_to::<InnerProtocolClient<Channel>>(id, addrs).await?;
Ok(InnerConnectApiWrapper::new_from_arc(conn))
}
}

Expand Down
2 changes: 1 addition & 1 deletion curp/src/rpc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::{

/// Rpc connect
pub(crate) mod connect;
pub(crate) use connect::{connect, inner_connect};
pub(crate) use connect::{connects, inner_connects};

// Skip for generated code
#[allow(
Expand Down
2 changes: 1 addition & 1 deletion curp/src/server/curp_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ impl<C: 'static + Command, RC: RoleChange + 'static> CurpNode<C, RC> {
.into_iter()
.map(|server_id| (server_id, Arc::new(Event::new())))
.collect();
let connects = rpc::inner_connect(cluster_info.peers_addrs())
let connects = rpc::inner_connects(cluster_info.peers_addrs())
.await
.map_err(|e| CurpError::Internal(format!("parse peers addresses failed, err {e:?}")))?
.collect();
Expand Down

0 comments on commit 2e3869f

Please sign in to comment.