diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,7 +13,9 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include +#include #include "core/platform/tracing.h" +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +45,49 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "isCaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cae714954f72f..b045f30a59797 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -46,10 +46,11 @@ #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/platform/Barrier.h" -#include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" #ifdef _WIN32 #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -241,6 +242,10 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; +#ifdef _WIN32 +OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +#endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -351,17 +356,47 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); - // a monotonically increasing session id for use in telemetry - session_id_ = global_session_id_.fetch_add(1); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // a monotonically increasing session id for use in telemetry + session_id_ = global_session_id_.fetch_add(1); + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[global_session_id_++] = this; + + // Register callback for ETW capture state (rundown) + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + LogAllSessions(); + } + }); +#endif + SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options); + TraceSessionOptions(session_options, false); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -475,7 +510,9 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { + (void)captureState; // Otherwise Linux build error + LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 @@ -498,7 +535,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.graph_optimization_level), "graph_optimization_level"), TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), - TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute")); + TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), + TraceLoggingBoolean(captureState, "isCaptureState")); TraceLoggingWrite( telemetry_provider_handle, @@ -511,7 +549,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"), TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), - TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero")); + TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), + TraceLoggingBoolean(captureState, "isCaptureState")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -520,7 +559,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBoolean(captureState, "isCaptureState")); } #endif } @@ -616,6 +656,12 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); +#endif + active_sessions_.erase(global_session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -3070,4 +3116,14 @@ IOBinding* SessionIOBinding::Get() { return binding_.get(); } +#ifdef _WIN32 +void InferenceSession::LogAllSessions() { + std::lock_guard lock(active_sessions_mutex_); + for (const auto& session_pair : active_sessions_) { + InferenceSession* session = session_pair.second; + TraceSessionOptions(session->session_options_, true); + } +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 96db49aabdaf6..f8211bfd2dd4e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -21,11 +22,12 @@ #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" +#include "core/framework/session_options.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/framework/session_options.h" +#include "core/platform/ort_mutex.h" #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -119,6 +121,10 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; +#ifdef _WIN32 + static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ +#endif public: #if !defined(ORT_MINIMAL_BUILD) @@ -642,7 +648,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options); + void TraceSessionOptions(const SessionOptions& session_options, bool captureState); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -679,6 +685,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ade1d96d617fb..17a955ba8ce1a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + for (const auto& config_pair : provider_options) { + ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str())); + } + if (strcmp(provider_name, "DML") == 0) { #if defined(USE_DML) options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md index 59fe946b929f2..309b474c016c9 100644 --- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md +++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md @@ -3,13 +3,13 @@ The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW. ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time. -Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7). +Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](/docs/FAQ.md?plain=1#L7). However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically. - Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D -- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h) -- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner +- Keyword: Logs (0x2) keyword per [logging.h](/include/onnxruntime/core/common/logging/logging.h) +- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](/onnxruntime/core/platform/windows/logging/etw_sink.cc) to [ONNX severity](/include/onnxruntime/core/common/logging/severity.h) in an intuitive manner Notes: - The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete