Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Replace async Mutex for RaftInner.core_state with standard Mutex and a watch channel #1211

Merged
merged 2 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions openraft/src/raft/core_state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::error::Fatal;
use crate::error::Infallible;
use crate::type_config::alias::JoinHandleOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::RaftTypeConfig;

/// The running state of RaftCore
Expand All @@ -10,6 +11,9 @@ where C: RaftTypeConfig
/// The RaftCore task is still running.
Running(JoinHandleOf<C, Result<Infallible, Fatal<C>>>),

/// The RaftCore task is waiting for a signal to finish joining.
Joining(WatchReceiverOf<C, bool>),

/// The RaftCore task has finished. The return value of the task is stored.
Done(Result<Infallible, Fatal<C>>),
}
Expand Down
8 changes: 4 additions & 4 deletions openraft/src/raft/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,8 +318,8 @@ where C: RaftTypeConfig
rx_metrics,
rx_data_metrics,
rx_server_metrics,
tx_shutdown: Mutex::new(Some(tx_shutdown)),
core_state: Mutex::new(CoreState::Running(core_handle)),
tx_shutdown: std::sync::Mutex::new(Some(tx_shutdown)),
core_state: std::sync::Mutex::new(CoreState::Running(core_handle)),

snapshot: Mutex::new(None),
};
Expand Down Expand Up @@ -828,7 +828,7 @@ where C: RaftTypeConfig
tracing::debug!("{} receives result is error: {:?}", func_name!(), recv_res.is_err());

let Ok(v) = recv_res else {
if self.inner.is_core_running().await {
if self.inner.is_core_running() {
return Ok(Err(InvalidStateMachineType::new::<SM>()));
} else {
let fatal = self.inner.get_core_stopped_error("receiving rx from RaftCore", None::<&'static str>).await;
Expand Down Expand Up @@ -919,7 +919,7 @@ where C: RaftTypeConfig
///
/// It sends a shutdown signal and waits until `RaftCore` returns.
pub async fn shutdown(&self) -> Result<(), JoinErrorOf<C>> {
if let Some(tx) = self.inner.tx_shutdown.lock().await.take() {
if let Some(tx) = self.inner.tx_shutdown.lock().unwrap().take() {
// A failure to send means the RaftCore is already shutdown. Continue to check the task
// return value.
let send_res = tx.send(());
Expand Down
74 changes: 57 additions & 17 deletions openraft/src/raft/raft_inner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use std::sync::Arc;
use tokio::sync::Mutex;
use tracing::Level;

use crate::async_runtime::watch::WatchReceiver;
use crate::async_runtime::watch::WatchSender;
use crate::async_runtime::MpscUnboundedSender;
use crate::config::RuntimeConfig;
use crate::core::raft_msg::external_command::ExternalCommand;
Expand All @@ -16,11 +18,13 @@ use crate::error::RaftError;
use crate::metrics::RaftDataMetrics;
use crate::metrics::RaftServerMetrics;
use crate::raft::core_state::CoreState;
use crate::type_config::alias::AsyncRuntimeOf;
use crate::type_config::alias::MpscUnboundedSenderOf;
use crate::type_config::alias::OneshotReceiverOf;
use crate::type_config::alias::OneshotSenderOf;
use crate::type_config::alias::WatchReceiverOf;
use crate::type_config::AsyncRuntime;
use crate::type_config::TypeConfigExt;
use crate::Config;
use crate::OptionalSend;
use crate::RaftMetrics;
Expand All @@ -40,10 +44,8 @@ where C: RaftTypeConfig
pub(in crate::raft) rx_data_metrics: WatchReceiverOf<C, RaftDataMetrics<C>>,
pub(in crate::raft) rx_server_metrics: WatchReceiverOf<C, RaftServerMetrics<C>>,

// TODO(xp): it does not need to be a async mutex.
#[allow(clippy::type_complexity)]
pub(in crate::raft) tx_shutdown: Mutex<Option<OneshotSenderOf<C, ()>>>,
pub(in crate::raft) core_state: Mutex<CoreState<C>>,
pub(in crate::raft) tx_shutdown: std::sync::Mutex<Option<OneshotSenderOf<C, ()>>>,
pub(in crate::raft) core_state: std::sync::Mutex<CoreState<C>>,

/// The ongoing snapshot transmission.
pub(in crate::raft) snapshot: Mutex<Option<crate::network::snapshot_transport::Streaming<C>>>,
Expand Down Expand Up @@ -131,8 +133,8 @@ where C: RaftTypeConfig
Ok(())
}

pub(in crate::raft) async fn is_core_running(&self) -> bool {
let state = self.core_state.lock().await;
pub(in crate::raft) fn is_core_running(&self) -> bool {
let state = self.core_state.lock().unwrap();
state.is_running()
}

Expand All @@ -147,7 +149,7 @@ where C: RaftTypeConfig

// Retrieve the result.
let core_res = {
let state = self.core_state.lock().await;
let state = self.core_state.lock().unwrap();
if let CoreState::Done(core_task_res) = &*state {
core_task_res.clone()
} else {
Expand All @@ -172,15 +174,40 @@ where C: RaftTypeConfig
/// Wait for `RaftCore` task to finish and record the returned value from the task.
#[tracing::instrument(level = "debug", skip_all)]
pub(in crate::raft) async fn join_core_task(&self) {
let mut state = self.core_state.lock().await;
match &mut *state {
CoreState::Running(handle) => {
let res = handle.await;
tracing::info!(res = debug(&res), "RaftCore exited");
// Get the Running state of RaftCore,
// or an error if RaftCore has been in Joining state.
let running_res = {
let mut state = self.core_state.lock().unwrap();

match &*state {
CoreState::Running(_) => {
let (tx, rx) = C::watch_channel::<bool>(false);

let prev = std::mem::replace(&mut *state, CoreState::Joining(rx));

let CoreState::Running(join_handle) = prev else {
unreachable!()
};

Ok((join_handle, tx))
}
CoreState::Joining(watch_rx) => Err(watch_rx.clone()),
CoreState::Done(_) => {
// RaftCore has already finished exiting, nothing to do
return;
}
}
};

match running_res {
Ok((join_handle, tx)) => {
let join_res = join_handle.await;

let core_task_res = match res {
tracing::info!(res = debug(&join_res), "RaftCore exited");

let core_task_res = match join_res {
Err(err) => {
if C::AsyncRuntime::is_panic(&err) {
if AsyncRuntimeOf::<C>::is_panic(&err) {
Err(Fatal::Panicked)
} else {
Err(Fatal::Stopped)
Expand All @@ -189,10 +216,23 @@ where C: RaftTypeConfig
Ok(returned_res) => returned_res,
};

*state = CoreState::Done(core_task_res);
{
let mut state = self.core_state.lock().unwrap();
*state = CoreState::Done(core_task_res);
}
tx.send(true).ok();
}
CoreState::Done(_) => {
// RaftCore has already quit, nothing to do
Err(mut rx) => {
// Other thread is waiting for the core to finish.
loop {
let res = rx.changed().await;
if res.is_err() {
break;
}
if *rx.borrow_watched() {
break;
}
}
}
}
}
Expand Down
Loading