From 85d4add8acdffb6296eb6049c6945fe45f083264 Mon Sep 17 00:00:00 2001
From: Phoeniix Zhao <Phoenix500526@163.com>
Date: Sat, 20 Apr 2024 11:32:30 +0800
Subject: [PATCH] refactor: add timeout mechanism for wait_for_XX_shutdown
 method

Signed-off-by: Phoeniix Zhao <Phoenix500526@163.com>
---
 crates/curp/tests/it/common/curp_group.rs | 34 ++++++++++++++++-------
 crates/curp/tests/it/server.rs            | 17 +++++++++---
 crates/utils/src/task_manager/tasks.rs    |  5 ++++
 3 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/crates/curp/tests/it/common/curp_group.rs b/crates/curp/tests/it/common/curp_group.rs
index f34f4dffa..0a12cff6a 100644
--- a/crates/curp/tests/it/common/curp_group.rs
+++ b/crates/curp/tests/it/common/curp_group.rs
@@ -32,6 +32,7 @@ use tokio::{
     runtime::{Handle, Runtime},
     sync::{mpsc, watch},
     task::{block_in_place, JoinHandle},
+    time::timeout,
 };
 use tokio_stream::wrappers::TcpListenerStream;
 use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, ServerTlsConfig};
@@ -52,6 +53,17 @@ pub use commandpb::{
     ProposeResponse,
 };
 
+/// `BOTTOM_TASKS` are tasks which not dependent on other tasks in the task group.
+/// `CurpGroup` uses `BOTTOM_TASKS` to detect whether the curp group is closed or not.
+const BOTTOM_TASKS: [TaskName; 3] = [
+    TaskName::WatchTask,
+    TaskName::ConfChange,
+    TaskName::LogPersist,
+];
+
+/// The default shutdown timeout used in `wait_for_targets_shutdown`
+pub(crate) const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(7);
+
 pub struct CurpNode {
     pub id: ServerId,
     pub addr: String,
@@ -339,33 +351,35 @@ impl CurpGroup {
             .all(|node| node.task_manager.is_finished())
     }
 
-    pub async fn wait_for_node_shutdown(&self, node_id: u64) {
+    pub async fn wait_for_node_shutdown(&self, node_id: u64, duration: Duration) {
         let node = self
             .nodes
             .get(&node_id)
             .expect("{node_id} should exist in nodes");
         let res = std::iter::once(node);
-        Self::wait_for_targets_shutdown(res).await;
+        timeout(duration, Self::wait_for_targets_shutdown(res))
+            .await
+            .expect("wait for group to shutdown timeout");
         assert!(
             node.task_manager.is_finished(),
             "The target node({node_id}) is not finished yet"
         );
     }
 
-    pub async fn wait_for_group_shutdown(&self) {
-        Self::wait_for_targets_shutdown(self.nodes.values()).await;
+    pub async fn wait_for_group_shutdown(&self, duration: Duration) {
+        timeout(
+            duration,
+            Self::wait_for_targets_shutdown(self.nodes.values()),
+        )
+        .await
+        .expect("wait for group to shutdown timeout");
         assert!(self.is_finished(), "The group is not finished yet");
     }
 
     async fn wait_for_targets_shutdown(targets: impl Iterator<Item = &CurpNode>) {
-        let final_tasks: [TaskName; 3] = [
-            TaskName::WatchTask,
-            TaskName::ConfChange,
-            TaskName::LogPersist,
-        ];
         let listeners = targets
             .flat_map(|node| {
-                final_tasks
+                BOTTOM_TASKS
                     .iter()
                     .map(|task| node.task_manager.get_shutdown_listener(task.to_owned()))
                     .collect::<Vec<_>>()
diff --git a/crates/curp/tests/it/server.rs b/crates/curp/tests/it/server.rs
index df9843e8f..3726772f0 100644
--- a/crates/curp/tests/it/server.rs
+++ b/crates/curp/tests/it/server.rs
@@ -19,6 +19,7 @@ use utils::{config::ClientConfig, timestamp};
 
 use crate::common::curp_group::{
     commandpb::ProposeId, CurpGroup, FetchClusterRequest, ProposeRequest, ProposeResponse,
+    DEFAULT_SHUTDOWN_TIMEOUT,
 };
 
 #[tokio::test(flavor = "multi_thread")]
@@ -297,7 +298,9 @@ async fn shutdown_rpc_should_shutdown_the_cluster() {
     ));
 
     let collection = collection_task.await.unwrap();
-    group.wait_for_group_shutdown().await;
+    group
+        .wait_for_group_shutdown(DEFAULT_SHUTDOWN_TIMEOUT)
+        .await;
 
     let group = CurpGroup::new_rocks(3, tmp_path).await;
     let client = group.new_client().await;
@@ -418,7 +421,9 @@ async fn shutdown_rpc_should_shutdown_the_cluster_when_client_has_wrong_leader()
         .unwrap();
     client.propose_shutdown().await.unwrap();
 
-    group.wait_for_group_shutdown().await;
+    group
+        .wait_for_group_shutdown(DEFAULT_SHUTDOWN_TIMEOUT)
+        .await;
 }
 
 #[tokio::test(flavor = "multi_thread")]
@@ -548,7 +553,9 @@ async fn shutdown_rpc_should_shutdown_the_cluster_when_client_has_wrong_cluster(
         .await;
     client.propose_shutdown().await.unwrap();
 
-    group.wait_for_group_shutdown().await;
+    group
+        .wait_for_group_shutdown(DEFAULT_SHUTDOWN_TIMEOUT)
+        .await;
 }
 
 #[tokio::test(flavor = "multi_thread")]
@@ -574,7 +581,9 @@ async fn propose_conf_change_rpc_should_work_when_client_has_wrong_cluster() {
     let members = client.propose_conf_change(changes).await.unwrap();
     assert_eq!(members.len(), 3);
     assert!(members.iter().all(|m| m.id != node_id));
-    group.wait_for_node_shutdown(node_id).await;
+    group
+        .wait_for_node_shutdown(node_id, DEFAULT_SHUTDOWN_TIMEOUT)
+        .await;
 }
 
 #[tokio::test(flavor = "multi_thread")]
diff --git a/crates/utils/src/task_manager/tasks.rs b/crates/utils/src/task_manager/tasks.rs
index 02dc6fa7b..b4e29f2ec 100644
--- a/crates/utils/src/task_manager/tasks.rs
+++ b/crates/utils/src/task_manager/tasks.rs
@@ -6,6 +6,11 @@
 //                    \        /      |      \       /
 //                   WATCH_TASK  CONF_CHANGE  LOG_PERSIST
 
+// NOTE: In integration tests, we use bottom tasks, like `WatchTask`, `ConfChange`, and `LogPersist`,
+// which are not dependent on other tasks to detect the curp group is closed or not. If you want
+// to refactor the task group, don't forget to modify the `BOTTOM_TASKS` in `crates/curp/tests/it/common/curp_group.rs`
+// to prevent the integration tests from failing.
+
 /// Generate enum with iterator
 macro_rules! enum_with_iter {
     ( $($variant:ident),* $(,)? ) => {