diff --git a/src/errors.rs b/src/errors.rs index 693ec98..54f1ff1 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -4,20 +4,21 @@ use std::sync::Arc; use miette::Diagnostic; use thiserror::Error; +use tokio::sync::mpsc; use crate::ErrTypeTraits; /// This enum contains all the possible errors that could be returned /// by [`handle_shutdown_requests()`](crate::Toplevel::handle_shutdown_requests). -#[derive(Error, Debug, Diagnostic)] +#[derive(Debug, Error, Diagnostic)] pub enum GracefulShutdownError { /// At least one subsystem caused an error. - #[error("at least one subsystem returned an error")] #[diagnostic(code(graceful_shutdown::failed))] + #[error("at least one subsystem returned an error")] SubsystemsFailed(#[related] Box<[SubsystemError]>), /// The shutdown did not finish within the given timeout. - #[error("shutdown timed out")] #[diagnostic(code(graceful_shutdown::timeout))] + #[error("shutdown timed out")] ShutdownTimeout(#[related] Box<[SubsystemError]>), } @@ -124,7 +125,34 @@ impl SubsystemError { /// [`cancel_on_shutdown()`](crate::FutureExt::cancel_on_shutdown). #[derive(Error, Debug, Diagnostic)] #[error("A shutdown request caused this task to be cancelled")] +#[diagnostic(code(graceful_shutdown::future::cancelled_by_shutdown))] pub struct CancelledByShutdown; +// This function contains code that stems from the principle +// of defensive coding - meaning, handle potential errors +// gracefully, even if they should not happen. +// Therefore it is in this special function, so we don't +// get coverage problems. +pub(crate) fn handle_dropped_error( + result: Result<(), mpsc::error::SendError>, +) { + if let Err(mpsc::error::SendError(e)) = result { + tracing::warn!("An error got dropped: {e:?}"); + } +} + +// This function contains code that stems from the principle +// of defensive coding - meaning, handle potential errors +// gracefully, even if they should not happen. +// Therefore it is in this special function, so we don't +// get coverage problems. +pub(crate) fn handle_unhandled_stopreason( + maybe_stop_reason: Option>, +) { + if let Some(stop_reason) = maybe_stop_reason { + tracing::warn!("Unhandled stop reason: {:?}", stop_reason); + } +} + #[cfg(test)] mod tests; diff --git a/src/errors/tests.rs b/src/errors/tests.rs index ac86457..79efd75 100644 --- a/src/errors/tests.rs +++ b/src/errors/tests.rs @@ -1,8 +1,18 @@ +use tracing_test::traced_test; + use crate::BoxedError; use super::*; -fn examine_report(report: miette::Report) { +fn examine_report( + error: impl miette::Diagnostic + std::error::Error + std::fmt::Debug + Sync + Send + 'static, +) { + println!("{}", error); + println!("{:?}", error); + println!("{:?}", error.source()); + println!("{}", error.code().unwrap()); + // Convert to report + let report: miette::Report = error.into(); println!("{}", report); println!("{:?}", report); // Convert to std::error::Error @@ -13,14 +23,21 @@ fn examine_report(report: miette::Report) { #[test] fn errors_can_be_converted_to_diagnostic() { - examine_report(GracefulShutdownError::ShutdownTimeout::(Box::new([])).into()); - examine_report(GracefulShutdownError::SubsystemsFailed::(Box::new([])).into()); - examine_report(SubsystemJoinError::SubsystemsFailed::(Arc::new([])).into()); - examine_report(SubsystemError::Panicked::("".into()).into()); - examine_report( - SubsystemError::Failed::("".into(), SubsystemFailure("".into())).into(), - ); - examine_report(CancelledByShutdown.into()); + examine_report(GracefulShutdownError::ShutdownTimeout::( + Box::new([]), + )); + examine_report(GracefulShutdownError::SubsystemsFailed::( + Box::new([]), + )); + examine_report(SubsystemJoinError::SubsystemsFailed::( + Arc::new([]), + )); + examine_report(SubsystemError::Panicked::("".into())); + examine_report(SubsystemError::Failed::( + "".into(), + SubsystemFailure("".into()), + )); + examine_report(CancelledByShutdown); } #[test] @@ -61,3 +78,23 @@ fn extract_contained_error_from_convert_subsystem_failure() { assert_eq!(msg, *failure); assert_eq!(msg, failure.into_error()); } + +#[test] +#[traced_test] +fn handle_dropped_errors() { + handle_dropped_error(Err(mpsc::error::SendError(BoxedError::from(String::from( + "ABC", + ))))); + + assert!(logs_contain("An error got dropped: \"ABC\"")); +} + +#[test] +#[traced_test] +fn handle_unhandled_stopreasons() { + handle_unhandled_stopreason(Some(SubsystemError::::Panicked(Arc::from( + "def", + )))); + + assert!(logs_contain("Unhandled stop reason: Panicked(\"def\")")); +} diff --git a/src/runner.rs b/src/runner.rs index 1354057..837465c 100644 --- a/src/runner.rs +++ b/src/runner.rs @@ -75,15 +75,10 @@ async fn run_subsystem( Ok(Ok(())) => None, Ok(Err(e)) => Some(SubsystemError::Failed(name, SubsystemFailure(e))), Err(e) => { - if e.is_panic() { - Some(SubsystemError::Panicked(name)) - } else { - // Don't do anything in case of a cancellation; - // cancellations can't be forwarded (because the - // current function we are in will be cancelled - // simultaneously) - None - } + // We can assume that this is a panic, because a cancellation + // can never happen as long as we still hold `guard`. + assert!(e.is_panic()); + Some(SubsystemError::Panicked(name)) } }; @@ -95,7 +90,10 @@ async fn run_subsystem( // It is still important that the handle does not leak out of the subsystem. let subsystem_handle = match redirected_subsystem_handle.try_recv() { Ok(s) => s, - Err(_) => panic!("The SubsystemHandle object must not be leaked out of the subsystem!"), + Err(_) => { + tracing::error!("The SubsystemHandle object must not be leaked out of the subsystem!"); + panic!("The SubsystemHandle object must not be leaked out of the subsystem!"); + } }; // Raise potential errors diff --git a/src/subsystem/subsystem_handle.rs b/src/subsystem/subsystem_handle.rs index 52a5f2f..0cc4c99 100644 --- a/src/subsystem/subsystem_handle.rs +++ b/src/subsystem/subsystem_handle.rs @@ -9,7 +9,7 @@ use tokio::sync::{mpsc, oneshot}; use tokio_util::sync::CancellationToken; use crate::{ - errors::SubsystemError, + errors::{handle_dropped_error, SubsystemError}, runner::{AliveGuard, SubsystemRunner}, utils::{remote_drop_collection::RemotelyDroppableItems, JoinerToken}, BoxedError, ErrTypeTraits, ErrorAction, NestedSubsystem, SubsystemBuilder, @@ -124,9 +124,7 @@ impl SubsystemHandle { match error_action { ErrorAction::Forward => Some(e), ErrorAction::CatchAndLocalShutdown => { - if let Err(mpsc::error::SendError(e)) = error_sender.send(e) { - tracing::warn!("An error got dropped: {e:?}"); - }; + handle_dropped_error(error_sender.send(e)); cancellation_token.cancel(); None } @@ -167,7 +165,7 @@ impl SubsystemHandle { } /// Waits until all the children of this subsystem are finished. - pub async fn wait_for_children(&mut self) { + pub async fn wait_for_children(&self) { self.inner.joiner_token.join_children().await } diff --git a/src/toplevel.rs b/src/toplevel.rs index bcd52a0..3476ae7 100644 --- a/src/toplevel.rs +++ b/src/toplevel.rs @@ -5,7 +5,7 @@ use tokio::sync::mpsc; use tokio_util::sync::CancellationToken; use crate::{ - errors::{GracefulShutdownError, SubsystemError}, + errors::{handle_dropped_error, GracefulShutdownError, SubsystemError}, signal_handling::wait_for_signal, subsystem::{self, ErrorActions}, BoxedError, ErrTypeTraits, ErrorAction, NestedSubsystem, SubsystemHandle, @@ -74,9 +74,7 @@ impl Toplevel { } }; - if let Err(mpsc::error::SendError(e)) = error_sender.send(e) { - tracing::warn!("An error got dropped: {e:?}"); - }; + handle_dropped_error(error_sender.send(e)); }); let toplevel_subsys = root_handle.start_with_abs_name( @@ -181,7 +179,12 @@ impl Toplevel { ); match tokio::time::timeout(shutdown_timeout, self.toplevel_subsys.join()).await { - Ok(Ok(())) => { + Ok(result) => { + // An `Err` here would indicate a programming error, + // because the toplevel subsys doesn't catch any errors; + // it only forwards them. + assert!(result.is_ok()); + let errors = collect_errors(); if errors.is_empty() { tracing::info!("Shutdown finished."); @@ -191,10 +194,6 @@ impl Toplevel { Err(GracefulShutdownError::SubsystemsFailed(errors)) } } - Ok(Err(_)) => { - // This can't happen because the toplevel subsys doesn't catch any errors; it only forwards them. - unreachable!(); - } Err(_) => { tracing::error!("Shutdown timed out!"); Err(GracefulShutdownError::ShutdownTimeout(collect_errors())) diff --git a/src/utils/joiner_token.rs b/src/utils/joiner_token.rs index 6b672ea..4973791 100644 --- a/src/utils/joiner_token.rs +++ b/src/utils/joiner_token.rs @@ -2,7 +2,10 @@ use std::{fmt::Debug, sync::Arc}; use tokio::sync::watch; -use crate::{errors::SubsystemError, ErrTypeTraits}; +use crate::{ + errors::{handle_unhandled_stopreason, SubsystemError}, + ErrTypeTraits, +}; struct Inner { counter: watch::Sender<(bool, u32)>, @@ -67,9 +70,7 @@ impl JoinerToken { (Self { inner }, weak_ref) } - // Requires `mut` access to prevent children from being spawned - // while waiting - pub(crate) async fn join_children(&mut self) { + pub(crate) async fn join_children(&self) { let mut subscriber = self.inner.counter.subscribe(); // Ignore errors; if the channel got closed, that definitely means @@ -126,9 +127,7 @@ impl JoinerToken { maybe_parent = parent.parent.as_ref(); } - if let Some(stop_reason) = maybe_stop_reason { - tracing::warn!("Unhandled stop reason: {:?}", stop_reason); - } + handle_unhandled_stopreason(maybe_stop_reason); } pub(crate) fn downgrade(self) -> JoinerTokenRef { diff --git a/src/utils/joiner_token/tests.rs b/src/utils/joiner_token/tests.rs index 03ba7c9..c63c3b8 100644 --- a/src/utils/joiner_token/tests.rs +++ b/src/utils/joiner_token/tests.rs @@ -116,7 +116,7 @@ fn counters_weak() { async fn join() { let (superroot, _) = JoinerToken::::new(|_| None); - let (mut root, _) = superroot.child_token(|_| None); + let (root, _) = superroot.child_token(|_| None); let (child1, _) = root.child_token(|_| None); let (child2, _) = child1.child_token(|_| None); diff --git a/src/utils/remote_drop_collection.rs b/src/utils/remote_drop_collection.rs index ad88612..c149496 100644 --- a/src/utils/remote_drop_collection.rs +++ b/src/utils/remote_drop_collection.rs @@ -62,18 +62,22 @@ impl Drop for RemoteDrop { // Important: lock first, then read the offset. let mut data = data.lock().unwrap(); - if let Some(offset) = self.offset.upgrade() { - let offset = offset.load(Ordering::Acquire); + let offset = self + .offset + .upgrade() + .expect("Trying to delete non-existent item! Please report this.") + .load(Ordering::Acquire); - if let Some(last_item) = data.pop() { - if offset != data.len() { - // There must have been at least two items, and we are not at the end. - // So swap first before dropping. + let last_item = data + .pop() + .expect("Trying to delete non-existent item! Please report this."); - last_item.offset.store(offset, Ordering::Release); - data[offset] = last_item; - } - } + if offset != data.len() { + // There must have been at least two items, and we are not at the end. + // So swap first before dropping. + + last_item.offset.store(offset, Ordering::Release); + data[offset] = last_item; } } } diff --git a/src/utils/remote_drop_collection/tests.rs b/src/utils/remote_drop_collection/tests.rs index 45a194b..ef770be 100644 --- a/src/utils/remote_drop_collection/tests.rs +++ b/src/utils/remote_drop_collection/tests.rs @@ -1,6 +1,20 @@ use super::*; use crate::{utils::JoinerToken, BoxedError}; +#[test] +fn single_item() { + let items = RemotelyDroppableItems::new(); + + let (count1, _) = JoinerToken::::new(|_| None); + assert_eq!(0, count1.count()); + + let token1 = items.insert(count1.child_token(|_| None)); + assert_eq!(1, count1.count()); + + drop(token1); + assert_eq!(0, count1.count()); +} + #[test] fn insert_and_drop() { let items = RemotelyDroppableItems::new(); diff --git a/tests/integration_test_2.rs b/tests/integration_test_2.rs new file mode 100644 index 0000000..28eff68 --- /dev/null +++ b/tests/integration_test_2.rs @@ -0,0 +1,163 @@ +use tokio::time::{sleep, Duration}; +use tokio_graceful_shutdown::{SubsystemBuilder, SubsystemHandle, Toplevel}; +use tracing_test::traced_test; + +pub mod common; + +use std::{ + error::Error, + sync::{Arc, Mutex}, +}; + +use crate::common::Event; + +/// Wrapper function to simplify lambdas +type BoxedError = Box; +type BoxedResult = Result<(), BoxedError>; + +#[tokio::test] +#[traced_test] +async fn leak_subsystem_handle() { + let subsys_ext: Arc>> = Default::default(); + let subsys_ext2 = Arc::clone(&subsys_ext); + + let subsystem = move |subsys: SubsystemHandle| async move { + subsys.on_shutdown_requested().await; + + *subsys_ext2.lock().unwrap() = Some(subsys); + + BoxedResult::Ok(()) + }; + + let toplevel = Toplevel::new(move |s| async move { + s.start(SubsystemBuilder::new("subsys", subsystem)); + + sleep(Duration::from_millis(100)).await; + s.request_shutdown(); + }); + + let result = toplevel + .handle_shutdown_requests(Duration::from_millis(100)) + .await; + assert!(result.is_err()); + assert!(logs_contain( + "The SubsystemHandle object must not be leaked out of the subsystem!" + )); +} + +#[tokio::test] +#[traced_test] +async fn wait_for_children() { + let (nested1_started, set_nested1_started) = Event::create(); + let (nested1_finished, set_nested1_finished) = Event::create(); + let (nested2_started, set_nested2_started) = Event::create(); + let (nested2_finished, set_nested2_finished) = Event::create(); + + let nested_subsys2 = move |subsys: SubsystemHandle| async move { + set_nested2_started(); + subsys.on_shutdown_requested().await; + sleep(Duration::from_millis(100)).await; + set_nested2_finished(); + BoxedResult::Ok(()) + }; + + let nested_subsys1 = move |subsys: SubsystemHandle| async move { + subsys.start(SubsystemBuilder::new("nested2", nested_subsys2)); + set_nested1_started(); + subsys.on_shutdown_requested().await; + sleep(Duration::from_millis(100)).await; + set_nested1_finished(); + BoxedResult::Ok(()) + }; + + let subsys1 = move |subsys: SubsystemHandle| async move { + subsys.start(SubsystemBuilder::new("nested1", nested_subsys1)); + + sleep(Duration::from_millis(100)).await; + + subsys.request_shutdown(); + + assert!(nested1_started.get()); + assert!(!nested1_finished.get()); + assert!(nested2_started.get()); + assert!(!nested2_finished.get()); + + subsys.wait_for_children().await; + + assert!(nested1_finished.get()); + assert!(nested2_finished.get()); + + BoxedResult::Ok(()) + }; + + Toplevel::new(|s| async move { + s.start(SubsystemBuilder::new("subsys", subsys1)); + }) + .handle_shutdown_requests(Duration::from_millis(500)) + .await + .unwrap(); +} + +#[tokio::test] +#[traced_test] +async fn request_local_shutdown() { + let (nested1_started, set_nested1_started) = Event::create(); + let (nested1_finished, set_nested1_finished) = Event::create(); + let (nested2_started, set_nested2_started) = Event::create(); + let (nested2_finished, set_nested2_finished) = Event::create(); + let (global_finished, set_global_finished) = Event::create(); + + let nested_subsys2 = move |subsys: SubsystemHandle| async move { + set_nested2_started(); + subsys.on_shutdown_requested().await; + set_nested2_finished(); + BoxedResult::Ok(()) + }; + + let nested_subsys1 = move |subsys: SubsystemHandle| async move { + subsys.start(SubsystemBuilder::new("nested2", nested_subsys2)); + set_nested1_started(); + subsys.on_shutdown_requested().await; + set_nested1_finished(); + BoxedResult::Ok(()) + }; + + let subsys1 = move |subsys: SubsystemHandle| async move { + subsys.start(SubsystemBuilder::new("nested1", nested_subsys1)); + + sleep(Duration::from_millis(100)).await; + + assert!(nested1_started.get()); + assert!(!nested1_finished.get()); + assert!(nested2_started.get()); + assert!(!nested2_finished.get()); + assert!(!global_finished.get()); + assert!(!subsys.is_shutdown_requested()); + + subsys.request_local_shutdown(); + sleep(Duration::from_millis(200)).await; + + assert!(nested1_finished.get()); + assert!(nested2_finished.get()); + assert!(!global_finished.get()); + assert!(subsys.is_shutdown_requested()); + + subsys.request_shutdown(); + sleep(Duration::from_millis(50)).await; + + assert!(global_finished.get()); + assert!(subsys.is_shutdown_requested()); + + BoxedResult::Ok(()) + }; + + Toplevel::new(move |s| async move { + s.start(SubsystemBuilder::new("subsys", subsys1)); + + s.on_shutdown_requested().await; + set_global_finished(); + }) + .handle_shutdown_requests(Duration::from_millis(100)) + .await + .unwrap(); +}