diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index db8afcf3b775c..c8bd6a52002fa 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -99,7 +99,8 @@ export class WebNNBackend { this.bufferManager.releaseBufferId(bufferId); } - public ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]): MLBuffer { + public async ensureBuffer(bufferId: BufferId, onnxDataType: number|MLOperandDataType, dimensions: number[]): + Promise { let dataType: MLOperandDataType; if (typeof onnxDataType === 'number') { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index ea9fe6f35b6b8..2f38a042249ff 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -246,7 +246,7 @@ export const init = // jsepReleaseBufferId, (bufferId: number) => backend.releaseBufferId(bufferId), // jsepEnsureBuffer - (bufferId: number, onnxDataType: number, dimensions: number[]) => + async (bufferId: number, onnxDataType: number, dimensions: number[]) => backend.ensureBuffer(bufferId, onnxDataType, dimensions), // jsepUploadBuffer (bufferId: number, data: Uint8Array) => { diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts index 7d13aa760504b..6e84195e7d314 100644 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -25,7 +25,7 @@ export interface BufferManager { /** * Ensure a MLBuffer is created for the BufferId. */ - ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer; + ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise; /** * Upload data to a MLBuffer. */ @@ -85,12 +85,12 @@ class BufferTracker { this.mlBuffer = undefined; } - public ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): MLBuffer { + public async ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): Promise { if (this.mlBuffer) { return this.mlBuffer; } - const buffer = this.context.createBuffer({dataType, dimensions}); + const buffer = await this.context.createBuffer({dataType, dimensions}); this.mlBuffer = buffer; if (this.activeUpload) { @@ -151,7 +151,7 @@ class BufferManagerImpl implements BufferManager { } } - public ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): MLBuffer { + public async ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise { const buffer = this.buffersById.get(bufferId); if (!buffer) { throw new Error('Buffer not found.'); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 11e75af7245bc..17bd3b6243342 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -387,7 +387,7 @@ interface MLBuffer { type MLNamedBuffers = Record; interface MLContext { - createBuffer(descriptor: MLOperandDescriptor): MLBuffer; + createBuffer(descriptor: MLOperandDescriptor): Promise; writeBuffer( dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number, srcElementSize?: number): void; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 0c26afeb6bac4..829d4e838c676 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -704,7 +704,7 @@ export const run = async( // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use // ensureBuffer to get/create the MLBuffer. - const mlBuffer = ensureBuffer(dataOffset, dataType, dims); + const mlBuffer = await ensureBuffer(dataOffset, dataType, dims); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index bcf4360aa8585..e30e6dd3c8e7f 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -25,7 +25,8 @@ export declare namespace JSEP { type ReplayFunction = () => void; type ReserveBufferIdFunction = () => number; type ReleaseBufferIdFunction = (bufferId: number) => void; - type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer; + type EnsureBufferFunction = (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => + Promise; type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void; type DownloadBufferFunction = (bufferId: number) => Promise; @@ -154,7 +155,7 @@ export declare namespace JSEP { * @param bufferId - specify the MLBuffer ID. * @returns the MLBuffer. */ - jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => MLBuffer; + jsepEnsureBuffer: (bufferId: number, dataType: number|MLOperandDataType, dimensions: number[]) => Promise; /** * [exported from pre-jsep.js] Upload data to MLBuffer. * @param bufferId - specify the MLBuffer ID. diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 8f78c90d11014..5e71b769278a8 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -257,8 +257,7 @@ export class ModelTestContext { const executionProviderConfig = modelTest.backend === 'webnn' ? (testOptions?.webnnOptions || {name: 'webnn'}) : modelTest.backend!; let mlContext: MLContext|undefined; - if(['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) { - + if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) { const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption; const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType; const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads; @@ -593,7 +592,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty const dataType = type === 'bool' ? 'uint8' : type; - const mlBuffer = mlContext.createBuffer({dataType, dimensions: dims as number[]}); + const mlBuffer = await mlContext.createBuffer({dataType, dimensions: dims as number[]}); return ort.Tensor.fromMLBuffer(mlBuffer, { dataType: type, @@ -611,7 +610,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`); } const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; - const mlBuffer = mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]}); + const mlBuffer = await mlContext.createBuffer({dataType, dimensions: cpuTensor.dims as number[]}); mlContext.writeBuffer(mlBuffer, cpuTensor.data); return ort.Tensor.fromMLBuffer( mlBuffer, {dataType: cpuTensor.type, dims: cpuTensor.dims, dispose: () => mlBuffer.destroy()}); diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index f88e902e5efe0..22271640ef57f 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -211,10 +211,9 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { } } -bool IsMLBufferSupported(WebnnDeviceType device_type) { +bool IsMLBufferSupported() { static bool is_supported = !emscripten::val::global("MLBuffer").isUndefined(); - // The current MLBuffer implementation only supports GPU and NPU devices. - return is_supported && device_type != WebnnDeviceType::CPU; + return is_supported; } } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 14decb6e77ca2..ed7f41aea092d 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -285,7 +285,7 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); -bool IsMLBufferSupported(WebnnDeviceType device_type); +bool IsMLBufferSupported(); } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 151a5ed559d3b..ba84a5d6c56fd 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -155,27 +155,31 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer"); - for (const auto& input : inputs) { - const std::string& name = input.first; - const struct OnnxTensorData tensor = input.second; + auto promises = emscripten::val::array(); + for (const auto& [_, tensor] : inputs) { emscripten::val shape = emscripten::val::array(); for (const auto& dim : tensor.tensor_info.shape) { uint32_t dim_val = SafeInt(dim); shape.call("push", dim_val); } auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); - wnn_inputs_.set(name, buffer); + promises.call("push", buffer); } - for (const auto& output : outputs) { - const std::string& name = output.first; - const struct OnnxTensorData tensor = output.second; + for (const auto& [_, tensor] : outputs) { emscripten::val shape = emscripten::val::array(); for (const auto& dim : tensor.tensor_info.shape) { uint32_t dim_val = SafeInt(dim); shape.call("push", dim_val); } auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); - wnn_outputs_.set(name, buffer); + promises.call("push", buffer); + } + auto buffers = emscripten::val::global("Promise").call("all", promises).await(); + for (const auto& [name, _] : inputs) { + wnn_inputs_.set(name, buffers.call("shift")); + } + for (const auto& [name, _] : outputs) { + wnn_outputs_.set(name, buffers.call("shift")); } wnn_context_.call("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index cb785e7705d35..8cc56e212b444 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -340,7 +340,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { } // Explicitly release the WebNN builder to free memory. wnn_builder_ = emscripten::val::undefined(); - model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported(wnn_device_type_))); + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported())); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index efee65b9ea786..c3280ee3855d1 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -24,7 +24,7 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f onnxruntime::kWebNNExecutionProvider, // If MLBuffer is supported, we force all the tensors to be allocated as MLBuffer. OrtDevice( - webnn::IsMLBufferSupported(webnn::DeviceTypeFromString(webnn_device_flags)) ? OrtDevice::GPU : OrtDevice::CPU, + webnn::IsMLBufferSupported() ? OrtDevice::GPU : OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0)}, wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) { @@ -381,14 +381,14 @@ WebNNExecutionProvider::GetKernelRegistry() const { } std::unique_ptr WebNNExecutionProvider::GetDataTransfer() const { - if (!webnn::IsMLBufferSupported(wnn_device_type_)) { + if (!webnn::IsMLBufferSupported()) { return nullptr; } return std::make_unique(); } std::vector WebNNExecutionProvider::CreatePreferredAllocators() { - if (!webnn::IsMLBufferSupported(wnn_device_type_)) { + if (!webnn::IsMLBufferSupported()) { return {}; } AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) {