From 7252c6e747de83b65285601281a9d07aea801fba Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 25 Jan 2024 07:37:35 +0800 Subject: [PATCH] [WebNN EP] Support WebNN async API with Asyncify (#19145) --- js/web/lib/build-def.d.ts | 4 --- js/web/lib/index.ts | 4 +-- js/web/lib/wasm/binding/ort-wasm.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 4 +-- js/web/script/build.ts | 7 +--- js/web/script/test-runner-cli-args.ts | 4 --- .../core/providers/webnn/builders/model.cc | 35 ++++++++----------- .../providers/webnn/builders/model_builder.cc | 12 +++---- .../webnn/webnn_execution_provider.cc | 3 +- onnxruntime/wasm/js_internal_api.js | 4 +++ 10 files changed, 30 insertions(+), 49 deletions(-) diff --git a/js/web/lib/build-def.d.ts b/js/web/lib/build-def.d.ts index b3868871a4753..2c9cd88a375bd 100644 --- a/js/web/lib/build-def.d.ts +++ b/js/web/lib/build-def.d.ts @@ -21,10 +21,6 @@ interface BuildDefinitions { /** * defines whether to disable the whole WebNN backend in the build. */ - readonly DISABLE_WEBNN: boolean; - /** - * defines whether to disable the whole WebAssembly backend in the build. - */ readonly DISABLE_WASM: boolean; /** * defines whether to disable proxy feature in WebAssembly backend in the build. diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index baf45e74addea..b212c0f49df3b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,12 +23,10 @@ if (!BUILD_DEFS.DISABLE_WASM) { require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_WEBGPU) { registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } registerBackend('cpu', wasmBackend, 10); registerBackend('wasm', wasmBackend, 10); - if (!BUILD_DEFS.DISABLE_WEBNN) { - registerBackend('webnn', wasmBackend, 9); - } } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 68054210e79a7..24d7062c85fcb 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -31,7 +31,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; - _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 8768643fa7257..046336dc9cac0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -84,7 +84,7 @@ export const initRuntime = async(env: Env): Promise => { * @param epName */ export const initEp = async(env: Env, epName: string): Promise => { - if (!BUILD_DEFS.DISABLE_WEBGPU && epName === 'webgpu') { + if (!BUILD_DEFS.DISABLE_WEBGPU && (epName === 'webgpu' || epName === 'webnn')) { // perform WebGPU availability check if (typeof navigator === 'undefined' || !navigator.gpu) { throw new Error('WebGPU is not supported in current environment'); @@ -228,7 +228,7 @@ export const createSession = async( await Promise.all(loadingPromises); } - sessionHandle = wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); + sessionHandle = await wasm._OrtCreateSession(modelDataOffset, modelDataLength, sessionOptionsHandle); if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } diff --git a/js/web/script/build.ts b/js/web/script/build.ts index ea0c122cb51de..d3652f3820357 100644 --- a/js/web/script/build.ts +++ b/js/web/script/build.ts @@ -44,7 +44,6 @@ const SOURCE_ROOT_FOLDER = path.join(__dirname, '../..'); // /js/ const DEFAULT_DEFINE = { 'BUILD_DEFS.DISABLE_WEBGL': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'false', - 'BUILD_DEFS.DISABLE_WEBNN': 'false', 'BUILD_DEFS.DISABLE_WASM': 'false', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'false', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'false', @@ -364,7 +363,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -397,7 +395,7 @@ async function main() { // ort.webgpu[.min].js await addAllWebBuildTasks({ outputBundleName: 'ort.webgpu', - define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true', 'BUILD_DEFS.DISABLE_WEBNN': 'true'}, + define: {...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGL': 'true'}, }); // ort.wasm[.min].js await addAllWebBuildTasks({ @@ -411,7 +409,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WASM': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); // ort.wasm-core[.min].js @@ -421,7 +418,6 @@ async function main() { ...DEFAULT_DEFINE, 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', 'BUILD_DEFS.DISABLE_WASM_PROXY': 'true', 'BUILD_DEFS.DISABLE_WASM_THREAD': 'true', }, @@ -434,7 +430,6 @@ async function main() { 'BUILD_DEFS.DISABLE_TRAINING': 'false', 'BUILD_DEFS.DISABLE_WEBGPU': 'true', 'BUILD_DEFS.DISABLE_WEBGL': 'true', - 'BUILD_DEFS.DISABLE_WEBNN': 'true', }, }); } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 8f6c5f6f04122..ed4dd76a6e315 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -396,10 +396,6 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs const globalEnvFlags = parseGlobalEnvFlags(args); - if (backend.includes('webnn') && !globalEnvFlags.wasm!.proxy) { - throw new Error('Backend webnn requires flag "wasm-enable-proxy" to be set to true.'); - } - // Options: // --log-verbose=<...> // --log-info=<...> diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index eaf549ef4e072..ef807a8c4fa26 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -70,22 +70,13 @@ Status Model::Predict(const InlinedHashMap& inputs, "The input of graph has unsupported type, name: ", name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the inputs from Wasm SharedArrayBuffer to the pre-allocated ArrayBuffers. + // Copy the inputs from Wasm ArrayBuffer to the WebNN inputs ArrayBuffer. + // As Wasm ArrayBuffer is not detachable. wnn_inputs_[name].call("set", view); -#else - wnn_inputs_.set(name, view); -#endif } -#ifdef ENABLE_WEBASSEMBLY_THREADS - // This vector uses for recording output buffers from WebNN graph compution when WebAssembly - // multi-threads is enabled, since WebNN API only accepts non-shared ArrayBufferView, - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews - // and at this time the 'view' defined by Emscripten is shared ArrayBufferView, the memory - // address is different from the non-shared one, additional memory copy is required here. InlinedHashMap output_views; -#endif + for (const auto& output : outputs) { const std::string& name = output.first; const struct OnnxTensorData tensor = output.second; @@ -131,21 +122,23 @@ Status Model::Predict(const InlinedHashMap& inputs, name, " type: ", tensor.tensor_info.data_type); } -#ifdef ENABLE_WEBASSEMBLY_THREADS output_views.insert({name, view}); -#else - wnn_outputs_.set(name, view); -#endif } - wnn_context_.call("computeSync", wnn_graph_, wnn_inputs_, wnn_outputs_); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm SharedArrayBuffer. + emscripten::val results = wnn_context_.call( + "compute", wnn_graph_, wnn_inputs_, wnn_outputs_) + .await(); + + // Copy the outputs from pre-allocated ArrayBuffers back to the Wasm ArrayBuffer. for (const auto& output : outputs) { const std::string& name = output.first; emscripten::val view = output_views.at(name); - view.call("set", wnn_outputs_[name]); + view.call("set", results["outputs"][name]); } -#endif + // WebNN compute() method would return the input and output buffers via the promise + // resolution. Reuse the buffers to avoid additional allocation. + wnn_inputs_ = results["inputs"]; + wnn_outputs_ = results["outputs"]; + return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cf8a0e23db43b..56f7ead8ccf5d 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -386,7 +386,8 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { for (auto& name : output_names_) { named_operands.set(name, wnn_operands_.at(name)); } - emscripten::val wnn_graph = wnn_builder_.call("buildSync", named_operands); + + emscripten::val wnn_graph = wnn_builder_.call("build", named_operands).await(); if (!wnn_graph.as()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to build WebNN graph."); } @@ -395,13 +396,10 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); model->SetInputOutputInfo(std::move(input_output_info_)); -#ifdef ENABLE_WEBASSEMBLY_THREADS - // Pre-allocate the input and output tensors for the WebNN graph - // when WebAssembly multi-threads is enabled since WebNN API only - // accepts non-shared ArrayBufferView. - // https://www.w3.org/TR/webnn/#typedefdef-mlnamedarraybufferviews + // Wasm heap is not transferrable, we have to pre-allocate the MLNamedArrayBufferViews + // for inputs and outputs because they will be transferred after compute() done. + // https://webmachinelearning.github.io/webnn/#api-mlcontext-async-execution model->AllocateInputOutputBuffers(); -#endif return Status::OK(); } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 2922cf9540a8e..df7871614b267 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -42,7 +42,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (webnn_power_flags.compare("default") != 0) { context_options.set("powerPreference", emscripten::val(webnn_power_flags)); } - wnn_context_ = ml.call("createContextSync", context_options); + + wnn_context_ = ml.call("createContext", context_options).await(); if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 7c70515e73eab..7e9c0a6f99c32 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -160,6 +160,10 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea }; // replace the original functions with asyncified versions + Module['_OrtCreateSession'] = jsepWrapAsync( + Module['_OrtCreateSession'], + () => Module['_OrtCreateSession'], + v => Module['_OrtCreateSession'] = v); Module['_OrtRun'] = runAsync(jsepWrapAsync( Module['_OrtRun'], () => Module['_OrtRun'],