Skip to content

Commit

Permalink
Dropping CallbackData at the right time
Browse files Browse the repository at this point in the history
Wrapping it into an Arc ensures we're not dropping it when the trace is stopped, but we're waiting for the potential callbacks to terminate first

This fixes "Race 2" in n4r1b#45 (n4r1b#45)
  • Loading branch information
daladim committed Nov 7, 2022
1 parent e446b3f commit 9cd993d
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 11 deletions.
6 changes: 4 additions & 2 deletions src/native/etw_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use crate::trace::{CallbackData, TraceProperties, TraceTrait};
use std::ffi::c_void;
use std::fmt::Formatter;
use std::marker::PhantomData;
use std::sync::Arc;

use windows::core::GUID;
use windows::core::PWSTR;
use windows::Win32::System::Diagnostics::Etw;
Expand Down Expand Up @@ -240,7 +242,7 @@ pub struct EventTraceLogfile<'callbackdata> {

impl<'callbackdata> EventTraceLogfile<'callbackdata> {
/// Create a new instance
pub fn create(callback_data: &'callbackdata Box<CallbackData>, mut wide_logger_name: U16CString, callback: unsafe extern "system" fn(*mut Etw::EVENT_RECORD)) -> Self {
pub fn create(callback_data: &'callbackdata Box<Arc<CallbackData>>, mut wide_logger_name: U16CString, callback: unsafe extern "system" fn(*mut Etw::EVENT_RECORD)) -> Self {
let mut native = Etw::EVENT_TRACE_LOGFILEW::default();

native.LoggerName = PWSTR(wide_logger_name.as_mut_ptr());
Expand All @@ -249,7 +251,7 @@ impl<'callbackdata> EventTraceLogfile<'callbackdata> {

native.Anonymous2.EventRecordCallback = Some(callback);

let not_really_mut_ptr = callback_data.as_ref() as *const CallbackData as *const c_void as *mut c_void; // That's kind-of fine because the user context is _not supposed_ to be changed by Windows APIs
let not_really_mut_ptr = callback_data.as_ref() as *const Arc<CallbackData> as *const c_void as *mut c_void; // That's kind-of fine because the user context is _not supposed_ to be changed by Windows APIs
native.Context = not_really_mut_ptr;

Self {
Expand Down
10 changes: 7 additions & 3 deletions src/native/evntrace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! This module makes sure the calls are safe memory-wise, but does not attempt to ensure they are called in the right order.<br/>
//! Thus, you should prefer using `UserTrace`s, `KernelTrace`s and `TraceBuilder`s, that will ensure these API are correctly used.
use std::panic::AssertUnwindSafe;
use std::sync::Arc;

use widestring::{U16CString, U16CStr};
use windows::Win32::Foundation::WIN32_ERROR;
Expand Down Expand Up @@ -49,7 +50,7 @@ extern "system" fn trace_callback_thunk(p_record: *mut Etw::EVENT_RECORD) {
};

if let Some(event_record) = record_from_ptr {
let p_user_context = event_record.user_context().cast::<CallbackData>();
let p_user_context = event_record.user_context().cast::<Arc<CallbackData>>();
let user_context = unsafe {
// Safety:
// * the API of this create guarantees this points to a `CallbackData` already allocated and created
Expand All @@ -61,7 +62,10 @@ extern "system" fn trace_callback_thunk(p_record: *mut Etw::EVENT_RECORD) {
p_user_context.as_ref()
};
if let Some(user_context) = user_context {
user_context.on_event(event_record);
// The UserContext is owned by the `Trace` object. When it is dropped, so will the UserContext.
// We clone it now, so that the original Arc can be safely dropped at all times, but the callback data (including the closure captured context) will still be alive until the callback ends.
let cloned_arc = Arc::clone(user_context);
cloned_arc.on_event(event_record);
}
}
})) {
Expand Down Expand Up @@ -135,7 +139,7 @@ where
/// Subscribe to a started trace
///
/// Microsoft calls this "opening" the trace (and this calls `OpenTraceW`)
pub fn open_trace(trace_name: U16CString, callback_data: &Box<CallbackData>) -> EvntraceNativeResult<TraceHandle> {
pub fn open_trace(trace_name: U16CString, callback_data: &Box<Arc<CallbackData>>) -> EvntraceNativeResult<TraceHandle> {
let mut log_file = EventTraceLogfile::create(&callback_data, trace_name, trace_callback_thunk);

let trace_handle = unsafe {
Expand Down
15 changes: 9 additions & 6 deletions src/trace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//! Provides both a Kernel and User trace that allows to start an ETW session
use std::marker::PhantomData;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;

use self::private::PrivateTraceTrait;

Expand Down Expand Up @@ -191,8 +192,9 @@ pub struct UserTrace {
control_handle: ControlHandle,
trace_handle: TraceHandle,
// CallbackData is
// * `Arc`ed, so that dropping a Trace while a callback is still running is not an issue
// * `Boxed`, so that the `UserTrace` can be moved around the stack (e.g. returned by a function) but the pointers to the `CallbackData` given to Windows ETW API stay valid
callback_data: Box<CallbackData>,
callback_data: Box<Arc<CallbackData>>,
}

/// A kernel trace session
Expand All @@ -204,8 +206,9 @@ pub struct KernelTrace {
control_handle: ControlHandle,
trace_handle: TraceHandle,
// CallbackData is
// * `Arc`ed, so that dropping a Trace while a callback is still running is not an issue
// * `Boxed`, so that the `UserTrace` can be moved around the stack (e.g. returned by a function) but the pointers to the `CallbackData` given to Windows ETW API stay valid
callback_data: Box<CallbackData>,
callback_data: Box<Arc<CallbackData>>,
}

/// Provides a way to crate Trace objects.
Expand Down Expand Up @@ -275,7 +278,7 @@ mod private {

pub trait PrivateTraceTrait {
const TRACE_KIND: TraceKind;
fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<CallbackData>) -> Self;
fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<Arc<CallbackData>>) -> Self;
fn augmented_file_mode() -> u32;
fn enable_flags(_providers: &[Provider]) -> u32;
// This function aims at de-deduplicating code called by `impl Drop` and `Trace::stop`.
Expand All @@ -287,7 +290,7 @@ mod private {
impl private::PrivateTraceTrait for UserTrace {
const TRACE_KIND: private::TraceKind = private::TraceKind::User;

fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<CallbackData>) -> Self {
fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<Arc<CallbackData>>) -> Self {
UserTrace {
properties,
control_handle,
Expand All @@ -313,7 +316,7 @@ impl private::PrivateTraceTrait for UserTrace {
impl private::PrivateTraceTrait for KernelTrace {
const TRACE_KIND: private::TraceKind = private::TraceKind::Kernel;

fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<CallbackData>) -> Self {
fn build(properties: EventTraceProperties, control_handle: ControlHandle, trace_handle: TraceHandle, callback_data: Box<Arc<CallbackData>>) -> Self {
KernelTrace {
properties,
control_handle,
Expand Down Expand Up @@ -394,7 +397,7 @@ impl<T: TraceTrait + PrivateTraceTrait> TraceBuilder<T> {
trace_wide_vec.truncate(crate::native::etw_types::TRACE_NAME_MAX_CHARS);
let trace_wide_name = U16CString::from_vec_truncate(trace_wide_vec);

let callback_data = Box::new(self.callback_data);
let callback_data = Box::new(Arc::new(self.callback_data));
let flags = callback_data.provider_flags::<T>();
let (full_properties, control_handle) = start_trace::<T>(
&trace_wide_name,
Expand Down

0 comments on commit 9cd993d

Please sign in to comment.