From acdc4c297429efa648f9c31e0c862dd972e88259 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E7=82=8E=E6=B3=BC?= Date: Mon, 22 Jul 2024 13:54:49 +0800 Subject: [PATCH] Feature: Add MPSC channel to AsyncRuntime `AsyncRuntime` trait defines the async-runtime such as tokio to run Openraft. This commit add MPSC abstraction to `AsyncRuntime` and MPSC implementations to tokio based runtime and monoio based runtime. --- openraft/src/testing/runtime/mod.rs | 83 +++++++++++++++ openraft/src/type_config.rs | 19 +++- openraft/src/type_config/async_runtime/mod.rs | 7 ++ .../src/type_config/async_runtime/mpsc/mod.rs | 73 +++++++++++++ .../tokio_impls/tokio_runtime.rs | 70 ++++++++++++ openraft/src/type_config/util.rs | 14 +++ rt-monoio/Cargo.toml | 7 +- rt-monoio/src/lib.rs | 100 +++++++++++++++++- 8 files changed, 366 insertions(+), 7 deletions(-) create mode 100644 openraft/src/type_config/async_runtime/mpsc/mod.rs diff --git a/openraft/src/testing/runtime/mod.rs b/openraft/src/testing/runtime/mod.rs index d1270b914..5b119df01 100644 --- a/openraft/src/testing/runtime/mod.rs +++ b/openraft/src/testing/runtime/mod.rs @@ -7,7 +7,11 @@ use std::task::Poll; use crate::async_runtime::watch::WatchReceiver; use crate::async_runtime::watch::WatchSender; +use crate::async_runtime::Mpsc; +use crate::async_runtime::MpscReceiver; +use crate::async_runtime::MpscSender; use crate::async_runtime::MpscUnboundedWeakSender; +use crate::async_runtime::MpscWeakSender; use crate::instant::Instant; use crate::type_config::async_runtime::mpsc_unbounded::MpscUnbounded; use crate::type_config::async_runtime::mpsc_unbounded::MpscUnboundedReceiver; @@ -43,11 +47,19 @@ impl Suite { Self::test_sleep_until().await; Self::test_timeout().await; Self::test_timeout_at().await; + + Self::test_mpsc_recv_empty().await; + Self::test_mpsc_recv_channel_closed().await; + Self::test_mpsc_weak_sender_wont_prevent_channel_close().await; + Self::test_mpsc_weak_sender_upgrade().await; + Self::test_mpsc_send().await; + Self::test_unbounded_mpsc_recv_empty().await; Self::test_unbounded_mpsc_recv_channel_closed().await; Self::test_unbounded_mpsc_weak_sender_wont_prevent_channel_close().await; Self::test_unbounded_mpsc_weak_sender_upgrade().await; Self::test_unbounded_mpsc_send().await; + Self::test_watch_init_value().await; Self::test_watch_overwrite_init_value().await; Self::test_watch_send_error_no_receiver().await; @@ -131,6 +143,77 @@ impl Suite { assert!(timeout_result.is_err()); } + pub async fn test_mpsc_recv_empty() { + let (_tx, mut rx) = Rt::Mpsc::channel::<()>(5); + let recv_err = rx.try_recv().unwrap_err(); + assert!(matches!(recv_err, TryRecvError::Empty)); + } + + pub async fn test_mpsc_recv_channel_closed() { + let (_, mut rx) = Rt::Mpsc::channel::<()>(5); + let recv_err = rx.try_recv().unwrap_err(); + assert!(matches!(recv_err, TryRecvError::Disconnected)); + + let recv_result = rx.recv().await; + assert!(recv_result.is_none()); + } + + pub async fn test_mpsc_weak_sender_wont_prevent_channel_close() { + let (tx, mut rx) = Rt::Mpsc::channel::<()>(5); + + let _weak_tx = tx.downgrade(); + drop(tx); + let recv_err = rx.try_recv().unwrap_err(); + assert!(matches!(recv_err, TryRecvError::Disconnected)); + + let recv_result = rx.recv().await; + assert!(recv_result.is_none()); + } + + pub async fn test_mpsc_weak_sender_upgrade() { + let (tx, _rx) = Rt::Mpsc::channel::<()>(5); + + let weak_tx = tx.downgrade(); + let opt_tx = weak_tx.upgrade(); + assert!(opt_tx.is_some()); + + drop(tx); + drop(opt_tx); + // now there is no Sender instances alive + + let opt_tx = weak_tx.upgrade(); + assert!(opt_tx.is_none()); + } + + pub async fn test_mpsc_send() { + let (tx, mut rx) = Rt::Mpsc::channel::(5); + let tx = Arc::new(tx); + + let n_senders = 10_usize; + let recv_expected = (0..n_senders).collect::>(); + + for idx in 0..n_senders { + let tx = tx.clone(); + // no need to wait for senders here, we wait by recv()ing + let _handle = Rt::spawn(async move { + tx.send(idx).await.unwrap(); + }); + } + + let mut recv = Vec::with_capacity(n_senders); + while let Some(recv_number) = rx.recv().await { + recv.push(recv_number); + + if recv.len() == n_senders { + break; + } + } + + recv.sort(); + + assert_eq!(recv_expected, recv); + } + pub async fn test_unbounded_mpsc_recv_empty() { let (_tx, mut rx) = Rt::MpscUnbounded::channel::<()>(); let recv_err = rx.try_recv().unwrap_err(); diff --git a/openraft/src/type_config.rs b/openraft/src/type_config.rs index 2d784ea41..1a1eb17b9 100644 --- a/openraft/src/type_config.rs +++ b/openraft/src/type_config.rs @@ -95,6 +95,7 @@ pub trait RaftTypeConfig: /// [`type-alias`]: crate::docs::feature_flags#feature-flag-type-alias pub mod alias { use crate::async_runtime::watch; + use crate::async_runtime::Mpsc; use crate::async_runtime::MpscUnbounded; use crate::async_runtime::Oneshot; use crate::raft::responder::Responder; @@ -125,13 +126,23 @@ pub mod alias { pub type OneshotReceiverErrorOf = as Oneshot>::ReceiverError; pub type OneshotReceiverOf = as Oneshot>::Receiver; + pub type MpscOf = as AsyncRuntime>::Mpsc; + + // MPSC bounded + type MpscB = MpscOf; + + pub type MpscSenderOf = as Mpsc>::Sender; + pub type MpscReceiverOf = as Mpsc>::Receiver; + pub type MpscWeakSenderOf = as Mpsc>::WeakSender; + pub type MpscUnboundedOf = as AsyncRuntime>::MpscUnbounded; - type Mpsc = MpscUnboundedOf; + // MPSC unbounded + type MpscUB = MpscUnboundedOf; - pub type MpscUnboundedSenderOf = as MpscUnbounded>::Sender; - pub type MpscUnboundedReceiverOf = as MpscUnbounded>::Receiver; - pub type MpscUnboundedWeakSenderOf = as MpscUnbounded>::WeakSender; + pub type MpscUnboundedSenderOf = as MpscUnbounded>::Sender; + pub type MpscUnboundedReceiverOf = as MpscUnbounded>::Receiver; + pub type MpscUnboundedWeakSenderOf = as MpscUnbounded>::WeakSender; pub type WatchOf = as AsyncRuntime>::Watch; pub type WatchSenderOf = as watch::Watch>::Sender; diff --git a/openraft/src/type_config/async_runtime/mod.rs b/openraft/src/type_config/async_runtime/mod.rs index 132a60c28..eca8e1718 100644 --- a/openraft/src/type_config/async_runtime/mod.rs +++ b/openraft/src/type_config/async_runtime/mod.rs @@ -9,6 +9,7 @@ pub(crate) mod tokio_impls { mod tokio_runtime; pub use tokio_runtime::TokioRuntime; } +pub mod mpsc; pub mod mpsc_unbounded; pub mod mutex; pub mod oneshot; @@ -19,6 +20,10 @@ use std::fmt::Display; use std::future::Future; use std::time::Duration; +pub use mpsc::Mpsc; +pub use mpsc::MpscReceiver; +pub use mpsc::MpscSender; +pub use mpsc::MpscWeakSender; pub use mpsc_unbounded::MpscUnbounded; pub use mpsc_unbounded::MpscUnboundedReceiver; pub use mpsc_unbounded::MpscUnboundedSender; @@ -99,6 +104,8 @@ pub trait AsyncRuntime: Debug + Default + PartialEq + Eq + OptionalSend + Option /// sent to another thread. fn thread_rng() -> Self::ThreadLocalRng; + type Mpsc: Mpsc; + type MpscUnbounded: MpscUnbounded; type Watch: Watch; diff --git a/openraft/src/type_config/async_runtime/mpsc/mod.rs b/openraft/src/type_config/async_runtime/mpsc/mod.rs new file mode 100644 index 000000000..1514ceee8 --- /dev/null +++ b/openraft/src/type_config/async_runtime/mpsc/mod.rs @@ -0,0 +1,73 @@ +use std::future::Future; + +use base::OptionalSend; +use base::OptionalSync; + +/// mpsc shares the same error types as mpsc_unbounded +pub use super::mpsc_unbounded::SendError; +pub use super::mpsc_unbounded::TryRecvError; +use crate::base; + +/// Multi-producer, single-consumer channel. +pub trait Mpsc: Sized + OptionalSend { + type Sender: MpscSender; + type Receiver: MpscReceiver; + type WeakSender: MpscWeakSender; + + /// Creates a bounded mpsc channel for communicating between asynchronous tasks with + /// backpressure. + fn channel(buffer: usize) -> (Self::Sender, Self::Receiver); +} + +/// Send values to the associated [`MpscReceiver`]. +pub trait MpscSender: OptionalSend + OptionalSync + Clone +where + MU: Mpsc, + T: OptionalSend, +{ + /// Attempts to send a message, blocks if there is no capacity. + /// + /// If the receiving half of the channel is closed, this + /// function returns an error. The error includes the value passed to `send`. + fn send(&self, msg: T) -> impl Future>> + OptionalSend; + + /// Converts the [`MpscSender`] to a [`MpscWeakSender`] that does not count + /// towards RAII semantics, i.e. if all `Sender` instances of the + /// channel were dropped and only `WeakSender` instances remain, + /// the channel is closed. + fn downgrade(&self) -> MU::WeakSender; +} + +/// Receive values from the associated [`MpscSender`]. +pub trait MpscReceiver: OptionalSend + OptionalSync { + /// Receives the next value for this receiver. + /// + /// This method returns `None` if the channel has been closed and there are + /// no remaining messages in the channel's buffer. + fn recv(&mut self) -> impl Future> + OptionalSend; + + /// Tries to receive the next value for this receiver. + /// + /// This method returns the [`TryRecvError::Empty`] error if the channel is currently + /// empty, but there are still outstanding senders. + /// + /// This method returns the [`TryRecvError::Disconnected`] error if the channel is + /// currently empty, and there are no outstanding senders. + fn try_recv(&mut self) -> Result; +} + +/// A sender that does not prevent the channel from being closed. +/// +/// If all [`MpscSender`] instances of a channel were dropped and only +/// `WeakSender` instances remain, the channel is closed. +pub trait MpscWeakSender: OptionalSend + OptionalSync + Clone +where + MU: Mpsc, + T: OptionalSend, +{ + /// Tries to convert a [`MpscWeakSender`] into an [`MpscSender`]. + /// + /// This will return `Some` if there are other `Sender` instances alive and + /// the channel wasn't previously dropped, otherwise `None` is returned. + fn upgrade(&self) -> Option>; +} diff --git a/openraft/src/type_config/async_runtime/tokio_impls/tokio_runtime.rs b/openraft/src/type_config/async_runtime/tokio_impls/tokio_runtime.rs index b79b8084e..78882411f 100644 --- a/openraft/src/type_config/async_runtime/tokio_impls/tokio_runtime.rs +++ b/openraft/src/type_config/async_runtime/tokio_impls/tokio_runtime.rs @@ -74,6 +74,7 @@ impl AsyncRuntime for TokioRuntime { rand::thread_rng() } + type Mpsc = mpsc_impl::TokioMpsc; type MpscUnbounded = TokioMpscUnbounded; type Watch = TokioWatch; type Oneshot = TokioOneshot; @@ -134,6 +135,75 @@ where T: OptionalSend } } +mod mpsc_impl { + use std::future::Future; + + use futures::TryFutureExt; + use tokio::sync::mpsc; + + use crate::async_runtime::Mpsc; + use crate::async_runtime::MpscReceiver; + use crate::async_runtime::MpscSender; + use crate::async_runtime::MpscWeakSender; + use crate::async_runtime::SendError; + use crate::async_runtime::TryRecvError; + use crate::OptionalSend; + + pub struct TokioMpsc; + + impl Mpsc for TokioMpsc { + type Sender = mpsc::Sender; + type Receiver = mpsc::Receiver; + type WeakSender = mpsc::WeakSender; + + /// Creates a bounded mpsc channel for communicating between asynchronous + /// tasks with backpressure. + fn channel(buffer: usize) -> (Self::Sender, Self::Receiver) { + mpsc::channel(buffer) + } + } + + impl MpscSender for mpsc::Sender + where T: OptionalSend + { + #[inline] + fn send(&self, msg: T) -> impl Future>> + OptionalSend { + self.send(msg).map_err(|e| SendError(e.0)) + } + + #[inline] + fn downgrade(&self) -> ::WeakSender { + self.downgrade() + } + } + + impl MpscReceiver for mpsc::Receiver + where T: OptionalSend + { + #[inline] + fn recv(&mut self) -> impl Future> + OptionalSend { + self.recv() + } + + #[inline] + fn try_recv(&mut self) -> Result { + self.try_recv().map_err(|e| match e { + mpsc::error::TryRecvError::Empty => TryRecvError::Empty, + mpsc::error::TryRecvError::Disconnected => TryRecvError::Disconnected, + }) + } + } + + impl MpscWeakSender for mpsc::WeakSender + where T: OptionalSend + { + #[inline] + fn upgrade(&self) -> Option<::Sender> { + self.upgrade() + } + } +} + pub struct TokioWatch; impl watch::Watch for TokioWatch { diff --git a/openraft/src/type_config/util.rs b/openraft/src/type_config/util.rs index 7fb9b53e2..f55d31260 100644 --- a/openraft/src/type_config/util.rs +++ b/openraft/src/type_config/util.rs @@ -5,11 +5,15 @@ use openraft_macros::since; use crate::async_runtime::mutex::Mutex; use crate::async_runtime::watch::Watch; +use crate::async_runtime::Mpsc; use crate::async_runtime::MpscUnbounded; use crate::async_runtime::Oneshot; use crate::type_config::alias::AsyncRuntimeOf; use crate::type_config::alias::InstantOf; use crate::type_config::alias::JoinHandleOf; +use crate::type_config::alias::MpscOf; +use crate::type_config::alias::MpscReceiverOf; +use crate::type_config::alias::MpscSenderOf; use crate::type_config::alias::MpscUnboundedOf; use crate::type_config::alias::MpscUnboundedReceiverOf; use crate::type_config::alias::MpscUnboundedSenderOf; @@ -72,6 +76,16 @@ pub trait TypeConfigExt: RaftTypeConfig { OneshotOf::::channel() } + /// Creates a mpsc channel for communicating between asynchronous + /// tasks with backpressure. + /// + /// This is just a wrapper of + /// [`AsyncRuntime::Mpsc::channel()`](`crate::async_runtime::Mpsc::channel`). + fn mpsc(buffer: usize) -> (MpscSenderOf, MpscReceiverOf) + where T: OptionalSend { + MpscOf::::channel(buffer) + } + /// Creates an unbounded mpsc channel for communicating between asynchronous /// tasks without backpressure. /// diff --git a/rt-monoio/Cargo.toml b/rt-monoio/Cargo.toml index c30f8f104..8749c7b20 100644 --- a/rt-monoio/Cargo.toml +++ b/rt-monoio/Cargo.toml @@ -18,6 +18,9 @@ repository = "https://github.com/datafuselabs/openraft" openraft = { path = "../openraft", version = "0.10.0", default-features = false, features = ["singlethreaded"] } rand = "0.8" -tokio = { version = "1.22", features = ["sync"] } -monoio = "0.2.3" + +futures = { version = "0.3" } local-sync = "0.1.1" + +monoio = "0.2.3" +tokio = { version = "1.22", features = ["sync"] } diff --git a/rt-monoio/src/lib.rs b/rt-monoio/src/lib.rs index 62441f0cd..5e36286f3 100644 --- a/rt-monoio/src/lib.rs +++ b/rt-monoio/src/lib.rs @@ -86,7 +86,8 @@ impl AsyncRuntime for MonoioRuntime { rand::thread_rng() } - type MpscUnbounded = mpsc_mod::TokioMpscUnbounded; + type Mpsc = mpsc_mod::MonoioMpsc; + type MpscUnbounded = mpsc_unbounded_mod::TokioMpscUnbounded; type Watch = watch_mod::TokioWatch; type Oneshot = oneshot_mod::MonoioOneshot; type Mutex = mutex_mod::TokioMutex; @@ -203,7 +204,104 @@ mod oneshot_mod { // Put the wrapper types in a private module to make them `pub` but not // exposed to the user. +/// MPSC channel is implemented with tokio MPSC channels. +/// +/// Tokio MPSC channel are runtime independent. mod mpsc_mod { + //! MPSC channel wrapper types and their trait impl. + + use std::future::Future; + + use futures::TryFutureExt; + use openraft::async_runtime::Mpsc; + use openraft::async_runtime::MpscReceiver; + use openraft::async_runtime::MpscSender; + use openraft::async_runtime::MpscWeakSender; + use openraft::async_runtime::SendError; + use openraft::async_runtime::TryRecvError; + use openraft::OptionalSend; + use tokio::sync::mpsc as tokio_mpsc; + + pub struct MonoioMpsc; + + pub struct MonoioMpscSender(tokio_mpsc::Sender); + + impl Clone for MonoioMpscSender { + #[inline] + fn clone(&self) -> Self { + Self(self.0.clone()) + } + } + + pub struct MonoioMpscReceiver(tokio_mpsc::Receiver); + + pub struct MonoioMpscWeakSender(tokio_mpsc::WeakSender); + + impl Clone for MonoioMpscWeakSender { + #[inline] + fn clone(&self) -> Self { + Self(self.0.clone()) + } + } + + impl Mpsc for MonoioMpsc { + type Sender = MonoioMpscSender; + type Receiver = MonoioMpscReceiver; + type WeakSender = MonoioMpscWeakSender; + + #[inline] + fn channel(buffer: usize) -> (Self::Sender, Self::Receiver) { + let (tx, rx) = tokio_mpsc::channel(buffer); + let tx_wrapper = MonoioMpscSender(tx); + let rx_wrapper = MonoioMpscReceiver(rx); + + (tx_wrapper, rx_wrapper) + } + } + + impl MpscSender for MonoioMpscSender + where T: OptionalSend + { + #[inline] + fn send(&self, msg: T) -> impl Future>> { + self.0.send(msg).map_err(|e| SendError(e.0)) + } + + #[inline] + fn downgrade(&self) -> ::WeakSender { + let inner = self.0.downgrade(); + MonoioMpscWeakSender(inner) + } + } + + impl MpscReceiver for MonoioMpscReceiver { + #[inline] + fn recv(&mut self) -> impl Future> { + self.0.recv() + } + + #[inline] + fn try_recv(&mut self) -> Result { + self.0.try_recv().map_err(|e| match e { + tokio_mpsc::error::TryRecvError::Empty => TryRecvError::Empty, + tokio_mpsc::error::TryRecvError::Disconnected => TryRecvError::Disconnected, + }) + } + } + + impl MpscWeakSender for MonoioMpscWeakSender + where T: OptionalSend + { + #[inline] + fn upgrade(&self) -> Option<::Sender> { + self.0.upgrade().map(MonoioMpscSender) + } + } +} + +// Put the wrapper types in a private module to make them `pub` but not +// exposed to the user. +mod mpsc_unbounded_mod { //! Unbounded MPSC channel wrapper types and their trait impl. use openraft::type_config::async_runtime::mpsc_unbounded;