From ebf7a0ba799342d50c691c808179611446cc79d4 Mon Sep 17 00:00:00 2001 From: ivberg Date: Tue, 29 Oct 2024 14:22:35 -0700 Subject: [PATCH] Fix reliability issues in LogAllSessions. (#22568) ### Description Issue can happen with multiple sessions and when ETW captureState / rundown is triggered. Resolves use after free issue. Tested with local unit test creating/destroying multiple sessions while continually enabling & disabling ETW. This currently requires Admin prompt so not checking in ### Motivation and Context ORT should not crash --- onnxruntime/core/session/inference_session.cc | 190 ++++++++++-------- onnxruntime/core/session/inference_session.h | 4 +- 2 files changed, 105 insertions(+), 89 deletions(-) diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index e6aafaa1f2283..4be107758d392 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -370,86 +370,12 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, // a monotonically increasing session id for use in telemetry session_id_ = global_session_id_.fetch_add(1); -#ifdef _WIN32 - std::lock_guard lock(active_sessions_mutex_); - active_sessions_[global_session_id_++] = this; - - // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider - callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( - [this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - // Check if this callback is for capturing state - if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && - ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { - LogAllSessions(); - } - }); - WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); - - // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start - auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); - callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( - [&etwRegistrationManager, this](LPCGUID SourceId, - ULONG IsEnabled, - UCHAR Level, - ULONGLONG MatchAnyKeyword, - ULONGLONG MatchAllKeyword, - PEVENT_FILTER_DESCRIPTOR FilterData, - PVOID CallbackContext) { - (void)SourceId; - (void)Level; - (void)MatchAnyKeyword; - (void)MatchAllKeyword; - (void)FilterData; - (void)CallbackContext; - - if (logging_manager_ != nullptr) { - auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); - - if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && - IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { - LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - logging_manager_->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( - onnxruntime::logging::SinkType::EtwSink, - []() -> std::unique_ptr { return std::make_unique(); }, - ortETWSeverity); - LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; - } - if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { - LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; - logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); - LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; - } - } - }); - - // Register callback for ETW capture state (rundown) - etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); - -#endif - SetLoggingManager(session_options, session_env); // The call to InitLogger depends on the final state of session_options_. Hence it should be invoked // after the invocation of FinalizeSessionOptions. InitLogger(logging_manager_); // this sets session_logger_ so that it can be used for logging after this point. - TraceSessionOptions(session_options, false); + TraceSessionOptions(session_options, false, *session_logger_); #if !defined(ORT_MINIMAL_BUILD) // Update the number of steps for the graph transformer manager using the "finalized" session options @@ -575,14 +501,97 @@ void InferenceSession::ConstructorCommon(const SessionOptions& session_options, } telemetry_ = {}; + +#ifdef _WIN32 + std::lock_guard lock(active_sessions_mutex_); + active_sessions_[session_id_] = this; + + // Register callback for ETW capture state (rundown) for Microsoft.ML.ONNXRuntime provider + callback_ML_ORT_provider_ = onnxruntime::WindowsTelemetry::EtwInternalCallback( + [](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + (void)SourceId; + (void)Level; + (void)MatchAnyKeyword; + (void)MatchAllKeyword; + (void)FilterData; + (void)CallbackContext; + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + // Check if this callback is for capturing state + if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) && + ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) { + InferenceSession::LogAllSessions(); + } + }); + WindowsTelemetry::RegisterInternalCallback(callback_ML_ORT_provider_); + + // Register callback for ETW start / stop so that LOGS tracing can be adjusted dynamically after session start + auto& etwRegistrationManager = logging::EtwRegistrationManager::Instance(); + callback_ETWSink_provider_ = onnxruntime::logging::EtwRegistrationManager::EtwInternalCallback( + [&etwRegistrationManager, this](LPCGUID SourceId, + ULONG IsEnabled, + UCHAR Level, + ULONGLONG MatchAnyKeyword, + ULONGLONG MatchAllKeyword, + PEVENT_FILTER_DESCRIPTOR FilterData, + PVOID CallbackContext) { + ORT_UNUSED_PARAMETER(SourceId); + ORT_UNUSED_PARAMETER(Level); + ORT_UNUSED_PARAMETER(MatchAnyKeyword); + ORT_UNUSED_PARAMETER(MatchAllKeyword); + ORT_UNUSED_PARAMETER(FilterData); + ORT_UNUSED_PARAMETER(CallbackContext); + + if (logging_manager_ != nullptr) { + auto ortETWSeverity = etwRegistrationManager.MapLevelToSeverity(); + + if ((MatchAnyKeyword & static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Logs)) != 0 && + IsEnabled == EVENT_CONTROL_CODE_ENABLE_PROVIDER) { + LOGS(*session_logger_, VERBOSE) << "Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + logging_manager_->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + onnxruntime::logging::LoggingManager::GetDefaultInstance()->AddSinkOfType( + onnxruntime::logging::SinkType::EtwSink, + []() -> std::unique_ptr { return std::make_unique(); }, + ortETWSeverity); + LOGS(*session_logger_, INFO) << "Done Adding ETW Sink to logger with severity level: " << (ULONG)ortETWSeverity; + } + if (IsEnabled == EVENT_CONTROL_CODE_DISABLE_PROVIDER) { + LOGS(*session_logger_, INFO) << "Removing ETW Sink from logger"; + logging_manager_->RemoveSink(onnxruntime::logging::SinkType::EtwSink); + LOGS(*session_logger_, VERBOSE) << "Done Removing ETW Sink from logger"; + } + } + }); + + // Register callback for ETW capture state (rundown) + etwRegistrationManager.RegisterInternalCallback(callback_ETWSink_provider_); + +#endif } -void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState) { +void InferenceSession::TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger) { ORT_UNUSED_PARAMETER(captureState); // Otherwise Linux build error - LOGS(*session_logger_, INFO) << session_options; + LOGS(logger, INFO) << session_options; #ifdef _WIN32 + std::string optimized_model_filepath = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath); + std::string profile_file_prefix = ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix); + TraceLoggingWrite(telemetry_provider_handle, "SessionOptions", TraceLoggingKeyword(static_cast(onnxruntime::logging::ORTTraceLoggingKeyword::Session)), @@ -590,11 +599,11 @@ void InferenceSession::TraceSessionOptions(const SessionOptions& session_options TraceLoggingUInt8(static_cast(session_options.execution_mode), "execution_mode"), TraceLoggingUInt8(static_cast(session_options.execution_order), "execution_order"), TraceLoggingBoolean(session_options.enable_profiling, "enable_profiling"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.optimized_model_filepath).c_str(), "optimized_model_filepath"), + TraceLoggingString(optimized_model_filepath.c_str(), "optimized_model_filepath"), TraceLoggingBoolean(session_options.enable_mem_pattern, "enable_mem_pattern"), TraceLoggingBoolean(session_options.enable_mem_reuse, "enable_mem_reuse"), TraceLoggingBoolean(session_options.enable_cpu_mem_arena, "enable_cpu_mem_arena"), - TraceLoggingString(ORT_TSTR_CONVERT_TO_PRINTABLE_STRING(session_options.profile_file_prefix).c_str(), "profile_file_prefix"), + TraceLoggingString(profile_file_prefix.c_str(), "profile_file_prefix"), TraceLoggingString(session_options.session_logid.c_str(), "session_logid"), TraceLoggingInt8(static_cast(session_options.session_log_severity_level), "session_log_severity_level"), TraceLoggingInt8(static_cast(session_options.session_log_verbosity_level), "session_log_verbosity_level"), @@ -729,7 +738,7 @@ InferenceSession::~InferenceSession() { WindowsTelemetry::UnregisterInternalCallback(callback_ML_ORT_provider_); logging::EtwRegistrationManager::Instance().UnregisterInternalCallback(callback_ETWSink_provider_); #endif - active_sessions_.erase(global_session_id_); + active_sessions_.erase(session_id_); #ifdef ONNXRUNTIME_ENABLE_INSTRUMENT if (session_activity_started_) @@ -3313,14 +3322,21 @@ void InferenceSession::LogAllSessions() { for (const auto& session_pair : active_sessions_) { InferenceSession* session = session_pair.second; - onnxruntime::Graph& graph = model_->MainGraph(); - bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); - env.GetTelemetryProvider().LogSessionCreation( - session_id_, model_->IrVersion(), model_->ProducerName(), model_->ProducerVersion(), model_->Domain(), - graph.DomainToVersionMap(), graph.Name(), model_->MetaData(), - telemetry_.event_name_, execution_providers_.GetIds(), model_has_fp16_inputs, true); + if (!session) { + continue; + } + + auto model = session->model_; + if (nullptr != model) { + onnxruntime::Graph& graph = model->MainGraph(); + bool model_has_fp16_inputs = ModelHasFP16Inputs(graph); + env.GetTelemetryProvider().LogSessionCreation( + session->session_id_, model->IrVersion(), model->ProducerName(), model->ProducerVersion(), model->Domain(), + graph.DomainToVersionMap(), graph.Name(), model->MetaData(), + session->telemetry_.event_name_, session->execution_providers_.GetIds(), model_has_fp16_inputs, true); + } - TraceSessionOptions(session->session_options_, true); + InferenceSession::TraceSessionOptions(session->session_options_, true, *session->session_logger_); } } #endif diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index 424248da793f1..514a478e3f97e 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -663,7 +663,7 @@ class InferenceSession { void InitLogger(logging::LoggingManager* logging_manager); - void TraceSessionOptions(const SessionOptions& session_options, bool captureState); + static void TraceSessionOptions(const SessionOptions& session_options, bool captureState, const logging::Logger& logger); [[nodiscard]] common::Status CheckShapes(const std::string& input_name, const TensorShape& input_shape, const TensorShape& expected_shape, const char* input_output_moniker) const; @@ -700,7 +700,7 @@ class InferenceSession { void ShrinkMemoryArenas(gsl::span arenas_to_shrink); #ifdef _WIN32 - void LogAllSessions(); + static void LogAllSessions(); #endif #if !defined(ORT_MINIMAL_BUILD)