From 13b92804de6b3374d2aad4465730a8422cb312a7 Mon Sep 17 00:00:00 2001 From: Finomnis Date: Wed, 7 Feb 2024 11:04:15 +0100 Subject: [PATCH] Add `.finished()` to `NestedSubsystem`, add sequential shutdown example --- examples/19_sequential_shutdown.rs | 126 +++++++++++++++++++++ src/future_ext.rs | 2 +- src/lib.rs | 1 + src/subsystem/mod.rs | 13 ++- src/subsystem/nested_subsystem.rs | 14 ++- src/subsystem/subsystem_finished_future.rs | 25 ++++ src/utils/joiner_token.rs | 1 + tests/integration_test_2.rs | 39 ++++++- 8 files changed, 215 insertions(+), 6 deletions(-) create mode 100644 examples/19_sequential_shutdown.rs create mode 100644 src/subsystem/subsystem_finished_future.rs diff --git a/examples/19_sequential_shutdown.rs b/examples/19_sequential_shutdown.rs new file mode 100644 index 0000000..ca0da8d --- /dev/null +++ b/examples/19_sequential_shutdown.rs @@ -0,0 +1,126 @@ +//! This example demonstrates how multiple subsystems could be shut down sequentially. +//! +//! When a shutdown gets triggered (via Ctrl+C), Nested1 will shutdown first, +//! followed by Nested2 and Nested3. Only once the previous subsystem is finished shutting down, +//! the next subsystem will follow. + +use miette::Result; +use tokio::time::{sleep, Duration}; +use tokio_graceful_shutdown::{ + FutureExt, SubsystemBuilder, SubsystemFinishedFuture, SubsystemHandle, Toplevel, +}; + +async fn counter(id: &str) { + let mut i = 0; + loop { + tracing::info!("{id}: {i}"); + i += 1; + sleep(Duration::from_millis(50)).await; + } +} + +async fn nested1(subsys: SubsystemHandle) -> Result<()> { + tracing::info!("Nested1 started."); + if counter("Nested1").cancel_on_shutdown(&subsys).await.is_ok() { + tracing::info!("Nested1 counter finished."); + } else { + tracing::info!("Nested1 shutting down ..."); + sleep(Duration::from_millis(200)).await; + } + subsys.on_shutdown_requested().await; + tracing::info!("Nested1 stopped."); + Ok(()) +} + +async fn nested2(subsys: SubsystemHandle, nested1_finished: SubsystemFinishedFuture) -> Result<()> { + // Create a future that triggers once nested1 is finished **and** a shutdown is requested + let shutdown = { + let shutdown_requested = subsys.on_shutdown_requested(); + async move { + tokio::join!(shutdown_requested, nested1_finished); + } + }; + + tracing::info!("Nested2 started."); + tokio::select! { + _ = shutdown => { + tracing::info!("Nested2 shutting down ..."); + sleep(Duration::from_millis(200)).await; + } + _ = counter("Nested2") => { + tracing::info!("Nested2 counter finished."); + } + } + + tracing::info!("Nested2 stopped."); + Ok(()) +} + +async fn nested3(subsys: SubsystemHandle, nested2_finished: SubsystemFinishedFuture) -> Result<()> { + // Create a future that triggers once nested2 is finished **and** a shutdown is requested + let shutdown = { + // This is an alternative to `on_shutdown_requested()` (as shown in nested2). + // Use this if `on_shutdown_requested()` gives you lifetime issues. + let cancellation_token = subsys.create_cancellation_token(); + async move { + tokio::join!(cancellation_token.cancelled(), nested2_finished); + } + }; + + tracing::info!("Nested3 started."); + tokio::select! { + _ = shutdown => { + tracing::info!("Nested3 shutting down ..."); + sleep(Duration::from_millis(200)).await; + } + _ = counter("Nested3") => { + tracing::info!("Nested3 counter finished."); + } + } + + tracing::info!("Nested3 stopped."); + Ok(()) +} + +async fn root(subsys: SubsystemHandle) -> Result<()> { + // This subsystem shuts down the nested subsystem after 5 seconds. + tracing::info!("Root started."); + + tracing::info!("Starting nested subsystems ..."); + let nested1 = subsys.start(SubsystemBuilder::new("Nested1", nested1)); + let nested1_finished = nested1.finished(); + let nested2 = subsys.start(SubsystemBuilder::new("Nested2", |s| { + nested2(s, nested1_finished) + })); + let nested2_finished = nested2.finished(); + subsys.start(SubsystemBuilder::new("Nested3", |s| { + nested3(s, nested2_finished) + })); + tracing::info!("Nested subsystems started."); + + // Wait for all children to finish shutting down. + subsys.wait_for_children().await; + + tracing::info!("All children finished, stopping Root ..."); + sleep(Duration::from_millis(200)).await; + tracing::info!("Root stopped."); + + Ok(()) +} + +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<()> { + // Init logging + tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .init(); + + // Setup and execute subsystem tree + Toplevel::new(|s| async move { + s.start(SubsystemBuilder::new("Root", root)); + }) + .catch_signals() + .handle_shutdown_requests(Duration::from_millis(1000)) + .await + .map_err(Into::into) +} diff --git a/src/future_ext.rs b/src/future_ext.rs index f3c2ee7..310d4e1 100644 --- a/src/future_ext.rs +++ b/src/future_ext.rs @@ -5,7 +5,7 @@ use pin_project_lite::pin_project; use tokio_util::sync::WaitForCancellationFuture; pin_project! { - /// A Future that is resolved once the corresponding task is finished + /// A future that is resolved once the corresponding task is finished /// or a shutdown is initiated. #[must_use = "futures do nothing unless polled"] pub struct CancelOnShutdownFuture<'a, T: std::future::Future>{ diff --git a/src/lib.rs b/src/lib.rs index e1d4a89..9ecc147 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -121,5 +121,6 @@ pub use future_ext::FutureExt; pub use into_subsystem::IntoSubsystem; pub use subsystem::NestedSubsystem; pub use subsystem::SubsystemBuilder; +pub use subsystem::SubsystemFinishedFuture; pub use subsystem::SubsystemHandle; pub use toplevel::Toplevel; diff --git a/src/subsystem/mod.rs b/src/subsystem/mod.rs index 7073c1e..086637a 100644 --- a/src/subsystem/mod.rs +++ b/src/subsystem/mod.rs @@ -1,9 +1,14 @@ mod error_collector; mod nested_subsystem; mod subsystem_builder; +mod subsystem_finished_future; mod subsystem_handle; -use std::sync::{Arc, Mutex}; +use std::{ + future::Future, + pin::Pin, + sync::{Arc, Mutex}, +}; pub use subsystem_builder::SubsystemBuilder; pub use subsystem_handle::SubsystemHandle; @@ -35,3 +40,9 @@ pub(crate) struct ErrorActions { pub(crate) on_failure: Atomic, pub(crate) on_panic: Atomic, } + +/// A future that is resolved once the corresponding subsystem is finished. +#[must_use = "futures do nothing unless polled"] +pub struct SubsystemFinishedFuture { + future: Pin + Send + Sync>>, +} diff --git a/src/subsystem/nested_subsystem.rs b/src/subsystem/nested_subsystem.rs index cfc5378..eaf5fc9 100644 --- a/src/subsystem/nested_subsystem.rs +++ b/src/subsystem/nested_subsystem.rs @@ -2,7 +2,7 @@ use std::sync::atomic::Ordering; use crate::{errors::SubsystemJoinError, ErrTypeTraits, ErrorAction}; -use super::NestedSubsystem; +use super::{NestedSubsystem, SubsystemFinishedFuture}; impl NestedSubsystem { /// Wait for the subsystem to be finished. @@ -68,7 +68,7 @@ impl NestedSubsystem { /// Changes the way this subsystem should react to failures, /// meaning if it or one of its children returns an `Err` value. /// - /// For more information, see [ErrorAction]. + /// For more information, see [`ErrorAction`]. pub fn change_failure_action(&self, action: ErrorAction) { self.error_actions .on_failure @@ -78,8 +78,16 @@ impl NestedSubsystem { /// Changes the way this subsystem should react if it or one /// of its children panic. /// - /// For more information, see [ErrorAction]. + /// For more information, see [`ErrorAction`]. pub fn change_panic_action(&self, action: ErrorAction) { self.error_actions.on_panic.store(action, Ordering::Relaxed); } + + /// Returns a future that resolves once the subsystem is finished. + /// + /// Similar to [`join`](NestedSubsystem::join), but more light-weight + /// as does not perform any error handling. + pub fn finished(&self) -> SubsystemFinishedFuture { + SubsystemFinishedFuture::new(self.joiner.clone()) + } } diff --git a/src/subsystem/subsystem_finished_future.rs b/src/subsystem/subsystem_finished_future.rs new file mode 100644 index 0000000..4008395 --- /dev/null +++ b/src/subsystem/subsystem_finished_future.rs @@ -0,0 +1,25 @@ +use std::{ + future::Future, + pin::Pin, + task::{Context, Poll}, +}; + +use crate::utils::JoinerTokenRef; + +use super::SubsystemFinishedFuture; + +impl SubsystemFinishedFuture { + pub(crate) fn new(joiner: JoinerTokenRef) -> Self { + Self { + future: Box::pin(async move { joiner.join().await }), + } + } +} + +impl Future for SubsystemFinishedFuture { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + self.future.as_mut().poll(cx) + } +} diff --git a/src/utils/joiner_token.rs b/src/utils/joiner_token.rs index 4973791..3561d5d 100644 --- a/src/utils/joiner_token.rs +++ b/src/utils/joiner_token.rs @@ -20,6 +20,7 @@ pub(crate) struct JoinerToken { /// A reference version that does not keep the content alive; purely for /// joining the subtree. +#[derive(Clone)] pub(crate) struct JoinerTokenRef { counter: watch::Receiver<(bool, u32)>, } diff --git a/tests/integration_test_2.rs b/tests/integration_test_2.rs index 4d67c98..47b0802 100644 --- a/tests/integration_test_2.rs +++ b/tests/integration_test_2.rs @@ -6,7 +6,10 @@ pub mod common; use std::{ error::Error, - sync::{Arc, Mutex}, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, Mutex, + }, }; use crate::common::Event; @@ -247,3 +250,37 @@ async fn cancellation_token_does_not_propagate_up() { .await; assert!(result.is_ok()); } + +#[tokio::test] +#[traced_test] +async fn subsystem_finished_works_correctly() { + let subsystem = |subsys: SubsystemHandle| async move { + subsys.on_shutdown_requested().await; + BoxedResult::Ok(()) + }; + + let toplevel = Toplevel::new(move |s| async move { + let nested = s.start(SubsystemBuilder::new("subsys", subsystem)); + let nested_finished = nested.finished(); + + let is_finished = AtomicBool::new(false); + tokio::join!( + async { + nested_finished.await; + is_finished.store(true, Ordering::Release); + }, + async { + sleep(Duration::from_millis(20)).await; + assert!(!is_finished.load(Ordering::Acquire)); + nested.initiate_shutdown(); + sleep(Duration::from_millis(20)).await; + assert!(is_finished.load(Ordering::Acquire)); + } + ); + }); + + let result = toplevel + .handle_shutdown_requests(Duration::from_millis(400)) + .await; + assert!(result.is_ok()); +}