Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into skottmckay/CoreML_MLP…
Browse files Browse the repository at this point in the history
…rogram_support
  • Loading branch information
skottmckay committed Feb 9, 2024
2 parents dd3d802 + 90cf037 commit 6468186
Show file tree
Hide file tree
Showing 65 changed files with 958 additions and 301 deletions.
8 changes: 2 additions & 6 deletions docs/python/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@
# Licensed under the MIT License.
# pylint: disable=C0103

# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
"""Configuration file for the Sphinx documentation builder."""

import os
import shutil # noqa: F401
import shutil
import sys

sys.path.append(os.path.join(os.path.dirname(__file__), "..", "_common"))
Expand Down Expand Up @@ -127,7 +125,5 @@ def setup(app):
urllib.request.urlretrieve(url, dest)
loc = os.path.split(dest)[-1]
if not os.path.exists(loc):
import shutil # noqa: F811

shutil.copy(dest, loc)
return app
10 changes: 10 additions & 0 deletions js/web/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,46 +69,56 @@
"exports": {
".": {
"node": "./dist/ort.node.min.js",
"types": "./types.d.ts",
"default": {
"import": "./dist/esm/ort.min.js",
"require": "./dist/cjs/ort.min.js",
"types": "./types.d.ts",
"default": {
"development": "./dist/ort.js",
"types": "./types.d.ts",
"default": "./dist/ort.min.js"
}
}
},
"./experimental": {
"import": "./dist/esm/ort.all.min.js",
"require": "./dist/cjs/ort.all.min.js",
"types": "./types.d.ts",
"default": {
"development": "./dist/ort.all.js",
"types": "./types.d.ts",
"default": "./dist/ort.all.min.js"
}
},
"./wasm": {
"import": "./dist/esm/ort.wasm.min.js",
"require": "./dist/cjs/ort.wasm.min.js",
"types": "./types.d.ts",
"default": "./dist/ort.wasm.min.js"
},
"./wasm-core": {
"import": "./dist/esm/ort.wasm-core.min.js",
"require": "./dist/cjs/ort.wasm-core.min.js",
"types": "./types.d.ts",
"default": "./dist/ort.wasm-core.min.js"
},
"./webgl": {
"import": "./dist/esm/ort.webgl.min.js",
"require": "./dist/cjs/ort.webgl.min.js",
"types": "./types.d.ts",
"default": "./dist/ort.webgl.min.js"
},
"./webgpu": {
"import": "./dist/esm/ort.webgpu.min.js",
"require": "./dist/cjs/ort.webgpu.min.js",
"types": "./types.d.ts",
"default": "./dist/ort.webgpu.min.js"
},
"./training": {
"import": "./dist/esm/ort.training.wasm.min.js",
"require": "./dist/cjs/ort.training.wasm.min.js",
"types": "./types.d.ts",
"default": "./dist/ort.training.wasm.min.js"
}
},
Expand Down
55 changes: 48 additions & 7 deletions onnxruntime/core/framework/execution_providers.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

#pragma once

// #include <map>
#include <memory>
#include <string>
#include <unordered_map>
Expand All @@ -14,7 +13,9 @@
#include "core/common/logging/logging.h"
#ifdef _WIN32
#include <winmeta.h>
#include <evntrace.h>
#include "core/platform/tracing.h"
#include "core/platform/windows/telemetry.h"
#endif

namespace onnxruntime {
Expand Down Expand Up @@ -44,6 +45,49 @@ class ExecutionProviders {
exec_provider_options_[provider_id] = providerOptions;

#ifdef _WIN32
LogProviderOptions(provider_id, providerOptions, false);

// Register callback for ETW capture state (rundown)
WindowsTelemetry::RegisterInternalCallback(
[this](
LPCGUID SourceId,
ULONG IsEnabled,
UCHAR Level,
ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
(void)SourceId;
(void)Level;
(void)MatchAnyKeyword;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

// Check if this callback is for capturing state
if ((IsEnabled == EVENT_CONTROL_CODE_CAPTURE_STATE) &&
((MatchAnyKeyword & static_cast<ULONGLONG>(onnxruntime::logging::ORTTraceLoggingKeyword::Session)) != 0)) {
for (size_t i = 0; i < exec_providers_.size(); ++i) {
const auto& provider_id = exec_provider_ids_[i];

auto it = exec_provider_options_.find(provider_id);
if (it != exec_provider_options_.end()) {
const auto& options = it->second;

LogProviderOptions(provider_id, options, true);
}
}
}
});
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}

#ifdef _WIN32
void LogProviderOptions(const std::string& provider_id, const ProviderOptions& providerOptions, bool captureState) {
for (const auto& config_pair : providerOptions) {
TraceLoggingWrite(
telemetry_provider_handle,
Expand All @@ -52,14 +96,11 @@ class ExecutionProviders {
TraceLoggingLevel(WINEVENT_LEVEL_INFO),
TraceLoggingString(provider_id.c_str(), "ProviderId"),
TraceLoggingString(config_pair.first.c_str(), "Key"),
TraceLoggingString(config_pair.second.c_str(), "Value"));
TraceLoggingString(config_pair.second.c_str(), "Value"),
TraceLoggingBool(captureState, "isCaptureState"));
}
#endif

exec_provider_ids_.push_back(provider_id);
exec_providers_.push_back(p_exec_provider);
return Status::OK();
}
#endif

const IExecutionProvider* Get(const onnxruntime::Node& node) const {
return Get(node.GetExecutionProviderType());
Expand Down
24 changes: 19 additions & 5 deletions onnxruntime/core/platform/windows/telemetry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

#include "core/platform/windows/telemetry.h"
#include "core/platform/ort_mutex.h"
#include "core/common/logging/logging.h"
#include "onnxruntime_config.h"

Expand Down Expand Up @@ -63,6 +64,8 @@ bool WindowsTelemetry::enabled_ = true;
uint32_t WindowsTelemetry::projection_ = 0;
UCHAR WindowsTelemetry::level_ = 0;
UINT64 WindowsTelemetry::keyword_ = 0;
std::vector<WindowsTelemetry::EtwInternalCallback> WindowsTelemetry::callbacks_;
OrtMutex WindowsTelemetry::callbacks_mutex_;

WindowsTelemetry::WindowsTelemetry() {
std::lock_guard<OrtMutex> lock(mutex_);
Expand Down Expand Up @@ -104,6 +107,11 @@ UINT64 WindowsTelemetry::Keyword() const {
// return etw_status_;
// }

void WindowsTelemetry::RegisterInternalCallback(const EtwInternalCallback& callback) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
callbacks_.push_back(callback);
}

void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand All @@ -112,15 +120,21 @@ void NTAPI WindowsTelemetry::ORT_TL_EtwEnableCallback(
_In_ ULONGLONG MatchAllKeyword,
_In_opt_ PEVENT_FILTER_DESCRIPTOR FilterData,
_In_opt_ PVOID CallbackContext) {
(void)SourceId;
(void)MatchAllKeyword;
(void)FilterData;
(void)CallbackContext;

std::lock_guard<OrtMutex> lock(provider_change_mutex_);
enabled_ = (IsEnabled != 0);
level_ = Level;
keyword_ = MatchAnyKeyword;

InvokeCallbacks(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}

void WindowsTelemetry::InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData,
PVOID CallbackContext) {
std::lock_guard<OrtMutex> lock(callbacks_mutex_);
for (const auto& callback : callbacks_) {
callback(SourceId, IsEnabled, Level, MatchAnyKeyword, MatchAllKeyword, FilterData, CallbackContext);
}
}

void WindowsTelemetry::EnableTelemetryEvents() const {
Expand Down
15 changes: 14 additions & 1 deletion onnxruntime/core/platform/windows/telemetry.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
// Licensed under the MIT License.

#pragma once
#include <atomic>
#include <vector>

#include "core/platform/telemetry.h"
#include <Windows.h>
#include <TraceLoggingProvider.h>
#include "core/platform/ort_mutex.h"
#include "core/platform/windows/TraceLoggingConfig.h"
#include <atomic>

namespace onnxruntime {

Expand Down Expand Up @@ -58,16 +60,27 @@ class WindowsTelemetry : public Telemetry {

void LogExecutionProviderEvent(LUID* adapterLuid) const override;

using EtwInternalCallback = std::function<void(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level,
ULONGLONG MatchAnyKeyword, ULONGLONG MatchAllKeyword,
PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext)>;

static void RegisterInternalCallback(const EtwInternalCallback& callback);

private:
static OrtMutex mutex_;
static uint32_t global_register_count_;
static bool enabled_;
static uint32_t projection_;

static std::vector<EtwInternalCallback> callbacks_;
static OrtMutex callbacks_mutex_;
static OrtMutex provider_change_mutex_;
static UCHAR level_;
static ULONGLONG keyword_;

static void InvokeCallbacks(LPCGUID SourceId, ULONG IsEnabled, UCHAR Level, ULONGLONG MatchAnyKeyword,
ULONGLONG MatchAllKeyword, PEVENT_FILTER_DESCRIPTOR FilterData, PVOID CallbackContext);

static void NTAPI ORT_TL_EtwEnableCallback(
_In_ LPCGUID SourceId,
_In_ ULONG IsEnabled,
Expand Down
Loading

0 comments on commit 6468186

Please sign in to comment.