Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fully dynamic ETW controlled logging for ORT and QNN logs #20537

Merged
merged 12 commits into from
Jun 7, 2024
Merged
5 changes: 5 additions & 0 deletions include/onnxruntime/core/common/logging/isink.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ class ISink {
public:
ISink() = default;

enum SinkType { BaseSink,
CompositeSink,
EtwSink };
virtual SinkType GetType() const { return BaseSink; }
ivberg marked this conversation as resolved.
Show resolved Hide resolved

/**
Sends the message to the sink.
@param timestamp The timestamp.
Expand Down
32 changes: 26 additions & 6 deletions include/onnxruntime/core/common/logging/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
#include "core/common/common.h"
#include "core/common/profiler_common.h"
#include "core/common/logging/capture.h"
#include "core/common/logging/severity.h"

#include "core/common/logging/macros.h"

#include "core/common/logging/severity.h"
#include "core/platform/ort_mutex.h"
#include "date/date.h"

/*
Expand Down Expand Up @@ -167,6 +166,24 @@ class LoggingManager final {
*/
static bool HasDefaultLogger() { return nullptr != s_default_logger_; }

/**
Gets the default instance of the LoggingManager.
*/
static LoggingManager* GetDefaultInstance();
ivberg marked this conversation as resolved.
Show resolved Hide resolved

#ifdef _WIN32
/**
Removes the ETW Sink if one is present
*/
void RemoveEtwSink();

/**
Adds an ETW Sink to the current sink creating a CompositeSink if necessary
@param etwSeverity The severity level for the ETW Sink
*/
void AddEtwSink(logging::Severity etwSeverity);
#endif

/**
Change the minimum severity level for log messages to be output by the default logger.
@param severity The severity.
Expand Down Expand Up @@ -214,7 +231,10 @@ class LoggingManager final {
void CreateDefaultLogger(const std::string& logger_id);

std::unique_ptr<ISink> sink_;
const Severity default_min_severity_;
#ifdef _WIN32
mutable OrtMutex sink_mutex_;
#endif
Severity default_min_severity_;
const bool default_filter_user_data_;
const int default_max_vlog_level_;
bool owns_default_logger_;
Expand Down Expand Up @@ -362,8 +382,8 @@ unsigned int GetProcessId();
/**
If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then adds to the existing logger.
*/
std::unique_ptr<ISink> EnhanceLoggerWithEtw(std::unique_ptr<ISink> existingLogger, logging::Severity originalSeverity,
logging::Severity etwSeverity);
std::unique_ptr<ISink> EnhanceSinkWithEtw(std::unique_ptr<ISink> existingSink, logging::Severity originalSeverity,
logging::Severity etwSeverity);

/**
If the ONNXRuntimeTraceLoggingProvider ETW Provider is enabled, then can override the logging level.
Expand Down
91 changes: 85 additions & 6 deletions onnxruntime/core/common/logging/logging.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,10 @@
#include <sys/syscall.h>
#endif
#endif
#include "core/platform/ort_mutex.h"

#if __FreeBSD__
#include <sys/thr.h> // Use thr_self() syscall under FreeBSD to get thread id
#include "logging.h"

Check warning on line 28 in onnxruntime/core/common/logging/logging.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Include the directory when naming header files [build/include_subdir] [4] Raw Output: onnxruntime/core/common/logging/logging.cc:28: Include the directory when naming header files [build/include_subdir] [4]
#endif

namespace onnxruntime {
Expand All @@ -52,7 +52,11 @@
return default_instance;
}

LoggingManager* LoggingManager::GetDefaultInstance() {
return static_cast<LoggingManager*>(DefaultLoggerManagerInstance().load());
}

// GSL_SUPRESS(i.22) is broken. Ignore the warnings for the static local variables that are trivial

Check warning on line 59 in onnxruntime/core/common/logging/logging.cc

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "SUPRESS" is a misspelling of "SUPPRESS" Raw Output: ./onnxruntime/core/common/logging/logging.cc:59:7: "SUPRESS" is a misspelling of "SUPPRESS"
// and should not have any destruction order issues via pragmas instead.
// https://developercommunity.visualstudio.com/content/problem/249706/gslsuppress-does-not-work-for-i22-c-core-guideline.html
#ifdef _MSC_VER
Expand All @@ -66,6 +70,7 @@
}

Logger* LoggingManager::s_default_logger_ = nullptr;
OrtMutex sink_mutex_;

#ifdef _MSC_VER
#pragma warning(pop)
Expand Down Expand Up @@ -245,23 +250,23 @@
#endif
}

std::unique_ptr<ISink> EnhanceLoggerWithEtw(std::unique_ptr<ISink> existingLogger, logging::Severity originalSeverity,
logging::Severity etwSeverity) {
std::unique_ptr<ISink> EnhanceSinkWithEtw(std::unique_ptr<ISink> existingSink, logging::Severity originalSeverity,
logging::Severity etwSeverity) {
ivberg marked this conversation as resolved.
Show resolved Hide resolved
#ifdef _WIN32
auto& manager = EtwRegistrationManager::Instance();
if (manager.IsEnabled()) {
auto compositeSink = std::make_unique<CompositeSink>();
compositeSink->AddSink(std::move(existingLogger), originalSeverity);
compositeSink->AddSink(std::move(existingSink), originalSeverity);
compositeSink->AddSink(std::make_unique<EtwSink>(), etwSeverity);
return compositeSink;
} else {
return existingLogger;
return existingSink;
}
#else
// On non-Windows platforms, just return the existing logger
(void)originalSeverity;
(void)etwSeverity;
return existingLogger;
return existingSink;
#endif // _WIN32
}

Expand All @@ -276,5 +281,79 @@
return originalSeverity;
}

#ifdef _WIN32
void LoggingManager::AddEtwSink(logging::Severity etwSeverity) {
ivberg marked this conversation as resolved.
Show resolved Hide resolved
std::lock_guard<OrtMutex> guard(sink_mutex_);

// Check if the EtwRegistrationManager is enabled
auto& manager = EtwRegistrationManager::Instance();
if (!manager.IsEnabled()) {
return; // ETW not enabled, no operation needed
}

if (sink_->GetType() != ISink::CompositeSink) {
// Current sink is not a composite, create a new composite sink and add the current sink to it
auto new_composite = std::make_unique<CompositeSink>();
new_composite->AddSink(std::move(sink_), default_min_severity_); // Move the current sink into the new composite
ivberg marked this conversation as resolved.
Show resolved Hide resolved
sink_ = std::move(new_composite); // Now sink_ is pointing to the new composite
}

// Adjust the default minimum severity level to accommodate ETW logging needs
default_min_severity_ = std::min(default_min_severity_, etwSeverity);
if (s_default_logger_ != nullptr) {
s_default_logger_->SetSeverity(default_min_severity_);
}

CompositeSink* current_composite = static_cast<CompositeSink*>(sink_.get());
// Check if an EtwSink already exists in the current composite
const auto& sinks = current_composite->GetSinks();
if (std::any_of(sinks.begin(), sinks.end(), [](const auto& pair) {
return pair.first->GetType() == ISink::EtwSink;
})) {
return; // EtwSink already exists, do not add another
}

// Add a new EtwSink
current_composite->AddSink(std::make_unique<EtwSink>(), etwSeverity);
}

void LoggingManager::RemoveEtwSink() {
std::lock_guard<OrtMutex> guard(sink_mutex_);

if (sink_->GetType() == ISink::CompositeSink) {
auto composite_sink = static_cast<CompositeSink*>(sink_.get());
const auto& sinks_with_severity = composite_sink->GetSinks();
std::unique_ptr<ISink> remaining_sink;

Severity newSeverity = Severity::kFATAL;

for (const auto& sink_pair : sinks_with_severity) {
if (sink_pair.first->GetType() != ISink::EtwSink) {
if (remaining_sink) {
// If more than one non-EtwSink is found, we leave the CompositeSink intact
return;
}
newSeverity = std::min(newSeverity, sink_pair.second);
remaining_sink = std::move(const_cast<std::unique_ptr<ISink>&>(sink_pair.first));
}
}

// If only one non-EtwSink remains, replace the CompositeSink with this sink
if (remaining_sink) {
sink_ = std::move(remaining_sink);

} else {
// Handle the case where all sinks were EtwSinks
// sink_ = std::make_unique<NullSink>(); // Assuming NullSink is a basic ISink that does nothing
}

default_min_severity_ = newSeverity;
if (s_default_logger_ != nullptr) {
s_default_logger_->SetSeverity(default_min_severity_);
}
}
}
#endif

} // namespace logging
} // namespace onnxruntime
2 changes: 2 additions & 0 deletions onnxruntime/core/common/logging/sinks/composite_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ class CompositeSink : public ISink {
/// </summary>
CompositeSink() {}

SinkType GetType() const override { return ISink::CompositeSink; }

ivberg marked this conversation as resolved.
Show resolved Hide resolved
/// <summary>
/// Adds a sink. Takes ownership of the sink (so pass unique_ptr by value).
/// </summary>
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/core/platform/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
bool use_fp16) const {
bool use_fp16, bool captureState) const {
ORT_UNUSED_PARAMETER(session_id);
ORT_UNUSED_PARAMETER(ir_version);
ORT_UNUSED_PARAMETER(model_producer_name);
Expand All @@ -67,6 +67,7 @@ void Telemetry::LogSessionCreation(uint32_t session_id, int64_t ir_version, cons
ORT_UNUSED_PARAMETER(loadedFrom);
ORT_UNUSED_PARAMETER(execution_provider_ids);
ORT_UNUSED_PARAMETER(use_fp16);
ORT_UNUSED_PARAMETER(captureState);
}

void Telemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/platform/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class Telemetry {
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
bool use_fp16) const;
bool use_fp16, bool captureState) const;

virtual void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/platform/windows/logging/etw_sink.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ class EtwSink : public ISink {
EtwSink() = default;
~EtwSink() = default;

SinkType GetType() const override { return ISink::EtwSink; }

constexpr static const char* kEventName = "ONNXRuntimeLogEvent";

private:
Expand Down
79 changes: 52 additions & 27 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -210,23 +210,23 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loaded_from, const std::vector<std::string>& execution_provider_ids,
bool use_fp16) const {
bool use_fp16, bool captureState) const {
if (global_register_count_ == 0 || enabled_ == false)
return;

// build the strings we need

std::string domain_to_verison_string;
std::string domain_to_version_string;
bool first = true;
for (auto& i : domain_to_version_map) {
if (first) {
first = false;
} else {
domain_to_verison_string += ',';
domain_to_version_string += ',';
}
domain_to_verison_string += i.first;
domain_to_verison_string += '=';
domain_to_verison_string += std::to_string(i.second);
domain_to_version_string += i.first;
domain_to_version_string += '=';
domain_to_version_string += std::to_string(i.second);
}

std::string model_metadata_string;
Expand All @@ -253,27 +253,52 @@ void WindowsTelemetry::LogSessionCreation(uint32_t session_id, int64_t ir_versio
execution_provider_string += i;
}

TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),
TraceLoggingBool(use_fp16, "usefp16"),
TraceLoggingString(domain_to_verison_string.c_str(), "domainToVersionMap"),
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
// Difference is MeasureEvent & isCaptureState, but keep in sync otherwise
if (!captureState) {
TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation",
ivberg marked this conversation as resolved.
Show resolved Hide resolved
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
TraceLoggingKeyword(MICROSOFT_KEYWORD_MEASURES),
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),
TraceLoggingBool(use_fp16, "usefp16"),
TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"),
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
} else {
TraceLoggingWrite(telemetry_provider_handle,
"SessionCreation_CaptureState",
TraceLoggingBool(true, "UTCReplace_AppSessionGuid"),
TelemetryPrivacyDataTag(PDT_ProductAndServiceUsage),
// Not a measure event
TraceLoggingKeyword(static_cast<uint64_t>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)),
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
// Telemetry info
TraceLoggingUInt8(0, "schemaVersion"),
TraceLoggingUInt32(session_id, "sessionId"),
TraceLoggingInt64(ir_version, "irVersion"),
TraceLoggingUInt32(projection_, "OrtProgrammingProjection"),
TraceLoggingString(model_producer_name.c_str(), "modelProducerName"),
TraceLoggingString(model_producer_version.c_str(), "modelProducerVersion"),
TraceLoggingString(model_domain.c_str(), "modelDomain"),
TraceLoggingBool(use_fp16, "usefp16"),
TraceLoggingString(domain_to_version_string.c_str(), "domainToVersionMap"),
TraceLoggingString(model_graph_name.c_str(), "modelGraphName"),
TraceLoggingString(model_metadata_string.c_str(), "modelMetaData"),
TraceLoggingString(loaded_from.c_str(), "loadedFrom"),
TraceLoggingString(execution_provider_string.c_str(), "executionProviderIds"));
}
}

void WindowsTelemetry::LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/platform/windows/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class WindowsTelemetry : public Telemetry {
const std::string& model_graph_name,
const std::unordered_map<std::string, std::string>& model_metadata,
const std::string& loadedFrom, const std::vector<std::string>& execution_provider_ids,
bool use_fp16) const override;
bool use_fp16, bool captureState) const override;

void LogRuntimeError(uint32_t session_id, const common::Status& status, const char* file,
const char* function, uint32_t line) const override;
Expand Down
Loading
Loading