diff --git a/crates/msr-core/src/realtime/worker/thread/mod.rs b/crates/msr-core/src/realtime/worker/thread/mod.rs index d89f67b..7a5a7dc 100644 --- a/crates/msr-core/src/realtime/worker/thread/mod.rs +++ b/crates/msr-core/src/realtime/worker/thread/mod.rs @@ -1,5 +1,9 @@ use std::{ any::Any, + sync::{ + atomic::{AtomicU8, Ordering}, + Arc, + }, thread::{self, JoinHandle}, }; @@ -14,12 +18,9 @@ pub enum State { #[default] Unknown, Starting, - Started, Running, - Suspended, - Resumed, + Suspending, Finishing, - Finished, Terminating, } @@ -35,17 +36,6 @@ impl State { } } -/// Emitted event -#[derive(Debug)] -pub enum Event { - /// The new state - StateChanged(State), -} - -pub trait EmitEvent { - fn emit_event(&mut self, event: Event); -} - /// Spawn parameters /// /// The parameters are passed into the worker thread when spawned @@ -54,16 +44,27 @@ pub trait EmitEvent { /// If joining the work thread fails these parameters will be lost /// inevitably! #[allow(missing_debug_implementations)] -pub struct Context { +pub struct Context { pub progress_hint_rx: ProgressHintReceiver, pub worker: W, pub environment: ::Environment, - pub emit_event: E, } #[derive(Debug)] -pub struct WorkerThread { - join_handle: JoinHandle>, +pub struct WorkerThread { + shared_state: Arc, + shared_state_load_ordering: Ordering, + join_handle: JoinHandle>, +} + +impl WorkerThread +where + W: Worker, +{ + #[must_use] + pub fn state(&self) -> State { + State::from_u8(self.shared_state.load(self.shared_state_load_ordering)).unwrap() + } } struct ThreadSchedulingScope { @@ -264,25 +265,25 @@ impl Drop for ThreadSchedulingScope { } } -fn thread_fn(thread_scheduling: ThreadScheduling, context: &mut Context) -> Result<()> +fn thread_fn( + context: &mut Context, + thread_scheduling: ThreadScheduling, + shared_state: Arc, + shared_state_store_ordering: Ordering, +) -> Result<()> where W: Worker, - E: EmitEvent, { let Context { progress_hint_rx, worker, environment, - emit_event, } = context; log::debug!("Starting"); - emit_event.emit_event(Event::StateChanged(State::Starting)); - + shared_state.store(State::Starting.to_u8(), shared_state_store_ordering); worker.start_working(environment)?; - log::debug!("Started"); - emit_event.emit_event(Event::StateChanged(State::Started)); let scheduling_scope = match thread_scheduling { ThreadScheduling::Default => None, @@ -291,7 +292,7 @@ where }; loop { log::debug!("Running"); - emit_event.emit_event(Event::StateChanged(State::Running)); + shared_state.store(State::Running.to_u8(), shared_state_store_ordering); match worker.perform_work(environment, progress_hint_rx)? { CompletionStatus::Suspending => { @@ -303,13 +304,10 @@ where continue; } - log::debug!("Suspended"); - emit_event.emit_event(Event::StateChanged(State::Suspended)); - + log::debug!("Suspending"); + shared_state.store(State::Suspending.to_u8(), shared_state_store_ordering); progress_hint_rx.wait_while_suspending(); - - log::debug!("Resumed"); - emit_event.emit_event(Event::StateChanged(State::Resumed)); + log::debug!("Resuming"); } CompletionStatus::Finishing => { // The worker may have decided to finish itself independent @@ -328,33 +326,30 @@ where } log::debug!("Finishing"); - emit_event.emit_event(Event::StateChanged(State::Finishing)); - + shared_state.store(State::Finishing.to_u8(), shared_state_store_ordering); worker.finish_working(environment)?; - log::debug!("Finished"); - emit_event.emit_event(Event::StateChanged(State::Finished)); log::debug!("Terminating"); - emit_event.emit_event(Event::StateChanged(State::Terminating)); + shared_state.store(State::Terminating.to_u8(), shared_state_store_ordering); Ok(()) } /// Outcome of [`WorkerThread::join()`] #[allow(missing_debug_implementations)] -pub struct TerminatedThread { +pub struct TerminatedThread { /// The result of the thread function pub result: Result<()>, /// The recovered parameters - pub context: Context, + pub context: Context, } /// Outcome of [`WorkerThread::join()`] #[allow(missing_debug_implementations)] -pub enum JoinedThread { - Terminated(TerminatedThread), +pub enum JoinedThread { + Terminated(TerminatedThread), JoinError(Box), } @@ -382,33 +377,68 @@ pub enum ThreadScheduling { RealtimeOrDefault, } -impl WorkerThread +#[derive(Debug, Clone, Copy)] +pub enum AtomicStateOrdering { + Relaxed, + AcquireRelease, +} + +impl WorkerThread where W: Worker + Send + 'static, ::Environment: Send + 'static, - E: EmitEvent + Send + 'static, { - pub fn spawn(thread_scheduling: ThreadScheduling, context: Context) -> Self { + pub fn spawn( + context: Context, + thread_scheduling: ThreadScheduling, + atomic_state_ordering: AtomicStateOrdering, + ) -> Self { + let (shared_state_load_ordering, shared_state_store_ordering) = match atomic_state_ordering + { + AtomicStateOrdering::Relaxed => (Ordering::Relaxed, Ordering::Relaxed), + AtomicStateOrdering::AcquireRelease => (Ordering::Acquire, Ordering::Release), + }; + let shared_state = Arc::new(AtomicU8::new(State::Unknown.to_u8())); let join_handle = { + let shared_state = Arc::clone(&shared_state); std::thread::spawn({ move || { // The function parameters need to be mutable within the real-time thread let mut context = context; - let result = thread_fn(thread_scheduling, &mut context); + let result = thread_fn( + &mut context, + thread_scheduling, + shared_state, + shared_state_store_ordering, + ); let context = context; TerminatedThread { result, context } } }) }; - Self { join_handle } + Self { + shared_state, + shared_state_load_ordering, + join_handle, + } } - pub fn join(self) -> JoinedThread { - let Self { join_handle } = self; - join_handle + pub fn join(self) -> JoinedThread { + let Self { + join_handle, + shared_state, + shared_state_load_ordering, + } = self; + log::debug!("Joining thread"); + let joined_thread = join_handle .join() .map(JoinedThread::Terminated) - .unwrap_or_else(JoinedThread::JoinError) + .unwrap_or_else(JoinedThread::JoinError); + debug_assert!( + shared_state_load_ordering == Ordering::Relaxed + || shared_state.load(shared_state_load_ordering) == State::Terminating.to_u8() + ); + joined_thread } } diff --git a/crates/msr-core/src/realtime/worker/thread/tests.rs b/crates/msr-core/src/realtime/worker/thread/tests.rs index 2436364..be9758d 100644 --- a/crates/msr-core/src/realtime/worker/thread/tests.rs +++ b/crates/msr-core/src/realtime/worker/thread/tests.rs @@ -1,5 +1,3 @@ -use std::sync::atomic::{AtomicUsize, Ordering}; - use crate::realtime::worker::progress::{ProgressHint, ProgressHintSender, SwitchProgressHintOk}; use super::*; @@ -57,107 +55,42 @@ impl Worker for SmokeTestWorker { } } -#[derive(Default)] -struct StateChangedCount { - starting: AtomicUsize, - started: AtomicUsize, - running: AtomicUsize, - suspended: AtomicUsize, - resumed: AtomicUsize, - finishing: AtomicUsize, - finished: AtomicUsize, - terminating: AtomicUsize, -} - -struct SmokeTestEvents { - progress_hint_tx: ProgressHintSender, - state_changed_count: StateChangedCount, -} - -impl SmokeTestEvents { - fn new(progress_hint_tx: ProgressHintSender) -> Self { - Self { - progress_hint_tx, - state_changed_count: Default::default(), - } - } - - fn on_event(&self, event: Event) { - match event { - Event::StateChanged(state) => match state { - State::Unknown => unreachable!(), - State::Starting => { - self.state_changed_count - .starting - .fetch_add(1, Ordering::SeqCst); - } - State::Started => { - self.state_changed_count - .started - .fetch_add(1, Ordering::SeqCst); - } - State::Running => { - self.state_changed_count - .running - .fetch_add(1, Ordering::SeqCst); - } - State::Suspended => { - self.state_changed_count - .suspended - .fetch_add(1, Ordering::SeqCst); - assert_eq!( - SwitchProgressHintOk::Accepted { - previous_state: ProgressHint::Suspend, - }, - self.progress_hint_tx.resume().expect("resuming") - ); - } - State::Resumed => { - self.state_changed_count - .resumed - .fetch_add(1, Ordering::SeqCst); - } - State::Finishing => { - self.state_changed_count - .finishing - .fetch_add(1, Ordering::SeqCst); - } - State::Finished => { - self.state_changed_count - .finished - .fetch_add(1, Ordering::SeqCst); - } - State::Terminating => { - self.state_changed_count - .terminating - .fetch_add(1, Ordering::SeqCst); - } - }, - } - } -} - -impl EmitEvent for SmokeTestEvents { - fn emit_event(&mut self, event: Event) { - self.on_event(event); - } -} - #[test] fn smoke_test() -> anyhow::Result<()> { for expected_perform_work_invocations in 1..10 { let worker = SmokeTestWorker::new(expected_perform_work_invocations); let progress_hint_rx = ProgressHintReceiver::default(); - let event_handler = SmokeTestEvents::new(ProgressHintSender::attach(&progress_hint_rx)); + let progress_hint_tx = ProgressHintSender::attach(&progress_hint_rx); let context = Context { progress_hint_rx, worker, environment: SmokeTestEnvironment, - emit_event: event_handler, }; // Real-time thread scheduling might not be supported when running the tests // in containers on CI platforms. - let worker_thread = WorkerThread::spawn(ThreadScheduling::Default, context); + let worker_thread = WorkerThread::spawn( + context, + ThreadScheduling::Default, + AtomicStateOrdering::AcquireRelease, + ); + let mut resume_accepted = 0; + loop { + match worker_thread.state() { + State::Starting | State::Finishing | State::Running | State::Unknown => (), + State::Suspending => match progress_hint_tx.resume() { + Ok(SwitchProgressHintOk::Accepted { .. }) => { + resume_accepted += 1; + } + // The worker thread might already have terminated itself, which in turn + // detaches our `ProgressHintSender`. + Ok(SwitchProgressHintOk::Ignored) | Err(_) => (), + }, + State::Terminating => { + // Exit loop + break; + } + } + } match worker_thread.join() { JoinedThread::Terminated(TerminatedThread { context: @@ -165,7 +98,6 @@ fn smoke_test() -> anyhow::Result<()> { progress_hint_rx: _, worker, environment: _, - emit_event: event_handler, }, result, }) => { @@ -176,62 +108,7 @@ fn smoke_test() -> anyhow::Result<()> { expected_perform_work_invocations, worker.actual_perform_work_invocations ); - assert_eq!( - 1, - event_handler - .state_changed_count - .starting - .load(Ordering::SeqCst) - ); - assert_eq!( - 1, - event_handler - .state_changed_count - .started - .load(Ordering::SeqCst) - ); - assert_eq!( - expected_perform_work_invocations, - event_handler - .state_changed_count - .running - .load(Ordering::SeqCst) - ); - assert_eq!( - expected_perform_work_invocations - 1, - event_handler - .state_changed_count - .suspended - .load(Ordering::SeqCst) - ); - assert_eq!( - expected_perform_work_invocations - 1, - event_handler - .state_changed_count - .resumed - .load(Ordering::SeqCst) - ); - assert_eq!( - 1, - event_handler - .state_changed_count - .finishing - .load(Ordering::SeqCst) - ); - assert_eq!( - 1, - event_handler - .state_changed_count - .finished - .load(Ordering::SeqCst) - ); - assert_eq!( - 1, - event_handler - .state_changed_count - .terminating - .load(Ordering::SeqCst) - ); + assert_eq!(expected_perform_work_invocations, resume_accepted + 1,); } JoinedThread::JoinError(err) => { return Err(anyhow::anyhow!("Failed to join worker thread: {:?}", err)) diff --git a/crates/msr-core/tests/cyclic_realtime_worker_timing.rs b/crates/msr-core/tests/cyclic_realtime_worker_timing.rs index f81e183..9d196b9 100644 --- a/crates/msr-core/tests/cyclic_realtime_worker_timing.rs +++ b/crates/msr-core/tests/cyclic_realtime_worker_timing.rs @@ -1,20 +1,13 @@ use std::{ fmt, - ops::Deref, - sync::{ - atomic::{AtomicU8, Ordering}, - mpsc, Arc, - }, + sync::mpsc, time::{Duration, Instant}, }; use msr_core::{ realtime::worker::{ progress::{ProgressHint, ProgressHintReceiver, ProgressHintSender}, - thread::{ - Context, EmitEvent, Event, JoinedThread, State, TerminatedThread, ThreadScheduling, - WorkerThread, - }, + thread::{Context, JoinedThread, State, TerminatedThread, ThreadScheduling, WorkerThread}, CompletionStatus, Worker, }, thread, @@ -23,53 +16,6 @@ use msr_core::{ #[derive(Default)] struct CyclicWorkerEnvironment; -struct CyclicWorkerEvents { - state: AtomicU8, -} - -impl CyclicWorkerEvents { - pub const fn new() -> Self { - Self { - state: AtomicU8::new(State::Unknown.to_u8()), - } - } - - pub fn last_state(&self) -> State { - State::from_u8(self.state.load(Ordering::Acquire)).expect("valid value") - } - - fn on_event(&self, event: Event) { - match event { - Event::StateChanged(state) => { - self.state.store(state.to_u8(), Ordering::Release); - } - } - } -} - -impl Default for CyclicWorkerEvents { - fn default() -> Self { - Self::new() - } -} - -#[derive(Clone, Default)] -struct SharedCyclicWorkerEvents(Arc); - -impl EmitEvent for SharedCyclicWorkerEvents { - fn emit_event(&mut self, event: Event) { - self.on_event(event); - } -} - -impl Deref for SharedCyclicWorkerEvents { - type Target = CyclicWorkerEvents; - - fn deref(&self) -> &Self::Target { - &self.0 - } -} - #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum CyclicWorkerTiming { /// Avoid being blocked by lower priority threads due to priority @@ -263,33 +209,29 @@ fn run_cyclic_worker(params: CyclicWorkerParams) -> anyhow::Result { + match worker_thread.state() { + State::Unknown | State::Starting | State::Running | State::Finishing => { // These (intermediate) states might not be visible when reading // the last state at arbitrary times from an atomic and cannot // be used for controlling the control flow of the test! } - State::Suspended => { + State::Suspending => { assert!(resumed_count <= suspended_count); if resumed_count < suspended_count { progress_hint_tx.resume().expect("resumed"); @@ -349,7 +291,6 @@ fn run_cyclic_worker(params: CyclicWorkerParams) -> anyhow::Result result