From 6b8bd70cd5149ae325fd2254671ee13e9f3b11a2 Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Mon, 22 Jan 2024 12:53:23 -0800 Subject: [PATCH 1/7] Support capturestate for logging session options --- .../core/platform/windows/telemetry.cc | 24 +++++++-- onnxruntime/core/platform/windows/telemetry.h | 15 +++++- onnxruntime/core/session/inference_session.cc | 52 +++++++++++++++++-- onnxruntime/core/session/inference_session.h | 10 +++- 4 files changed, 91 insertions(+), 10 deletions(-) diff --git a/onnxruntime/core/platform/windows/telemetry.cc b/onnxruntime/core/platform/windows/telemetry.cc index a9849873fd060..654281d526e4d 100644 --- a/onnxruntime/core/platform/windows/telemetry.cc +++ b/onnxruntime/core/platform/windows/telemetry.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "core/platform/windows/telemetry.h" +#include "core/platform/ort_mutex.h" #include "core/common/logging/logging.h" #include "onnxruntime_config.h" @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true; uint32_t WindowsTelemetry::projection_ = 0; UCHAR WindowsTelemetry::level_ = 0; UINT64 WindowsTelemetry::keyword_ = 0; +std::vector WindowsTelemetry::callbacks_; +OrtMutex WindowsTelemetry::callbacks_mutex_; WindowsTelemetry::WindowsTelemetry() { std::lock_guard lock(mutex_); @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const { // return etw_status_; // } +void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) { + std::lock_guard lock(callbacks_mutex_); + callbacks_.push_back(callback); +} + void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback( _In_ ULONGLONG MatchAllKeyword, _In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData, _In_opt_ PVOID CallbackContext) { - (void)SourceId; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - std::lock_guard lock(provider_change_mutex_); enabled_ = (IsEnabled != 0); level_ = Level; keyword_ = MatchAnyKeyword; + + InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); +} + +void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + std::lock_guard lock(callbacks_mutex_); + for (const auto& callback : callbacks_) { + callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext); + } } void WindowsTelemetry::EnableTelemetryEvents() const { diff --git a/onnxruntime/core/platform/windows/telemetry.h b/onnxruntime/core/platform/windows/telemetry.h index c3798943d491d..cdb186e9ed703 100644 --- a/onnxruntime/core/platform/windows/telemetry.h +++ b/onnxruntime/core/platform/windows/telemetry.h @@ -2,12 +2,14 @@ // Licensed under the MIT License. #pragma once +#include +#include + #include "core/platform/telemetry.h" #include #include #include "core/platform/ort_mutex.h" #include "core/platform/windows/TraceLoggingConfig.h" -#include namespace onnxruntime { @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry { void LogExecutionProviderEvent(LUID* adapterLuid) const override; + using EtwInternalCallback = std::function; + + static void RegisterInternalCallback(const EtwInternalCallback& callback); + private: static OrtMutex mutex_; static uint32_t global_register_count_; static bool enabled_; static uint32_t projection_; + static std::vector callbacks_; + static OrtMutex callbacks_mutex_; static OrtMutex provider_change_mutex_; static UCHAR level_; static ULONGLONG keyword_; + static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext); + static void NTAPI ORT_TL_EtwEnableCallback( _In_ LPCGUID SourceId, _In_ ULONG IsEnabled, diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cae714954f72f..cd80ac8b86939 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -46,10 +46,11 @@ #include "core/optimizer/transformer_memcpy.h" #include "core/optimizer/transpose_optimization/ort_optimizer_utils.h" #include "core/platform/Barrier.h" -#include "core/platform/ort_mutex.h" #include "core/platform/threadpool.h" #ifdef _WIN32 #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif #include "core/providers/cpu/controlflow/utils.h" #include "core/providers/cpu/cpu_execution_provider.h" @@ -241,6 +242,8 @@ Status GetMinimalBuildOptimizationHandling( } // namespace std::atomic InferenceSession::global_session_id_{1}; +std::map InferenceSession::active_sessions_; +OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -351,11 +354,40 @@ void InferenceSession::SetLoggingManager(const SessionOptions& session_options, void InferenceSession::ConstructorCommon(const SessionOptions& session_options, const Environment& session_env) { auto status = FinalizeSessionOptions(session_options, model_proto_, is_model_proto_parsed_, session_options_); - // a monotonically increasing session id for use in telemetry - session_id_ = global_session_id_.fetch_add(1); ORT_ENFORCE(status.IsOK(), "Could not finalize session options while constructing the inference session. Error Message: ", status.ErrorMessage()); + // a monotonically increasing session id for use in telemetry + session_id_ = global_session_id_.fetch_add(1); + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[global_session_id_++] = this; + +#ifdef _WIN32 + // auto& manager = WindowsTelemetry:: EtwRegistrationManager::Instance(); + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + LogAllSessions(); + } + }); +#endif + SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked @@ -616,6 +648,10 @@ InferenceSession::~InferenceSession() { } } + // Unregister the session + std::lock_guard lock(active_sessions_mutex_); + active_sessions_.erase(global_session_id_); + #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) TraceLoggingWriteStop(session_activity, "OrtInferenceSessionActivity"); @@ -3070,4 +3106,14 @@ IOBinding* SessionIOBinding::Get() { return binding_.get(); } +#ifdef _WIN32 +void InferenceSession::LogAllSessions() { + std::lock_guard lock(active_sessions_mutex_); + for (const auto& session_pair : active_sessions_) { + InferenceSession* session = session_pair.second; + TraceSessionOptions(session->session_options_); + } +} +#endif + } // namespace onnxruntime diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 96db49aabdaf6..d4228c3603879 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -21,11 +22,12 @@ #include "core/framework/session_state.h" #include "core/framework/tuning_results.h" #include "core/framework/framework_provider_common.h" +#include "core/framework/session_options.h" #include "core/graph/basic_types.h" #include "core/optimizer/graph_transformer_level.h" #include "core/optimizer/graph_transformer_mgr.h" #include "core/optimizer/insert_cast_transformer.h" -#include "core/framework/session_options.h" +#include "core/platform/ort_mutex.h" #ifdef ENABLE_LANGUAGE_INTEROP_OPS #include "core/language_interop_ops/language_interop_ops.h" #endif @@ -119,6 +121,8 @@ class InferenceSession { }; using InputOutputDefMetaMap = InlinedHashMap; + static std::map active_sessions_; + static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ public: #if !defined(ORT_MINIMAL_BUILD) @@ -679,6 +683,10 @@ class InferenceSession { */ void ShrinkMemoryArenas(gsl::span arenas_to_shrink); +#ifdef _WIN32 + void LogAllSessions(); +#endif + #if !defined(ORT_MINIMAL_BUILD) virtual common::Status AddPredefinedTransformers( GraphTransformerManager& transformer_manager, From c49d2831b991bf6ebd77c9b0522b2d67fbd0078d Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Thu, 1 Feb 2024 16:59:26 -0800 Subject: [PATCH 2/7] Partial check-in of rundown flag and attempt to log EP options --- .../core/framework/execution_providers.h | 53 ++++++++++++++++--- onnxruntime/core/session/inference_session.cc | 16 +++--- onnxruntime/core/session/inference_session.h | 2 +- 3 files changed, 57 insertions(+), 14 deletions(-) diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 61147e4367876..b0659cb19e4cc 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -15,6 +15,8 @@ #ifdef _WIN32 #include #include "core/platform/tracing.h" +#include +#include "core/platform/windows/telemetry.h" #endif namespace onnxruntime { @@ -44,6 +46,48 @@ class ExecutionProviders { exec_provider_options_[provider_id] = providerOptions; #ifdef _WIN32 + LogProviderOptions(provider_id, providerOptions, false); + + WindowsTelemetry::RegisterInternalCallback( + [this]( + LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + for (size_t i = 0; i < exec_providers_.size(); ++i) { + const auto& provider_id = exec_provider_ids_[i]; + + auto it = exec_provider_options_.find(provider_id); + if (it != exec_provider_options_.end()) { + const auto& options = it->second; + + LogProviderOptions(provider_id, options, true); + } + } + } + }); +#endif + + exec_provider_ids_.push_back(provider_id); + exec_providers_.push_back(p_exec_provider); + return Status::OK(); + } + +#ifdef _WIN32 + void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) { for (const auto& config_pair : providerOptions) { TraceLoggingWrite( telemetry_provider_handle, @@ -52,14 +96,11 @@ class ExecutionProviders { TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBool(captureState, "CaptureState")); } -#endif - - exec_provider_ids_.push_back(provider_id); - exec_providers_.push_back(p_exec_provider); - return Status::OK(); } +#endif const IExecutionProvider* Get(const onnxruntime::Node& node) const { return Get(node.GetExecutionProviderType()); diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index cd80ac8b86939..68947845a5543 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -363,7 +363,6 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, active_sessions_[global_session_id_++] = this; #ifdef _WIN32 - // auto& manager = WindowsTelemetry:: EtwRegistrationManager::Instance(); WindowsTelemetry::RegisterInternalCallback( [this]( LPCGUID SourceId, @@ -393,7 +392,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options); + TraceSessionOptions(session_options, false); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -507,7 +506,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool rundown) { LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 @@ -530,7 +529,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.graph_optimization_level), "graph_optimization_level"), TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), - TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute")); + TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), + TraceLoggingBoolean(rundown, "isRundown")); TraceLoggingWrite( telemetry_provider_handle, @@ -543,7 +543,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingInt32(session_options.intra_op_param.dynamic_block_base_, "dynamic_block_base_"), TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), - TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero")); + TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), + TraceLoggingBoolean(rundown, "isRundown")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -552,7 +553,8 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), - TraceLoggingString(config_pair.second.c_str(), "Value")); + TraceLoggingString(config_pair.second.c_str(), "Value"), + TraceLoggingBoolean(rundown, "isRundown")); } #endif } @@ -3111,7 +3113,7 @@ void InferenceSession::LogAllSessions() { std::lock_guard lock(active_sessions_mutex_); for (const auto& session_pair : active_sessions_) { InferenceSession* session = session_pair.second; - TraceSessionOptions(session->session_options_); + TraceSessionOptions(session->session_options_, true); } } #endif diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index d4228c3603879..b2ebb168ad2fd 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -646,7 +646,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options); + void TraceSessionOptions(const SessionOptions& session_options, bool rundown); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; From e6052ce85c48578b774e5396a57af4e27e3b8967 Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Fri, 2 Feb 2024 12:56:24 -0800 Subject: [PATCH 3/7] Add provider registration config options to session config options so that values can be logged at rundown --- onnxruntime/core/session/provider_registration.cc | 4 ++++ .../platform/windows/logging/HowToValidateEtwSinkOutput.md | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/session/provider_registration.cc b/onnxruntime/core/session/provider_registration.cc index ade1d96d617fb..17a955ba8ce1a 100644 --- a/onnxruntime/core/session/provider_registration.cc +++ b/onnxruntime/core/session/provider_registration.cc @@ -90,6 +90,10 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider, (std::string(provider_name) + " execution provider is not supported in this build. ").c_str()); }; + for (const auto& config_pair : provider_options) { + ORT_THROW_IF_ERROR(options->value.config_options.AddConfigEntry((std::string(provider_name) + ":" + config_pair.first).c_str(), config_pair.second.c_str())); + } + if (strcmp(provider_name, "DML") == 0) { #if defined(USE_DML) options->provider_factories.push_back(DMLProviderFactoryCreator::CreateFromProviderOptions(provider_options)); diff --git a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md index 59fe946b929f2..309b474c016c9 100644 --- a/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md +++ b/onnxruntime/test/platform/windows/logging/HowToValidateEtwSinkOutput.md @@ -3,13 +3,13 @@ The ETW Sink (ONNXRuntimeTraceLoggingProvider) allows ONNX semi-structured printf style logs to be output via ETW. ETW makes it easy and useful to only enable and listen for events with great performance, and when you need them instead of only at compile time. -Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](docs/FAQ.md?plain=1#L7). +Therefore ONNX will preserve any existing loggers and log severity [provided at compile time](/docs/FAQ.md?plain=1#L7). However, when the provider is enabled a new ETW logger sink will also be added and the severity separately controlled via ETW dynamically. - Provider GUID: 929DD115-1ECB-4CB5-B060-EBD4983C421D -- Keyword: Logs (0x2) keyword per [logging.h](include\onnxruntime\core\common\logging\logging.h) -- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](onnxruntime\core\platform\windows\logging\etw_sink.cc) to [ONNX severity](include\onnxruntime\core\common\logging\severity.h) in an intuitive manner +- Keyword: Logs (0x2) keyword per [logging.h](/include/onnxruntime/core/common/logging/logging.h) +- Level: 1-5 ([CRITICAL through VERBOSE](https://learn.microsoft.com/en-us/windows/win32/api/evntprov/ns-evntprov-event_descriptor)) [mapping](/onnxruntime/core/platform/windows/logging/etw_sink.cc) to [ONNX severity](/include/onnxruntime/core/common/logging/severity.h) in an intuitive manner Notes: - The ETW provider must be enabled prior to session creation, as that as when internal logging setup is complete From 2904cc2989b01b1934feec17d39a5a20d698065d Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Fri, 2 Feb 2024 14:04:57 -0800 Subject: [PATCH 4/7] Fix Linux build issue due to header include order --- onnxruntime/core/framework/execution_providers.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index b0659cb19e4cc..98278dde234da 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -3,7 +3,6 @@ #pragma once -// #include #include #include #include @@ -14,8 +13,8 @@ #include "core/common/logging/logging.h" #ifdef _WIN32 #include -#include "core/platform/tracing.h" #include +#include "core/platform/tracing.h" #include "core/platform/windows/telemetry.h" #endif From e50c606c2777729d7e1ef3058f914b56e338b26f Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Fri, 2 Feb 2024 14:59:25 -0800 Subject: [PATCH 5/7] Fix another Linux build issue due to unused param (on Linux) --- onnxruntime/core/session/inference_session.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 68947845a5543..c1d48c731b988 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -507,6 +507,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool rundown) { + (void)rundown; // Otherwise Linux build error + LOGS(*session_logger_, INFO) << session_options; #ifdef _WIN32 From 2086d76de73aa6c3d0e95640e66db3680e3e2bf1 Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Fri, 2 Feb 2024 16:21:20 -0800 Subject: [PATCH 6/7] Rename rundown to captureState --- onnxruntime/core/framework/execution_providers.h | 3 ++- onnxruntime/core/session/inference_session.cc | 11 ++++++----- onnxruntime/core/session/inference_session.h | 2 +- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/framework/execution_providers.h b/onnxruntime/core/framework/execution_providers.h index 98278dde234da..dc45cad692b6e 100644 --- a/onnxruntime/core/framework/execution_providers.h +++ b/onnxruntime/core/framework/execution_providers.h @@ -47,6 +47,7 @@ class ExecutionProviders { #ifdef _WIN32 LogProviderOptions(provider_id, providerOptions, false); + // Register callback for ETW capture state (rundown) WindowsTelemetry::RegisterInternalCallback( [this]( LPCGUID SourceId, @@ -96,7 +97,7 @@ class ExecutionProviders { TraceLoggingString(provider_id.c_str(), "ProviderId"), TraceLoggingString(config_pair.first.c_str(), "Key"), TraceLoggingString(config_pair.second.c_str(), "Value"), - TraceLoggingBool(captureState, "CaptureState")); + TraceLoggingBool(captureState, "isCaptureState")); } } #endif diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index c1d48c731b988..9a4dac480dd48 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -363,6 +363,7 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, active_sessions_[global_session_id_++] = this; #ifdef _WIN32 + // Register callback for ETW capture state (rundown) WindowsTelemetry::RegisterInternalCallback( [this]( LPCGUID SourceId, @@ -506,8 +507,8 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, telemetry_ = {}; } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool rundown) { - (void)rundown; // Otherwise Linux build error +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { + (void)captureState; // Otherwise Linux build error LOGS(*session_logger_, INFO) << session_options; @@ -532,7 +533,7 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingBoolean(session_options.use_per_session_threads, "use_per_session_threads"), TraceLoggingBoolean(session_options.thread_pool_allow_spinning, "thread_pool_allow_spinning"), TraceLoggingBoolean(session_options.use_deterministic_compute, "use_deterministic_compute"), - TraceLoggingBoolean(rundown, "isRundown")); + TraceLoggingBoolean(captureState, "isCaptureState")); TraceLoggingWrite( telemetry_provider_handle, @@ -546,7 +547,7 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt32(session_options.intra_op_param.stack_size, "stack_size"), TraceLoggingString(!session_options.intra_op_param.affinity_str.empty() ? session_options.intra_op_param.affinity_str.c_str() : "", "affinity_str"), TraceLoggingBoolean(session_options.intra_op_param.set_denormal_as_zero, "set_denormal_as_zero"), - TraceLoggingBoolean(rundown, "isRundown")); + TraceLoggingBoolean(captureState, "isCaptureState")); for (const auto& config_pair : session_options.config_options.configurations) { TraceLoggingWrite( @@ -556,7 +557,7 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingLevel(WINEVENT_LEVEL_INFO), TraceLoggingString(config_pair.first.c_str(), "Key"), TraceLoggingString(config_pair.second.c_str(), "Value"), - TraceLoggingBoolean(rundown, "isRundown")); + TraceLoggingBoolean(captureState, "isCaptureState")); } #endif } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index b2ebb168ad2fd..ad9aece0c7052 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -646,7 +646,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options, bool rundown); + void TraceSessionOptions(const SessionOptions& session_options, bool captureState); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; From 4fc424b91c8413f39efde0907a246b5cbb9ea930 Mon Sep 17 00:00:00 2001 From: Ivan Berg Date: Wed, 7 Feb 2024 14:29:50 -0800 Subject: [PATCH 7/7] Per PR, place active_sessions_mutex_ under win32 since it's only used there for now --- onnxruntime/core/session/inference_session.cc | 7 ++++++- onnxruntime/core/session/inference_session.h | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 9a4dac480dd48..b045f30a59797 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -243,7 +243,9 @@ Status GetMinimalBuildOptimizationHandling( std::atomic InferenceSession::global_session_id_{1}; std::map InferenceSession::active_sessions_; +#ifdef _WIN32 OrtMutex InferenceSession::active_sessions_mutex_; // Protects access to active_sessions_ +#endif static Status FinalizeSessionOptions(const SessionOptions& user_provided_session_options, const ONNX_NAMESPACE::ModelProto& model_proto, @@ -359,10 +361,11 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // a monotonically increasing session id for use in telemetry session_id_ = global_session_id_.fetch_add(1); + +#ifdef _WIN32 std::lock_guard lock(active_sessions_mutex_); active_sessions_[global_session_id_++] = this; -#ifdef _WIN32 // Register callback for ETW capture state (rundown) WindowsTelemetry::RegisterInternalCallback( [this]( @@ -654,7 +657,9 @@ InferenceSession::~InferenceSession() { } // Unregister the session +#ifdef _WIN32 std::lock_guard lock(active_sessions_mutex_); +#endif active_sessions_.erase(global_session_id_); #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index ad9aece0c7052..f8211bfd2dd4e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -122,7 +122,9 @@ class InferenceSession { using InputOutputDefMetaMap = InlinedHashMap; static std::map active_sessions_; +#ifdef _WIN32 static OrtMutex active_sessions_mutex_; // Protects access to active_sessions_ +#endif public: #if !defined(ORT_MINIMAL_BUILD)