Skip to content

Commit

Permalink
RT Worker: Replace event callback with shared state
Browse files Browse the repository at this point in the history
  • Loading branch information
uklotzde committed Oct 7, 2022
1 parent bad4098 commit 67844fc
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 268 deletions.
132 changes: 81 additions & 51 deletions crates/msr-core/src/realtime/worker/thread/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
use std::{
any::Any,
sync::{
atomic::{AtomicU8, Ordering},
Arc,
},
thread::{self, JoinHandle},
};

Expand All @@ -14,12 +18,9 @@ pub enum State {
#[default]
Unknown,
Starting,
Started,
Running,
Suspended,
Resumed,
Suspending,
Finishing,
Finished,
Terminating,
}

Expand All @@ -35,17 +36,6 @@ impl State {
}
}

/// Emitted event
#[derive(Debug)]
pub enum Event {
/// The new state
StateChanged(State),
}

pub trait EmitEvent {
fn emit_event(&mut self, event: Event);
}

/// Spawn parameters
///
/// The parameters are passed into the worker thread when spawned
Expand All @@ -54,16 +44,27 @@ pub trait EmitEvent {
/// If joining the work thread fails these parameters will be lost
/// inevitably!
#[allow(missing_debug_implementations)]
pub struct Context<W: Worker, E: EmitEvent> {
pub struct Context<W: Worker> {
pub progress_hint_rx: ProgressHintReceiver,
pub worker: W,
pub environment: <W as Worker>::Environment,
pub emit_event: E,
}

#[derive(Debug)]
pub struct WorkerThread<W: Worker, E: EmitEvent> {
join_handle: JoinHandle<TerminatedThread<W, E>>,
pub struct WorkerThread<W: Worker> {
shared_state: Arc<AtomicU8>,
shared_state_load_ordering: Ordering,
join_handle: JoinHandle<TerminatedThread<W>>,
}

impl<W> WorkerThread<W>
where
W: Worker,
{
#[must_use]
pub fn state(&self) -> State {
State::from_u8(self.shared_state.load(self.shared_state_load_ordering)).unwrap()
}
}

struct ThreadSchedulingScope {
Expand Down Expand Up @@ -264,25 +265,25 @@ impl Drop for ThreadSchedulingScope {
}
}

fn thread_fn<W, E>(thread_scheduling: ThreadScheduling, context: &mut Context<W, E>) -> Result<()>
fn thread_fn<W>(
context: &mut Context<W>,
thread_scheduling: ThreadScheduling,
shared_state: Arc<AtomicU8>,
shared_state_store_ordering: Ordering,
) -> Result<()>
where
W: Worker,
E: EmitEvent,
{
let Context {
progress_hint_rx,
worker,
environment,
emit_event,
} = context;

log::debug!("Starting");
emit_event.emit_event(Event::StateChanged(State::Starting));

shared_state.store(State::Starting.to_u8(), shared_state_store_ordering);
worker.start_working(environment)?;

log::debug!("Started");
emit_event.emit_event(Event::StateChanged(State::Started));

let scheduling_scope = match thread_scheduling {
ThreadScheduling::Default => None,
Expand All @@ -291,7 +292,7 @@ where
};
loop {
log::debug!("Running");
emit_event.emit_event(Event::StateChanged(State::Running));
shared_state.store(State::Running.to_u8(), shared_state_store_ordering);

match worker.perform_work(environment, progress_hint_rx)? {
CompletionStatus::Suspending => {
Expand All @@ -303,13 +304,10 @@ where
continue;
}

log::debug!("Suspended");
emit_event.emit_event(Event::StateChanged(State::Suspended));

log::debug!("Suspending");
shared_state.store(State::Suspending.to_u8(), shared_state_store_ordering);
progress_hint_rx.wait_while_suspending();

log::debug!("Resumed");
emit_event.emit_event(Event::StateChanged(State::Resumed));
log::debug!("Resuming");
}
CompletionStatus::Finishing => {
// The worker may have decided to finish itself independent
Expand All @@ -328,33 +326,30 @@ where
}

log::debug!("Finishing");
emit_event.emit_event(Event::StateChanged(State::Finishing));

shared_state.store(State::Finishing.to_u8(), shared_state_store_ordering);
worker.finish_working(environment)?;

log::debug!("Finished");
emit_event.emit_event(Event::StateChanged(State::Finished));

log::debug!("Terminating");
emit_event.emit_event(Event::StateChanged(State::Terminating));
shared_state.store(State::Terminating.to_u8(), shared_state_store_ordering);

Ok(())
}

/// Outcome of [`WorkerThread::join()`]
#[allow(missing_debug_implementations)]
pub struct TerminatedThread<W: Worker, E: EmitEvent> {
pub struct TerminatedThread<W: Worker> {
/// The result of the thread function
pub result: Result<()>,

/// The recovered parameters
pub context: Context<W, E>,
pub context: Context<W>,
}

/// Outcome of [`WorkerThread::join()`]
#[allow(missing_debug_implementations)]
pub enum JoinedThread<W: Worker, E: EmitEvent> {
Terminated(TerminatedThread<W, E>),
pub enum JoinedThread<W: Worker> {
Terminated(TerminatedThread<W>),
JoinError(Box<dyn Any + Send + 'static>),
}

Expand Down Expand Up @@ -382,33 +377,68 @@ pub enum ThreadScheduling {
RealtimeOrDefault,
}

impl<W, E> WorkerThread<W, E>
#[derive(Debug, Clone, Copy)]
pub enum AtomicStateOrdering {
Relaxed,
AcquireRelease,
}

impl<W> WorkerThread<W>
where
W: Worker + Send + 'static,
<W as Worker>::Environment: Send + 'static,
E: EmitEvent + Send + 'static,
{
pub fn spawn(thread_scheduling: ThreadScheduling, context: Context<W, E>) -> Self {
pub fn spawn(
context: Context<W>,
thread_scheduling: ThreadScheduling,
atomic_state_ordering: AtomicStateOrdering,
) -> Self {
let (shared_state_load_ordering, shared_state_store_ordering) = match atomic_state_ordering
{
AtomicStateOrdering::Relaxed => (Ordering::Relaxed, Ordering::Relaxed),
AtomicStateOrdering::AcquireRelease => (Ordering::Acquire, Ordering::Release),
};
let shared_state = Arc::new(AtomicU8::new(State::Unknown.to_u8()));
let join_handle = {
let shared_state = Arc::clone(&shared_state);
std::thread::spawn({
move || {
// The function parameters need to be mutable within the real-time thread
let mut context = context;
let result = thread_fn(thread_scheduling, &mut context);
let result = thread_fn(
&mut context,
thread_scheduling,
shared_state,
shared_state_store_ordering,
);
let context = context;
TerminatedThread { result, context }
}
})
};
Self { join_handle }
Self {
shared_state,
shared_state_load_ordering,
join_handle,
}
}

pub fn join(self) -> JoinedThread<W, E> {
let Self { join_handle } = self;
join_handle
pub fn join(self) -> JoinedThread<W> {
let Self {
join_handle,
shared_state,
shared_state_load_ordering,
} = self;
log::debug!("Joining thread");
let joined_thread = join_handle
.join()
.map(JoinedThread::Terminated)
.unwrap_or_else(JoinedThread::JoinError)
.unwrap_or_else(JoinedThread::JoinError);
debug_assert!(
shared_state_load_ordering == Ordering::Relaxed
|| shared_state.load(shared_state_load_ordering) == State::Terminating.to_u8()
);
joined_thread
}
}

Expand Down
Loading

0 comments on commit 67844fc

Please sign in to comment.