Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add capturestate / rundown ETW support logging for session and provider options #19397

Merged
merged 7 commits into from
Feb 8, 2024
55 changes: 48 additions & 7 deletions onnxruntime/core/framework/execution_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#pragma once

// #include <map>
#include <memory>
#include <string>
#include <unordered_map>
Expand All @@ -14,7 +13,9 @@
#include "core/common/logging/logging.h"
#ifdef _WIN32
#include <winmeta.h>
#include <evntrace.h>
#include "core/platform/tracing.h"
#include "core/platform/windows/telemetry.h"
#endif

namespace onnxruntime {
Expand Down Expand Up @@ -44,6 +45,49 @@
exec_provider_options_[provider_id] = providerOptions;

#ifdef _WIN32
LogProviderOptions(provider_id, providerOptions, false);

// Register callback for ETW capture state (rundown)
WindowsTelemetry::RegisterInternalCallback(
[this](
LPCGUID SourceId,
ULONG IsEnabled,
UCHAR Level,
ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
(void)SourceId;
(void)Level;
(void)MatchAnyKeyword;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

// Check if this callback is for capturing state
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {

Check warning on line 69 in onnxruntime/core/framework/execution_providers.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/framework/execution_providers.h:69: Lines should be <= 120 characters long [whitespace/line_length] [2]
for (size_t i = 0; i < exec_providers_.size(); ++i) {
const auto& provider_id = exec_provider_ids_[i];

auto it = exec_provider_options_.find(provider_id);
if (it != exec_provider_options_.end()) {
const auto& options = it->second;

LogProviderOptions(provider_id, options, true);
}
}
}
});
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}

#ifdef _WIN32
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) {
for (const auto& config_pair : providerOptions) {
TraceLoggingWrite(
telemetry_provider_handle,
Expand All @@ -52,14 +96,11 @@
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(provider_id.c_str(), "ProviderId"),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
TraceLoggingString(config_pair.second.c_str(), "Value"),
TraceLoggingBool(captureState, "isCaptureState"));
}
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}
#endif

const IExecutionProvider* Get(const onnxruntime::Node& node) const {
return Get(node.GetExecutionProviderType());
Expand Down
24 changes: 19 additions & 5 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/platform/windows/telemetry.h"
#include "core/platform/ort_mutex.h"
#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"

Expand Down Expand Up @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true;
uint32_t WindowsTelemetry::projection_ = 0;
UCHAR WindowsTelemetry::level_ = 0;
UINT64 WindowsTelemetry::keyword_ = 0;
std::vector<WindowsTelemetry::EtwInternalCallback> WindowsTelemetry::callbacks_;
OrtMutex WindowsTelemetry::callbacks_mutex_;

WindowsTelemetry::WindowsTelemetry() {
std::lock_guard<OrtMutex> lock(mutex_);
Expand Down Expand Up @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const {
// return etw_status_;
// }

void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
callbacks_.push_back(callback);
}

void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand All @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ ULONGLONG MatchAllKeyword,
_In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
_In_opt_ PVOID CallbackContext) {
(void)SourceId;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

std::lock_guard<OrtMutex> lock(provider_change_mutex_);
enabled_ = (IsEnabled != 0);
level_ = Level;
keyword_ = MatchAnyKeyword;

InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}

void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
for (const auto& callback : callbacks_) {
callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}
}

void WindowsTelemetry::EnableTelemetryEvents() const {
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/core/platform/windows/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// Licensed under the MIT License.

#pragma once
#include <atomic>
#include <vector>

#include "core/platform/telemetry.h"
#include <Windows.h>
#include <TraceLoggingProvider.h>
#include "core/platform/ort_mutex.h"
#include "core/platform/windows/TraceLoggingConfig.h"
#include <atomic>

namespace onnxruntime {

Expand Down Expand Up @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry {

void LogExecutionProviderEvent(LUID* adapterLuid) const override;

using EtwInternalCallback = std::function<void(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level,
ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>;

static void RegisterInternalCallback(const EtwInternalCallback& callback);

private:
static OrtMutex mutex_;
static uint32_t global_register_count_;
static bool enabled_;
static uint32_t projection_;

static std::vector<EtwInternalCallback> callbacks_;
static OrtMutex callbacks_mutex_;
static OrtMutex provider_change_mutex_;
static UCHAR level_;
static ULONGLONG keyword_;

static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext);

static void NTAPI ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand Down
72 changes: 64 additions & 8 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,11 @@
#include "core/optimizer/transformer_memcpy.h"
#include "core/optimizer/transpose_optimization/ort_optimizer_utils.h"
#include "core/platform/Barrier.h"
#include "core/platform/ort_mutex.h"
#include "core/platform/threadpool.h"
#ifdef _WIN32
#include "core/platform/tracing.h"
#include <Windows.h>
#include "core/platform/windows/telemetry.h"
#endif
#include "core/providers/cpu/controlflow/utils.h"
#include "core/providers/cpu/cpu_execution_provider.h"
Expand Down Expand Up @@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling(
} // namespace

std::atomic<uint32_t> InferenceSession::global_session_id_{1};
std::map<uint32_t, InferenceSession*> InferenceSession::active_sessions_;
#ifdef _WIN32
OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_
#endif

static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options,
const ONNX_NAMESPACE::ModelProto& model_proto,
Expand Down Expand Up @@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options,
void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
const Environment& session_env) {
auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_);
// a monotonically increasing session id for use in telemetry
session_id_ = global_session_id_.fetch_add(1);
ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ",
status.ErrorMessage());

// a monotonically increasing session id for use in telemetry
session_id_ = global_session_id_.fetch_add(1);

#ifdef _WIN32
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
ivberg marked this conversation as resolved.
Show resolved Hide resolved
active_sessions_[global_session_id_++] = this;

// Register callback for ETW capture state (rundown)
WindowsTelemetry::RegisterInternalCallback(
[this](
LPCGUID SourceId,
ULONG IsEnabled,
UCHAR Level,
ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
(void)SourceId;
(void)Level;
(void)MatchAnyKeyword;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

// Check if this callback is for capturing state
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
LogAllSessions();
}
});
#endif

SetLoggingManager(session_options, session_env);

// The call to InitLogger depends on the final state of session_options_. Hence it should be invoked
// after the invocation of FinalizeSessionOptions.
InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point.
TraceSessionOptions(session_options);
TraceSessionOptions(session_options, false);

#if !defined(ORT_MINIMAL_BUILD)
// Update the number of steps for the graph transformer manager using the "finalized" session options
Expand Down Expand Up @@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options,
telemetry_ = {};
}

void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) {
void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) {
(void)captureState; // Otherwise Linux build error

LOGS(*session_logger_, INFO) << session_options;

#ifdef _WIN32
Expand All @@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingUInt8(static_cast<UINT8>(session_options.graph_optimization_level), "graph_optimization_level"),
TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"),
TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"),
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"));
TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"),
TraceLoggingBoolean(captureState, "isCaptureState"));

TraceLoggingWrite(
telemetry_provider_handle,
Expand All @@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"),
TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"),
TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"),
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"));
TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"),
TraceLoggingBoolean(captureState, "isCaptureState"));

for (const auto& config_pair : session_options.config_options.configurations) {
TraceLoggingWrite(
Expand All @@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
TraceLoggingString(config_pair.second.c_str(), "Value"),
TraceLoggingBoolean(captureState, "isCaptureState"));
}
#endif
}
Expand Down Expand Up @@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() {
}
}

// Unregister the session
#ifdef _WIN32
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
#endif
active_sessions_.erase(global_session_id_);

#ifdef ONNXRUNTIME_ENABLE_INSTRUMENT
if (session_activity_started_)
TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity");
Expand Down Expand Up @@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() {
return binding_.get();
}

#ifdef _WIN32
void InferenceSession::LogAllSessions() {
std::lock_guard<OrtMutex> lock(active_sessions_mutex_);
for (const auto& session_pair : active_sessions_) {
InferenceSession* session = session_pair.second;
TraceSessionOptions(session->session_options_, true);
}
}
#endif

} // namespace onnxruntime
14 changes: 12 additions & 2 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#pragma once

#include <map>
#include <optional>
#include <string>
#include <unordered_map>
Expand All @@ -21,11 +22,12 @@
#include "core/framework/session_state.h"
#include "core/framework/tuning_results.h"
#include "core/framework/framework_provider_common.h"
#include "core/framework/session_options.h"
#include "core/graph/basic_types.h"
#include "core/optimizer/graph_transformer_level.h"
#include "core/optimizer/graph_transformer_mgr.h"
#include "core/optimizer/insert_cast_transformer.h"
#include "core/framework/session_options.h"
#include "core/platform/ort_mutex.h"
#ifdef ENABLE_LANGUAGE_INTEROP_OPS
#include "core/language_interop_ops/language_interop_ops.h"
#endif
Expand Down Expand Up @@ -119,6 +121,10 @@ class InferenceSession {
};

using InputOutputDefMetaMap = InlinedHashMap<std::string_view, InputOutputDefMetaData>;
static std::map<uint32_t, InferenceSession*> active_sessions_;
#ifdef _WIN32
static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_
#endif

public:
#if !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -642,7 +648,7 @@ class InferenceSession {

void InitLogger(logging::LoggingManager* logging_manager);

void TraceSessionOptions(const SessionOptions& session_options);
void TraceSessionOptions(const SessionOptions& session_options, bool captureState);

[[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape,
const TensorShape& expected_shape, const char* input_output_moniker) const;
Expand Down Expand Up @@ -679,6 +685,10 @@ class InferenceSession {
*/
void ShrinkMemoryArenas(gsl::span<const AllocatorPtr> arenas_to_shrink);

#ifdef _WIN32
void LogAllSessions();
#endif

#if !defined(ORT_MINIMAL_BUILD)
virtual common::Status AddPredefinedTransformers(
GraphTransformerManager& transformer_manager,
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@
(std::string(provider_name) + " execution provider is not supported in this build. ").c_str());
};

for (const auto& config_pair : provider_options) {
ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str()));

Check warning on line 94 in onnxruntime/core/session/provider_registration.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/provider_registration.cc:94: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

if (strcmp(provider_name, "DML") == 0) {
#if defined(USE_DML)
options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options));
Expand Down
Loading
Loading