Skip to content

Commit

Permalink
Add capturestate / rundown ETW support logging for session and provid…
Browse files Browse the repository at this point in the history
…er options (#19397)

### Description
Add capturestate / rundown ETW support logging for session and provider
options.

### Motivation and Context
Follow-up to #16259 and #18882

This is very useful when you have longer running ONNX sessions which
will be the case for a lot of AI workloads. That means ETW tracing may
start minutes or hours after a process & session has been established.
When a trace is captured, you would want to know the state of ONNX at
that time. The state for ONNX is session and config options so that they
show up in the trace.

Tested with xperf and ORT 
xperf -start ort -on 3a26b1ff-7484-7484-7484-15261f42614d
xperf -capturestate ort 3a26b1ff-7484-7484-7484-15261f42614d <--- Run
this after session has been up for some time
xperf -stop ort -d .\ort.etl  <- Trace will now also have rundown events

Also these will show if you use WPR [CaptureStateOnSave
](https://learn.microsoft.com/en-us/windows-hardware/test/wpt/capturestateonsave)
  • Loading branch information
ivberg authored Feb 8, 2024
1 parent 3b1b183 commit 148f54c
Show file tree
Hide file tree
Showing 7 changed files with 164 additions and 26 deletions.
55 changes: 48 additions & 7 deletions onnxruntime/core/framework/execution_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#pragma once

// #include <map>
#include <memory>
#include <string>
#include <unordered_map>
Expand All @@ -14,7 +13,9 @@
#include "core/common/logging/logging.h"
#ifdef _WIN32
#include <winmeta.h>
#include <evntrace.h>
#include "core/platform/tracing.h"
#include "core/platform/windows/telemetry.h"
#endif

namespace onnxruntime {
Expand Down Expand Up @@ -44,6 +45,49 @@ class ExecutionProviders {
exec_provider_options_[provider_id] = providerOptions;

#ifdef _WIN32
LogProviderOptions(provider_id, providerOptions, false);

// Register callback for ETW capture state (rundown)
WindowsTelemetry::RegisterInternalCallback(
[this](
LPCGUID SourceId,
ULONG IsEnabled,
UCHAR Level,
ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
(void)SourceId;
(void)Level;
(void)MatchAnyKeyword;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

// Check if this callback is for capturing state
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
for (size_t i = 0; i < exec_providers_.size(); ++i) {
const auto& provider_id = exec_provider_ids_[i];

auto it = exec_provider_options_.find(provider_id);
if (it != exec_provider_options_.end()) {
const auto& options = it->second;

LogProviderOptions(provider_id, options, true);
}
}
}
});
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}

#ifdef _WIN32
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) {
for (const auto& config_pair : providerOptions) {
TraceLoggingWrite(
telemetry_provider_handle,
Expand All @@ -52,14 +96,11 @@ class ExecutionProviders {
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(provider_id.c_str(), "ProviderId"),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
TraceLoggingString(config_pair.second.c_str(), "Value"),
TraceLoggingBool(captureState, "isCaptureState"));
}
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}
#endif

const IExecutionProvider* Get(const onnxruntime::Node& node) const {
return Get(node.GetExecutionProviderType());
Expand Down
24 changes: 19 additions & 5 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/platform/windows/telemetry.h"
#include "core/platform/ort_mutex.h"
#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"

Expand Down Expand Up @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true;
uint32_t WindowsTelemetry::projection_ = 0;
UCHAR WindowsTelemetry::level_ = 0;
UINT64 WindowsTelemetry::keyword_ = 0;
std::vector<WindowsTelemetry::EtwInternalCallback> WindowsTelemetry::callbacks_;
OrtMutex WindowsTelemetry::callbacks_mutex_;

WindowsTelemetry::WindowsTelemetry() {
std::lock_guard<OrtMutex> lock(mutex_);
Expand Down Expand Up @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const {
// return etw_status_;
// }

void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
callbacks_.push_back(callback);
}

void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand All @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ ULONGLONG MatchAllKeyword,
_In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
_In_opt_ PVOID CallbackContext) {
(void)SourceId;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

std::lock_guard<OrtMutex> lock(provider_change_mutex_);
enabled_ = (IsEnabled != 0);
level_ = Level;
keyword_ = MatchAnyKeyword;

InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}

void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
for (const auto& callback : callbacks_) {
callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}
}

void WindowsTelemetry::EnableTelemetryEvents() const {
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/core/platform/windows/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// Licensed under the MIT License.

#pragma once
#include <atomic>
#include <vector>

#include "core/platform/telemetry.h"
#include <Windows.h>
#include <TraceLoggingProvider.h>
#include "core/platform/ort_mutex.h"
#include "core/platform/windows/TraceLoggingConfig.h"
#include <atomic>

namespace onnxruntime {

Expand Down Expand Up @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry {

void LogExecutionProviderEvent(LUID* adapterLuid) const override;

using EtwInternalCallback = std::function<void(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level,
ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>;

static void RegisterInternalCallback(const EtwInternalCallback& callback);

private:
static OrtMutex mutex_;
static uint32_t global_register_count_;
static bool enabled_;
static uint32_t projection_;

static std::vector<EtwInternalCallback> callbacks_;
static OrtMutex callbacks_mutex_;
static OrtMutex provider_change_mutex_;
static UCHAR level_;
static ULONGLONG keyword_;

static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext);

static void NTAPI ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand Down
72 changes: 64 additions & 8 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@
#include "core/optimizer/transformer_memcpy.h"
#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h"
#include "core/platform/Barrier.h"
#include "core/platform/ort_mutex.h"
#include "core/platform/threadpool.h"
#ifdef _WIN32
#include "core/platform/tracing.h"
#include <Windows.h>
#include "core/platform/windows/telemetry.h"
#endif
#include "core/providers/cpu/controlflow/utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
Expand Down Expand Up @@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling(
} // namespace

std::atomic<uint32_t> InferenceSession::global_session_id_{1};
std::map<uint32_t, InferenceSession*> InferenceSession::active_sessions_;
#ifdef _WIN32
OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
#endif

static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options,
const ONNX_NAMESPACE::ModelProto& model_proto,
Expand Down Expand Up @@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options,
void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env) {
auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_);
// a monotonically increasing session id for use in telemetry
session_id_ = global_session_id_.fetch_add(1);
ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ",
status.ErrorMessage());

// a monotonically increasing session id for use in telemetry
session_id_ = global_session_id_.fetch_add(1);

#ifdef _WIN32
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
active_sessions_[global_session_id_++] = this;

// Register callback for ETW capture state (rundown)
WindowsTelemetry::RegisterInternalCallback(
[this](
LPCGUID SourceId,
ULONG IsEnabled,
UCHAR Level,
ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
(void)SourceId;
(void)Level;
(void)MatchAnyKeyword;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

// Check if this callback is for capturing state
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
LogAllSessions();
}
});
#endif

SetLoggingManager(session_options, session_env);

// The call to InitLogger depends on the final state of session_options_. Hence it should be invoked
// after the invocation of FinalizeSessionOptions.
InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point.
TraceSessionOptions(session_options);
TraceSessionOptions(session_options, false);

#if !defined(ORT_MINIMAL_BUILD)
// Update the number of steps for the graph transformer manager using the "finalized" session options
Expand Down Expand Up @@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
telemetry_ = {};
}

void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) {
void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) {
(void)captureState; // Otherwise Linux build error

LOGS(*session_logger_, INFO) << session_options;

#ifdef _WIN32
Expand All @@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingUInt8(static_cast<UINT8>(session_options.graph_optimization_level), "graph_optimization_level"),
TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"),
TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"),
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"));
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"),
TraceLoggingBoolean(captureState, "isCaptureState"));

TraceLoggingWrite(
telemetry_provider_handle,
Expand All @@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"),
TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"),
TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"),
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"));
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"),
TraceLoggingBoolean(captureState, "isCaptureState"));

for (const auto& config_pair : session_options.config_options.configurations) {
TraceLoggingWrite(
Expand All @@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
TraceLoggingString(config_pair.second.c_str(), "Value"),
TraceLoggingBoolean(captureState, "isCaptureState"));
}
#endif
}
Expand Down Expand Up @@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() {
}
}

// Unregister the session
#ifdef _WIN32
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
#endif
active_sessions_.erase(global_session_id_);

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
if (session_activity_started_)
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
Expand Down Expand Up @@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() {
return binding_.get();
}

#ifdef _WIN32
void InferenceSession::LogAllSessions() {
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
for (const auto& session_pair : active_sessions_) {
InferenceSession* session = session_pair.second;
TraceSessionOptions(session->session_options_, true);
}
}
#endif

} // namespace onnxruntime
14 changes: 12 additions & 2 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <map>
#include <optional>
#include <string>
#include <unordered_map>
Expand All @@ -21,11 +22,12 @@
#include "core/framework/session_state.h"
#include "core/framework/tuning_results.h"
#include "core/framework/framework_provider_common.h"
#include "core/framework/session_options.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/optimizer/insert_cast_transformer.h"
#include "core/framework/session_options.h"
#include "core/platform/ort_mutex.h"
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
#include "core/language_interop_ops/language_interop_ops.h"
#endif
Expand Down Expand Up @@ -119,6 +121,10 @@ class InferenceSession {
};

using InputOutputDefMetaMap = InlinedHashMap<std::string_view, InputOutputDefMetaData>;
static std::map<uint32_t, InferenceSession*> active_sessions_;
#ifdef _WIN32
static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_
#endif

public:
#if !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -642,7 +648,7 @@ class InferenceSession {

void InitLogger(logging::LoggingManager* logging_manager);

void TraceSessionOptions(const SessionOptions& session_options);
void TraceSessionOptions(const SessionOptions& session_options, bool captureState);

[[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
const TensorShape& expected_shape, const char* input_output_moniker) const;
Expand Down Expand Up @@ -679,6 +685,10 @@ class InferenceSession {
*/
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);

#ifdef _WIN32
void LogAllSessions();
#endif

#if !defined(ORT_MINIMAL_BUILD)
virtual common::Status AddPredefinedTransformers(
GraphTransformerManager& transformer_manager,
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
};

for (const auto& config_pair : provider_options) {
ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str()));
}

if (strcmp(provider_name, "DML") == 0) {
#if defined(USE_DML)
options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options));
Expand Down
Loading

0 comments on commit 148f54c

Please sign in to comment.