Skip to content

Commit

Permalink
fix: task manager panic
Browse files Browse the repository at this point in the history
Signed-off-by: lxl66566 <[email protected]>
  • Loading branch information
lxl66566 committed Jul 29, 2024
1 parent c86d12f commit 7418d9f
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 29 deletions.
4 changes: 3 additions & 1 deletion crates/curp/src/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,9 @@ impl<C: Command, RC: RoleChange> Rpc<C, RC> {

use crate::rpc::{InnerProtocolServer, ProtocolServer};

let n = task_manager.get_shutdown_listener(TaskName::TonicServer);
let n = task_manager
.get_shutdown_listener(TaskName::TonicServer)
.unwrap_or_else(|| unreachable!("task TonicServer should exist"));
let server = Self::new(
cluster_info,
is_leader,
Expand Down
22 changes: 15 additions & 7 deletions crates/curp/tests/it/common/curp_group.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,3 @@
use std::{
collections::HashMap, error::Error, fmt::Display, iter, path::PathBuf, sync::Arc, thread,
time::Duration,
};

use async_trait::async_trait;
use clippy_utilities::NumericCast;
use curp::{
Expand All @@ -27,6 +22,10 @@ use engine::{
};
use futures::{future::join_all, stream::FuturesUnordered, Future};
use itertools::Itertools;
use std::{
collections::HashMap, error::Error, fmt::Display, iter, path::PathBuf, sync::Arc, thread,
time::Duration,
};
use tokio::{
net::TcpListener,
runtime::{Handle, Runtime},
Expand All @@ -36,7 +35,7 @@ use tokio::{
};
use tokio_stream::wrappers::TcpListenerStream;
use tonic::transport::{Certificate, Channel, ClientTlsConfig, Endpoint, ServerTlsConfig};
use tracing::debug;
use tracing::{debug, info};
use utils::{
build_endpoint,
config::{
Expand Down Expand Up @@ -377,16 +376,25 @@ impl CurpGroup {
}

async fn wait_for_targets_shutdown(targets: impl Iterator<Item = &CurpNode>) {
let targets = targets.collect::<Vec<_>>();
let listeners = targets
.iter()
.flat_map(|node| {
BOTTOM_TASKS
.iter()
.map(|task| node.task_manager.get_shutdown_listener(task.to_owned()))
.filter_map(|task| node.task_manager.get_shutdown_listener(task.to_owned()))
.collect::<Vec<_>>()
})
.collect::<Vec<_>>();
let waiters: Vec<_> = listeners.iter().map(|l| l.wait()).collect();
futures::future::join_all(waiters.into_iter()).await;
for node in targets {
assert!(
node.task_manager.is_empty(),
"The tm in target node({}) is not empty",
node.id
);
}
}

async fn stop(&mut self) {
Expand Down
45 changes: 28 additions & 17 deletions crates/utils/src/task_manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,16 +135,14 @@ impl TaskManager {
/// Get shutdown listener
#[must_use]
#[inline]
pub fn get_shutdown_listener(&self, name: TaskName) -> Listener {
let task = self
.tasks
.get(&name)
.unwrap_or_else(|| unreachable!("task {:?} should exist", name));
Listener::new(
Arc::clone(&self.state),
Arc::clone(&task.notifier),
Arc::clone(&self.cluster_shutdown_tracker),
)
pub fn get_shutdown_listener(&self, name: TaskName) -> Option<Listener> {
self.tasks.get(&name).map(|task| {
Listener::new(
Arc::clone(&self.state),
Arc::clone(&task.notifier),
Arc::clone(&self.cluster_shutdown_tracker),
)
})
}

/// Spawn a task
Expand Down Expand Up @@ -173,10 +171,14 @@ impl TaskManager {

/// Get root tasks queue
fn root_tasks_queue(tasks: &DashMap<TaskName, Task>) -> VecDeque<TaskName> {
tasks
let root_tasks: VecDeque<_> = tasks
.iter()
.filter_map(|task| (task.depend_cnt == 0).then_some(task.name))
.collect()
.collect();
if !tasks.is_empty() {
assert!(!root_tasks.is_empty(), "root tasks should not be empty");
}
root_tasks
}

/// Inner shutdown task
Expand All @@ -187,12 +189,14 @@ impl TaskManager {
let Some((_name, mut task)) = tasks.remove(&v) else {
continue;
};
let handles = task.handle.drain(..);
task.notifier.notify_waiters();
for handle in task.handle.drain(..) {
handle
.await
.unwrap_or_else(|e| unreachable!("background task should not panic: {e}"));
}
futures::future::join_all(handles)
.await
.into_iter()
.for_each(|res| {
res.unwrap_or_else(|e| unreachable!("background task should not panic: {e}"));
});
for child in task.depend_by.drain(..) {
let Some(mut child_task) = tasks.get_mut(&child) else {
continue;
Expand Down Expand Up @@ -260,6 +264,13 @@ impl TaskManager {
}
true
}

/// is the task empty
#[inline]
#[must_use]
pub fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
}

impl Default for TaskManager {
Expand Down
9 changes: 9 additions & 0 deletions crates/utils/src/task_manager/tasks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ macro_rules! enum_with_iter {
VARIANTS.iter().copied()
}
}

impl From<TaskName> for &'static str {
#[inline]
fn from(task_name: TaskName) -> &'static str {
match task_name {
$(TaskName::$variant => stringify!($variant)),*
}
}
}
}
}
enum_with_iter! {
Expand Down
6 changes: 4 additions & 2 deletions crates/xline/src/server/lease_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ impl LeaseServer {
) -> Pin<Box<dyn Stream<Item = Result<LeaseKeepAliveResponse, tonic::Status>> + Send>> {
let shutdown_listener = self
.task_manager
.get_shutdown_listener(TaskName::LeaseKeepAlive);
.get_shutdown_listener(TaskName::LeaseKeepAlive)
.unwrap_or_else(|| unreachable!("task LeaseKeepAlive should exist"));
let lease_storage = Arc::clone(&self.lease_storage);
let stream = try_stream! {
loop {
Expand Down Expand Up @@ -192,7 +193,8 @@ impl LeaseServer {
> {
let shutdown_listener = self
.task_manager
.get_shutdown_listener(TaskName::LeaseKeepAlive);
.get_shutdown_listener(TaskName::LeaseKeepAlive)
.unwrap_or_else(|| unreachable!("task LeaseKeepAlive should exist"));
let endpoints = build_endpoints(leader_addrs, self.client_tls_config.as_ref())?;
let channel = tonic::transport::Channel::balance_list(endpoints.into_iter());
let mut lease_client = LeaseClient::new(channel);
Expand Down
8 changes: 6 additions & 2 deletions crates/xline/src/server/watch_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ mod test {
.return_const(-1_i64);
let watcher = Arc::new(mock_watcher);
let next_id = Arc::new(WatchIdGenerator::new(1));
let n = task_manager.get_shutdown_listener(TaskName::WatchTask);
let n = task_manager
.get_shutdown_listener(TaskName::WatchTask)
.unwrap_or_else(|| unreachable!("task WatchTask should exist"));
let handle = tokio::spawn(WatchServer::task(
next_id,
Arc::clone(&watcher),
Expand Down Expand Up @@ -729,7 +731,9 @@ mod test {
.return_const(-1_i64);
let watcher = Arc::new(mock_watcher);
let next_id = Arc::new(WatchIdGenerator::new(1));
let n = task_manager.get_shutdown_listener(TaskName::WatchTask);
let n = task_manager
.get_shutdown_listener(TaskName::WatchTask)
.unwrap_or_else(|| unreachable!("task WatchTask should exist"));
let handle = tokio::spawn(WatchServer::task(
next_id,
Arc::clone(&watcher),
Expand Down

0 comments on commit 7418d9f

Please sign in to comment.