Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[js/webnn] Enable user-supplied MLContext #20600

Merged
merged 3 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
401 changes: 401 additions & 0 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts

Large diffs are not rendered by default.

22 changes: 0 additions & 22 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ const setExecutionProviders =
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(deviceType, allocs);
Expand All @@ -76,26 +74,6 @@ const setExecutionProviders =
checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
}
}
if (numThreads !== undefined) {
// Just ignore invalid webnnOptions.numThreads.
const validatedNumThreads =
(typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 :
numThreads;
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`);
}
}
if (powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`);
}
}
}
break;
case 'webgpu':
Expand Down
37 changes: 37 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />

import {Env, InferenceSession, Tensor} from 'onnxruntime-common';

import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
Expand Down Expand Up @@ -253,11 +258,43 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

for (const provider of options?.executionProviders ?? []) {
const providerName = typeof provider === 'string' ? provider : provider.name;
if (providerName === 'webnn') {
if (wasm.currentContext) {
throw new Error('WebNN execution provider is already set.');
}
if (typeof provider !== 'string') {
const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (context) {
wasm.currentContext = context as MLContext;
} else if (gpuDevice) {
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
egalli marked this conversation as resolved.
Show resolved Hide resolved
} else {
wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference});
}
} else {
wasm.currentContext = await navigator.ml.createContext();
}
break;
}
}

sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}

// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.currentContext = undefined;
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

const enableGraphCapture = !!options?.enableGraphCapture;
Expand Down
14 changes: 13 additions & 1 deletion js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />

import type {Tensor} from 'onnxruntime-common';

/* eslint-disable @typescript-eslint/naming-convention */
Expand All @@ -19,7 +24,7 @@ export declare namespace JSEP {
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;

export interface Module extends WebGpuModule {
export interface Module extends WebGpuModule, WebNnModule {
/**
* Mount the external data file to an internal map, which will be used during session initialization.
*
Expand Down Expand Up @@ -106,6 +111,13 @@ export declare namespace JSEP {
*/
jsepOnReleaseSession: (sessionId: number) => void;
}

export interface WebNnModule {
/**
* Active MLContext used to create WebNN EP.
*/
currentContext: MLContext;
}
}

export interface OrtInferenceAPIs {
Expand Down
19 changes: 2 additions & 17 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,12 @@

namespace onnxruntime {

WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
// Create WebNN context and graph builder.
const emscripten::val ml = emscripten::val::global("navigator")["ml"];
if (!ml.as<bool>()) {
ORT_THROW("Failed to get ml from navigator.");
}
emscripten::val context_options = emscripten::val::object();
context_options.set("deviceType", emscripten::val(webnn_device_flags));
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
if (webnn_device_flags.compare("cpu") == 0) {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
// Set "numThreads" if it's not default 0.
if (webnn_threads_number.compare("0") != 0) {
context_options.set("numThreads", stoi(webnn_threads_number));
}
} else {
preferred_layout_ = DataLayout::NCHW;
if (webnn_device_flags.compare("gpu") == 0) {
Expand All @@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
ORT_THROW("Unknown WebNN deviceType.");
}
}
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
}

wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
wnn_context_ = emscripten::val::module_property("currentContext");
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class Model;

class WebNNExecutionProvider : public IExecutionProvider {
public:
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags);
explicit WebNNExecutionProvider(const std::string& webnn_device_flags);
virtual ~WebNNExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
13 changes: 4 additions & 9 deletions onnxruntime/core/providers/webnn/webnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,22 @@ using namespace onnxruntime;

namespace onnxruntime {
struct WebNNProviderFactory : IExecutionProviderFactory {
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags)
: webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {}
explicit WebNNProviderFactory(const std::string& webnn_device_flags)
: webnn_device_flags_(webnn_device_flags) {}
~WebNNProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;

std::string webnn_device_flags_;
std::string webnn_threads_number_;
std::string webnn_power_flags_;
};

std::unique_ptr<IExecutionProvider> WebNNProviderFactory::CreateProvider() {
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_);
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_);
}

std::shared_ptr<IExecutionProviderFactory> WebNNProviderFactoryCreator::Create(
const ProviderOptions& provider_options) {
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"),
provider_options.at("numThreads"),
provider_options.at("powerPreference"));
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"));
}

} // namespace onnxruntime
4 changes: 0 additions & 4 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
} else if (strcmp(provider_name, "WEBNN") == 0) {
#if defined(USE_WEBNN)
std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu");
std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0");
std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default");
provider_options["deviceType"] = deviceType;
provider_options["numThreads"] = numThreads;
provider_options["powerPreference"] = powerPreference;
options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
Expand Down
Loading