Skip to content

Commit

Permalink
feat: Replace unbounded channel with bounded (#312)
Browse files Browse the repository at this point in the history
Co-authored-by: Sirius Fang <[email protected]>
  • Loading branch information
cirias and Sirius Fang authored Jun 5, 2024
1 parent ed4914a commit 918c25a
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 19 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ description = "Rust client for Apache Pulsar"
keywords = ["pulsar", "api", "client"]

[dependencies]
async-channel = "2"
bytes = "^1.4.0"
crc = "^3.0.1"
nom = { version="^7.1.3", default-features=false, features=["alloc"] }
Expand Down
13 changes: 13 additions & 0 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,17 +179,20 @@ impl<Exe: Executor> Pulsar<Exe> {
connection_retry_parameters: Option<ConnectionRetryOptions>,
operation_retry_parameters: Option<OperationRetryOptions>,
tls_options: Option<TlsOptions>,
outbound_channel_size: Option<usize>,
executor: Exe,
) -> Result<Self, Error> {
let url: String = url.into();
let executor = Arc::new(executor);
let operation_retry_options = operation_retry_parameters.unwrap_or_default();
let outbound_channel_size = outbound_channel_size.unwrap_or(100);
let manager = ConnectionManager::new(
url,
auth,
connection_retry_parameters,
operation_retry_options.clone(),
tls_options,
outbound_channel_size,
executor.clone(),
)
.await?;
Expand Down Expand Up @@ -252,6 +255,7 @@ impl<Exe: Executor> Pulsar<Exe> {
connection_retry_options: None,
operation_retry_options: None,
tls_options: None,
outbound_channel_size: None,
executor,
}
}
Expand Down Expand Up @@ -452,6 +456,7 @@ pub struct PulsarBuilder<Exe: Executor> {
connection_retry_options: Option<ConnectionRetryOptions>,
operation_retry_options: Option<OperationRetryOptions>,
tls_options: Option<TlsOptions>,
outbound_channel_size: Option<usize>,
executor: Exe,
}

Expand Down Expand Up @@ -549,6 +554,12 @@ impl<Exe: Executor> PulsarBuilder<Exe> {
Ok(self.with_certificate_chain(v))
}

#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
pub fn with_outbound_channel_size(mut self, size: usize) -> Result<Self, std::io::Error> {
self.outbound_channel_size = Some(size);
Ok(self)
}

/// creates the Pulsar client and connects it
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
pub async fn build(self) -> Result<Pulsar<Exe>, Error> {
Expand All @@ -558,6 +569,7 @@ impl<Exe: Executor> PulsarBuilder<Exe> {
connection_retry_options,
operation_retry_options,
tls_options,
outbound_channel_size,
executor,
} = self;

Expand All @@ -567,6 +579,7 @@ impl<Exe: Executor> PulsarBuilder<Exe> {
connection_retry_options,
operation_retry_options,
tls_options,
outbound_channel_size,
executor,
)
.await
Expand Down
54 changes: 35 additions & 19 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl crate::authentication::Authentication for Authentication {

pub(crate) struct Receiver<S: Stream<Item = Result<Message, ConnectionError>>> {
inbound: Pin<Box<S>>,
outbound: mpsc::UnboundedSender<Message>,
outbound: async_channel::Sender<Message>,
error: SharedError,
pending_requests: BTreeMap<RequestKey, oneshot::Sender<Message>>,
consumers: BTreeMap<u64, mpsc::UnboundedSender<Message>>,
Expand All @@ -114,7 +114,7 @@ impl<S: Stream<Item = Result<Message, ConnectionError>>> Receiver<S> {
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
pub fn new(
inbound: S,
outbound: mpsc::UnboundedSender<Message>,
outbound: async_channel::Sender<Message>,
error: SharedError,
registrations: mpsc::UnboundedReceiver<Register>,
shutdown: oneshot::Receiver<()>,
Expand Down Expand Up @@ -187,7 +187,9 @@ impl<S: Stream<Item = Result<Message, ConnectionError>>> Future for Receiver<S>
command: BaseCommand { ping: Some(_), .. },
..
} => {
let _ = self.outbound.unbounded_send(messages::pong());
if let Err(e) = self.outbound.try_send(messages::pong()) {
error!("failed to send pong: {}", e);
}
}
Message {
command: BaseCommand { pong: Some(_), .. },
Expand Down Expand Up @@ -289,7 +291,7 @@ impl SerialId {
//#[derive(Clone)]
pub struct ConnectionSender<Exe: Executor> {
connection_id: Uuid,
tx: mpsc::UnboundedSender<Message>,
tx: async_channel::Sender<Message>,
registrations: mpsc::UnboundedSender<Register>,
receiver_shutdown: Option<oneshot::Sender<()>>,
request_id: SerialId,
Expand All @@ -302,7 +304,7 @@ impl<Exe: Executor> ConnectionSender<Exe> {
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
pub(crate) fn new(
connection_id: Uuid,
tx: mpsc::UnboundedSender<Message>,
tx: async_channel::Sender<Message>,
registrations: mpsc::UnboundedSender<Register>,
receiver_shutdown: oneshot::Sender<()>,
request_id: SerialId,
Expand Down Expand Up @@ -349,9 +351,9 @@ impl<Exe: Executor> ConnectionSender<Exe> {
match (
self.registrations
.unbounded_send(Register::Ping { resolver }),
self.tx.unbounded_send(messages::ping()),
self.tx.try_send(messages::ping())?,
) {
(Ok(_), Ok(_)) => {
(Ok(_), ()) => {
let delay_f = self.executor.delay(self.operation_timeout);
pin_mut!(response);
pin_mut!(delay_f);
Expand Down Expand Up @@ -526,8 +528,8 @@ impl<Exe: Executor> ConnectionSender<Exe> {
#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
pub fn send_flow(&self, consumer_id: u64, message_permits: u32) -> Result<(), ConnectionError> {
self.tx
.unbounded_send(messages::flow(consumer_id, message_permits))
.map_err(|_| ConnectionError::Disconnected)
.try_send(messages::flow(consumer_id, message_permits))?;
Ok(())
}

#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
Expand All @@ -538,8 +540,8 @@ impl<Exe: Executor> ConnectionSender<Exe> {
cumulative: bool,
) -> Result<(), ConnectionError> {
self.tx
.unbounded_send(messages::ack(consumer_id, message_ids, cumulative))
.map_err(|_| ConnectionError::Disconnected)
.try_send(messages::ack(consumer_id, message_ids, cumulative))?;
Ok(())
}

#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
Expand All @@ -549,11 +551,11 @@ impl<Exe: Executor> ConnectionSender<Exe> {
message_ids: Vec<proto::MessageIdData>,
) -> Result<(), ConnectionError> {
self.tx
.unbounded_send(messages::redeliver_unacknowleged_messages(
.try_send(messages::redeliver_unacknowleged_messages(
consumer_id,
message_ids,
))
.map_err(|_| ConnectionError::Disconnected)
))?;
Ok(())
}

#[cfg_attr(feature = "telemetry", tracing::instrument(skip_all))]
Expand Down Expand Up @@ -661,7 +663,7 @@ impl<Exe: Executor> ConnectionSender<Exe> {
match (
self.registrations
.unbounded_send(Register::Request { key, resolver }),
self.tx.unbounded_send(msg),
self.tx.try_send(msg),
) {
(Ok(_), Ok(_)) => {
let connection_id = self.connection_id;
Expand Down Expand Up @@ -700,6 +702,7 @@ impl<Exe: Executor> ConnectionSender<Exe> {

Ok(fut)
}
(_, Err(e)) if e.is_full() => Err(ConnectionError::SlowDown),
_ => {
warn!(
"connection {} disconnected sending message to the Pulsar server",
Expand Down Expand Up @@ -781,6 +784,7 @@ impl<Exe: Executor> Connection<Exe> {
tls_hostname_verification_enabled: bool,
connection_timeout: Duration,
operation_timeout: Duration,
outbound_channel_size: usize,
executor: Arc<Exe>,
) -> Result<Connection<Exe>, ConnectionError> {
if url.scheme() != "pulsar" && url.scheme() != "pulsar+ssl" {
Expand Down Expand Up @@ -839,6 +843,7 @@ impl<Exe: Executor> Connection<Exe> {
tls_hostname_verification_enabled,
executor.clone(),
operation_timeout,
outbound_channel_size,
);
let delay_f = executor.delay(connection_timeout);

Expand Down Expand Up @@ -916,6 +921,7 @@ impl<Exe: Executor> Connection<Exe> {
tls_hostname_verification_enabled: bool,
executor: Arc<Exe>,
operation_timeout: Duration,
outbound_channel_size: usize,
) -> Result<ConnectionSender<Exe>, ConnectionError> {
match executor.kind() {
#[cfg(feature = "tokio-runtime")]
Expand Down Expand Up @@ -945,6 +951,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
} else {
Expand All @@ -959,6 +966,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
}
Expand Down Expand Up @@ -1007,6 +1015,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
} else {
Expand All @@ -1021,6 +1030,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
}
Expand Down Expand Up @@ -1053,6 +1063,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
} else {
Expand All @@ -1067,6 +1078,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
}
Expand Down Expand Up @@ -1119,6 +1131,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
} else {
Expand All @@ -1133,6 +1146,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url,
executor,
operation_timeout,
outbound_channel_size,
)
.await
}
Expand All @@ -1155,6 +1169,7 @@ impl<Exe: Executor> Connection<Exe> {
proxy_to_broker_url: Option<String>,
executor: Arc<Exe>,
operation_timeout: Duration,
outbound_channel_size: usize,
) -> Result<ConnectionSender<Exe>, ConnectionError>
where
S: Stream<Item = Result<Message, ConnectionError>>,
Expand Down Expand Up @@ -1194,7 +1209,7 @@ impl<Exe: Executor> Connection<Exe> {
}?;

let (mut sink, stream) = stream.split();
let (tx, mut rx) = mpsc::unbounded();
let (tx, rx) = async_channel::bounded(outbound_channel_size);
let (registrations_tx, registrations_rx) = mpsc::unbounded();
let error = SharedError::new();
let (receiver_shutdown_tx, receiver_shutdown_rx) = oneshot::channel();
Expand All @@ -1220,7 +1235,7 @@ impl<Exe: Executor> Connection<Exe> {

let err = error.clone();
let res = executor.spawn(Box::pin(async move {
while let Some(msg) = rx.next().await {
while let Ok(msg) = rx.recv().await {
// println!("real sent msg: {:?}", msg);
if let Err(e) = sink.send(msg).await {
err.set(e);
Expand All @@ -1236,7 +1251,7 @@ impl<Exe: Executor> Connection<Exe> {
if auth.is_some() {
let auth_challenge_res = executor.spawn({
let err = error.clone();
let mut tx = tx.clone();
let tx = tx.clone();
let auth = auth.clone();
Box::pin(async move {
while auth_challenge_rx.next().await.is_some() {
Expand Down Expand Up @@ -1796,7 +1811,7 @@ mod tests {
#[cfg(any(feature = "tokio-runtime", feature = "tokio-rustls-runtime"))]
async fn receiver_auth_challenge_test() {
let (message_tx, message_rx) = mpsc::unbounded();
let (tx, _) = mpsc::unbounded();
let (tx, _) = async_channel::bounded(10);
let (_registrations_tx, registrations_rx) = mpsc::unbounded();
let error = SharedError::new();
let (_receiver_shutdown_tx, receiver_shutdown_rx) = oneshot::channel();
Expand Down Expand Up @@ -1904,6 +1919,7 @@ mod tests {
None,
TokioExecutor.into(),
Duration::from_secs(10),
100,
)
.await;

Expand Down
4 changes: 4 additions & 0 deletions src/connection_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ pub struct ConnectionManager<Exe: Executor> {
pub(crate) operation_retry_options: OperationRetryOptions,
tls_options: TlsOptions,
certificate_chain: Vec<Certificate>,
outbound_channel_size: usize,
}

impl<Exe: Executor> ConnectionManager<Exe> {
Expand All @@ -133,6 +134,7 @@ impl<Exe: Executor> ConnectionManager<Exe> {
connection_retry: Option<ConnectionRetryOptions>,
operation_retry_options: OperationRetryOptions,
tls: Option<TlsOptions>,
outbound_channel_size: usize,
executor: Arc<Exe>,
) -> Result<Self, ConnectionError> {
let connection_retry_options = connection_retry.unwrap_or_default();
Expand Down Expand Up @@ -191,6 +193,7 @@ impl<Exe: Executor> ConnectionManager<Exe> {
operation_retry_options,
tls_options,
certificate_chain,
outbound_channel_size,
};
let broker_address = BrokerAddress {
url: url.clone(),
Expand Down Expand Up @@ -312,6 +315,7 @@ impl<Exe: Executor> ConnectionManager<Exe> {
self.tls_options.tls_hostname_verification_enabled,
self.connection_retry_options.connection_timeout,
self.operation_retry_options.operation_timeout,
self.outbound_channel_size,
self.executor.clone(),
)
.await
Expand Down
5 changes: 5 additions & 0 deletions src/consumer/engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,11 @@ impl<Exe: Executor> ConsumerEngine<Exe> {
.sender()
.send_flow(self.id, self.batch_size - self.remaining_messages)?;
}
Err(ConnectionError::SlowDown) => {
self.connection
.sender()
.send_flow(self.id, self.batch_size - self.remaining_messages)?;
}
Err(e) => return Err(e.into()),
}
self.remaining_messages = self.batch_size;
Expand Down
Loading

0 comments on commit 918c25a

Please sign in to comment.