Skip to content

Commit

Permalink
[js/webnn] Enable user-supplied MLContext (#20600)
Browse files Browse the repository at this point in the history
### Description
This PR enables the API added in #20816 as well as moving context
creation to JS.

### Motivation and Context
In order to enable I/O Binding with the upcoming
[MLBuffer](webmachinelearning/webnn#542) API
in the WebNN specification, we need to share the same `MLContext` across
multiple sessions. This is because `MLBuffer`s are restricted to the
`MLContext` where they were created. This PR enables developers to use
the same `MLContext` across multiple sessions.
  • Loading branch information
egalli authored Jul 8, 2024
1 parent cd516a1 commit 4c3c809
Show file tree
Hide file tree
Showing 8 changed files with 458 additions and 55 deletions.
401 changes: 401 additions & 0 deletions js/web/lib/wasm/jsep/webnn/webnn.d.ts

Large diffs are not rendered by default.

22 changes: 0 additions & 22 deletions js/web/lib/wasm/session-options.ts
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ const setExecutionProviders =
const webnnOptions = ep as InferenceSession.WebNNExecutionProviderOption;
// const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (deviceType) {
const keyDataOffset = allocWasmString('deviceType', allocs);
const valueDataOffset = allocWasmString(deviceType, allocs);
Expand All @@ -76,26 +74,6 @@ const setExecutionProviders =
checkLastError(`Can't set a session config entry: 'deviceType' - ${deviceType}.`);
}
}
if (numThreads !== undefined) {
// Just ignore invalid webnnOptions.numThreads.
const validatedNumThreads =
(typeof numThreads !== 'number' || !Number.isInteger(numThreads) || numThreads < 0) ? 0 :
numThreads;
const keyDataOffset = allocWasmString('numThreads', allocs);
const valueDataOffset = allocWasmString(validatedNumThreads.toString(), allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'numThreads' - ${numThreads}.`);
}
}
if (powerPreference) {
const keyDataOffset = allocWasmString('powerPreference', allocs);
const valueDataOffset = allocWasmString(powerPreference, allocs);
if (getInstance()._OrtAddSessionConfigEntry(sessionOptionsHandle, keyDataOffset, valueDataOffset) !==
0) {
checkLastError(`Can't set a session config entry: 'powerPreference' - ${powerPreference}.`);
}
}
}
break;
case 'webgpu':
Expand Down
37 changes: 37 additions & 0 deletions js/web/lib/wasm/wasm-core-impl.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />

import {Env, InferenceSession, Tensor} from 'onnxruntime-common';

import {SerializableInternalBuffer, SerializableSessionMetadata, SerializableTensorMetadata, TensorMetadata} from './proxy-messages';
Expand Down Expand Up @@ -253,11 +258,43 @@ export const createSession = async(
await Promise.all(loadingPromises);
}

for (const provider of options?.executionProviders ?? []) {
const providerName = typeof provider === 'string' ? provider : provider.name;
if (providerName === 'webnn') {
if (wasm.currentContext) {
throw new Error('WebNN execution provider is already set.');
}
if (typeof provider !== 'string') {
const webnnOptions = provider as InferenceSession.WebNNExecutionProviderOption;
const context = (webnnOptions as InferenceSession.WebNNOptionsWithMLContext)?.context;
const gpuDevice = (webnnOptions as InferenceSession.WebNNOptionsWebGpu)?.gpuDevice;
const deviceType = (webnnOptions as InferenceSession.WebNNContextOptions)?.deviceType;
const numThreads = (webnnOptions as InferenceSession.WebNNContextOptions)?.numThreads;
const powerPreference = (webnnOptions as InferenceSession.WebNNContextOptions)?.powerPreference;
if (context) {
wasm.currentContext = context as MLContext;
} else if (gpuDevice) {
wasm.currentContext = await navigator.ml.createContext(gpuDevice);
} else {
wasm.currentContext = await navigator.ml.createContext({deviceType, numThreads, powerPreference});
}
} else {
wasm.currentContext = await navigator.ml.createContext();
}
break;
}
}

sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle);
if (sessionHandle === 0) {
checkLastError('Can\'t create a session.');
}

// clear current MLContext after session creation
if (wasm.currentContext) {
wasm.currentContext = undefined;
}

const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle);

const enableGraphCapture = !!options?.enableGraphCapture;
Expand Down
14 changes: 13 additions & 1 deletion js/web/lib/wasm/wasm-types.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from
// WebNN API specification.
// https://github.com/webmachinelearning/webnn/issues/677
/// <reference path="jsep/webnn/webnn.d.ts" />

import type {Tensor} from 'onnxruntime-common';

/* eslint-disable @typescript-eslint/naming-convention */
Expand All @@ -19,7 +24,7 @@ export declare namespace JSEP {
type CaptureEndFunction = () => void;
type ReplayFunction = () => void;

export interface Module extends WebGpuModule {
export interface Module extends WebGpuModule, WebNnModule {
/**
* Mount the external data file to an internal map, which will be used during session initialization.
*
Expand Down Expand Up @@ -106,6 +111,13 @@ export declare namespace JSEP {
*/
jsepOnReleaseSession: (sessionId: number) => void;
}

export interface WebNnModule {
/**
* Active MLContext used to create WebNN EP.
*/
currentContext: MLContext;
}
}

export interface OrtInferenceAPIs {
Expand Down
19 changes: 2 additions & 17 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,12 @@

namespace onnxruntime {

WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags,
const std::string& webnn_threads_number, const std::string& webnn_power_flags)
WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags)
: IExecutionProvider{onnxruntime::kWebNNExecutionProvider} {
// Create WebNN context and graph builder.
const emscripten::val ml = emscripten::val::global("navigator")["ml"];
if (!ml.as<bool>()) {
ORT_THROW("Failed to get ml from navigator.");
}
emscripten::val context_options = emscripten::val::object();
context_options.set("deviceType", emscripten::val(webnn_device_flags));
// WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend.
if (webnn_device_flags.compare("cpu") == 0) {
preferred_layout_ = DataLayout::NHWC;
wnn_device_type_ = webnn::WebnnDeviceType::CPU;
// Set "numThreads" if it's not default 0.
if (webnn_threads_number.compare("0") != 0) {
context_options.set("numThreads", stoi(webnn_threads_number));
}
} else {
preferred_layout_ = DataLayout::NCHW;
if (webnn_device_flags.compare("gpu") == 0) {
Expand All @@ -45,11 +33,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f
ORT_THROW("Unknown WebNN deviceType.");
}
}
if (webnn_power_flags.compare("default") != 0) {
context_options.set("powerPreference", emscripten::val(webnn_power_flags));
}

wnn_context_ = ml.call<emscripten::val>("createContext", context_options).await();
wnn_context_ = emscripten::val::module_property("currentContext");
if (!wnn_context_.as<bool>()) {
ORT_THROW("Failed to create WebNN context.");
}
Expand Down
3 changes: 1 addition & 2 deletions onnxruntime/core/providers/webnn/webnn_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@ class Model;

class WebNNExecutionProvider : public IExecutionProvider {
public:
WebNNExecutionProvider(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags);
explicit WebNNExecutionProvider(const std::string& webnn_device_flags);
virtual ~WebNNExecutionProvider();

std::vector<std::unique_ptr<ComputeCapability>>
Expand Down
13 changes: 4 additions & 9 deletions onnxruntime/core/providers/webnn/webnn_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,27 +10,22 @@ using namespace onnxruntime;

namespace onnxruntime {
struct WebNNProviderFactory : IExecutionProviderFactory {
WebNNProviderFactory(const std::string& webnn_device_flags, const std::string& webnn_threads_number,
const std::string& webnn_power_flags)
: webnn_device_flags_(webnn_device_flags), webnn_threads_number_(webnn_threads_number), webnn_power_flags_(webnn_power_flags) {}
explicit WebNNProviderFactory(const std::string& webnn_device_flags)
: webnn_device_flags_(webnn_device_flags) {}
~WebNNProviderFactory() override {}

std::unique_ptr<IExecutionProvider> CreateProvider() override;

std::string webnn_device_flags_;
std::string webnn_threads_number_;
std::string webnn_power_flags_;
};

std::unique_ptr<IExecutionProvider> WebNNProviderFactory::CreateProvider() {
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_, webnn_threads_number_, webnn_power_flags_);
return std::make_unique<WebNNExecutionProvider>(webnn_device_flags_);
}

std::shared_ptr<IExecutionProviderFactory> WebNNProviderFactoryCreator::Create(
const ProviderOptions& provider_options) {
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"),
provider_options.at("numThreads"),
provider_options.at("powerPreference"));
return std::make_shared<onnxruntime::WebNNProviderFactory>(provider_options.at("deviceType"));
}

} // namespace onnxruntime
4 changes: 0 additions & 4 deletions onnxruntime/core/session/provider_registration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider,
} else if (strcmp(provider_name, "WEBNN") == 0) {
#if defined(USE_WEBNN)
std::string deviceType = options->value.config_options.GetConfigOrDefault("deviceType", "cpu");
std::string numThreads = options->value.config_options.GetConfigOrDefault("numThreads", "0");
std::string powerPreference = options->value.config_options.GetConfigOrDefault("powerPreference", "default");
provider_options["deviceType"] = deviceType;
provider_options["numThreads"] = numThreads;
provider_options["powerPreference"] = powerPreference;
options->provider_factories.push_back(WebNNProviderFactoryCreator::Create(provider_options));
#else
status = create_not_supported_status();
Expand Down

0 comments on commit 4c3c809

Please sign in to comment.