From 918c25a82b77bf55a1afde0e4a991faeb8757ecc Mon Sep 17 00:00:00 2001 From: cirias Date: Wed, 5 Jun 2024 09:56:52 +0800 Subject: [PATCH] feat: Replace unbounded channel with bounded (#312) Co-authored-by: Sirius Fang --- Cargo.toml | 1 + src/client.rs | 13 ++++++++++ src/connection.rs | 54 +++++++++++++++++++++++++-------------- src/connection_manager.rs | 4 +++ src/consumer/engine.rs | 5 ++++ src/error.rs | 12 +++++++++ 6 files changed, 70 insertions(+), 19 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d4927f9e..1fd57f72 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,7 @@ description = "Rust client for Apache Pulsar" keywords = ["pulsar", "api", "client"] [dependencies] +async-channel = "2" bytes = "^1.4.0" crc = "^3.0.1" nom = { version="^7.1.3", default-features=false, features=["alloc"] } diff --git a/src/client.rs b/src/client.rs index e4fd0bec..45c86bd0 100644 --- a/src/client.rs +++ b/src/client.rs @@ -179,17 +179,20 @@ impl Pulsar { connection_retry_parameters: Option, operation_retry_parameters: Option, tls_options: Option, + outbound_channel_size: Option, executor: Exe, ) -> Result { let url: String = url.into(); let executor = Arc::new(executor); let operation_retry_options = operation_retry_parameters.unwrap_or_default(); + let outbound_channel_size = outbound_channel_size.unwrap_or(100); let manager = ConnectionManager::new( url, auth, connection_retry_parameters, operation_retry_options.clone(), tls_options, + outbound_channel_size, executor.clone(), ) .await?; @@ -252,6 +255,7 @@ impl Pulsar { connection_retry_options: None, operation_retry_options: None, tls_options: None, + outbound_channel_size: None, executor, } } @@ -452,6 +456,7 @@ pub struct PulsarBuilder { connection_retry_options: Option, operation_retry_options: Option, tls_options: Option, + outbound_channel_size: Option, executor: Exe, } @@ -549,6 +554,12 @@ impl PulsarBuilder { Ok(self.with_certificate_chain(v)) } + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + pub fn with_outbound_channel_size(mut self, size: usize) -> Result { + self.outbound_channel_size = Some(size); + Ok(self) + } + /// creates the Pulsar client and connects it #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub async fn build(self) -> Result, Error> { @@ -558,6 +569,7 @@ impl PulsarBuilder { connection_retry_options, operation_retry_options, tls_options, + outbound_channel_size, executor, } = self; @@ -567,6 +579,7 @@ impl PulsarBuilder { connection_retry_options, operation_retry_options, tls_options, + outbound_channel_size, executor, ) .await diff --git a/src/connection.rs b/src/connection.rs index ecc3148e..13eaf085 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -99,7 +99,7 @@ impl crate::authentication::Authentication for Authentication { pub(crate) struct Receiver>> { inbound: Pin>, - outbound: mpsc::UnboundedSender, + outbound: async_channel::Sender, error: SharedError, pending_requests: BTreeMap>, consumers: BTreeMap>, @@ -114,7 +114,7 @@ impl>> Receiver { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub fn new( inbound: S, - outbound: mpsc::UnboundedSender, + outbound: async_channel::Sender, error: SharedError, registrations: mpsc::UnboundedReceiver, shutdown: oneshot::Receiver<()>, @@ -187,7 +187,9 @@ impl>> Future for Receiver command: BaseCommand { ping: Some(_), .. }, .. } => { - let _ = self.outbound.unbounded_send(messages::pong()); + if let Err(e) = self.outbound.try_send(messages::pong()) { + error!("failed to send pong: {}", e); + } } Message { command: BaseCommand { pong: Some(_), .. }, @@ -289,7 +291,7 @@ impl SerialId { //#[derive(Clone)] pub struct ConnectionSender { connection_id: Uuid, - tx: mpsc::UnboundedSender, + tx: async_channel::Sender, registrations: mpsc::UnboundedSender, receiver_shutdown: Option>, request_id: SerialId, @@ -302,7 +304,7 @@ impl ConnectionSender { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub(crate) fn new( connection_id: Uuid, - tx: mpsc::UnboundedSender, + tx: async_channel::Sender, registrations: mpsc::UnboundedSender, receiver_shutdown: oneshot::Sender<()>, request_id: SerialId, @@ -349,9 +351,9 @@ impl ConnectionSender { match ( self.registrations .unbounded_send(Register::Ping { resolver }), - self.tx.unbounded_send(messages::ping()), + self.tx.try_send(messages::ping())?, ) { - (Ok(_), Ok(_)) => { + (Ok(_), ()) => { let delay_f = self.executor.delay(self.operation_timeout); pin_mut!(response); pin_mut!(delay_f); @@ -526,8 +528,8 @@ impl ConnectionSender { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] pub fn send_flow(&self, consumer_id: u64, message_permits: u32) -> Result<(), ConnectionError> { self.tx - .unbounded_send(messages::flow(consumer_id, message_permits)) - .map_err(|_| ConnectionError::Disconnected) + .try_send(messages::flow(consumer_id, message_permits))?; + Ok(()) } #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] @@ -538,8 +540,8 @@ impl ConnectionSender { cumulative: bool, ) -> Result<(), ConnectionError> { self.tx - .unbounded_send(messages::ack(consumer_id, message_ids, cumulative)) - .map_err(|_| ConnectionError::Disconnected) + .try_send(messages::ack(consumer_id, message_ids, cumulative))?; + Ok(()) } #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] @@ -549,11 +551,11 @@ impl ConnectionSender { message_ids: Vec, ) -> Result<(), ConnectionError> { self.tx - .unbounded_send(messages::redeliver_unacknowleged_messages( + .try_send(messages::redeliver_unacknowleged_messages( consumer_id, message_ids, - )) - .map_err(|_| ConnectionError::Disconnected) + ))?; + Ok(()) } #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] @@ -661,7 +663,7 @@ impl ConnectionSender { match ( self.registrations .unbounded_send(Register::Request { key, resolver }), - self.tx.unbounded_send(msg), + self.tx.try_send(msg), ) { (Ok(_), Ok(_)) => { let connection_id = self.connection_id; @@ -700,6 +702,7 @@ impl ConnectionSender { Ok(fut) } + (_, Err(e)) if e.is_full() => Err(ConnectionError::SlowDown), _ => { warn!( "connection {} disconnected sending message to the Pulsar server", @@ -781,6 +784,7 @@ impl Connection { tls_hostname_verification_enabled: bool, connection_timeout: Duration, operation_timeout: Duration, + outbound_channel_size: usize, executor: Arc, ) -> Result, ConnectionError> { if url.scheme() != "pulsar" && url.scheme() != "pulsar+ssl" { @@ -839,6 +843,7 @@ impl Connection { tls_hostname_verification_enabled, executor.clone(), operation_timeout, + outbound_channel_size, ); let delay_f = executor.delay(connection_timeout); @@ -916,6 +921,7 @@ impl Connection { tls_hostname_verification_enabled: bool, executor: Arc, operation_timeout: Duration, + outbound_channel_size: usize, ) -> Result, ConnectionError> { match executor.kind() { #[cfg(feature = "tokio-runtime")] @@ -945,6 +951,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } else { @@ -959,6 +966,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } @@ -1007,6 +1015,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } else { @@ -1021,6 +1030,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } @@ -1053,6 +1063,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } else { @@ -1067,6 +1078,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } @@ -1119,6 +1131,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } else { @@ -1133,6 +1146,7 @@ impl Connection { proxy_to_broker_url, executor, operation_timeout, + outbound_channel_size, ) .await } @@ -1155,6 +1169,7 @@ impl Connection { proxy_to_broker_url: Option, executor: Arc, operation_timeout: Duration, + outbound_channel_size: usize, ) -> Result, ConnectionError> where S: Stream>, @@ -1194,7 +1209,7 @@ impl Connection { }?; let (mut sink, stream) = stream.split(); - let (tx, mut rx) = mpsc::unbounded(); + let (tx, rx) = async_channel::bounded(outbound_channel_size); let (registrations_tx, registrations_rx) = mpsc::unbounded(); let error = SharedError::new(); let (receiver_shutdown_tx, receiver_shutdown_rx) = oneshot::channel(); @@ -1220,7 +1235,7 @@ impl Connection { let err = error.clone(); let res = executor.spawn(Box::pin(async move { - while let Some(msg) = rx.next().await { + while let Ok(msg) = rx.recv().await { // println!("real sent msg: {:?}", msg); if let Err(e) = sink.send(msg).await { err.set(e); @@ -1236,7 +1251,7 @@ impl Connection { if auth.is_some() { let auth_challenge_res = executor.spawn({ let err = error.clone(); - let mut tx = tx.clone(); + let tx = tx.clone(); let auth = auth.clone(); Box::pin(async move { while auth_challenge_rx.next().await.is_some() { @@ -1796,7 +1811,7 @@ mod tests { #[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))] async fn receiver_auth_challenge_test() { let (message_tx, message_rx) = mpsc::unbounded(); - let (tx, _) = mpsc::unbounded(); + let (tx, _) = async_channel::bounded(10); let (_registrations_tx, registrations_rx) = mpsc::unbounded(); let error = SharedError::new(); let (_receiver_shutdown_tx, receiver_shutdown_rx) = oneshot::channel(); @@ -1904,6 +1919,7 @@ mod tests { None, TokioExecutor.into(), Duration::from_secs(10), + 100, ) .await; diff --git a/src/connection_manager.rs b/src/connection_manager.rs index 43909473..12ab4895 100644 --- a/src/connection_manager.rs +++ b/src/connection_manager.rs @@ -123,6 +123,7 @@ pub struct ConnectionManager { pub(crate) operation_retry_options: OperationRetryOptions, tls_options: TlsOptions, certificate_chain: Vec, + outbound_channel_size: usize, } impl ConnectionManager { @@ -133,6 +134,7 @@ impl ConnectionManager { connection_retry: Option, operation_retry_options: OperationRetryOptions, tls: Option, + outbound_channel_size: usize, executor: Arc, ) -> Result { let connection_retry_options = connection_retry.unwrap_or_default(); @@ -191,6 +193,7 @@ impl ConnectionManager { operation_retry_options, tls_options, certificate_chain, + outbound_channel_size, }; let broker_address = BrokerAddress { url: url.clone(), @@ -312,6 +315,7 @@ impl ConnectionManager { self.tls_options.tls_hostname_verification_enabled, self.connection_retry_options.connection_timeout, self.operation_retry_options.operation_timeout, + self.outbound_channel_size, self.executor.clone(), ) .await diff --git a/src/consumer/engine.rs b/src/consumer/engine.rs index 08eb7a4e..5da88dd0 100644 --- a/src/consumer/engine.rs +++ b/src/consumer/engine.rs @@ -159,6 +159,11 @@ impl ConsumerEngine { .sender() .send_flow(self.id, self.batch_size - self.remaining_messages)?; } + Err(ConnectionError::SlowDown) => { + self.connection + .sender() + .send_flow(self.id, self.batch_size - self.remaining_messages)?; + } Err(e) => return Err(e.into()), } self.remaining_messages = self.batch_size; diff --git a/src/error.rs b/src/error.rs index 182dbf50..b9df7e9c 100644 --- a/src/error.rs +++ b/src/error.rs @@ -81,6 +81,7 @@ impl std::error::Error for Error { #[derive(Debug)] pub enum ConnectionError { Io(io::Error), + SlowDown, Disconnected, PulsarError(Option, Option), Unexpected(String), @@ -155,11 +156,22 @@ impl From for ConnectionError { } } +impl From> for ConnectionError { + #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] + fn from(err: async_channel::TrySendError) -> Self { + match err { + async_channel::TrySendError::Full(_) => ConnectionError::SlowDown, + async_channel::TrySendError::Closed(_) => ConnectionError::Disconnected, + } + } +} + impl fmt::Display for ConnectionError { #[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))] fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { ConnectionError::Io(e) => write!(f, "{e}"), + ConnectionError::SlowDown => write!(f, "SlowDown"), ConnectionError::Disconnected => write!(f, "Disconnected"), ConnectionError::PulsarError(e, s) => { write!(f, "Server error ({:?}): {}", e, s.as_deref().unwrap_or(""))