From 64d7d9664e202f81b7a71231437cfe4acdd522c6 Mon Sep 17 00:00:00 2001 From: lxl66566 Date: Mon, 29 Jul 2024 22:18:23 +0800 Subject: [PATCH] fix: task manager panic Signed-off-by: lxl66566 --- crates/curp/tests/it/common/curp_group.rs | 26 +++++++++++++---------- crates/curp/tests/it/server.rs | 2 +- crates/utils/src/task_manager/mod.rs | 18 +++++++++++++--- crates/utils/src/task_manager/tasks.rs | 9 ++++++++ crates/xline/src/server/watch_server.rs | 4 ++-- 5 files changed, 42 insertions(+), 17 deletions(-) diff --git a/crates/curp/tests/it/common/curp_group.rs b/crates/curp/tests/it/common/curp_group.rs index 8fe32ae18..efc5a9052 100644 --- a/crates/curp/tests/it/common/curp_group.rs +++ b/crates/curp/tests/it/common/curp_group.rs @@ -1,8 +1,3 @@ -use std::{ - collections::HashMap, error::Error, fmt::Display, iter, path::PathBuf, sync::Arc, thread, - time::Duration, -}; - use async_trait::async_trait; use clippy_utilities::NumericCast; use curp::{ @@ -27,6 +22,10 @@ use engine::{ }; use futures::{future::join_all, stream::FuturesUnordered, Future}; use itertools::Itertools; +use std::{ + collections::HashMap, error::Error, fmt::Display, iter, path::PathBuf, sync::Arc, thread, + time::Duration, +}; use tokio::{ net::TcpListener, runtime::{Handle, Runtime}, @@ -36,7 +35,7 @@ use tokio::{ }; use tokio_stream::wrappers::TcpListenerStream; use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, ServerTlsConfig}; -use tracing::debug; +use tracing::{debug, info}; use utils::{ build_endpoint, config::{ @@ -379,20 +378,25 @@ impl CurpGroup { } async fn wait_for_targets_shutdown(targets: impl Iterator) { + let targets = targets.collect::>(); let listeners = targets + .iter() .flat_map(|node| { BOTTOM_TASKS .iter() - .map(|task| { - node.task_manager - .get_shutdown_listener(task.to_owned()) - .unwrap() - }) + .filter_map(|task| node.task_manager.get_shutdown_listener(task.to_owned())) .collect::>() }) .collect::>(); let waiters: Vec<_> = listeners.iter().map(|l| l.wait()).collect(); futures::future::join_all(waiters.into_iter()).await; + for node in targets { + assert!( + node.task_manager.is_empty(), + "The tm in target node({}) is not empty", + node.id + ); + } } async fn stop(&mut self) { diff --git a/crates/curp/tests/it/server.rs b/crates/curp/tests/it/server.rs index 09ff3879c..a91e226c7 100644 --- a/crates/curp/tests/it/server.rs +++ b/crates/curp/tests/it/server.rs @@ -329,7 +329,7 @@ async fn shutdown_rpc_should_shutdown_the_cluster() { .await; assert!(matches!( CurpError::from(res.unwrap_err()), - CurpError::ShuttingDown(_) + CurpError::ShuttingDown(_) | CurpError::RpcTransport(_) )); let collection = collection_task.await.unwrap(); diff --git a/crates/utils/src/task_manager/mod.rs b/crates/utils/src/task_manager/mod.rs index 587613cb7..0d11ff1a5 100644 --- a/crates/utils/src/task_manager/mod.rs +++ b/crates/utils/src/task_manager/mod.rs @@ -174,10 +174,14 @@ impl TaskManager { /// Get root tasks queue fn root_tasks_queue(tasks: &DashMap) -> VecDeque { - tasks + let root_tasks: VecDeque<_> = tasks .iter() .filter_map(|task| (task.depend_cnt == 0).then_some(task.name)) - .collect() + .collect(); + if !tasks.is_empty() { + assert!(!root_tasks.is_empty(), "root tasks should not be empty"); + } + root_tasks } /// Inner shutdown task @@ -187,8 +191,9 @@ impl TaskManager { let Some((_name, mut task)) = tasks.remove(&v) else { continue; }; + let handles = task.handle.drain(..); task.notifier.notify_waiters(); - for handle in task.handle.drain(..) { + for handle in handles { // Directly abort the task if it's cancel safe if task.name.cancel_safe() { handle.abort(); @@ -268,6 +273,13 @@ impl TaskManager { } true } + + /// is the task empty + #[inline] + #[must_use] + pub fn is_empty(&self) -> bool { + self.tasks.is_empty() + } } impl Default for TaskManager { diff --git a/crates/utils/src/task_manager/tasks.rs b/crates/utils/src/task_manager/tasks.rs index e32606b00..8d7244a2c 100644 --- a/crates/utils/src/task_manager/tasks.rs +++ b/crates/utils/src/task_manager/tasks.rs @@ -33,6 +33,15 @@ macro_rules! enum_with_iter { VARIANTS.iter().copied() } } + + impl From for &'static str { + #[inline] + fn from(task_name: TaskName) -> &'static str { + match task_name { + $(TaskName::$variant => stringify!($variant)),* + } + } + } } } enum_with_iter! { diff --git a/crates/xline/src/server/watch_server.rs b/crates/xline/src/server/watch_server.rs index 29f67cf74..f6409829f 100644 --- a/crates/xline/src/server/watch_server.rs +++ b/crates/xline/src/server/watch_server.rs @@ -483,7 +483,7 @@ mod test { let next_id = Arc::new(WatchIdGenerator::new(1)); let n = task_manager .get_shutdown_listener(TaskName::WatchTask) - .unwrap(); + .unwrap_or_else(|| unreachable!("task WatchTask should exist")); let handle = tokio::spawn(WatchServer::task( next_id, Arc::clone(&watcher), @@ -737,7 +737,7 @@ mod test { let next_id = Arc::new(WatchIdGenerator::new(1)); let n = task_manager .get_shutdown_listener(TaskName::WatchTask) - .unwrap(); + .unwrap_or_else(|| unreachable!("task WatchTask should exist")); let handle = tokio::spawn(WatchServer::task( next_id, Arc::clone(&watcher),