From d05970cb9d4868ee38204ee4a6a03090dc621c81 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Thu, 15 Aug 2024 22:23:53 -0700 Subject: [PATCH 01/17] [WebNN EP] Enable IO Bindings with MLBuffer Enables using the MLBuffers to pass data between models. This reduces the number of copies between the CPU and devices as well as the renderer and GPU process in Chromium. --- .../onnxruntime/core/framework/allocator.h | 1 + js/common/lib/tensor-factory-impl.ts | 12 + js/common/lib/tensor-factory.ts | 46 ++++ js/common/lib/tensor-impl.ts | 59 ++++- js/common/lib/tensor-utils-impl.ts | 8 + js/common/lib/tensor.ts | 30 ++- js/web/lib/wasm/jsep/backend-webnn.ts | 146 ++++++++++++ js/web/lib/wasm/jsep/init.ts | 19 +- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 212 ++++++++++++++++++ js/web/lib/wasm/jsep/webnn/webnn.d.ts | 13 +- js/web/lib/wasm/proxy-messages.ts | 10 +- js/web/lib/wasm/session-handler-inference.ts | 12 +- js/web/lib/wasm/wasm-common.ts | 18 +- js/web/lib/wasm/wasm-core-impl.ts | 59 ++++- js/web/lib/wasm/wasm-types.ts | 85 ++++++- js/web/script/test-runner-cli-args.ts | 6 +- js/web/script/test-runner-cli.ts | 2 +- js/web/test/test-runner.ts | 95 +++++++- js/web/test/test-types.ts | 6 +- onnxruntime/core/framework/allocator.cc | 3 +- onnxruntime/core/providers/webnn/allocator.cc | 41 ++++ onnxruntime/core/providers/webnn/allocator.h | 32 +++ .../core/providers/webnn/builders/helper.cc | 18 ++ .../core/providers/webnn/builders/helper.h | 4 + .../core/providers/webnn/builders/model.cc | 54 ++++- .../core/providers/webnn/builders/model.h | 10 +- .../providers/webnn/builders/model_builder.cc | 2 +- .../core/providers/webnn/data_transfer.cc | 47 ++++ .../core/providers/webnn/data_transfer.h | 21 ++ .../webnn/webnn_execution_provider.cc | 39 +++- .../webnn/webnn_execution_provider.h | 2 + onnxruntime/wasm/api.cc | 26 ++- onnxruntime/wasm/pre-jsep.js | 33 +++ 33 files changed, 1110 insertions(+), 61 deletions(-) create mode 100644 js/web/lib/wasm/jsep/backend-webnn.ts create mode 100644 js/web/lib/wasm/jsep/webnn/buffer-manager.ts create mode 100644 onnxruntime/core/providers/webnn/allocator.cc create mode 100644 onnxruntime/core/providers/webnn/allocator.h create mode 100644 onnxruntime/core/providers/webnn/data_transfer.cc create mode 100644 onnxruntime/core/providers/webnn/data_transfer.h diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 097873c5e3653..17d8d804d4ae3 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -51,6 +51,7 @@ constexpr const char* HIP_PINNED = "HipPinned"; constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; +constexpr const char* WEBNN_BUFFER = "WebNN_Buffer"; constexpr size_t kAllocAlignment = 256; diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts index 52e028a9fcd31..38e3d841146f9 100644 --- a/js/common/lib/tensor-factory-impl.ts +++ b/js/common/lib/tensor-factory-impl.ts @@ -11,6 +11,7 @@ import { TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, + TensorFromMLBufferOptions, TensorFromTextureOptions, TensorFromUrlOptions, } from './tensor-factory.js'; @@ -310,6 +311,17 @@ export const tensorFromGpuBuffer = ( + mlBuffer: TensorInterface.MLBufferType, + options: TensorFromMLBufferOptions, +): Tensor => { + const { dataType, dims, download, dispose } = options; + return new Tensor({ location: 'ml-buffer', type: dataType ?? 'float32', mlBuffer, dims, download, dispose }); +}; + /** * implementation of Tensor.fromPinnedBuffer(). */ diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 7938b4a4eb927..95e822535a787 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -86,6 +86,20 @@ export interface GpuBufferConstructorParameters + extends CommonConstructorParameters, + GpuResourceConstructorParameters { + /** + * Specify the location of the data to be 'ml-buffer'. + */ + readonly location: 'ml-buffer'; + + /** + * Specify the WebNN buffer that holds the tensor data. + */ + readonly mlBuffer: Tensor.MLBufferType; +} + // #endregion // the following region contains type definitions of each individual options. @@ -219,6 +233,15 @@ export interface TensorFromGpuBufferOptions dataType?: T; } +export interface TensorFromMLBufferOptions + extends Pick, + GpuResourceConstructorParameters { + /** + * Describes the data type of the tensor. + */ + dataType?: T; +} + // #endregion /** @@ -336,6 +359,29 @@ export interface TensorFactory { options: TensorFromGpuBufferOptions, ): TypedTensor; + /** + * create a tensor from a WebNN MLBuffer + * + * @param buffer - the MLBuffer object to create tensor from + * @param options - An optional object representing options for creating tensor from a WebNN MLBuffer. + * + * The options include following properties: + * - `dataType`: the data type of the tensor. If omitted, assume 'float32'. + * - `dims`: the dimension of the tensor. Required. + * - `download`: an optional function to download the tensor data from the MLBuffer to CPU. If omitted, the MLBuffer + * data will not be able to download. Usually, this is provided by the WebNN backend for the inference outputs. + * Users don't need to provide this function. + * - `dispose`: an optional function to dispose the tensor data on the WebNN MLBuffer. If omitted, the MLBuffer will + * not be disposed. Usually, this is provided by the WebNN backend for the inference outputs. Users don't need to + * provide this function. + * + * @returns a tensor object + */ + fromMLBuffer( + buffer: Tensor.MLBufferType, + options: TensorFromMLBufferOptions, + ): TypedTensor; + /** * create a tensor from a pre-allocated buffer. The buffer will be used as a pinned buffer. * diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 12c6d79d88d2b..47cb3e3e9905c 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -6,16 +6,19 @@ import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conve import { tensorFromGpuBuffer, tensorFromImage, + tensorFromMLBuffer, tensorFromPinnedBuffer, tensorFromTexture, } from './tensor-factory-impl.js'; import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, + MLBufferConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, + TensorFromMLBufferOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters, @@ -37,6 +40,7 @@ type TensorDataType = TensorInterface.DataType; type TensorDataLocation = TensorInterface.DataLocation; type TensorTextureType = TensorInterface.TextureType; type TensorGpuBufferType = TensorInterface.GpuBufferType; +type TensorMLBufferType = TensorInterface.MLBufferType; /** * the implementation of Tensor interface. @@ -83,6 +87,15 @@ export class Tensor implements TensorInterface { */ constructor(params: GpuBufferConstructorParameters); + /** + * Construct a new tensor object from the WebNN buffer with the given type and dims. + * + * Tensor's location will be set to 'ml-buffer'. + * + * @param params - Specify the parameters to construct the tensor. + */ + constructor(params: MLBufferConstructorParameters); + /** * implementation. */ @@ -94,7 +107,8 @@ export class Tensor implements TensorInterface { | readonly boolean[] | CpuPinnedConstructorParameters | TextureConstructorParameters - | GpuBufferConstructorParameters, + | GpuBufferConstructorParameters + | MLBufferConstructorParameters, arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { @@ -149,6 +163,25 @@ export class Tensor implements TensorInterface { this.disposer = arg0.dispose; break; } + case 'ml-buffer': { + if ( + type !== 'float32' && + type !== 'float16' && + type !== 'int32' && + type !== 'int64' && + type !== 'uint32' && + type !== 'uint64' && + type !== 'int8' && + type !== 'uint8' && + type !== 'bool' + ) { + throw new TypeError(`unsupported type "${type}" to create tensor from MLBuffer`); + } + this.mlBufferData = arg0.mlBuffer; + this.downloader = arg0.download; + this.disposer = arg0.dispose; + break; + } default: throw new Error(`Tensor constructor: unsupported location '${this.dataLocation}'`); } @@ -310,6 +343,13 @@ export class Tensor implements TensorInterface { return tensorFromGpuBuffer(gpuBuffer, options); } + static fromMLBuffer( + mlBuffer: TensorMLBufferType, + options: TensorFromMLBufferOptions, + ): TensorInterface { + return tensorFromMLBuffer(mlBuffer, options); + } + static fromPinnedBuffer( type: T, buffer: TensorInterface.DataTypeMap[T], @@ -358,6 +398,11 @@ export class Tensor implements TensorInterface { */ private gpuBufferData?: TensorGpuBufferType; + /** + * stores the underlying WebNN MLBuffer when location is 'ml-buffer'. otherwise empty. + */ + private mlBufferData?: TensorMLBufferType; + /** * stores an optional downloader function to download data from GPU to CPU. */ @@ -405,6 +450,14 @@ export class Tensor implements TensorInterface { } return this.gpuBufferData; } + + get mlBuffer(): TensorMLBufferType { + this.ensureValid(); + if (!this.mlBufferData) { + throw new Error('The data is not stored as a WebNN buffer.'); + } + return this.mlBufferData; + } // #endregion // #region methods @@ -416,7 +469,8 @@ export class Tensor implements TensorInterface { case 'cpu-pinned': return this.data; case 'texture': - case 'gpu-buffer': { + case 'gpu-buffer': + case 'ml-buffer': { if (!this.downloader) { throw new Error('The current tensor is not created with a specified data downloader.'); } @@ -457,6 +511,7 @@ export class Tensor implements TensorInterface { this.cpuData = undefined; this.gpuTextureData = undefined; this.gpuBufferData = undefined; + this.mlBufferData = undefined; this.downloader = undefined; this.isDownloading = undefined; diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts index 9c633cd95fac3..4c4c9b1d80185 100644 --- a/js/common/lib/tensor-utils-impl.ts +++ b/js/common/lib/tensor-utils-impl.ts @@ -4,6 +4,7 @@ import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, + MLBufferConstructorParameters, TextureConstructorParameters, } from './tensor-factory.js'; import { Tensor } from './tensor-impl.js'; @@ -56,6 +57,13 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor = type: tensor.type as GpuBufferConstructorParameters['type'], dims, }); + case 'ml-buffer': + return new Tensor({ + location: 'ml-buffer', + mlBuffer: tensor.mlBuffer, + type: tensor.type as MLBufferConstructorParameters['type'], + dims, + }); default: throw new Error(`tensorReshape: tensor location ${tensor.location} is not supported`); } diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 70396bbe1e9a3..636ab0704ffe5 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -42,6 +42,13 @@ interface TypedTensorBase { */ readonly gpuBuffer: Tensor.GpuBufferType; + /** + * Get the WebNN buffer that holds the tensor data. + * + * If the data is not in a WebNN MLBuffer, throw error. + */ + readonly mlBuffer: Tensor.MLBufferType; + /** * Get the buffer data of the tensor. * @@ -136,15 +143,36 @@ export declare namespace Tensor { */ export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; + /** + * type alias for WebNN MLBuffer + * + * The specification for WebNN's ML Buffer is currently in flux. + */ + export type MLBufferType = unknown; + /** * supported data types for constructing a tensor from a WebGPU buffer */ export type GpuBufferDataTypes = 'float32' | 'float16' | 'int32' | 'int64' | 'uint32' | 'uint8' | 'bool'; + /** + * supported data types for constructing a tensor from a WebNN MLBuffer + */ + export type MLBufferDataTypes = + | 'float32' + | 'float16' + | 'int8' + | 'uint8' + | 'int32' + | 'uint32' + | 'int64' + | 'uint64' + | 'bool'; + /** * represent where the tensor data is stored */ - export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer'; + export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer' | 'ml-buffer'; /** * represent the data type of a tensor diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts new file mode 100644 index 0000000000000..c623f8d4e67d0 --- /dev/null +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -0,0 +1,146 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + +import { Tensor } from 'onnxruntime-common'; + +import { DataType } from '../wasm-common'; +import { getInstance } from '../wasm-factory'; + +import { createView } from './tensor-view'; +import { BufferId, BufferManager, createBufferManager } from './webnn/buffer-manager'; + +/* + * TensorProto::data_type to WebNN OperandType mapping. + */ +const onnxDataTypeToWebnnDataType = new Map([ + [DataType.float, 'float32'], + [DataType.float16, 'float16'], + [DataType.int32, 'int32'], + [DataType.uint32, 'uint32'], + [DataType.int64, 'int64'], + [DataType.uint64, 'uint64'], + [DataType.int8, 'int8'], + [DataType.uint8, 'uint8'], + [DataType.bool, 'uint8'], +]); + +/** + * WebNN backend implementation. This class is used to keep track of the MLBuffers created by the backend and keep track + * of the current MLContext being used by the sessions. + */ +export class WebNNBackend { + private bufferManager: BufferManager = createBufferManager(this); + /** + * Maps from session id to MLContexts. + */ + private mlContextBySessionId = new Map(); + /** + * Maps from MLContext to session ids. + */ + private sessionIdsByMLContext = new Map>(); + /** + * Current session id. + */ + currentSessionId?: number; + + public onRunStart(sessionId: number): void { + this.currentSessionId = sessionId; + } + + public get currentContext(): MLContext { + if (this.currentSessionId === undefined) { + throw new Error('No active session'); + } + return this.getMLContext(this.currentSessionId); + } + + public registerMLContext(sessionId: number, mlContext: MLContext): void { + this.mlContextBySessionId.set(sessionId, mlContext); + let sessionIds = this.sessionIdsByMLContext.get(mlContext); + if (!sessionIds) { + sessionIds = new Set(); + this.sessionIdsByMLContext.set(mlContext, sessionIds); + } + sessionIds.add(sessionId); + } + + public unregisterMLContext(sessionId: number): void { + const mlContext = this.mlContextBySessionId.get(sessionId)!; + if (!mlContext) { + throw new Error(`No MLContext found for session ${sessionId}`); + } + this.mlContextBySessionId.delete(sessionId); + const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; + sessionIds.delete(sessionId); + if (sessionIds.size === 0) { + this.sessionIdsByMLContext.delete(mlContext); + } + } + + public onReleaseSession(sessionId: number): void { + this.unregisterMLContext(sessionId); + this.bufferManager.releaseBuffersForContext(this.getMLContext(sessionId)); + } + + public getMLContext(sessionId: number): MLContext { + return this.mlContextBySessionId.get(sessionId)!; + } + + public reserveBufferId(): BufferId { + return this.bufferManager.reserveBufferId(); + } + + public releaseBufferId(bufferId: BufferId): void { + this.bufferManager.releaseBufferId(bufferId); + } + + public async ensureBuffer( + bufferId: BufferId, + onnxDataType: number | MLOperandDataType, + dimensions: number[], + ): Promise { + let dataType: MLOperandDataType; + if (typeof onnxDataType === 'number') { + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; + if (!webnnDataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + dataType = webnnDataType; + } else { + dataType = onnxDataType; + } + return this.bufferManager.ensureBuffer(bufferId, dataType, dimensions); + } + + public uploadBuffer(bufferId: BufferId, data: Uint8Array): void { + const wasm = getInstance(); + if (!wasm.shouldTransferToMLBuffer) { + throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false'); + } + this.bufferManager.upload(bufferId, data); + } + + public async downloadBuffer(bufferId: BufferId): Promise { + return this.bufferManager.download(bufferId); + } + + public createMLBufferDownloader(bufferId: BufferId, type: Tensor.MLBufferDataTypes): () => Promise { + return async () => { + const data = await this.bufferManager.download(bufferId); + return createView(data, type); + }; + } + + public registerMLBuffer(buffer: MLBuffer): BufferId { + return this.bufferManager.registerBuffer(this.currentContext, buffer); + } + + public flush(): void { + // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. + } +} diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3f326881079f0..33e0b1a5c2133 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -11,6 +11,7 @@ import { LOG_DEBUG } from './log'; import { TensorView } from './tensor-view'; import { ShapeUtil } from './util'; import { AdapterInfo, ComputeContext, ComputeContextInputsOutputsMapping, ProgramInfo } from './webgpu/types'; +import { WebNNBackend } from './backend-webnn'; /* eslint-disable no-bitwise */ @@ -257,6 +258,22 @@ export const init = async ( () => backend.replay(), ]); } else { - jsepInit('webnn'); + const backend = new WebNNBackend(); + jsepInit('webnn', [ + backend, + // jsepReserveBufferId + () => backend.reserveBufferId(), + // jsepReleaseBufferId, + (bufferId: number) => backend.releaseBufferId(bufferId), + // jsepEnsureBuffer + async (bufferId: number, onnxDataType: number, dimensions: number[]) => + backend.ensureBuffer(bufferId, onnxDataType, dimensions), + // jsepUploadBuffer + (bufferId: number, data: Uint8Array) => { + backend.uploadBuffer(bufferId, data); + }, + // jsepDownloadBuffer + async (bufferId: number) => backend.downloadBuffer(bufferId), + ]); } }; diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts new file mode 100644 index 0000000000000..fcd980d36a828 --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -0,0 +1,212 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { WebNNBackend } from '../backend-webnn'; + +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + +export type BufferId = number; + +/** + * Manages BufferId to MLBuffer mapping. + */ +export interface BufferManager { + /** + * Reserve a new BufferId. + */ + reserveBufferId(): BufferId; + /** + * Release a BufferId. + */ + releaseBufferId(bufferId: BufferId): void; + /** + * Ensure a MLBuffer is created for the BufferId. + */ + ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise; + /** + * Upload data to a MLBuffer. + */ + upload(bufferId: BufferId, data: Uint8Array): void; + /** + * Download data from a MLBuffer. + */ + download(bufferId: BufferId): Promise; + /** + * Release all buffers for a MLContext. + */ + releaseBuffersForContext(mlContext: MLContext): void; + /** + * Register an externally created MLBuffer with a given MLContext and return a BufferId. + */ + registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer): BufferId; +} + +let bufferGuid = 1; +const createNewBufferId = (): BufferId => bufferGuid++; + +/** + * BufferTracker tracks the MLBuffer and pending upload data. + * + * We need to track the MLBuffer and pending upload data because we delay the creation of MLBuffer until + * we know the data type and dimensions. This is because future implementations of WebNN will only support creating + * MLBuffers with dataTypes and dimensions. + */ +class BufferTracker { + private mlBuffer?: MLBuffer; + private activeUpload?: Uint8Array; + + constructor( + private mlContext?: MLContext, + buffer?: MLBuffer, + ) { + this.mlBuffer = buffer; + } + + public get buffer(): MLBuffer | undefined { + return this.mlBuffer; + } + + public get context(): MLContext { + if (!this.mlContext) { + throw new Error('MLContext has not been set.'); + } + return this.mlContext; + } + + public set context(mlContext: MLContext) { + if (this.mlContext && this.mlContext !== mlContext) { + throw new Error('MLBuffer in use in a different MLContext.'); + } + this.mlContext = mlContext; + } + + public destroy(): void { + this.mlBuffer?.destroy(); + this.mlBuffer = undefined; + } + + public async ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): Promise { + if (this.mlBuffer) { + return this.mlBuffer; + } + + const buffer = await this.context.createBuffer({ dataType, dimensions }); + this.mlBuffer = buffer; + + if (this.activeUpload) { + this.mlContext?.writeBuffer(buffer, this.activeUpload); + this.activeUpload = undefined; + } + + return buffer; + } + + public upload(data: Uint8Array): void { + if (!this.mlBuffer) { + this.activeUpload = new Uint8Array(data); + return; + } + + this.mlContext?.writeBuffer(this.mlBuffer, data); + } + + public async download(): Promise { + if (this.activeUpload) { + return this.activeUpload.buffer; + } + if (!this.mlBuffer) { + throw new Error('Buffer has not been created.'); + } + return this.context.readBuffer(this.mlBuffer); + } +} + +class BufferManagerImpl implements BufferManager { + private buffersById = new Map(); + private bufferIdsByContext = new Map>(); + + constructor(private backend: WebNNBackend) {} + + public reserveBufferId(): BufferId { + const bufferId = createNewBufferId(); + this.buffersById.set(bufferId, new BufferTracker()); + return bufferId; + } + + public releaseBufferId(bufferId: BufferId): void { + const bufferTracker = this.buffersById.get(bufferId); + if (!bufferTracker) { + return; + } + bufferTracker.destroy(); + this.buffersById.delete(bufferId); + for (const [mlContext, buffers] of this.bufferIdsByContext) { + if (buffers.has(bufferId)) { + buffers.delete(bufferId); + if (buffers.size === 0) { + this.bufferIdsByContext.delete(mlContext); + } + break; + } + } + } + + public async ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise { + const buffer = this.buffersById.get(bufferId); + if (!buffer) { + throw new Error('Buffer not found.'); + } + buffer.context = this.backend.currentContext; + if (!this.bufferIdsByContext.has(this.backend.currentContext)) { + this.bufferIdsByContext.set(this.backend.currentContext, new Set()); + } + this.bufferIdsByContext.get(this.backend.currentContext)?.add(bufferId); + return buffer.ensureBuffer(dataType, dimensions); + } + + public upload(bufferId: BufferId, data: Uint8Array): void { + this.buffersById.get(bufferId)!.upload(data); + } + + public async download(bufferId: BufferId): Promise { + return this.buffersById.get(bufferId)!.download(); + } + + public releaseBuffersForContext(mlContext: MLContext): void { + const buffers = this.bufferIdsByContext.get(mlContext); + if (!buffers) { + return; + } + for (const bufferId of buffers) { + this.buffersById.get(bufferId)!.destroy(); + this.buffersById.delete(bufferId); + } + this.bufferIdsByContext.delete(mlContext); + } + + public registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer): BufferId { + for (const [bufferId, bufferTracker] of this.buffersById) { + if (bufferTracker.buffer === mlBuffer) { + if (bufferTracker.context !== mlContext) { + throw new Error('MLBuffer cannot be registered with a different MLContext.'); + } + return bufferId; + } + } + const bufferId = createNewBufferId(); + this.buffersById.set(bufferId, new BufferTracker(mlContext, mlBuffer)); + let buffers = this.bufferIdsByContext.get(mlContext); + if (!buffers) { + buffers = new Set(); + this.bufferIdsByContext.set(mlContext, buffers); + } + buffers.add(bufferId); + return bufferId; + } +} + +export const createBufferManager = (...args: ConstructorParameters): BufferManager => + new BufferManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index f8a1e1966fd4c..17bd3b6243342 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -381,21 +381,16 @@ interface MLGraphBuilder { // Experimental MLBuffer interface -type MLSize64Out = number; interface MLBuffer { - readonly size: MLSize64Out; destroy(): void; } -type MLSize64 = number; -interface MLBufferDescriptor { - size: MLSize64; -} + type MLNamedBuffers = Record; interface MLContext { - createBuffer(descriptor: MLBufferDescriptor): MLBuffer; + createBuffer(descriptor: MLOperandDescriptor): Promise; writeBuffer( - dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: MLSize64, - srcElementSize?: MLSize64): void; + dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number, + srcElementSize?: number): void; readBuffer(srcBuffer: MLBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; } diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 8f3acdd582445..58aea4d0c6591 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -19,11 +19,18 @@ export type GpuBufferMetadata = { dispose?: () => void; }; +export type MLBufferMetadata = { + mlBuffer: Tensor.MLBufferType; + download?: () => Promise; + dispose?: () => void; +}; + /** - * Tensors on location 'cpu-pinned' and 'gpu-buffer' are not serializable. + * Tensors on location 'cpu-pinned', 'gpu-buffer', and 'ml-buffer' are not serializable. */ export type UnserializableTensorMetadata = | [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer'] + | [dataType: Tensor.Type, dims: readonly number[], data: MLBufferMetadata, location: 'ml-buffer'] | [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; /** @@ -34,6 +41,7 @@ export type UnserializableTensorMetadata = * - cpu: Uint8Array * - cpu-pinned: Uint8Array * - gpu-buffer: GpuBufferMetadata + * - ml-buffer: MLBufferMetadata * - location: tensor data location */ export type TensorMetadata = SerializableTensorMetadata | UnserializableTensorMetadata; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index eff3e91389c98..7ea52f3f470b7 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -12,7 +12,7 @@ import { import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper'; -import { isGpuBufferSupportedType } from './wasm-common'; +import { isGpuBufferSupportedType, isMLBufferSupportedType } from './wasm-common'; import { isNode } from './wasm-utils-env'; import { loadFile } from './wasm-utils-load-file'; @@ -22,6 +22,8 @@ export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): Ten return [tensor.type, tensor.dims, tensor.data, 'cpu']; case 'gpu-buffer': return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer']; + case 'ml-buffer': + return [tensor.type, tensor.dims, { mlBuffer: tensor.mlBuffer }, 'ml-buffer']; default: throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); } @@ -39,6 +41,14 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { const { gpuBuffer, download, dispose } = tensor[2]; return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose }); } + case 'ml-buffer': { + const dataType = tensor[0]; + if (!isMLBufferSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing MLBuffer tensor`); + } + const { mlBuffer, download, dispose } = tensor[2]; + return Tensor.fromMLBuffer(mlBuffer, { dataType, dims: tensor[1], download, dispose }); + } default: throw new Error(`invalid data location: ${tensor[3]}`); } diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index fd5d93675154c..e403aad0d0f2d 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -238,6 +238,20 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB type === 'uint8' || type === 'bool'; +/** + * Check whether the given tensor type is supported by WebNN MLBuffer + */ +export const isMLBufferSupportedType = (type: Tensor.Type): type is Tensor.MLBufferDataTypes => + type === 'float32' || + type === 'float16' || + type === 'int32' || + type === 'int64' || + type === 'uint32' || + type === 'uint64' || + type === 'int8' || + type === 'uint8' || + type === 'bool'; + /** * Map string data location to integer value */ @@ -253,6 +267,8 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number return 3; case 'gpu-buffer': return 4; + case 'ml-buffer': + return 5; default: throw new Error(`unsupported data location: ${location}`); } @@ -262,4 +278,4 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number * Map integer data location to string value */ export const dataLocationEnumToString = (location: number): Tensor.DataLocation | undefined => - (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer'] as const)[location]; + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer', 'ml-buffer'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..ef6fb87c41fe0 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -20,6 +20,7 @@ import { calculateTensorSizeInBytes, dataLocationStringToEnum, isGpuBufferSupportedType, + isMLBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, @@ -163,7 +164,7 @@ export const initEp = async (env: Env, epName: string): Promise => { /** * valid data locations for input/output tensors. */ -type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer'; +type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-buffer'; type IOBindingState = { /** @@ -174,7 +175,7 @@ type IOBindingState = { /** * the preferred location for each output tensor. * - * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer'. + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-buffer'. */ readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; @@ -288,6 +289,7 @@ export const createSession = async ( for (const provider of options?.executionProviders ?? []) { const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { + wasm.shouldTransferToMLBuffer = false; if (wasm.currentContext) { throw new Error('WebNN execution provider is already set.'); } @@ -319,7 +321,9 @@ export const createSession = async ( // clear current MLContext after session creation if (wasm.currentContext) { + wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); wasm.currentContext = undefined; + wasm.shouldTransferToMLBuffer = true; } const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); @@ -355,7 +359,7 @@ export const createSession = async ( typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : (options?.preferredOutputLocation?.[nameString] ?? 'cpu'); - if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer') { + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-buffer') { throw new Error(`Not supported preferred output location: ${location}.`); } if (enableGraphCapture && location !== 'gpu-buffer') { @@ -369,7 +373,7 @@ export const createSession = async ( // use IO binding only when at least one output is preffered to be on GPU. let bindingState: IOBindingState | null = null; - if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer')) { + if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-buffer')) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); if (ioBindingHandle === 0) { checkLastError("Can't create IO binding."); @@ -460,7 +464,7 @@ export const prepareInputOutputTensor = ( let rawData: number; let dataByteLength: number; - if (dataType === 'string' && location === 'gpu-buffer') { + if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-buffer')) { throw new Error('String tensor is not supported on GPU.'); } @@ -479,6 +483,15 @@ export const prepareInputOutputTensor = ( throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); } rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); + } else if (location === 'ml-buffer') { + const mlBuffer = tensor[2].mlBuffer as MLBuffer; + dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; + + const registerMLBuffer = wasm.jsepRegisterMLBuffer; + if (!registerMLBuffer) { + throw new Error('Tensor location "ml-buffer" is not supported without using WebNN.'); + } + rawData = registerMLBuffer(mlBuffer); } else { const data = tensor[2]; @@ -564,6 +577,9 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { + // WebNN backend needs the active session to check MLBuffers with the current context. + wasm.jsepOnRunStart?.(sessionHandle); + [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -655,7 +671,6 @@ export const run = async ( ]); } - wasm.jsepOnRunStart?.(sessionHandle); let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -727,7 +742,7 @@ export const run = async ( const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; if (type === 'string') { - if (preferredLocation === 'gpu-buffer') { + if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-buffer') { throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; @@ -767,6 +782,36 @@ export const run = async ( }, 'gpu-buffer', ]); + } else if (preferredLocation === 'ml-buffer' && size > 0) { + const ensureBuffer = wasm.jsepEnsureBuffer; + if (!ensureBuffer) { + throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.'); + } + const bufferSize = calculateTensorSizeInBytes(dataType, size); + if (bufferSize === undefined || !isMLBufferSupportedType(type)) { + throw new Error(`Unsupported data type: ${type}`); + } + + // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use + // ensureBuffer to get/create the MLBuffer. + const mlBuffer = await ensureBuffer(dataOffset, dataType, dims); + + // do not release the tensor right now. it will be released when user calls tensor.dispose(). + keepOutputTensor = true; + + output.push([ + type, + dims, + { + mlBuffer, + download: wasm.jsepCreateMLBufferDownloader!(dataOffset, type), + dispose: () => { + wasm.jsepReleaseBufferId!(dataOffset); + wasm._OrtReleaseTensor(tensor); + }, + }, + 'ml-buffer', + ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); const data = new typedArrayConstructor(size); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..8cbe111ff373a 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -27,6 +27,15 @@ export declare namespace JSEP { type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; + type ReserveBufferIdFunction = () => number; + type ReleaseBufferIdFunction = (bufferId: number) => void; + type EnsureBufferFunction = ( + bufferId: number, + dataType: number | MLOperandDataType, + dimensions: number[], + ) => Promise; + type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void; + type DownloadBufferFunction = (bufferId: number) => Promise; export interface Module extends WebGpuModule, WebNnModule { /** @@ -62,7 +71,17 @@ export declare namespace JSEP { replay: ReplayFunction, ], ): void; - jsepInit(name: 'webnn', initParams?: never): void; + jsepInit( + name: 'webnn', + initParams: [ + backend: BackendType, + reserveBufferId: ReserveBufferIdFunction, + releaseBufferId: ReleaseBufferIdFunction, + ensureBuffer: EnsureBufferFunction, + uploadBuffer: UploadBufferFunction, + downloadBuffer: DownloadBufferFunction, + ], + ): void; } export interface WebGpuModule { @@ -134,6 +153,70 @@ export declare namespace JSEP { * Active MLContext used to create WebNN EP. */ currentContext: MLContext; + + /** + * Disables creating MLBuffers. This is used to avoid creating MLBuffers for graph initializers. + */ + shouldTransferToMLBuffer: boolean; + + /** + * [exported from pre-jsep.js] Register MLContext for a session. + * @param sessionId - specify the session ID. + * @param context - specify the MLContext. + * @returns + */ + jsepRegisterMLContext: (sessionId: number, context: MLContext) => void; + /** + * [exported from pre-jsep.js] Reserve a MLBuffer ID attached to the current session. + * @returns the MLBuffer ID. + */ + jsepReserveBufferId: () => number; + /** + * [exported from pre-jsep.js] Release a MLBuffer ID from use and destroy buffer if no longer in use. + * @param bufferId - specify the MLBuffer ID. + * @returns + */ + jsepReleaseBufferId: (bufferId: number) => void; + /** + * [exported from pre-jsep.js] Get MLBuffer by ID. + * @param bufferId - specify the MLBuffer ID. + * @returns the MLBuffer. + */ + jsepEnsureBuffer: ( + bufferId: number, + dataType: number | MLOperandDataType, + dimensions: number[], + ) => Promise; + /** + * [exported from pre-jsep.js] Upload data to MLBuffer. + * @param bufferId - specify the MLBuffer ID. + * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType. + * @param dimensions - specify the dimensions. + * @returns + */ + jsepUploadBuffer: (bufferId: number, data: Uint8Array) => void; + /** + * [exported from pre-jsep.js] Download data from MLBuffer. + * @param bufferId - specify the MLBuffer ID. + * @returns the downloaded data. + */ + jsepDownloadBuffer: (bufferId: number) => Promise; + /** + * [exported from pre-jsep.js] Create a downloader function to download data from MLBuffer. + * @param bufferId - specify the MLBuffer ID. + * @param type - specify the data type. + * @returns the downloader function. + */ + jsepCreateMLBufferDownloader: ( + bufferId: number, + type: Tensor.MLBufferDataTypes, + ) => () => Promise; + /** + * [exported from pre-jsep.js] Register MLBuffer for a session. + * @param mlBuffer - specify the MLBuffer. + * @returns the MLBuffer ID. + */ + jsepRegisterMLBuffer: (buffer: MLBuffer) => number; } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index d237293dbb192..6e156c5e17516 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -62,6 +62,8 @@ Options: none (default) gpu-tensor use pre-allocated GPU tensors for inputs and outputs gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' + ml-tensor use pre-allocated MLBuffer tensors for inputs and outputs + ml-location use pre-allocated MLBuffer tensors for inputs and set preferredOutputLocation to 'ml-buffer' *** Logging Options *** @@ -133,7 +135,7 @@ export declare namespace TestRunnerCliArgs { type Backend = 'cpu' | 'webgl' | 'webgpu' | 'wasm' | 'onnxruntime' | 'webnn'; type Environment = 'chrome' | 'chromecanary' | 'edge' | 'firefox' | 'electron' | 'safari' | 'node' | 'bs'; type BundleMode = 'dev' | 'perf'; - type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; + type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location'; } export interface TestRunnerCliArgs { @@ -455,7 +457,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: -i=<...>, --io-binding=<...> const ioBindingArg = args['io-binding'] || args.i; const ioBindingMode = typeof ioBindingArg !== 'string' ? 'none' : ioBindingArg; - if (['none', 'gpu-tensor', 'gpu-location'].indexOf(ioBindingMode) === -1) { + if (['none', 'gpu-tensor', 'gpu-location', 'ml-tensor', 'ml-location'].indexOf(ioBindingMode) === -1) { throw new Error(`not supported io binding mode ${ioBindingMode}`); } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index a9fcd7b876b2f..68ee58dab7094 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -380,7 +380,7 @@ async function main() { } let ioBinding: Test.IOBindingMode; - if (backend !== 'webgpu' && args.ioBindingMode !== 'none') { + if (!['webgpu', 'webnn'].includes(backend) && args.ioBindingMode !== 'none') { npmlog.warn( 'TestRunnerCli.Init.Model', `Ignoring IO Binding Mode "${args.ioBindingMode}" for backend "${backend}".`, diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index aa9555c191501..e4b38827f874d 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -1,6 +1,11 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + import { Float16Array as Float16ArrayPolyfill } from '@petamoriken/float16'; import { expect } from 'chai'; import * as ort from 'onnxruntime-common'; @@ -19,6 +24,7 @@ import { createView } from '../lib/wasm/jsep/tensor-view'; import { calculateTensorSizeInBytes, isGpuBufferSupportedType, + isMLBufferSupportedType, tensorDataTypeStringToEnum, } from '../lib/wasm/wasm-common'; @@ -170,13 +176,20 @@ async function initializeSession( }`, ); + let preferredOutputLocation: ort.Tensor.DataLocation | undefined; + if (ioBindingMode === 'gpu-location') { + preferredOutputLocation = 'gpu-buffer'; + } else if (ioBindingMode === 'ml-location') { + preferredOutputLocation = 'ml-buffer'; + } + const profilerConfig = profile ? { maxNumberEvents: 65536 } : undefined; const sessionConfig = { ...sessionOptions, executionProviders: [backendHint], profiler: profilerConfig, enableProfiling: profile, - preferredOutputLocation: ioBindingMode === 'gpu-location' ? ('gpu-buffer' as const) : undefined, + preferredOutputLocation, externalData, }; @@ -219,6 +232,7 @@ export class ModelTestContext { readonly perfData: ModelTestContext.ModelTestPerfData, readonly ioBinding: Test.IOBindingMode, private readonly profile: boolean, + public readonly mlContext?: MLContext, ) {} /** @@ -272,7 +286,24 @@ export class ModelTestContext { const initStart = now(); const executionProviderConfig = - modelTest.backend === 'webnn' ? testOptions?.webnnOptions || 'webnn' : modelTest.backend!; + modelTest.backend === 'webnn' ? testOptions?.webnnOptions || { name: 'webnn' } : modelTest.backend!; + let mlContext: MLContext | undefined; + if (['ml-tensor', 'ml-location'].includes(modelTest.ioBinding)) { + const webnnOptions = executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption; + const deviceType = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.deviceType; + const numThreads = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.numThreads; + const powerPreference = (webnnOptions as ort.InferenceSession.WebNNContextOptions)?.powerPreference; + + mlContext = await navigator.ml.createContext({ + deviceType, + numThreads, + powerPreference, + }); + (executionProviderConfig as ort.InferenceSession.WebNNExecutionProviderOption).context = mlContext; + if (!deviceType) { + (executionProviderConfig as ort.InferenceSession.WebNNContextOptions).deviceType = deviceType; + } + } const session = await initializeSession( modelTest.modelUrl, executionProviderConfig, @@ -295,6 +326,7 @@ export class ModelTestContext { { init: initEnd - initStart, firstRun: -1, runs: [], count: 0 }, modelTest.ioBinding, profile, + mlContext, ); } finally { this.initializing = false; @@ -622,30 +654,70 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] }); } +async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Type, dims: readonly number[]) { + if (!isMLBufferSupportedType(type)) { + throw new Error(`createMLTensorForOutput can not work with ${type} tensor`); + } + + const dataType = type === 'bool' ? 'uint8' : type; + + const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: dims as number[] }); + + return ort.Tensor.fromMLBuffer(mlBuffer, { + dataType: type, + dims, + dispose: () => mlBuffer.destroy(), + download: async () => { + const arrayBuffer = await mlContext.readBuffer(mlBuffer); + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.MLBufferDataTypes]; + }, + }); +} + +async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tensor): Promise { + if (!isMLBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`); + } + const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; + const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: cpuTensor.dims as number[] }); + mlContext.writeBuffer(mlBuffer, cpuTensor.data); + return ort.Tensor.fromMLBuffer(mlBuffer, { + dataType: cpuTensor.type, + dims: cpuTensor.dims, + dispose: () => mlBuffer.destroy(), + }); +} + export async function sessionRun(options: { session: ort.InferenceSession; feeds: Record; outputsMetaInfo: Record>; ioBinding: Test.IOBindingMode; + mlContext?: MLContext; }): Promise<[number, number, ort.InferenceSession.OnnxValueMapType]> { const session = options.session; const feeds = options.feeds; const fetches: Record = {}; - // currently we only support IO Binding for WebGPU + // currently we only support IO Binding for WebGPU and WebNN // - // For inputs, we create GPU tensors on both 'gpu-tensor' and 'gpu-location' binding testing mode. - // For outputs, we create GPU tensors on 'gpu-tensor' binding testing mode only. + // For inputs, we create tensors on 'gpu-tensor', 'gpu-location', 'ml-tensor', and 'ml-location' binding testing + // modes. + // For outputs, we create tensors on 'gpu-tensor' and 'ml-tensor' binding testing modes. // in 'gpu-device' binding mode, outputs are not pre-allocated. - const shouldUploadInput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'gpu-location'; - const shouldUploadOutput = options.ioBinding === 'gpu-tensor'; + const shouldUploadInput = ['gpu-tensor', 'gpu-location', 'ml-location', 'ml-tensor'].includes(options.ioBinding); + const shouldUploadOutput = options.ioBinding === 'gpu-tensor' || options.ioBinding === 'ml-tensor'; try { if (shouldUploadInput) { // replace the CPU tensors in feeds into GPU tensors for (const name in feeds) { if (Object.hasOwnProperty.call(feeds, name)) { if (feeds[name].size > 0) { - feeds[name] = createGpuTensorForInput(feeds[name]); + if (options.ioBinding === 'ml-location' || options.ioBinding === 'ml-tensor') { + feeds[name] = await createMLTensorForInput(options.mlContext!, feeds[name]); + } else { + feeds[name] = createGpuTensorForInput(feeds[name]); + } } } } @@ -658,7 +730,11 @@ export async function sessionRun(options: { if (dims.some((d) => d === 0)) { fetches[name] = new ort.Tensor(type, [], dims); } else { - fetches[name] = createGpuTensorForOutput(type, dims); + if (options.ioBinding === 'ml-tensor') { + fetches[name] = await createMLTensorForOutput(options.mlContext!, type, dims); + } else { + fetches[name] = createGpuTensorForOutput(type, dims); + } } } } @@ -714,6 +790,7 @@ export async function runModelTestSet( feeds, outputsMetaInfo, ioBinding: context.ioBinding, + mlContext: context.mlContext, }); if (context.perfData.count === 0) { context.perfData.firstRun = end - start; diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index be1e56485ec5a..eddda1206eec9 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -52,8 +52,12 @@ export declare namespace Test { * `preferredOutputLocation` will be set to `gpu-buffer`. * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` * will not be set. + * - ml-location: inputs will be pre-allocated as ML tensors; no output will be pre-allocated; + * `preferredOutputLocation` will be set to `ml-buffer`. + * - ml-tensor: inputs and outputs will all be pre-allocated as MLBuffer tensors. `preferredOutputLocation` + * will not be set. */ - export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location'; + export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location'; export interface ModelTestCase { name: string; diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index c3e96e450c59b..7bd9f64e5603f 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -141,7 +141,8 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA strcmp(name1, onnxruntime::OpenVINO_GPU) == 0 || strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || - strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0) { + strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || + strcmp(name1, onnxruntime::WEBNN_BUFFER) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc new file mode 100644 index 0000000000000..4b8188a6f8344 --- /dev/null +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/allocator.h" + +#include "core/common/safeint.h" + +namespace onnxruntime { +namespace webnn { + +void* WebNNBufferAllocator::Alloc(size_t size) { + if (size == 0) { + return nullptr; + } + if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { + // We don't need to transfer the buffer to an MLBuffer, so we don't need to allocate buffer id. + return nullptr; + } + void* p = EM_ASM_PTR({ return Module.jsepReserveBufferId(); }); + allocations_[p] = size; + stats_.num_allocs++; + stats_.bytes_in_use += SafeInt(size); + return p; +} + +void WebNNBufferAllocator::Free(void* p) { + if (p == nullptr) { + return; + } + EM_ASM({ Module.jsepReleaseBufferId($0); }, p); + size_t size = allocations_[p]; + stats_.bytes_in_use -= size; + allocations_.erase(p); +} + +void WebNNBufferAllocator::GetStats(AllocatorStats* stats) { + *stats = stats_; +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/allocator.h b/onnxruntime/core/providers/webnn/allocator.h new file mode 100644 index 0000000000000..6d9fd2c0542e2 --- /dev/null +++ b/onnxruntime/core/providers/webnn/allocator.h @@ -0,0 +1,32 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +#include "core/common/inlined_containers.h" +#include "core/framework/allocator.h" +#include "core/framework/ortdevice.h" + +namespace onnxruntime { +namespace webnn { + +class WebNNBufferAllocator : public IAllocator { + public: + WebNNBufferAllocator() : IAllocator(OrtMemoryInfo(WEBNN_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) {} + + void* Alloc(size_t size) override; + + void Free(void* p) override; + + void GetStats(AllocatorStats* stats) override; + + private: + AllocatorStats stats_; + InlinedHashMap allocations_; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..22271640ef57f 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -12,6 +12,19 @@ namespace onnxruntime { namespace webnn { +WebnnDeviceType DeviceTypeFromString(const std::string& device_type) { + if (device_type == "gpu") { + return WebnnDeviceType::GPU; + } + if (device_type == "cpu") { + return WebnnDeviceType::CPU; + } + if (device_type == "npu") { + return WebnnDeviceType::NPU; + } + ORT_THROW("Unknown WebNN deviceType."); +} + InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer) { InitializedTensorSet all_initializers; if (graph_viewer.IsSubgraph()) { @@ -198,5 +211,10 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { } } +bool IsMLBufferSupported() { + static bool is_supported = !emscripten::val::global("MLBuffer").isUndefined(); + return is_supported; +} + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index fc13ce201f2e9..278c1c4e13ad2 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -31,6 +31,8 @@ enum class WebnnDeviceType { NPU, }; +WebnnDeviceType DeviceTypeFromString(const std::string& device_type); + typedef struct { std::string opName; bool isCpuSupported; // The WebNN CPU backend XNNPack supports it (not about the CPU EP). @@ -284,5 +286,7 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); +bool IsMLBufferSupported(); + } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index ef807a8c4fa26..ba84a5d6c56fd 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -11,21 +11,31 @@ #include "core/common/safeint.h" #include "core/graph/onnx_protobuf.h" #include "core/providers/common.h" -#include "core/providers/webnn/builders/helper.h" #include "model.h" namespace onnxruntime { namespace webnn { -Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger) +Model::Model(const emscripten::val& context, const emscripten::val& graph, const logging::Logger& logger, bool use_dispatch) : wnn_context_(context), wnn_graph_(graph), - logger_(logger) {} + logger_(logger), + use_dispatch_(use_dispatch) {} Model::~Model() {} Status Model::Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { + if (use_dispatch_) { + return Dispatch(inputs, outputs); + + } else { + return Compute(inputs, outputs); + } +} + +onnxruntime::common::Status Model::Compute(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { for (const auto& input : inputs) { const std::string& name = input.first; const struct OnnxTensorData tensor = input.second; @@ -142,6 +152,40 @@ Status Model::Predict(const InlinedHashMap& inputs, return Status::OK(); } +onnxruntime::common::Status Model::Dispatch(const InlinedHashMap& inputs, + const InlinedHashMap& outputs) { + auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer"); + auto promises = emscripten::val::array(); + for (const auto& [_, tensor] : inputs) { + emscripten::val shape = emscripten::val::array(); + for (const auto& dim : tensor.tensor_info.shape) { + uint32_t dim_val = SafeInt(dim); + shape.call("push", dim_val); + } + auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); + promises.call("push", buffer); + } + for (const auto& [_, tensor] : outputs) { + emscripten::val shape = emscripten::val::array(); + for (const auto& dim : tensor.tensor_info.shape) { + uint32_t dim_val = SafeInt(dim); + shape.call("push", dim_val); + } + auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); + promises.call("push", buffer); + } + auto buffers = emscripten::val::global("Promise").call("all", promises).await(); + for (const auto& [name, _] : inputs) { + wnn_inputs_.set(name, buffers.call("shift")); + } + for (const auto& [name, _] : outputs) { + wnn_outputs_.set(name, buffers.call("shift")); + } + wnn_context_.call("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_); + + return Status::OK(); +} + bool Model::IsScalarOutput(const std::string& output_name) const { return Contains(scalar_outputs_, output_name); } @@ -160,6 +204,10 @@ void Model::SetOutputMap(InlinedHashMap&& output_map) { // Pre-allocate the input and output buffers for the WebNN graph. void Model::AllocateInputOutputBuffers() { + // We don't need to allocate JS array buffers if the WebNN API supports MLBuffer. + if (use_dispatch_) { + return; + } for (const auto& input : inputs_) { const auto& input_info = input_output_info_.at(input); const auto input_shape = input_info.shape; diff --git a/onnxruntime/core/providers/webnn/builders/model.h b/onnxruntime/core/providers/webnn/builders/model.h index 4af82a2675691..f5ca137f5f6b5 100644 --- a/onnxruntime/core/providers/webnn/builders/model.h +++ b/onnxruntime/core/providers/webnn/builders/model.h @@ -58,6 +58,12 @@ class Model { size_t GetMappedOutputIdx(const std::string& name) const; private: + onnxruntime::common::Status Dispatch(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + + onnxruntime::common::Status Compute(const InlinedHashMap& inputs, + const InlinedHashMap& outputs); + emscripten::val wnn_context_ = emscripten::val::object(); emscripten::val wnn_graph_ = emscripten::val::object(); const logging::Logger& logger_; @@ -77,7 +83,9 @@ class Model { OrtMutex mutex_; - Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger); + bool use_dispatch_; + + Model(const emscripten::val& context, const emscripten::val& path, const logging::Logger& logger, bool use_dispatch); void SetInputOutputInfo(InlinedHashMap&& input_output_info) { input_output_info_ = std::move(input_output_info); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index b21f717eedc7a..8cc56e212b444 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -340,7 +340,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { } // Explicitly release the WebNN builder to free memory. wnn_builder_ = emscripten::val::undefined(); - model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_)); + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported())); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); model->SetScalarOutputs(std::move(scalar_outputs_)); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc new file mode 100644 index 0000000000000..5644de25fd306 --- /dev/null +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -0,0 +1,47 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/webnn/data_transfer.h" + +#include +#include "core/framework/tensor.h" + +namespace onnxruntime { +namespace webnn { + +bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { + // Copying data between MLBuffers is not supported by WebNN. + return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || + (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); +} + +common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { + if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { + // We don't need to transfer the buffer to an MLBuffer, so we don't need to copy the buffer. + return Status::OK(); + } + + size_t bytes = src.SizeInBytes(); + if (bytes > 0) { + const void* src_data = src.DataRaw(); + void* dst_data = dst.MutableDataRaw(); + + const auto& dst_device = dst.Location().device; + + if (dst_device.Type() == OrtDevice::GPU) { + EM_ASM({ Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); + } else { + auto jsepDownloadBuffer = emscripten::val::module_property("jsepDownloadBuffer"); + auto buffer = jsepDownloadBuffer(reinterpret_cast(src_data)).await(); + EM_ASM({ + const buffer = Emval.toValue($0); + const src_array = new Uint8Array(buffer, 0, $2); + HEAPU8.set(src_array, $1); }, buffer.as_handle(), reinterpret_cast(dst_data), bytes); + } + } + + return Status::OK(); +} + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/data_transfer.h b/onnxruntime/core/providers/webnn/data_transfer.h new file mode 100644 index 0000000000000..03cfada46d1a0 --- /dev/null +++ b/onnxruntime/core/providers/webnn/data_transfer.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/data_transfer.h" + +namespace onnxruntime { +namespace webnn { + +class DataTransfer : public IDataTransfer { + public: + bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; + + common::Status CopyTensor(const Tensor& src, Tensor& dst) const override; +}; + +} // namespace webnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 1cd382c1e75e9..c3280ee3855d1 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -10,6 +10,8 @@ #include "core/graph/graph_viewer.h" #include "core/session/onnxruntime_cxx_api.h" #include "core/common/safeint.h" +#include "core/providers/webnn/allocator.h" +#include "core/providers/webnn/data_transfer.h" #include "builders/model.h" #include "builders/helper.h" @@ -18,20 +20,19 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) - : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { + : IExecutionProvider{ + onnxruntime::kWebNNExecutionProvider, + // If MLBuffer is supported, we force all the tensors to be allocated as MLBuffer. + OrtDevice( + webnn::IsMLBufferSupported() ? OrtDevice::GPU : OrtDevice::CPU, + OrtDevice::MemType::DEFAULT, + 0)}, + wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. - if (webnn_device_flags.compare("cpu") == 0) { + if (wnn_device_type_ == webnn::WebnnDeviceType::CPU) { preferred_layout_ = DataLayout::NHWC; - wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { preferred_layout_ = DataLayout::NCHW; - if (webnn_device_flags.compare("gpu") == 0) { - wnn_device_type_ = webnn::WebnnDeviceType::GPU; - } else if (webnn_device_flags.compare("npu") == 0) { - wnn_device_type_ = webnn::WebnnDeviceType::NPU; - } else { - ORT_THROW("Unknown WebNN deviceType."); - } } wnn_context_ = emscripten::val::module_property("currentContext"); @@ -379,4 +380,22 @@ WebNNExecutionProvider::GetKernelRegistry() const { return kernel_registry; } +std::unique_ptr WebNNExecutionProvider::GetDataTransfer() const { + if (!webnn::IsMLBufferSupported()) { + return nullptr; + } + return std::make_unique(); +} + +std::vector WebNNExecutionProvider::CreatePreferredAllocators() { + if (!webnn::IsMLBufferSupported()) { + return {}; + } + AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) { + return std::make_unique(); + }, + 0, false); + return {CreateAllocator(customAllocatorCreationInfo)}; +} + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index d8c1e90c86cdb..81ab504582162 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -40,6 +40,8 @@ class WebNNExecutionProvider : public IExecutionProvider { #endif std::shared_ptr GetKernelRegistry() const override; + std::unique_ptr GetDataTransfer() const override; + std::vector CreatePreferredAllocators() override; private: emscripten::val wnn_context_ = emscripten::val::undefined(); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..f84af6c1a2325 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -23,7 +23,8 @@ enum DataLocation { DATA_LOCATION_CPU = 1, DATA_LOCATION_CPU_PINNED = 2, DATA_LOCATION_TEXTURE = 3, - DATA_LOCATION_GPU_BUFFER = 4 + DATA_LOCATION_GPU_BUFFER = 4, + DATA_LOCATION_ML_BUFFER = 5 }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); @@ -235,7 +236,8 @@ void OrtFree(void* ptr) { OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && - data_location != DATA_LOCATION_GPU_BUFFER) { + data_location != DATA_LOCATION_GPU_BUFFER && + data_location != DATA_LOCATION_ML_BUFFER) { std::ostringstream ostr; ostr << "Invalid data location: " << data_location; CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); @@ -264,10 +266,15 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* return UNREGISTER_AUTO_RELEASE(value); } else { OrtMemoryInfo* memory_info = nullptr; - if (data_location != DATA_LOCATION_GPU_BUFFER) { - RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); - } else { - RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + switch (data_location) { + case DATA_LOCATION_GPU_BUFFER: + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + break; + case DATA_LOCATION_ML_BUFFER: + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + break; + default: + RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); } REGISTER_AUTO_RELEASE_HANDLE(MemoryInfo, memory_info); @@ -418,15 +425,18 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, if (output_location != DATA_LOCATION_NONE && output_location != DATA_LOCATION_CPU && output_location != DATA_LOCATION_CPU_PINNED && - output_location != DATA_LOCATION_GPU_BUFFER) { + output_location != DATA_LOCATION_GPU_BUFFER && + output_location != DATA_LOCATION_ML_BUFFER) { std::ostringstream ostr; ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); } OrtMemoryInfo* memory_info = nullptr; - if (output_location != DATA_LOCATION_GPU_BUFFER) { + if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_BUFFER) { RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); + } else if (output_location == DATA_LOCATION_ML_BUFFER) { + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } else { RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 1cb7c6f5d8250..7587e4b6196c2 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -198,5 +198,38 @@ Module['jsepInit'] = (name, params) => { Module['jsepOnRunStart'] = sessionId => { return backend['onRunStart'](sessionId); }; + } else if(name === 'webnn') { + // Functions called from EM_ASM need to be assigned in a way that can be minified. + // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't + // change the name. + + [Module.jsepBackend, + Module.jsepReserveBufferId, + Module.jsepReleaseBufferId, + Module['jsepEnsureBuffer'], + Module.jsepUploadBuffer, + Module['jsepDownloadBuffer'], + ] = params; + + // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. + Module['jsepReleaseBufferId'] = Module.jsepReleaseBufferId; + + // Functions called from JS also need to have explicit names. + const backend = Module.jsepBackend; + Module['jsepOnRunStart'] = sessionId => { + return backend['onRunStart'](sessionId); + }; + Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { + backend['registerMLContext'](sessionId, mlContext); + }; + Module['jsepOnReleaseSession'] = sessionId => { + backend['onReleaseSession'](sessionId); + }; + Module['jsepCreateMLBufferDownloader'] = (bufferId, type) => { + return backend['createMLBufferDownloader'](bufferId, type); + } + Module['jsepRegisterMLBuffer'] = (buffer) => { + return backend['registerMLBuffer'](buffer); + } } }; From 50d19dc4c762525b522f5d172d0c4e17d4e3cfd1 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Thu, 15 Aug 2024 23:03:37 -0700 Subject: [PATCH 02/17] Workaround for ONNXRuntime reusing typed MLBuffers as a different type --- js/web/lib/wasm/jsep/backend-webnn.ts | 48 +++++--- js/web/lib/wasm/jsep/init.ts | 4 +- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 115 ++++++++++++++---- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 7 +- js/web/lib/wasm/wasm-types.ts | 22 ++-- .../core/providers/webnn/builders/model.cc | 4 +- .../webnn/webnn_execution_provider.cc | 29 ++++- onnxruntime/wasm/pre-jsep.js | 4 +- 9 files changed, 173 insertions(+), 62 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index c623f8d4e67d0..5764249de8292 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -12,7 +12,8 @@ import { DataType } from '../wasm-common'; import { getInstance } from '../wasm-factory'; import { createView } from './tensor-view'; -import { BufferId, BufferManager, createBufferManager } from './webnn/buffer-manager'; +import { BufferId, createBufferManager } from './webnn/buffer-manager'; +import { LOG_DEBUG } from './log'; /* * TensorProto::data_type to WebNN OperandType mapping. @@ -34,7 +35,10 @@ const onnxDataTypeToWebnnDataType = new Map([ * of the current MLContext being used by the sessions. */ export class WebNNBackend { - private bufferManager: BufferManager = createBufferManager(this); + /** + * Buffer managers for each session. + */ + private bufferManager = createBufferManager(this); /** * Maps from session id to MLContexts. */ @@ -46,16 +50,20 @@ export class WebNNBackend { /** * Current session id. */ - currentSessionId?: number; + private activeSessionId?: number; + + public get currentSessionId(): number { + if (this.activeSessionId === undefined) { + throw new Error('No active session'); + } + return this.activeSessionId; + } public onRunStart(sessionId: number): void { - this.currentSessionId = sessionId; + this.activeSessionId = sessionId; } public get currentContext(): MLContext { - if (this.currentSessionId === undefined) { - throw new Error('No active session'); - } return this.getMLContext(this.currentSessionId); } @@ -96,25 +104,21 @@ export class WebNNBackend { } public releaseBufferId(bufferId: BufferId): void { + LOG_DEBUG('verbose', () => `[WebNN] releaseBufferId {bufferId: ${bufferId}}`); this.bufferManager.releaseBufferId(bufferId); } public async ensureBuffer( bufferId: BufferId, - onnxDataType: number | MLOperandDataType, + onnxDataType: DataType, dimensions: number[], + copyOld: boolean, ): Promise { - let dataType: MLOperandDataType; - if (typeof onnxDataType === 'number') { - const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; - if (!webnnDataType) { - throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); - } - dataType = webnnDataType; - } else { - dataType = onnxDataType; + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; + if (!webnnDataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.bufferManager.ensureBuffer(bufferId, dataType, dimensions); + return this.bufferManager.ensureBuffer(bufferId, webnnDataType, dimensions, copyOld); } public uploadBuffer(bufferId: BufferId, data: Uint8Array): void { @@ -136,8 +140,12 @@ export class WebNNBackend { }; } - public registerMLBuffer(buffer: MLBuffer): BufferId { - return this.bufferManager.registerBuffer(this.currentContext, buffer); + public registerMLBuffer(buffer: MLBuffer, onnxDataType: DataType, dimensions: number[]): BufferId { + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; + if (!webnnDataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + return this.bufferManager.registerBuffer(this.currentContext, buffer, webnnDataType, dimensions); } public flush(): void { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 33e0b1a5c2133..bbab6c688cad3 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -266,8 +266,8 @@ export const init = async ( // jsepReleaseBufferId, (bufferId: number) => backend.releaseBufferId(bufferId), // jsepEnsureBuffer - async (bufferId: number, onnxDataType: number, dimensions: number[]) => - backend.ensureBuffer(bufferId, onnxDataType, dimensions), + async (bufferId: number, onnxDataType: number, dimensions: number[], copyOld) => + backend.ensureBuffer(bufferId, onnxDataType, dimensions, copyOld), // jsepUploadBuffer (bufferId: number, data: Uint8Array) => { backend.uploadBuffer(bufferId, data); diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts index fcd980d36a828..d2c7a4ad9b8a0 100644 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -2,6 +2,7 @@ // Licensed under the MIT License. import { WebNNBackend } from '../backend-webnn'; +import { LOG_DEBUG } from '../log'; // WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from // WebNN API specification. @@ -25,7 +26,12 @@ export interface BufferManager { /** * Ensure a MLBuffer is created for the BufferId. */ - ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise; + ensureBuffer( + bufferId: BufferId, + dataType: MLOperandDataType, + dimensions: readonly number[], + copyOld: boolean, + ): Promise; /** * Upload data to a MLBuffer. */ @@ -41,12 +47,14 @@ export interface BufferManager { /** * Register an externally created MLBuffer with a given MLContext and return a BufferId. */ - registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer): BufferId; + registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer, dataType: MLOperandDataType, dimensions: number[]): BufferId; } let bufferGuid = 1; const createNewBufferId = (): BufferId => bufferGuid++; +export type MLBufferEntry = [MLBuffer, MLOperandDataType, readonly number[]]; + /** * BufferTracker tracks the MLBuffer and pending upload data. * @@ -55,18 +63,20 @@ const createNewBufferId = (): BufferId => bufferGuid++; * MLBuffers with dataTypes and dimensions. */ class BufferTracker { - private mlBuffer?: MLBuffer; + private bufferEntry?: MLBufferEntry; private activeUpload?: Uint8Array; + private bufferCache: MLBufferEntry[]; constructor( private mlContext?: MLContext, - buffer?: MLBuffer, + bufferEntry?: MLBufferEntry, ) { - this.mlBuffer = buffer; + this.bufferEntry = bufferEntry; + this.bufferCache = bufferEntry ? [bufferEntry] : []; } public get buffer(): MLBuffer | undefined { - return this.mlBuffer; + return this.bufferEntry?.[0]; } public get context(): MLContext { @@ -84,17 +94,60 @@ class BufferTracker { } public destroy(): void { - this.mlBuffer?.destroy(); - this.mlBuffer = undefined; + for (const [mlBuffer] of this.bufferCache) { + mlBuffer.destroy(); + } + this.bufferCache = []; + this.bufferEntry = undefined; + } + + public trySelectBuffer(context: MLContext, tryMlBuffer: MLBuffer): boolean { + for (const [mlBuffer, dataType, dimensions] of this.bufferCache) { + if (tryMlBuffer === mlBuffer) { + if (this.context !== context) { + throw new Error('MLBuffer cannot be registered with a different MLContext.'); + } + this.bufferEntry = [mlBuffer, dataType, dimensions]; + return true; + } + } + return false; } - public async ensureBuffer(dataType: MLOperandDataType, dimensions: number[]): Promise { - if (this.mlBuffer) { - return this.mlBuffer; + public async ensureBuffer( + dataType: MLOperandDataType, + dimensions: readonly number[], + copyOld: boolean, + ): Promise { + if (this.bufferEntry) { + const [mlBuffer, existingDataType, existingDimensions] = this.bufferEntry; + if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + return mlBuffer; + } } + for (const [mlBuffer, existingDataType, existingDimensions] of this.bufferCache) { + if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + if (copyOld && this.bufferEntry) { + // WebNN does not support copyBufferToBuffer, so we need to read and write the buffers. + LOG_DEBUG( + 'verbose', + () => + `[WebNN] Slowdown may occur, having to copy existing buffer {dataType: ${ + dataType + }, dimensions: ${dimensions}}`, + ); + const data = await this.context.readBuffer(this.bufferEntry[0]); + this.context.writeBuffer(mlBuffer, data); + } + this.bufferEntry = [mlBuffer, existingDataType, existingDimensions]; + return mlBuffer; + } + } + LOG_DEBUG('verbose', () => `[WebNN] createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`); const buffer = await this.context.createBuffer({ dataType, dimensions }); - this.mlBuffer = buffer; + this.bufferEntry = [buffer, dataType, dimensions]; + this.bufferCache.push(this.bufferEntry); if (this.activeUpload) { this.mlContext?.writeBuffer(buffer, this.activeUpload); @@ -105,22 +158,22 @@ class BufferTracker { } public upload(data: Uint8Array): void { - if (!this.mlBuffer) { + if (!this.bufferEntry) { this.activeUpload = new Uint8Array(data); return; } - this.mlContext?.writeBuffer(this.mlBuffer, data); + this.mlContext?.writeBuffer(this.bufferEntry[0], data); } public async download(): Promise { if (this.activeUpload) { return this.activeUpload.buffer; } - if (!this.mlBuffer) { + if (!this.bufferEntry) { throw new Error('Buffer has not been created.'); } - return this.context.readBuffer(this.mlBuffer); + return this.context.readBuffer(this.bufferEntry[0]); } } @@ -154,7 +207,19 @@ class BufferManagerImpl implements BufferManager { } } - public async ensureBuffer(bufferId: BufferId, dataType: MLOperandDataType, dimensions: number[]): Promise { + public async ensureBuffer( + bufferId: BufferId, + dataType: MLOperandDataType, + dimensions: number[], + copyOld: boolean, + ): Promise { + LOG_DEBUG( + 'verbose', + () => + `[WebNN] BufferManager.ensureBuffer {bufferId: ${bufferId}, dataType: ${ + dataType + }, dimensions: ${dimensions}}, copyOld: ${copyOld}`, + ); const buffer = this.buffersById.get(bufferId); if (!buffer) { throw new Error('Buffer not found.'); @@ -164,7 +229,7 @@ class BufferManagerImpl implements BufferManager { this.bufferIdsByContext.set(this.backend.currentContext, new Set()); } this.bufferIdsByContext.get(this.backend.currentContext)?.add(bufferId); - return buffer.ensureBuffer(dataType, dimensions); + return buffer.ensureBuffer(dataType, dimensions, copyOld); } public upload(bufferId: BufferId, data: Uint8Array): void { @@ -187,17 +252,19 @@ class BufferManagerImpl implements BufferManager { this.bufferIdsByContext.delete(mlContext); } - public registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer): BufferId { + public registerBuffer( + mlContext: MLContext, + mlBuffer: MLBuffer, + dataType: MLOperandDataType, + dimensions: readonly number[], + ): BufferId { for (const [bufferId, bufferTracker] of this.buffersById) { - if (bufferTracker.buffer === mlBuffer) { - if (bufferTracker.context !== mlContext) { - throw new Error('MLBuffer cannot be registered with a different MLContext.'); - } + if (bufferTracker.trySelectBuffer(mlContext, mlBuffer)) { return bufferId; } } const bufferId = createNewBufferId(); - this.buffersById.set(bufferId, new BufferTracker(mlContext, mlBuffer)); + this.buffersById.set(bufferId, new BufferTracker(mlContext, [mlBuffer, dataType, dimensions])); let buffers = this.bufferIdsByContext.get(mlContext); if (!buffers) { buffers = new Set(); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 17bd3b6243342..0e765c714e1e7 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -30,7 +30,7 @@ type MLInputOperandLayout = 'nchw'|'nhwc'; type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; interface MLOperandDescriptor { dataType: MLOperandDataType; - dimensions?: number[]; + dimensions?: readonly number[]; } interface MLOperand { dataType(): MLOperandDataType; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index ef6fb87c41fe0..250c9358fe8d5 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -491,7 +491,7 @@ export const prepareInputOutputTensor = ( if (!registerMLBuffer) { throw new Error('Tensor location "ml-buffer" is not supported without using WebNN.'); } - rawData = registerMLBuffer(mlBuffer); + rawData = registerMLBuffer(mlBuffer, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -793,8 +793,9 @@ export const run = async ( } // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use - // ensureBuffer to get/create the MLBuffer. - const mlBuffer = await ensureBuffer(dataOffset, dataType, dims); + // ensureBuffer to get/create the MLBuffer. In which case, we don't need to copy the data if a new buffer is + // created. + const mlBuffer = await ensureBuffer(dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 8cbe111ff373a..bcb0fa44eb087 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -7,6 +7,7 @@ /// import type { Tensor } from 'onnxruntime-common'; +import { DataType } from './wasm-common'; /* eslint-disable @typescript-eslint/naming-convention */ @@ -31,8 +32,9 @@ export declare namespace JSEP { type ReleaseBufferIdFunction = (bufferId: number) => void; type EnsureBufferFunction = ( bufferId: number, - dataType: number | MLOperandDataType, - dimensions: number[], + dataType: DataType, + dimensions: readonly number[], + copyOld: boolean, ) => Promise; type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void; type DownloadBufferFunction = (bufferId: number) => Promise; @@ -178,14 +180,18 @@ export declare namespace JSEP { */ jsepReleaseBufferId: (bufferId: number) => void; /** - * [exported from pre-jsep.js] Get MLBuffer by ID. - * @param bufferId - specify the MLBuffer ID. - * @returns the MLBuffer. + * [exported from pre-jsep.js] Ensure a MLBuffer of a given type and shape has exists for a buffer ID. + * @param bufferId - specify the buffer ID. + * @param onnxDataType - specify the data type. + * @param dimensions - specify the dimensions. + * @param copyOld - specify whether to copy the old buffer, it . + * @returns the MLBuffer associated with the buffer ID. */ jsepEnsureBuffer: ( bufferId: number, - dataType: number | MLOperandDataType, + dataType: DataType, dimensions: number[], + copyOld: boolean, ) => Promise; /** * [exported from pre-jsep.js] Upload data to MLBuffer. @@ -214,9 +220,11 @@ export declare namespace JSEP { /** * [exported from pre-jsep.js] Register MLBuffer for a session. * @param mlBuffer - specify the MLBuffer. + * @param dataType - specify the data type. + * @param dimensions - specify the dimensions. * @returns the MLBuffer ID. */ - jsepRegisterMLBuffer: (buffer: MLBuffer) => number; + jsepRegisterMLBuffer: (buffer: MLBuffer, onnxDataType: DataType, dimensions: readonly number[]) => number; } } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index ba84a5d6c56fd..3e048ed278078 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -162,7 +162,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); + auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); promises.call("push", buffer); } for (const auto& [_, tensor] : outputs) { @@ -171,7 +171,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape); + auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", buffer); } auto buffers = emscripten::val::global("Promise").call("all", promises).await(); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index c3280ee3855d1..539d563b8c45e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -5,6 +5,7 @@ #include "webnn_execution_provider.h" #include "core/framework/compute_capability.h" +#include "core/framework/data_transfer_manager.h" #include "core/framework/memcpy.h" #include "core/framework/kernel_registry.h" #include "core/graph/graph_viewer.h" @@ -329,6 +330,32 @@ common::Status WebNNExecutionProvider::Compile(const std::vectorInput(0); + ORT_ENFORCE(X != nullptr, "Memcpy: input tensor is null"); + auto* Y = context->Output(0, X->Shape()); + ORT_ENFORCE(X != nullptr, "Memcpy: output tensor is null"); + emscripten::val shape = emscripten::val::array(); + for (auto dim : X->Shape().GetDims()) { + shape.call("push", SafeInt(dim).Ref()); + } + + jsepEnsureBuffer(reinterpret_cast(Y->MutableDataRaw()), + Y->GetElementType(), + shape, false) + .await(); + + const auto* data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device); + + return data_transfer->CopyTensor(*X, *Y); + } +}; + ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, @@ -337,7 +364,7 @@ ONNX_OPERATOR_KERNEL_EX( KernelDefBuilder() .InputMemoryType(OrtMemTypeCPUInput, 0) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), - Memcpy); + WebNNMemcpy); ONNX_OPERATOR_KERNEL_EX( MemcpyToHost, diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 7587e4b6196c2..94b1ba381bf5d 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -228,8 +228,8 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateMLBufferDownloader'] = (bufferId, type) => { return backend['createMLBufferDownloader'](bufferId, type); } - Module['jsepRegisterMLBuffer'] = (buffer) => { - return backend['registerMLBuffer'](buffer); + Module['jsepRegisterMLBuffer'] = (buffer, dataType, dimensions) => { + return backend['registerMLBuffer'](buffer, dataType, dimensions); } } }; From ed2f366812fc3d4275c85478f2c5b09569276aed Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Fri, 16 Aug 2024 10:27:53 -0700 Subject: [PATCH 03/17] PR feedback --- js/web/lib/wasm/jsep/backend-webnn.ts | 33 +++++++++++--------- js/web/lib/wasm/jsep/init.ts | 2 +- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 4 +-- js/web/lib/wasm/wasm-types.ts | 2 +- 4 files changed, 23 insertions(+), 18 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 5764249de8292..6e4ffcaf4a5a3 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -6,14 +6,14 @@ // https://github.com/webmachinelearning/webnn/issues/677 /// -import { Tensor } from 'onnxruntime-common'; +import { Env, Tensor } from 'onnxruntime-common'; import { DataType } from '../wasm-common'; import { getInstance } from '../wasm-factory'; import { createView } from './tensor-view'; import { BufferId, createBufferManager } from './webnn/buffer-manager'; -import { LOG_DEBUG } from './log'; +import { configureLogger, LOG_DEBUG } from './log'; /* * TensorProto::data_type to WebNN OperandType mapping. @@ -52,6 +52,10 @@ export class WebNNBackend { */ private activeSessionId?: number; + constructor(env: Env) { + configureLogger(env.logLevel!, !!env.debug); + } + public get currentSessionId(): number { if (this.activeSessionId === undefined) { throw new Error('No active session'); @@ -64,7 +68,11 @@ export class WebNNBackend { } public get currentContext(): MLContext { - return this.getMLContext(this.currentSessionId); + const mlContext = this.getMLContext(this.currentSessionId); + if (!mlContext) { + throw new Error(`No MLContext found for session ${this.currentSessionId}`); + } + return mlContext; } public registerMLContext(sessionId: number, mlContext: MLContext): void { @@ -77,26 +85,23 @@ export class WebNNBackend { sessionIds.add(sessionId); } - public unregisterMLContext(sessionId: number): void { + public onReleaseSession(sessionId: number): void { const mlContext = this.mlContextBySessionId.get(sessionId)!; if (!mlContext) { - throw new Error(`No MLContext found for session ${sessionId}`); + // Current session is not a WebNN session. + return; } this.mlContextBySessionId.delete(sessionId); const sessionIds = this.sessionIdsByMLContext.get(mlContext)!; sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); + this.bufferManager.releaseBuffersForContext(mlContext); } } - public onReleaseSession(sessionId: number): void { - this.unregisterMLContext(sessionId); - this.bufferManager.releaseBuffersForContext(this.getMLContext(sessionId)); - } - - public getMLContext(sessionId: number): MLContext { - return this.mlContextBySessionId.get(sessionId)!; + public getMLContext(sessionId: number): MLContext | undefined { + return this.mlContextBySessionId.get(sessionId); } public reserveBufferId(): BufferId { @@ -114,7 +119,7 @@ export class WebNNBackend { dimensions: number[], copyOld: boolean, ): Promise { - const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } @@ -141,7 +146,7 @@ export class WebNNBackend { } public registerMLBuffer(buffer: MLBuffer, onnxDataType: DataType, dimensions: number[]): BufferId { - const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType)!; + const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index bbab6c688cad3..786e004bcd06c 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -258,7 +258,7 @@ export const init = async ( () => backend.replay(), ]); } else { - const backend = new WebNNBackend(); + const backend = new WebNNBackend(env); jsepInit('webnn', [ backend, // jsepReserveBufferId diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts index d2c7a4ad9b8a0..49d2eac7a99ec 100644 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -144,7 +144,7 @@ class BufferTracker { return mlBuffer; } } - LOG_DEBUG('verbose', () => `[WebNN] createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`); + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`); const buffer = await this.context.createBuffer({ dataType, dimensions }); this.bufferEntry = [buffer, dataType, dimensions]; this.bufferCache.push(this.bufferEntry); @@ -218,7 +218,7 @@ class BufferManagerImpl implements BufferManager { () => `[WebNN] BufferManager.ensureBuffer {bufferId: ${bufferId}, dataType: ${ dataType - }, dimensions: ${dimensions}}, copyOld: ${copyOld}`, + }, dimensions: ${dimensions}, copyOld: ${copyOld}}`, ); const buffer = this.buffersById.get(bufferId); if (!buffer) { diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index bcb0fa44eb087..9f10d695f191c 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -184,7 +184,7 @@ export declare namespace JSEP { * @param bufferId - specify the buffer ID. * @param onnxDataType - specify the data type. * @param dimensions - specify the dimensions. - * @param copyOld - specify whether to copy the old buffer, it . + * @param copyOld - specify whether to copy the old buffer if a new buffer was created. * @returns the MLBuffer associated with the buffer ID. */ jsepEnsureBuffer: ( From f30218c2908f6d8f387bae7cb38cd0846ee6d380 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Thu, 22 Aug 2024 11:01:04 -0700 Subject: [PATCH 04/17] Updating MLBuffer API --- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 4 +++- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 18 +++++++++++++++++- js/web/test/test-runner.ts | 12 ++++++++++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts index 49d2eac7a99ec..af192c2d58a3d 100644 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -145,7 +145,9 @@ class BufferTracker { } } LOG_DEBUG('verbose', () => `[WebNN] MLContext.createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`); - const buffer = await this.context.createBuffer({ dataType, dimensions }); + // eslint-disable-next-line no-bitwise + const usage = MLBufferUsage.READ_FROM | MLBufferUsage.WRITE_TO; + const buffer = await this.context.createBuffer({ dataType, dimensions, usage }); this.bufferEntry = [buffer, dataType, dimensions]; this.bufferCache.push(this.bufferEntry); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 0e765c714e1e7..ca1baba257719 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +/* eslint-disable @typescript-eslint/naming-convention */ + interface NavigatorML { readonly ml: ML; } @@ -386,11 +388,25 @@ interface MLBuffer { } type MLNamedBuffers = Record; + +type MLBufferUsageFlags = number; + +declare const MLBufferUsage: { + readonly WEBGPU_INTEROP: MLBufferUsageFlags; + readonly READ_FROM: MLBufferUsageFlags; + readonly WRITE_TO: MLBufferUsageFlags; +}; + +interface MLBufferDescriptor extends MLOperandDescriptor { + usage: MLBufferUsageFlags; +} + interface MLContext { - createBuffer(descriptor: MLOperandDescriptor): Promise; + createBuffer(descriptor: MLBufferDescriptor): Promise; writeBuffer( dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number, srcElementSize?: number): void; readBuffer(srcBuffer: MLBuffer): Promise; + readBuffer(srcBuffer: MLBuffer, dstBuffer: ArrayBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; } diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index e4b38827f874d..4d5ea53c7f860 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -661,7 +661,11 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty const dataType = type === 'bool' ? 'uint8' : type; - const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: dims as number[] }); + const mlBuffer = await mlContext.createBuffer({ + dataType, + dimensions: dims as number[], + usage: MLBufferUsage.READ_FROM, + }); return ort.Tensor.fromMLBuffer(mlBuffer, { dataType: type, @@ -679,7 +683,11 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`); } const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; - const mlBuffer = await mlContext.createBuffer({ dataType, dimensions: cpuTensor.dims as number[] }); + const mlBuffer = await mlContext.createBuffer({ + dataType, + dimensions: cpuTensor.dims as number[], + usage: MLBufferUsage.WRITE_TO, + }); mlContext.writeBuffer(mlBuffer, cpuTensor.data); return ort.Tensor.fromMLBuffer(mlBuffer, { dataType: cpuTensor.type, From d27afb38ef5758f852ab9985f63014263899b70a Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 27 Aug 2024 16:56:24 -0700 Subject: [PATCH 05/17] Switching to readBuffer into ArrayBufferView --- js/web/lib/wasm/jsep/backend-webnn.ts | 4 +-- js/web/lib/wasm/jsep/init.ts | 2 +- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 25 ++++++++++++++++--- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 2 +- js/web/lib/wasm/wasm-types.ts | 9 +++++-- .../core/providers/webnn/data_transfer.cc | 7 ++---- 6 files changed, 34 insertions(+), 15 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 6e4ffcaf4a5a3..446bcb8a11bbb 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -134,8 +134,8 @@ export class WebNNBackend { this.bufferManager.upload(bufferId, data); } - public async downloadBuffer(bufferId: BufferId): Promise { - return this.bufferManager.download(bufferId); + public async downloadBuffer(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise { + return this.bufferManager.download(bufferId, dstBuffer); } public createMLBufferDownloader(bufferId: BufferId, type: Tensor.MLBufferDataTypes): () => Promise { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 786e004bcd06c..e0363912e5c14 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -273,7 +273,7 @@ export const init = async ( backend.uploadBuffer(bufferId, data); }, // jsepDownloadBuffer - async (bufferId: number) => backend.downloadBuffer(bufferId), + async (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadBuffer(bufferId, dstBuffer), ]); } }; diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts index af192c2d58a3d..8338cddbba6ba 100644 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts @@ -40,6 +40,7 @@ export interface BufferManager { * Download data from a MLBuffer. */ download(bufferId: BufferId): Promise; + download(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; /** * Release all buffers for a MLContext. */ @@ -168,13 +169,26 @@ class BufferTracker { this.mlContext?.writeBuffer(this.bufferEntry[0], data); } - public async download(): Promise { + public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { if (this.activeUpload) { - return this.activeUpload.buffer; + if (dstBuffer) { + if (dstBuffer instanceof ArrayBuffer) { + new Uint8Array(dstBuffer).set(this.activeUpload); + } else { + new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); + } + + return; + } else { + return this.activeUpload.buffer; + } } if (!this.bufferEntry) { throw new Error('Buffer has not been created.'); } + if (dstBuffer) { + return this.context.readBuffer(this.bufferEntry[0], dstBuffer); + } return this.context.readBuffer(this.bufferEntry[0]); } } @@ -238,8 +252,11 @@ class BufferManagerImpl implements BufferManager { this.buffersById.get(bufferId)!.upload(data); } - public async download(bufferId: BufferId): Promise { - return this.buffersById.get(bufferId)!.download(); + public async download(bufferId: BufferId): Promise; + public async download(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; + async download(bufferId: BufferId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + LOG_DEBUG('verbose', () => `[WebNN] BufferManager.download {bufferId: ${bufferId}, dstBuffer: ${dstBuffer}}`); + return this.buffersById.get(bufferId)!.download(dstBuffer); } public releaseBuffersForContext(mlContext: MLContext): void { diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index ca1baba257719..2376b0ce96634 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -407,6 +407,6 @@ interface MLContext { dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number, srcElementSize?: number): void; readBuffer(srcBuffer: MLBuffer): Promise; - readBuffer(srcBuffer: MLBuffer, dstBuffer: ArrayBuffer): Promise; + readBuffer(srcBuffer: MLBuffer, dstBuffer: ArrayBufferView|ArrayBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9f10d695f191c..af623788af606 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -37,7 +37,7 @@ export declare namespace JSEP { copyOld: boolean, ) => Promise; type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void; - type DownloadBufferFunction = (bufferId: number) => Promise; + type DownloadBufferFunction = (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { /** @@ -206,7 +206,12 @@ export declare namespace JSEP { * @param bufferId - specify the MLBuffer ID. * @returns the downloaded data. */ - jsepDownloadBuffer: (bufferId: number) => Promise; + jsepDownloadBuffer: (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; + /** + * [exported from pre-jsep.js] Download data from MLBuffer. + * @param bufferId - specify the MLBuffer ID. + * @returns the downloaded data. + */ /** * [exported from pre-jsep.js] Create a downloader function to download data from MLBuffer. * @param bufferId - specify the MLBuffer ID. diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index 5644de25fd306..833b5ecb7838c 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -32,11 +32,8 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { EM_ASM({ Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); } else { auto jsepDownloadBuffer = emscripten::val::module_property("jsepDownloadBuffer"); - auto buffer = jsepDownloadBuffer(reinterpret_cast(src_data)).await(); - EM_ASM({ - const buffer = Emval.toValue($0); - const src_array = new Uint8Array(buffer, 0, $2); - HEAPU8.set(src_array, $1); }, buffer.as_handle(), reinterpret_cast(dst_data), bytes); + auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); + jsepDownloadBuffer(reinterpret_cast(src_data), subarray).await(); } } From 66aec8ead28b186f3bc320a198bf6c7d76c5cede Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Mon, 9 Sep 2024 18:21:10 -0700 Subject: [PATCH 06/17] Rename MLBuffer to MLTensor --- js/common/lib/tensor-factory-impl.ts | 12 +- js/common/lib/tensor-factory.ts | 26 +- js/common/lib/tensor-impl.ts | 42 +-- js/common/lib/tensor-utils-impl.ts | 10 +- js/common/lib/tensor.ts | 18 +- js/web/lib/wasm/jsep/backend-webnn.ts | 58 ++-- js/web/lib/wasm/jsep/init.ts | 10 +- js/web/lib/wasm/jsep/webnn/buffer-manager.ts | 298 ----------------- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 300 ++++++++++++++++++ js/web/lib/wasm/jsep/webnn/webnn.d.ts | 30 +- js/web/lib/wasm/proxy-messages.ts | 12 +- js/web/lib/wasm/session-handler-inference.ts | 16 +- js/web/lib/wasm/wasm-common.ts | 8 +- js/web/lib/wasm/wasm-core-impl.ts | 54 ++-- js/web/lib/wasm/wasm-types.ts | 91 +++--- js/web/script/test-runner-cli-args.ts | 4 +- js/web/test/test-runner.ts | 30 +- js/web/test/test-types.ts | 4 +- onnxruntime/core/providers/webnn/allocator.cc | 8 +- .../core/providers/webnn/builders/helper.cc | 4 +- .../core/providers/webnn/builders/helper.h | 2 +- .../core/providers/webnn/builders/model.cc | 8 +- .../providers/webnn/builders/model_builder.cc | 2 +- .../core/providers/webnn/data_transfer.cc | 12 +- .../webnn/webnn_execution_provider.cc | 8 +- onnxruntime/wasm/api.cc | 12 +- onnxruntime/wasm/pre-jsep.js | 20 +- 27 files changed, 553 insertions(+), 546 deletions(-) delete mode 100644 js/web/lib/wasm/jsep/webnn/buffer-manager.ts create mode 100644 js/web/lib/wasm/jsep/webnn/tensor-manager.ts diff --git a/js/common/lib/tensor-factory-impl.ts b/js/common/lib/tensor-factory-impl.ts index 38e3d841146f9..7b32858ac0f02 100644 --- a/js/common/lib/tensor-factory-impl.ts +++ b/js/common/lib/tensor-factory-impl.ts @@ -11,7 +11,7 @@ import { TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, - TensorFromMLBufferOptions, + TensorFromMLTensorOptions, TensorFromTextureOptions, TensorFromUrlOptions, } from './tensor-factory.js'; @@ -312,14 +312,14 @@ export const tensorFromGpuBuffer = ( - mlBuffer: TensorInterface.MLBufferType, - options: TensorFromMLBufferOptions, +export const tensorFromMLTensor = ( + mlTensor: TensorInterface.MLTensorType, + options: TensorFromMLTensorOptions, ): Tensor => { const { dataType, dims, download, dispose } = options; - return new Tensor({ location: 'ml-buffer', type: dataType ?? 'float32', mlBuffer, dims, download, dispose }); + return new Tensor({ location: 'ml-tensor', type: dataType ?? 'float32', mlTensor, dims, download, dispose }); }; /** diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 95e822535a787..784e0d1a609d4 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -86,18 +86,18 @@ export interface GpuBufferConstructorParameters +export interface MLTensorConstructorParameters extends CommonConstructorParameters, GpuResourceConstructorParameters { /** - * Specify the location of the data to be 'ml-buffer'. + * Specify the location of the data to be 'ml-tensor'. */ - readonly location: 'ml-buffer'; + readonly location: 'ml-tensor'; /** * Specify the WebNN buffer that holds the tensor data. */ - readonly mlBuffer: Tensor.MLBufferType; + readonly mlTensor: Tensor.MLTensorType; } // #endregion @@ -233,7 +233,7 @@ export interface TensorFromGpuBufferOptions dataType?: T; } -export interface TensorFromMLBufferOptions +export interface TensorFromMLTensorOptions extends Pick, GpuResourceConstructorParameters { /** @@ -360,26 +360,26 @@ export interface TensorFactory { ): TypedTensor; /** - * create a tensor from a WebNN MLBuffer + * create a tensor from a WebNN MLTensor * - * @param buffer - the MLBuffer object to create tensor from - * @param options - An optional object representing options for creating tensor from a WebNN MLBuffer. + * @param buffer - the MLTensor object to create tensor from + * @param options - An optional object representing options for creating tensor from a WebNN MLTensor. * * The options include following properties: * - `dataType`: the data type of the tensor. If omitted, assume 'float32'. * - `dims`: the dimension of the tensor. Required. - * - `download`: an optional function to download the tensor data from the MLBuffer to CPU. If omitted, the MLBuffer + * - `download`: an optional function to download the tensor data from the MLTensor to CPU. If omitted, the MLTensor * data will not be able to download. Usually, this is provided by the WebNN backend for the inference outputs. * Users don't need to provide this function. - * - `dispose`: an optional function to dispose the tensor data on the WebNN MLBuffer. If omitted, the MLBuffer will + * - `dispose`: an optional function to dispose the tensor data on the WebNN MLTensor. If omitted, the MLTensor will * not be disposed. Usually, this is provided by the WebNN backend for the inference outputs. Users don't need to * provide this function. * * @returns a tensor object */ - fromMLBuffer( - buffer: Tensor.MLBufferType, - options: TensorFromMLBufferOptions, + fromMLTensor( + buffer: Tensor.MLTensorType, + options: TensorFromMLTensorOptions, ): TypedTensor; /** diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 658ff6f5bb024..e798977d6a92f 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -6,19 +6,19 @@ import { TensorToDataUrlOptions, TensorToImageDataOptions } from './tensor-conve import { tensorFromGpuBuffer, tensorFromImage, - tensorFromMLBuffer, + tensorFromMLTensor, tensorFromPinnedBuffer, tensorFromTexture, } from './tensor-factory-impl.js'; import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, - MLBufferConstructorParameters, + MLTensorConstructorParameters, TensorFromGpuBufferOptions, TensorFromImageBitmapOptions, TensorFromImageDataOptions, TensorFromImageElementOptions, - TensorFromMLBufferOptions, + TensorFromMLTensorOptions, TensorFromTextureOptions, TensorFromUrlOptions, TextureConstructorParameters, @@ -40,7 +40,7 @@ type TensorDataType = TensorInterface.DataType; type TensorDataLocation = TensorInterface.DataLocation; type TensorTextureType = TensorInterface.TextureType; type TensorGpuBufferType = TensorInterface.GpuBufferType; -type TensorMLBufferType = TensorInterface.MLBufferType; +type TensorMLTensorType = TensorInterface.MLTensorType; /** * the implementation of Tensor interface. @@ -90,11 +90,11 @@ export class Tensor implements TensorInterface { /** * Construct a new tensor object from the WebNN buffer with the given type and dims. * - * Tensor's location will be set to 'ml-buffer'. + * Tensor's location will be set to 'ml-tensor'. * * @param params - Specify the parameters to construct the tensor. */ - constructor(params: MLBufferConstructorParameters); + constructor(params: MLTensorConstructorParameters); /** * implementation. @@ -108,7 +108,7 @@ export class Tensor implements TensorInterface { | CpuPinnedConstructorParameters | TextureConstructorParameters | GpuBufferConstructorParameters - | MLBufferConstructorParameters, + | MLTensorConstructorParameters, arg1?: TensorDataType | readonly number[] | readonly string[] | readonly boolean[], arg2?: readonly number[], ) { @@ -165,7 +165,7 @@ export class Tensor implements TensorInterface { this.disposer = arg0.dispose; break; } - case 'ml-buffer': { + case 'ml-tensor': { if ( type !== 'float32' && type !== 'float16' && @@ -177,9 +177,9 @@ export class Tensor implements TensorInterface { type !== 'uint8' && type !== 'bool' ) { - throw new TypeError(`unsupported type "${type}" to create tensor from MLBuffer`); + throw new TypeError(`unsupported type "${type}" to create tensor from MLTensor`); } - this.mlBufferData = arg0.mlBuffer; + this.mlTensorData = arg0.mlTensor; this.downloader = arg0.download; this.disposer = arg0.dispose; break; @@ -345,11 +345,11 @@ export class Tensor implements TensorInterface { return tensorFromGpuBuffer(gpuBuffer, options); } - static fromMLBuffer( - mlBuffer: TensorMLBufferType, - options: TensorFromMLBufferOptions, + static fromMLTensor( + mlTensor: TensorMLTensorType, + options: TensorFromMLTensorOptions, ): TensorInterface { - return tensorFromMLBuffer(mlBuffer, options); + return tensorFromMLTensor(mlTensor, options); } static fromPinnedBuffer( @@ -401,9 +401,9 @@ export class Tensor implements TensorInterface { private gpuBufferData?: TensorGpuBufferType; /** - * stores the underlying WebNN MLBuffer when location is 'ml-buffer'. otherwise empty. + * stores the underlying WebNN MLTensor when location is 'ml-tensor'. otherwise empty. */ - private mlBufferData?: TensorMLBufferType; + private mlTensorData?: TensorMLTensorType; /** * stores an optional downloader function to download data from GPU to CPU. @@ -453,12 +453,12 @@ export class Tensor implements TensorInterface { return this.gpuBufferData; } - get mlBuffer(): TensorMLBufferType { + get mlTensor(): TensorMLTensorType { this.ensureValid(); - if (!this.mlBufferData) { + if (!this.mlTensorData) { throw new Error('The data is not stored as a WebNN buffer.'); } - return this.mlBufferData; + return this.mlTensorData; } // #endregion @@ -472,7 +472,7 @@ export class Tensor implements TensorInterface { return this.data; case 'texture': case 'gpu-buffer': - case 'ml-buffer': { + case 'ml-tensor': { if (!this.downloader) { throw new Error('The current tensor is not created with a specified data downloader.'); } @@ -513,7 +513,7 @@ export class Tensor implements TensorInterface { this.cpuData = undefined; this.gpuTextureData = undefined; this.gpuBufferData = undefined; - this.mlBufferData = undefined; + this.mlTensorData = undefined; this.downloader = undefined; this.isDownloading = undefined; diff --git a/js/common/lib/tensor-utils-impl.ts b/js/common/lib/tensor-utils-impl.ts index 4c4c9b1d80185..97b1735e6eac5 100644 --- a/js/common/lib/tensor-utils-impl.ts +++ b/js/common/lib/tensor-utils-impl.ts @@ -4,7 +4,7 @@ import { CpuPinnedConstructorParameters, GpuBufferConstructorParameters, - MLBufferConstructorParameters, + MLTensorConstructorParameters, TextureConstructorParameters, } from './tensor-factory.js'; import { Tensor } from './tensor-impl.js'; @@ -57,11 +57,11 @@ export const tensorReshape = (tensor: Tensor, dims: readonly number[]): Tensor = type: tensor.type as GpuBufferConstructorParameters['type'], dims, }); - case 'ml-buffer': + case 'ml-tensor': return new Tensor({ - location: 'ml-buffer', - mlBuffer: tensor.mlBuffer, - type: tensor.type as MLBufferConstructorParameters['type'], + location: 'ml-tensor', + mlTensor: tensor.mlTensor, + type: tensor.type as MLTensorConstructorParameters['type'], dims, }); default: diff --git a/js/common/lib/tensor.ts b/js/common/lib/tensor.ts index 636ab0704ffe5..b8e44abd9cb27 100644 --- a/js/common/lib/tensor.ts +++ b/js/common/lib/tensor.ts @@ -43,11 +43,11 @@ interface TypedTensorBase { readonly gpuBuffer: Tensor.GpuBufferType; /** - * Get the WebNN buffer that holds the tensor data. + * Get the WebNN MLTensor that holds the tensor data. * - * If the data is not in a WebNN MLBuffer, throw error. + * If the data is not in a WebNN MLTensor, throw error. */ - readonly mlBuffer: Tensor.MLBufferType; + readonly mlTensor: Tensor.MLTensorType; /** * Get the buffer data of the tensor. @@ -144,11 +144,11 @@ export declare namespace Tensor { export type GpuBufferType = { size: number; mapState: 'unmapped' | 'pending' | 'mapped' }; /** - * type alias for WebNN MLBuffer + * type alias for WebNN MLTensor * - * The specification for WebNN's ML Buffer is currently in flux. + * The specification for WebNN's MLTensor is currently in flux. */ - export type MLBufferType = unknown; + export type MLTensorType = unknown; /** * supported data types for constructing a tensor from a WebGPU buffer @@ -156,9 +156,9 @@ export declare namespace Tensor { export type GpuBufferDataTypes = 'float32' | 'float16' | 'int32' | 'int64' | 'uint32' | 'uint8' | 'bool'; /** - * supported data types for constructing a tensor from a WebNN MLBuffer + * supported data types for constructing a tensor from a WebNN MLTensor */ - export type MLBufferDataTypes = + export type MLTensorDataTypes = | 'float32' | 'float16' | 'int8' @@ -172,7 +172,7 @@ export declare namespace Tensor { /** * represent where the tensor data is stored */ - export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer' | 'ml-buffer'; + export type DataLocation = 'none' | 'cpu' | 'cpu-pinned' | 'texture' | 'gpu-buffer' | 'ml-tensor'; /** * represent the data type of a tensor diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 446bcb8a11bbb..2caa840ed03b5 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -12,7 +12,7 @@ import { DataType } from '../wasm-common'; import { getInstance } from '../wasm-factory'; import { createView } from './tensor-view'; -import { BufferId, createBufferManager } from './webnn/buffer-manager'; +import { TensorId, createTensorManager } from './webnn/tensor-manager'; import { configureLogger, LOG_DEBUG } from './log'; /* @@ -31,14 +31,14 @@ const onnxDataTypeToWebnnDataType = new Map([ ]); /** - * WebNN backend implementation. This class is used to keep track of the MLBuffers created by the backend and keep track + * WebNN backend implementation. This class is used to keep track of the MLTensors created by the backend and keep track * of the current MLContext being used by the sessions. */ export class WebNNBackend { /** - * Buffer managers for each session. + * Tensor managers for each session. */ - private bufferManager = createBufferManager(this); + private tensorManager = createTensorManager(this); /** * Maps from session id to MLContexts. */ @@ -96,7 +96,7 @@ export class WebNNBackend { sessionIds.delete(sessionId); if (sessionIds.size === 0) { this.sessionIdsByMLContext.delete(mlContext); - this.bufferManager.releaseBuffersForContext(mlContext); + this.tensorManager.releaseTensorsForContext(mlContext); } } @@ -104,53 +104,63 @@ export class WebNNBackend { return this.mlContextBySessionId.get(sessionId); } - public reserveBufferId(): BufferId { - return this.bufferManager.reserveBufferId(); + public reserveTensorId(): TensorId { + return this.tensorManager.reserveTensorId(); } - public releaseBufferId(bufferId: BufferId): void { - LOG_DEBUG('verbose', () => `[WebNN] releaseBufferId {bufferId: ${bufferId}}`); - this.bufferManager.releaseBufferId(bufferId); + public releaseTensorId(tensorId: TensorId): void { + LOG_DEBUG('verbose', () => `[WebNN] releaseTensorId {tensorId: ${tensorId}}`); + this.tensorManager.releaseTensorId(tensorId); } - public async ensureBuffer( - bufferId: BufferId, + public async ensureTensor( + tensorId: TensorId, onnxDataType: DataType, dimensions: number[], copyOld: boolean, - ): Promise { + ): Promise { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.bufferManager.ensureBuffer(bufferId, webnnDataType, dimensions, copyOld); + return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); } - public uploadBuffer(bufferId: BufferId, data: Uint8Array): void { + public uploadTensor(tensorId: TensorId, data: Uint8Array): void { const wasm = getInstance(); - if (!wasm.shouldTransferToMLBuffer) { - throw new Error('Trying to upload to a MLBuffer while shouldTransferToMLBuffer is false'); + if (!wasm.shouldTransferToMLTensor) { + throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false'); } - this.bufferManager.upload(bufferId, data); + LOG_DEBUG('verbose', () => `[WebNN] uploadBuffer {tensorId: ${tensorId}, data: ${data.byteLength}}`); + this.tensorManager.upload(tensorId, data); } - public async downloadBuffer(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise { - return this.bufferManager.download(bufferId, dstBuffer); + public async downloadTensor(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise { + return this.tensorManager.download(tensorId, dstBuffer); } - public createMLBufferDownloader(bufferId: BufferId, type: Tensor.MLBufferDataTypes): () => Promise { + public createMLTensorDownloader(tensorId: TensorId, type: Tensor.MLTensorDataTypes): () => Promise { return async () => { - const data = await this.bufferManager.download(bufferId); + const data = await this.tensorManager.download(tensorId); return createView(data, type); }; } - public registerMLBuffer(buffer: MLBuffer, onnxDataType: DataType, dimensions: number[]): BufferId { + public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.bufferManager.registerBuffer(this.currentContext, buffer, webnnDataType, dimensions); + + const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions); + LOG_DEBUG( + 'verbose', + () => + `[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${ + dimensions + }} -> {bufferId: ${id}}`, + ); + return id; } public flush(): void { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index d07b7e762ddd5..a13c8335f9e91 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -271,18 +271,18 @@ export const init = async ( jsepInit('webnn', [ backend, // jsepReserveBufferId - () => backend.reserveBufferId(), + () => backend.reserveTensorId(), // jsepReleaseBufferId, - (bufferId: number) => backend.releaseBufferId(bufferId), + (bufferId: number) => backend.releaseTensorId(bufferId), // jsepEnsureBuffer async (bufferId: number, onnxDataType: number, dimensions: number[], copyOld) => - backend.ensureBuffer(bufferId, onnxDataType, dimensions, copyOld), + backend.ensureTensor(bufferId, onnxDataType, dimensions, copyOld), // jsepUploadBuffer (bufferId: number, data: Uint8Array) => { - backend.uploadBuffer(bufferId, data); + backend.uploadTensor(bufferId, data); }, // jsepDownloadBuffer - async (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadBuffer(bufferId, dstBuffer), + async (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(bufferId, dstBuffer), ]); } }; diff --git a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts b/js/web/lib/wasm/jsep/webnn/buffer-manager.ts deleted file mode 100644 index 8338cddbba6ba..0000000000000 --- a/js/web/lib/wasm/jsep/webnn/buffer-manager.ts +++ /dev/null @@ -1,298 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -import { WebNNBackend } from '../backend-webnn'; -import { LOG_DEBUG } from '../log'; - -// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from -// WebNN API specification. -// https://github.com/webmachinelearning/webnn/issues/677 -/// - -export type BufferId = number; - -/** - * Manages BufferId to MLBuffer mapping. - */ -export interface BufferManager { - /** - * Reserve a new BufferId. - */ - reserveBufferId(): BufferId; - /** - * Release a BufferId. - */ - releaseBufferId(bufferId: BufferId): void; - /** - * Ensure a MLBuffer is created for the BufferId. - */ - ensureBuffer( - bufferId: BufferId, - dataType: MLOperandDataType, - dimensions: readonly number[], - copyOld: boolean, - ): Promise; - /** - * Upload data to a MLBuffer. - */ - upload(bufferId: BufferId, data: Uint8Array): void; - /** - * Download data from a MLBuffer. - */ - download(bufferId: BufferId): Promise; - download(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; - /** - * Release all buffers for a MLContext. - */ - releaseBuffersForContext(mlContext: MLContext): void; - /** - * Register an externally created MLBuffer with a given MLContext and return a BufferId. - */ - registerBuffer(mlContext: MLContext, mlBuffer: MLBuffer, dataType: MLOperandDataType, dimensions: number[]): BufferId; -} - -let bufferGuid = 1; -const createNewBufferId = (): BufferId => bufferGuid++; - -export type MLBufferEntry = [MLBuffer, MLOperandDataType, readonly number[]]; - -/** - * BufferTracker tracks the MLBuffer and pending upload data. - * - * We need to track the MLBuffer and pending upload data because we delay the creation of MLBuffer until - * we know the data type and dimensions. This is because future implementations of WebNN will only support creating - * MLBuffers with dataTypes and dimensions. - */ -class BufferTracker { - private bufferEntry?: MLBufferEntry; - private activeUpload?: Uint8Array; - private bufferCache: MLBufferEntry[]; - - constructor( - private mlContext?: MLContext, - bufferEntry?: MLBufferEntry, - ) { - this.bufferEntry = bufferEntry; - this.bufferCache = bufferEntry ? [bufferEntry] : []; - } - - public get buffer(): MLBuffer | undefined { - return this.bufferEntry?.[0]; - } - - public get context(): MLContext { - if (!this.mlContext) { - throw new Error('MLContext has not been set.'); - } - return this.mlContext; - } - - public set context(mlContext: MLContext) { - if (this.mlContext && this.mlContext !== mlContext) { - throw new Error('MLBuffer in use in a different MLContext.'); - } - this.mlContext = mlContext; - } - - public destroy(): void { - for (const [mlBuffer] of this.bufferCache) { - mlBuffer.destroy(); - } - this.bufferCache = []; - this.bufferEntry = undefined; - } - - public trySelectBuffer(context: MLContext, tryMlBuffer: MLBuffer): boolean { - for (const [mlBuffer, dataType, dimensions] of this.bufferCache) { - if (tryMlBuffer === mlBuffer) { - if (this.context !== context) { - throw new Error('MLBuffer cannot be registered with a different MLContext.'); - } - this.bufferEntry = [mlBuffer, dataType, dimensions]; - return true; - } - } - return false; - } - - public async ensureBuffer( - dataType: MLOperandDataType, - dimensions: readonly number[], - copyOld: boolean, - ): Promise { - if (this.bufferEntry) { - const [mlBuffer, existingDataType, existingDimensions] = this.bufferEntry; - if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { - return mlBuffer; - } - } - - for (const [mlBuffer, existingDataType, existingDimensions] of this.bufferCache) { - if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { - if (copyOld && this.bufferEntry) { - // WebNN does not support copyBufferToBuffer, so we need to read and write the buffers. - LOG_DEBUG( - 'verbose', - () => - `[WebNN] Slowdown may occur, having to copy existing buffer {dataType: ${ - dataType - }, dimensions: ${dimensions}}`, - ); - const data = await this.context.readBuffer(this.bufferEntry[0]); - this.context.writeBuffer(mlBuffer, data); - } - this.bufferEntry = [mlBuffer, existingDataType, existingDimensions]; - return mlBuffer; - } - } - LOG_DEBUG('verbose', () => `[WebNN] MLContext.createBuffer {dataType: ${dataType}, dimensions: ${dimensions}}`); - // eslint-disable-next-line no-bitwise - const usage = MLBufferUsage.READ_FROM | MLBufferUsage.WRITE_TO; - const buffer = await this.context.createBuffer({ dataType, dimensions, usage }); - this.bufferEntry = [buffer, dataType, dimensions]; - this.bufferCache.push(this.bufferEntry); - - if (this.activeUpload) { - this.mlContext?.writeBuffer(buffer, this.activeUpload); - this.activeUpload = undefined; - } - - return buffer; - } - - public upload(data: Uint8Array): void { - if (!this.bufferEntry) { - this.activeUpload = new Uint8Array(data); - return; - } - - this.mlContext?.writeBuffer(this.bufferEntry[0], data); - } - - public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { - if (this.activeUpload) { - if (dstBuffer) { - if (dstBuffer instanceof ArrayBuffer) { - new Uint8Array(dstBuffer).set(this.activeUpload); - } else { - new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); - } - - return; - } else { - return this.activeUpload.buffer; - } - } - if (!this.bufferEntry) { - throw new Error('Buffer has not been created.'); - } - if (dstBuffer) { - return this.context.readBuffer(this.bufferEntry[0], dstBuffer); - } - return this.context.readBuffer(this.bufferEntry[0]); - } -} - -class BufferManagerImpl implements BufferManager { - private buffersById = new Map(); - private bufferIdsByContext = new Map>(); - - constructor(private backend: WebNNBackend) {} - - public reserveBufferId(): BufferId { - const bufferId = createNewBufferId(); - this.buffersById.set(bufferId, new BufferTracker()); - return bufferId; - } - - public releaseBufferId(bufferId: BufferId): void { - const bufferTracker = this.buffersById.get(bufferId); - if (!bufferTracker) { - return; - } - bufferTracker.destroy(); - this.buffersById.delete(bufferId); - for (const [mlContext, buffers] of this.bufferIdsByContext) { - if (buffers.has(bufferId)) { - buffers.delete(bufferId); - if (buffers.size === 0) { - this.bufferIdsByContext.delete(mlContext); - } - break; - } - } - } - - public async ensureBuffer( - bufferId: BufferId, - dataType: MLOperandDataType, - dimensions: number[], - copyOld: boolean, - ): Promise { - LOG_DEBUG( - 'verbose', - () => - `[WebNN] BufferManager.ensureBuffer {bufferId: ${bufferId}, dataType: ${ - dataType - }, dimensions: ${dimensions}, copyOld: ${copyOld}}`, - ); - const buffer = this.buffersById.get(bufferId); - if (!buffer) { - throw new Error('Buffer not found.'); - } - buffer.context = this.backend.currentContext; - if (!this.bufferIdsByContext.has(this.backend.currentContext)) { - this.bufferIdsByContext.set(this.backend.currentContext, new Set()); - } - this.bufferIdsByContext.get(this.backend.currentContext)?.add(bufferId); - return buffer.ensureBuffer(dataType, dimensions, copyOld); - } - - public upload(bufferId: BufferId, data: Uint8Array): void { - this.buffersById.get(bufferId)!.upload(data); - } - - public async download(bufferId: BufferId): Promise; - public async download(bufferId: BufferId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; - async download(bufferId: BufferId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { - LOG_DEBUG('verbose', () => `[WebNN] BufferManager.download {bufferId: ${bufferId}, dstBuffer: ${dstBuffer}}`); - return this.buffersById.get(bufferId)!.download(dstBuffer); - } - - public releaseBuffersForContext(mlContext: MLContext): void { - const buffers = this.bufferIdsByContext.get(mlContext); - if (!buffers) { - return; - } - for (const bufferId of buffers) { - this.buffersById.get(bufferId)!.destroy(); - this.buffersById.delete(bufferId); - } - this.bufferIdsByContext.delete(mlContext); - } - - public registerBuffer( - mlContext: MLContext, - mlBuffer: MLBuffer, - dataType: MLOperandDataType, - dimensions: readonly number[], - ): BufferId { - for (const [bufferId, bufferTracker] of this.buffersById) { - if (bufferTracker.trySelectBuffer(mlContext, mlBuffer)) { - return bufferId; - } - } - const bufferId = createNewBufferId(); - this.buffersById.set(bufferId, new BufferTracker(mlContext, [mlBuffer, dataType, dimensions])); - let buffers = this.bufferIdsByContext.get(mlContext); - if (!buffers) { - buffers = new Set(); - this.bufferIdsByContext.set(mlContext, buffers); - } - buffers.add(bufferId); - return bufferId; - } -} - -export const createBufferManager = (...args: ConstructorParameters): BufferManager => - new BufferManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts new file mode 100644 index 0000000000000..1f6550fb578b8 --- /dev/null +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -0,0 +1,300 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { WebNNBackend } from '../backend-webnn'; +import { LOG_DEBUG } from '../log'; + +// WebNN API currently does not have a TypeScript definition file. This file is a workaround with types generated from +// WebNN API specification. +// https://github.com/webmachinelearning/webnn/issues/677 +/// + +export type TensorId = number; + +/** + * Manages TensorId to MLTensor mapping. + */ +export interface TensorManager { + /** + * Reserve a new TensorId. + */ + reserveTensorId(): TensorId; + /** + * Release a TensorId. + */ + releaseTensorId(tensorId: TensorId): void; + /** + * Ensure a MLTensor is created for the TensorId. + */ + ensureTensor( + tensorId: TensorId, + dataType: MLOperandDataType, + dimensions: readonly number[], + copyOld: boolean, + ): Promise; + /** + * Upload data to a MLTensor. + */ + upload(tensorId: TensorId, data: Uint8Array): void; + /** + * Download data from a MLTensor. + */ + download(tensorId: TensorId): Promise; + download(tensorId: TensorId, dstTensor: ArrayBufferView | ArrayBuffer): Promise; + /** + * Release all tensors for a MLContext. + */ + releaseTensorsForContext(mlContext: MLContext): void; + /** + * Register an externally created MLTensor with a given MLContext and return a TensorId. + */ + registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, dimensions: number[]): TensorId; +} + +let tensorGuid = 1; +const createNewTensorId = (): TensorId => tensorGuid++; + +export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]]; + +/** + * TensorTracker tracks the MLTensor and pending upload data. + * + * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until + * we know the data type and dimensions. This is because future implementations of WebNN will only support creating + * MLTensors with dataTypes and dimensions. + */ +class TensorTracker { + private tensorEntry?: MLTensorEntry; + private activeUpload?: Uint8Array; + private tensorCache: MLTensorEntry[]; + + constructor( + private mlContext?: MLContext, + tensorEntry?: MLTensorEntry, + ) { + this.tensorEntry = tensorEntry; + this.tensorCache = tensorEntry ? [tensorEntry] : []; + } + + public get tensor(): MLTensor | undefined { + return this.tensorEntry?.[0]; + } + + public get context(): MLContext { + if (!this.mlContext) { + throw new Error('MLContext has not been set.'); + } + return this.mlContext; + } + + public set context(mlContext: MLContext) { + if (this.mlContext && this.mlContext !== mlContext) { + throw new Error('MLTensor in use in a different MLContext.'); + } + this.mlContext = mlContext; + } + + public destroy(): void { + for (const [mlTensor] of this.tensorCache) { + mlTensor.destroy(); + } + this.tensorCache = []; + this.tensorEntry = undefined; + } + + public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean { + for (const [mlTensor, dataType, dimensions] of this.tensorCache) { + if (tryMLTensor === mlTensor) { + if (this.context !== context) { + throw new Error('MLTensor cannot be registered with a different MLContext.'); + } + this.tensorEntry = [mlTensor, dataType, dimensions]; + return true; + } + } + return false; + } + + public async ensureTensor( + dataType: MLOperandDataType, + dimensions: readonly number[], + copyOld: boolean, + ): Promise { + if (this.tensorEntry) { + const [mlTensor, existingDataType, existingDimensions] = this.tensorEntry; + if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + return mlTensor; + } + } + + for (const [mlTensor, existingDataType, existingDimensions] of this.tensorCache) { + if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + if (copyOld && this.tensorEntry) { + // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. + LOG_DEBUG( + 'verbose', + () => + `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${ + dataType + }, dimensions: ${dimensions}}`, + ); + const data = await this.context.readTensor(this.tensorEntry[0]); + this.context.writeTensor(mlTensor, data); + } + this.tensorEntry = [mlTensor, existingDataType, existingDimensions]; + return mlTensor; + } + } + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, dimensions: ${dimensions}}`); + // eslint-disable-next-line no-bitwise + const usage = MLTensorUsage.READ_FROM | MLTensorUsage.WRITE_TO; + const tensor = await this.context.createTensor({ dataType, dimensions, usage }); + this.tensorEntry = [tensor, dataType, dimensions]; + this.tensorCache.push(this.tensorEntry); + + if (this.activeUpload) { + this.mlContext?.writeTensor(tensor, this.activeUpload); + this.activeUpload = undefined; + } + + return tensor; + } + + public upload(data: Uint8Array): void { + if (!this.tensorEntry) { + this.activeUpload = new Uint8Array(data); + return; + } + this.mlContext?.writeTensor(this.tensorEntry[0], data); + } + + public async download(dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + if (this.activeUpload) { + if (dstBuffer) { + if (dstBuffer instanceof ArrayBuffer) { + new Uint8Array(dstBuffer).set(this.activeUpload); + } else { + new Uint8Array(dstBuffer.buffer, dstBuffer.byteOffset, dstBuffer.byteLength).set(this.activeUpload); + } + + return; + } else { + return this.activeUpload.buffer; + } + } + if (!this.tensorEntry) { + throw new Error('Tensor has not been created.'); + } + if (dstBuffer) { + return this.context.readTensor(this.tensorEntry[0], dstBuffer); + } + return this.context.readTensor(this.tensorEntry[0]); + } +} + +class TensorManagerImpl implements TensorManager { + private tensorsById = new Map(); + private tensorIdsByContext = new Map>(); + + constructor(private backend: WebNNBackend) {} + + public reserveTensorId(): TensorId { + const tensorId = createNewTensorId(); + this.tensorsById.set(tensorId, new TensorTracker()); + return tensorId; + } + + public releaseTensorId(tensorId: TensorId): void { + const tensorTracker = this.tensorsById.get(tensorId); + if (!tensorTracker) { + return; + } + tensorTracker.destroy(); + this.tensorsById.delete(tensorId); + for (const [mlContext, tensors] of this.tensorIdsByContext) { + if (tensors.has(tensorId)) { + tensors.delete(tensorId); + if (tensors.size === 0) { + this.tensorIdsByContext.delete(mlContext); + } + break; + } + } + } + + public async ensureTensor( + tensorId: TensorId, + dataType: MLOperandDataType, + dimensions: number[], + copyOld: boolean, + ): Promise { + LOG_DEBUG( + 'verbose', + () => + `[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${ + dataType + }, dimensions: ${dimensions}, copyOld: ${copyOld}}`, + ); + const tensor = this.tensorsById.get(tensorId); + if (!tensor) { + throw new Error('Tensor not found.'); + } + tensor.context = this.backend.currentContext; + if (!this.tensorIdsByContext.has(this.backend.currentContext)) { + this.tensorIdsByContext.set(this.backend.currentContext, new Set()); + } + this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId); + return tensor.ensureTensor(dataType, dimensions, copyOld); + } + + public upload(tensorId: TensorId, data: Uint8Array): void { + this.tensorsById.get(tensorId)!.upload(data); + } + + public async download(tensorId: TensorId): Promise; + public async download(tensorId: TensorId, dstBuffer: ArrayBufferView | ArrayBuffer): Promise; + async download(tensorId: TensorId, dstBuffer?: ArrayBufferView | ArrayBuffer): Promise { + LOG_DEBUG( + 'verbose', + () => `[WebNN] TensorManager.download {tensorId: ${tensorId}, dstBuffer: ${dstBuffer?.byteLength}}`, + ); + return this.tensorsById.get(tensorId)!.download(dstBuffer); + } + + public releaseTensorsForContext(mlContext: MLContext): void { + const tensors = this.tensorIdsByContext.get(mlContext); + if (!tensors) { + return; + } + for (const tensorId of tensors) { + this.tensorsById.get(tensorId)!.destroy(); + this.tensorsById.delete(tensorId); + } + this.tensorIdsByContext.delete(mlContext); + } + + public registerTensor( + mlContext: MLContext, + mlTensor: MLTensor, + dataType: MLOperandDataType, + dimensions: readonly number[], + ): TensorId { + for (const [tensorId, tensorTracker] of this.tensorsById) { + if (tensorTracker.trySelectTensor(mlContext, mlTensor)) { + return tensorId; + } + } + const tensorId = createNewTensorId(); + this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, dimensions])); + let tensors = this.tensorIdsByContext.get(mlContext); + if (!tensors) { + tensors = new Set(); + this.tensorIdsByContext.set(mlContext, tensors); + } + tensors.add(tensorId); + return tensorId; + } +} + +export const createTensorManager = (...args: ConstructorParameters): TensorManager => + new TensorManagerImpl(...args); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 2376b0ce96634..386b850b7e221 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -381,32 +381,32 @@ interface MLGraphBuilder { where(condition: MLOperand, input: MLOperand, other: MLOperand): MLOperand; } -// Experimental MLBuffer interface +// Experimental MLTensor interface -interface MLBuffer { +interface MLTensor { destroy(): void; } -type MLNamedBuffers = Record; +type MLNamedBuffers = Record; -type MLBufferUsageFlags = number; +type MLTensorUsageFlags = number; -declare const MLBufferUsage: { - readonly WEBGPU_INTEROP: MLBufferUsageFlags; - readonly READ_FROM: MLBufferUsageFlags; - readonly WRITE_TO: MLBufferUsageFlags; +declare const MLTensorUsage: { + readonly WEBGPU_INTEROP: MLTensorUsageFlags; + readonly READ_FROM: MLTensorUsageFlags; + readonly WRITE_TO: MLTensorUsageFlags; }; -interface MLBufferDescriptor extends MLOperandDescriptor { - usage: MLBufferUsageFlags; +interface MLTensorDescriptor extends MLOperandDescriptor { + usage: MLTensorUsageFlags; } interface MLContext { - createBuffer(descriptor: MLBufferDescriptor): Promise; - writeBuffer( - dstBuffer: MLBuffer, srcData: ArrayBufferView|ArrayBuffer, srcElementOffset?: number, + createTensor(descriptor: MLTensorDescriptor): Promise; + writeTensor( + destinationTensor: MLTensor, sourceData: ArrayBufferView|ArrayBuffer, sourceElementOffset?: number, srcElementSize?: number): void; - readBuffer(srcBuffer: MLBuffer): Promise; - readBuffer(srcBuffer: MLBuffer, dstBuffer: ArrayBufferView|ArrayBuffer): Promise; + readTensor(sourceTensor: MLTensor): Promise; + readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; } diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 58aea4d0c6591..559f319a10f66 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -19,18 +19,18 @@ export type GpuBufferMetadata = { dispose?: () => void; }; -export type MLBufferMetadata = { - mlBuffer: Tensor.MLBufferType; - download?: () => Promise; +export type MLTensorMetadata = { + mlTensor: Tensor.MLTensorType; + download?: () => Promise; dispose?: () => void; }; /** - * Tensors on location 'cpu-pinned', 'gpu-buffer', and 'ml-buffer' are not serializable. + * Tensors on location 'cpu-pinned', 'gpu-buffer', and 'ml-tensor' are not serializable. */ export type UnserializableTensorMetadata = | [dataType: Tensor.Type, dims: readonly number[], data: GpuBufferMetadata, location: 'gpu-buffer'] - | [dataType: Tensor.Type, dims: readonly number[], data: MLBufferMetadata, location: 'ml-buffer'] + | [dataType: Tensor.Type, dims: readonly number[], data: MLTensorMetadata, location: 'ml-tensor'] | [dataType: Tensor.Type, dims: readonly number[], data: Tensor.DataType, location: 'cpu-pinned']; /** @@ -41,7 +41,7 @@ export type UnserializableTensorMetadata = * - cpu: Uint8Array * - cpu-pinned: Uint8Array * - gpu-buffer: GpuBufferMetadata - * - ml-buffer: MLBufferMetadata + * - ml-tensor: MLTensorMetadata * - location: tensor data location */ export type TensorMetadata = SerializableTensorMetadata | UnserializableTensorMetadata; diff --git a/js/web/lib/wasm/session-handler-inference.ts b/js/web/lib/wasm/session-handler-inference.ts index 7ea52f3f470b7..c19043cc3637f 100644 --- a/js/web/lib/wasm/session-handler-inference.ts +++ b/js/web/lib/wasm/session-handler-inference.ts @@ -12,7 +12,7 @@ import { import { SerializableInternalBuffer, TensorMetadata } from './proxy-messages'; import { copyFromExternalBuffer, createSession, endProfiling, releaseSession, run } from './proxy-wrapper'; -import { isGpuBufferSupportedType, isMLBufferSupportedType } from './wasm-common'; +import { isGpuBufferSupportedType, isMLTensorSupportedType } from './wasm-common'; import { isNode } from './wasm-utils-env'; import { loadFile } from './wasm-utils-load-file'; @@ -22,8 +22,8 @@ export const encodeTensorMetadata = (tensor: Tensor, getName: () => string): Ten return [tensor.type, tensor.dims, tensor.data, 'cpu']; case 'gpu-buffer': return [tensor.type, tensor.dims, { gpuBuffer: tensor.gpuBuffer }, 'gpu-buffer']; - case 'ml-buffer': - return [tensor.type, tensor.dims, { mlBuffer: tensor.mlBuffer }, 'ml-buffer']; + case 'ml-tensor': + return [tensor.type, tensor.dims, { mlTensor: tensor.mlTensor }, 'ml-tensor']; default: throw new Error(`invalid data location: ${tensor.location} for ${getName()}`); } @@ -41,13 +41,13 @@ export const decodeTensorMetadata = (tensor: TensorMetadata): Tensor => { const { gpuBuffer, download, dispose } = tensor[2]; return Tensor.fromGpuBuffer(gpuBuffer, { dataType, dims: tensor[1], download, dispose }); } - case 'ml-buffer': { + case 'ml-tensor': { const dataType = tensor[0]; - if (!isMLBufferSupportedType(dataType)) { - throw new Error(`not supported data type: ${dataType} for deserializing MLBuffer tensor`); + if (!isMLTensorSupportedType(dataType)) { + throw new Error(`not supported data type: ${dataType} for deserializing MLTensor tensor`); } - const { mlBuffer, download, dispose } = tensor[2]; - return Tensor.fromMLBuffer(mlBuffer, { dataType, dims: tensor[1], download, dispose }); + const { mlTensor, download, dispose } = tensor[2]; + return Tensor.fromMLTensor(mlTensor, { dataType, dims: tensor[1], download, dispose }); } default: throw new Error(`invalid data location: ${tensor[3]}`); diff --git a/js/web/lib/wasm/wasm-common.ts b/js/web/lib/wasm/wasm-common.ts index bb4b323bff8c7..ad2ff62587252 100644 --- a/js/web/lib/wasm/wasm-common.ts +++ b/js/web/lib/wasm/wasm-common.ts @@ -241,9 +241,9 @@ export const isGpuBufferSupportedType = (type: Tensor.Type): type is Tensor.GpuB type === 'int4'; /** - * Check whether the given tensor type is supported by WebNN MLBuffer + * Check whether the given tensor type is supported by WebNN MLTensor */ -export const isMLBufferSupportedType = (type: Tensor.Type): type is Tensor.MLBufferDataTypes => +export const isMLTensorSupportedType = (type: Tensor.Type): type is Tensor.MLTensorDataTypes => type === 'float32' || type === 'float16' || type === 'int32' || @@ -269,7 +269,7 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number return 3; case 'gpu-buffer': return 4; - case 'ml-buffer': + case 'ml-tensor': return 5; default: throw new Error(`unsupported data location: ${location}`); @@ -280,4 +280,4 @@ export const dataLocationStringToEnum = (location: Tensor.DataLocation): number * Map integer data location to string value */ export const dataLocationEnumToString = (location: number): Tensor.DataLocation | undefined => - (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer', 'ml-buffer'] as const)[location]; + (['none', 'cpu', 'cpu-pinned', 'texture', 'gpu-buffer', 'ml-tensor'] as const)[location]; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 250c9358fe8d5..f912813b43458 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -20,7 +20,7 @@ import { calculateTensorSizeInBytes, dataLocationStringToEnum, isGpuBufferSupportedType, - isMLBufferSupportedType, + isMLTensorSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, @@ -164,7 +164,7 @@ export const initEp = async (env: Env, epName: string): Promise => { /** * valid data locations for input/output tensors. */ -type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-buffer'; +type SupportedTensorDataLocationForInputOutput = 'cpu' | 'cpu-pinned' | 'gpu-buffer' | 'ml-tensor'; type IOBindingState = { /** @@ -175,7 +175,7 @@ type IOBindingState = { /** * the preferred location for each output tensor. * - * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-buffer'. + * value is one of 'cpu', 'cpu-pinned', 'gpu-buffer', 'ml-tensor'. */ readonly outputPreferredLocations: readonly SupportedTensorDataLocationForInputOutput[]; @@ -289,7 +289,7 @@ export const createSession = async ( for (const provider of options?.executionProviders ?? []) { const providerName = typeof provider === 'string' ? provider : provider.name; if (providerName === 'webnn') { - wasm.shouldTransferToMLBuffer = false; + wasm.shouldTransferToMLTensor = false; if (wasm.currentContext) { throw new Error('WebNN execution provider is already set.'); } @@ -323,7 +323,7 @@ export const createSession = async ( if (wasm.currentContext) { wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext); wasm.currentContext = undefined; - wasm.shouldTransferToMLBuffer = true; + wasm.shouldTransferToMLTensor = true; } const [inputCount, outputCount] = getSessionInputOutputCount(sessionHandle); @@ -359,7 +359,7 @@ export const createSession = async ( typeof options?.preferredOutputLocation === 'string' ? options.preferredOutputLocation : (options?.preferredOutputLocation?.[nameString] ?? 'cpu'); - if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-buffer') { + if (location !== 'cpu' && location !== 'cpu-pinned' && location !== 'gpu-buffer' && location !== 'ml-tensor') { throw new Error(`Not supported preferred output location: ${location}.`); } if (enableGraphCapture && location !== 'gpu-buffer') { @@ -371,9 +371,9 @@ export const createSession = async ( } } - // use IO binding only when at least one output is preffered to be on GPU. + // use IO binding only when at least one output is preferred to be on GPU. let bindingState: IOBindingState | null = null; - if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-buffer')) { + if (!BUILD_DEFS.DISABLE_JSEP && outputPreferredLocations.some((l) => l === 'gpu-buffer' || l === 'ml-tensor')) { ioBindingHandle = wasm._OrtCreateBinding(sessionHandle); if (ioBindingHandle === 0) { checkLastError("Can't create IO binding."); @@ -464,7 +464,7 @@ export const prepareInputOutputTensor = ( let rawData: number; let dataByteLength: number; - if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-buffer')) { + if (dataType === 'string' && (location === 'gpu-buffer' || location === 'ml-tensor')) { throw new Error('String tensor is not supported on GPU.'); } @@ -483,15 +483,15 @@ export const prepareInputOutputTensor = ( throw new Error('Tensor location "gpu-buffer" is not supported without using WebGPU.'); } rawData = registerBuffer(sessionId, index, gpuBuffer, dataByteLength); - } else if (location === 'ml-buffer') { - const mlBuffer = tensor[2].mlBuffer as MLBuffer; + } else if (location === 'ml-tensor') { + const mlTensor = tensor[2].mlTensor as MLTensor; dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!; - const registerMLBuffer = wasm.jsepRegisterMLBuffer; - if (!registerMLBuffer) { - throw new Error('Tensor location "ml-buffer" is not supported without using WebNN.'); + const registerMLTensor = wasm.jsepRegisterMLTensor; + if (!registerMLTensor) { + throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - rawData = registerMLBuffer(mlBuffer, tensorDataTypeStringToEnum(dataType), dims); + rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -577,7 +577,7 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * 4); try { - // WebNN backend needs the active session to check MLBuffers with the current context. + // WebNN backend needs the active session to check MLTensors with the current context. wasm.jsepOnRunStart?.(sessionHandle); [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); @@ -742,7 +742,7 @@ export const run = async ( const preferredLocation = ioBindingState?.outputPreferredLocations[outputIndices[i]]; if (type === 'string') { - if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-buffer') { + if (preferredLocation === 'gpu-buffer' || preferredLocation === 'ml-tensor') { throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; @@ -782,20 +782,20 @@ export const run = async ( }, 'gpu-buffer', ]); - } else if (preferredLocation === 'ml-buffer' && size > 0) { - const ensureBuffer = wasm.jsepEnsureBuffer; + } else if (preferredLocation === 'ml-tensor' && size > 0) { + const ensureBuffer = wasm.jsepEnsureTensor; if (!ensureBuffer) { - throw new Error('preferredLocation "ml-buffer" is not supported without using WebNN.'); + throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.'); } const bufferSize = calculateTensorSizeInBytes(dataType, size); - if (bufferSize === undefined || !isMLBufferSupportedType(type)) { + if (bufferSize === undefined || !isMLTensorSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); } // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use - // ensureBuffer to get/create the MLBuffer. In which case, we don't need to copy the data if a new buffer is + // ensureBuffer to get/create the MLTensor. In which case, we don't need to copy the data if a new buffer is // created. - const mlBuffer = await ensureBuffer(dataOffset, dataType, dims, false); + const mlTensor = await ensureBuffer(dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; @@ -804,14 +804,14 @@ export const run = async ( type, dims, { - mlBuffer, - download: wasm.jsepCreateMLBufferDownloader!(dataOffset, type), + mlTensor, + download: wasm.jsepCreateMLTensorDownloader!(dataOffset, type), dispose: () => { - wasm.jsepReleaseBufferId!(dataOffset); + wasm.jsepReleaseTensorId!(dataOffset); wasm._OrtReleaseTensor(tensor); }, }, - 'ml-buffer', + 'ml-tensor', ]); } else { const typedArrayConstructor = tensorTypeToTypedArrayConstructor(type); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index af623788af606..82049d04d9926 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -28,16 +28,16 @@ export declare namespace JSEP { type CaptureBeginFunction = () => void; type CaptureEndFunction = () => void; type ReplayFunction = () => void; - type ReserveBufferIdFunction = () => number; - type ReleaseBufferIdFunction = (bufferId: number) => void; - type EnsureBufferFunction = ( - bufferId: number, + type ReserveTensorIdFunction = () => number; + type ReleaseTensorIdFunction = (tensorId: number) => void; + type EnsureTensorFunction = ( + tensorId: number, dataType: DataType, dimensions: readonly number[], copyOld: boolean, - ) => Promise; - type UploadBufferFunction = (bufferId: number, data: Uint8Array) => void; - type DownloadBufferFunction = (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; + ) => Promise; + type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void; + type DownloadTensorFunction = (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; export interface Module extends WebGpuModule, WebNnModule { /** @@ -77,11 +77,11 @@ export declare namespace JSEP { name: 'webnn', initParams: [ backend: BackendType, - reserveBufferId: ReserveBufferIdFunction, - releaseBufferId: ReleaseBufferIdFunction, - ensureBuffer: EnsureBufferFunction, - uploadBuffer: UploadBufferFunction, - downloadBuffer: DownloadBufferFunction, + reserveTensorId: ReserveTensorIdFunction, + releaseTensorId: ReleaseTensorIdFunction, + ensureTensor: EnsureTensorFunction, + uploadTensor: UploadTensorFunction, + downloadTensor: DownloadTensorFunction, ], ): void; } @@ -157,9 +157,9 @@ export declare namespace JSEP { currentContext: MLContext; /** - * Disables creating MLBuffers. This is used to avoid creating MLBuffers for graph initializers. + * Disables creating MLTensors. This is used to avoid creating MLTensors for graph initializers. */ - shouldTransferToMLBuffer: boolean; + shouldTransferToMLTensor: boolean; /** * [exported from pre-jsep.js] Register MLContext for a session. @@ -169,67 +169,62 @@ export declare namespace JSEP { */ jsepRegisterMLContext: (sessionId: number, context: MLContext) => void; /** - * [exported from pre-jsep.js] Reserve a MLBuffer ID attached to the current session. - * @returns the MLBuffer ID. + * [exported from pre-jsep.js] Reserve a MLTensor ID attached to the current session. + * @returns the MLTensor ID. */ - jsepReserveBufferId: () => number; + jsepReserveTensorId: () => number; /** - * [exported from pre-jsep.js] Release a MLBuffer ID from use and destroy buffer if no longer in use. - * @param bufferId - specify the MLBuffer ID. + * [exported from pre-jsep.js] Release a MLTensor ID from use and destroy buffer if no longer in use. + * @param tensorId - specify the MLTensor ID. * @returns */ - jsepReleaseBufferId: (bufferId: number) => void; + jsepReleaseTensorId: (tensorId: number) => void; /** - * [exported from pre-jsep.js] Ensure a MLBuffer of a given type and shape has exists for a buffer ID. - * @param bufferId - specify the buffer ID. + * [exported from pre-jsep.js] Ensure a MLTensor of a given type and shape has exists for a buffer ID. + * @param tensorId - specify the tensor ID. * @param onnxDataType - specify the data type. * @param dimensions - specify the dimensions. - * @param copyOld - specify whether to copy the old buffer if a new buffer was created. - * @returns the MLBuffer associated with the buffer ID. + * @param copyOld - specify whether to copy the old tensor if a new tensor was created. + * @returns the MLTensor associated with the tensor ID. */ - jsepEnsureBuffer: ( - bufferId: number, + jsepEnsureTensor: ( + tensorId: number, dataType: DataType, dimensions: number[], copyOld: boolean, - ) => Promise; + ) => Promise; /** - * [exported from pre-jsep.js] Upload data to MLBuffer. - * @param bufferId - specify the MLBuffer ID. + * [exported from pre-jsep.js] Upload data to MLTensor. + * @param tensorId - specify the MLTensor ID. * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType. * @param dimensions - specify the dimensions. * @returns */ - jsepUploadBuffer: (bufferId: number, data: Uint8Array) => void; + jsepUploadTensor: (tensorId: number, data: Uint8Array) => void; /** - * [exported from pre-jsep.js] Download data from MLBuffer. - * @param bufferId - specify the MLBuffer ID. + * [exported from pre-jsep.js] Download data from MLTensor. + * @param tensorId - specify the MLTensor ID. * @returns the downloaded data. */ - jsepDownloadBuffer: (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; + jsepDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; /** - * [exported from pre-jsep.js] Download data from MLBuffer. - * @param bufferId - specify the MLBuffer ID. - * @returns the downloaded data. - */ - /** - * [exported from pre-jsep.js] Create a downloader function to download data from MLBuffer. - * @param bufferId - specify the MLBuffer ID. + * [exported from pre-jsep.js] Create a downloader function to download data from MLTensor. + * @param tensorId - specify the MLTensor ID. * @param type - specify the data type. * @returns the downloader function. */ - jsepCreateMLBufferDownloader: ( - bufferId: number, - type: Tensor.MLBufferDataTypes, - ) => () => Promise; + jsepCreateMLTensorDownloader: ( + tensorId: number, + type: Tensor.MLTensorDataTypes, + ) => () => Promise; /** - * [exported from pre-jsep.js] Register MLBuffer for a session. - * @param mlBuffer - specify the MLBuffer. + * [exported from pre-jsep.js] Register MLTensor for a session. + * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. - * @returns the MLBuffer ID. + * @returns the MLTensor ID. */ - jsepRegisterMLBuffer: (buffer: MLBuffer, onnxDataType: DataType, dimensions: readonly number[]) => number; + jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; } } diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index 6e156c5e17516..e94e11d0ace56 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -62,8 +62,8 @@ Options: none (default) gpu-tensor use pre-allocated GPU tensors for inputs and outputs gpu-location use pre-allocated GPU tensors for inputs and set preferredOutputLocation to 'gpu-buffer' - ml-tensor use pre-allocated MLBuffer tensors for inputs and outputs - ml-location use pre-allocated MLBuffer tensors for inputs and set preferredOutputLocation to 'ml-buffer' + ml-tensor use pre-allocated MLTensor tensors for inputs and outputs + ml-location use pre-allocated MLTensor tensors for inputs and set preferredOutputLocation to 'ml-tensor' *** Logging Options *** diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 4d5ea53c7f860..29183ebd83657 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -24,7 +24,7 @@ import { createView } from '../lib/wasm/jsep/tensor-view'; import { calculateTensorSizeInBytes, isGpuBufferSupportedType, - isMLBufferSupportedType, + isMLTensorSupportedType, tensorDataTypeStringToEnum, } from '../lib/wasm/wasm-common'; @@ -180,7 +180,7 @@ async function initializeSession( if (ioBindingMode === 'gpu-location') { preferredOutputLocation = 'gpu-buffer'; } else if (ioBindingMode === 'ml-location') { - preferredOutputLocation = 'ml-buffer'; + preferredOutputLocation = 'ml-tensor'; } const profilerConfig = profile ? { maxNumberEvents: 65536 } : undefined; @@ -655,44 +655,44 @@ function createGpuTensorForOutput(type: ort.Tensor.Type, dims: readonly number[] } async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Type, dims: readonly number[]) { - if (!isMLBufferSupportedType(type)) { + if (!isMLTensorSupportedType(type)) { throw new Error(`createMLTensorForOutput can not work with ${type} tensor`); } const dataType = type === 'bool' ? 'uint8' : type; - const mlBuffer = await mlContext.createBuffer({ + const mlTensor = await mlContext.createTensor({ dataType, dimensions: dims as number[], - usage: MLBufferUsage.READ_FROM, + usage: MLTensorUsage.READ_FROM, }); - return ort.Tensor.fromMLBuffer(mlBuffer, { + return ort.Tensor.fromMLTensor(mlTensor, { dataType: type, dims, - dispose: () => mlBuffer.destroy(), + dispose: () => mlTensor.destroy(), download: async () => { - const arrayBuffer = await mlContext.readBuffer(mlBuffer); - return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.MLBufferDataTypes]; + const arrayBuffer = await mlContext.readTensor(mlTensor); + return createView(arrayBuffer, type) as ort.Tensor.DataTypeMap[ort.Tensor.MLTensorDataTypes]; }, }); } async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tensor): Promise { - if (!isMLBufferSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { + if (!isMLTensorSupportedType(cpuTensor.type) || Array.isArray(cpuTensor.data)) { throw new Error(`createMLTensorForInput can not work with ${cpuTensor.type} tensor`); } const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; - const mlBuffer = await mlContext.createBuffer({ + const mlTensor = await mlContext.createTensor({ dataType, dimensions: cpuTensor.dims as number[], - usage: MLBufferUsage.WRITE_TO, + usage: MLTensorUsage.WRITE_TO, }); - mlContext.writeBuffer(mlBuffer, cpuTensor.data); - return ort.Tensor.fromMLBuffer(mlBuffer, { + mlContext.writeTensor(mlTensor, cpuTensor.data); + return ort.Tensor.fromMLTensor(mlTensor, { dataType: cpuTensor.type, dims: cpuTensor.dims, - dispose: () => mlBuffer.destroy(), + dispose: () => mlTensor.destroy(), }); } diff --git a/js/web/test/test-types.ts b/js/web/test/test-types.ts index eddda1206eec9..29a11f969ffea 100644 --- a/js/web/test/test-types.ts +++ b/js/web/test/test-types.ts @@ -53,8 +53,8 @@ export declare namespace Test { * - gpu-tensor: inputs and outputs will all be pre-allocated as GPU tensors. `preferredOutputLocation` * will not be set. * - ml-location: inputs will be pre-allocated as ML tensors; no output will be pre-allocated; - * `preferredOutputLocation` will be set to `ml-buffer`. - * - ml-tensor: inputs and outputs will all be pre-allocated as MLBuffer tensors. `preferredOutputLocation` + * `preferredOutputLocation` will be set to `ml-tensor`. + * - ml-tensor: inputs and outputs will all be pre-allocated as MLTensor tensors. `preferredOutputLocation` * will not be set. */ export type IOBindingMode = 'none' | 'gpu-tensor' | 'gpu-location' | 'ml-tensor' | 'ml-location'; diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc index 4b8188a6f8344..6e83386eee7c0 100644 --- a/onnxruntime/core/providers/webnn/allocator.cc +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -12,11 +12,11 @@ void* WebNNBufferAllocator::Alloc(size_t size) { if (size == 0) { return nullptr; } - if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { - // We don't need to transfer the buffer to an MLBuffer, so we don't need to allocate buffer id. + if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { + // We don't need to transfer the buffer to an MLTensor, so we don't need to allocate buffer id. return nullptr; } - void* p = EM_ASM_PTR({ return Module.jsepReserveBufferId(); }); + void* p = EM_ASM_PTR({ return Module.jsepReserveTensorId(); }); allocations_[p] = size; stats_.num_allocs++; stats_.bytes_in_use += SafeInt(size); @@ -27,7 +27,7 @@ void WebNNBufferAllocator::Free(void* p) { if (p == nullptr) { return; } - EM_ASM({ Module.jsepReleaseBufferId($0); }, p); + EM_ASM({ Module.jsepReleaseTensorId($0); }, p); size_t size = allocations_[p]; stats_.bytes_in_use -= size; allocations_.erase(p); diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index 22271640ef57f..7f1a591234f4b 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -211,8 +211,8 @@ bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type) { } } -bool IsMLBufferSupported() { - static bool is_supported = !emscripten::val::global("MLBuffer").isUndefined(); +bool IsMLTensorSupported() { + static bool is_supported = !emscripten::val::global("MLTensor").isUndefined(); return is_supported; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index dc9ab5ac7f70a..8cd3cb65e707b 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -286,7 +286,7 @@ bool GetBidirectionalBroadcastShape(std::vector& shape_a, bool SetWebnnDataType(emscripten::val& desc, const int32_t data_type); -bool IsMLBufferSupported(); +bool IsMLTensorSupported(); } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index a8ded83aac90b..499ab60a65e23 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -154,7 +154,7 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { - auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer"); + auto jsepEnsureTensor = emscripten::val::module_property("jsepEnsureTensor"); auto promises = emscripten::val::array(); for (const auto& [_, tensor] : inputs) { emscripten::val shape = emscripten::val::array(); @@ -162,7 +162,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + auto buffer = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); promises.call("push", buffer); } for (const auto& [_, tensor] : outputs) { @@ -171,7 +171,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureBuffer(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + auto buffer = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", buffer); } auto buffers = emscripten::val::global("Promise").call("all", promises).await(); @@ -200,7 +200,7 @@ void Model::SetOutputMap(InlinedHashMap&& output_map) { // Pre-allocate the input and output buffers for the WebNN graph. void Model::AllocateInputOutputBuffers() { - // We don't need to allocate JS array buffers if the WebNN API supports MLBuffer. + // We don't need to allocate JS array buffers if the WebNN API supports MLTensor. if (use_dispatch_) { return; } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 86c2b9ec81b65..fd7e05d4a3a84 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -333,7 +333,7 @@ Status ModelBuilder::Compile(std::unique_ptr& model) { } // Explicitly release the WebNN builder to free memory. wnn_builder_ = emscripten::val::undefined(); - model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLBufferSupported())); + model.reset(new Model(std::move(wnn_context_), std::move(wnn_graph), logger_, IsMLTensorSupported())); model->SetInputs(std::move(input_names_)); model->SetOutputs(std::move(output_names_)); model->SetInputOutputInfo(std::move(input_output_info_)); diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index 833b5ecb7838c..a84c389c2fdaf 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -10,14 +10,14 @@ namespace onnxruntime { namespace webnn { bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { - // Copying data between MLBuffers is not supported by WebNN. + // Copying data between MLTensors is not supported by WebNN. return (dst_device.Type() == OrtDevice::GPU && src_device.Type() == OrtDevice::CPU) || (dst_device.Type() == OrtDevice::CPU && src_device.Type() == OrtDevice::GPU); } common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { - if (!emscripten::val::module_property("shouldTransferToMLBuffer").as()) { - // We don't need to transfer the buffer to an MLBuffer, so we don't need to copy the buffer. + if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { + // We don't need to transfer the buffer to an MLTensor, so we don't need to copy the buffer. return Status::OK(); } @@ -29,11 +29,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { const auto& dst_device = dst.Location().device; if (dst_device.Type() == OrtDevice::GPU) { - EM_ASM({ Module.jsepUploadBuffer($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); + EM_ASM({ Module.jsepUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast(src_data), bytes); } else { - auto jsepDownloadBuffer = emscripten::val::module_property("jsepDownloadBuffer"); + auto jsepDownloadTensor = emscripten::val::module_property("jsepDownloadTensor"); auto subarray = emscripten::typed_memory_view(bytes, static_cast(dst_data)); - jsepDownloadBuffer(reinterpret_cast(src_data), subarray).await(); + jsepDownloadTensor(reinterpret_cast(src_data), subarray).await(); } } diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 6ff4bde4c3eba..7a4440919ae3e 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -23,9 +23,9 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{ onnxruntime::kWebNNExecutionProvider, - // If MLBuffer is supported, we force all the tensors to be allocated as MLBuffer. + // If MLTensor is supported, we force all the tensors to be allocated as MLTensor. OrtDevice( - webnn::IsMLBufferSupported() ? OrtDevice::GPU : OrtDevice::CPU, + webnn::IsMLTensorSupported() ? OrtDevice::GPU : OrtDevice::CPU, OrtDevice::MemType::DEFAULT, 0)}, wnn_device_type_(webnn::DeviceTypeFromString(webnn_device_flags)) { @@ -398,14 +398,14 @@ WebNNExecutionProvider::GetKernelRegistry() const { } std::unique_ptr WebNNExecutionProvider::GetDataTransfer() const { - if (!webnn::IsMLBufferSupported()) { + if (!webnn::IsMLTensorSupported()) { return nullptr; } return std::make_unique(); } std::vector WebNNExecutionProvider::CreatePreferredAllocators() { - if (!webnn::IsMLBufferSupported()) { + if (!webnn::IsMLTensorSupported()) { return {}; } AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) { diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index f84af6c1a2325..4bba0148a93ff 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -24,7 +24,7 @@ enum DataLocation { DATA_LOCATION_CPU_PINNED = 2, DATA_LOCATION_TEXTURE = 3, DATA_LOCATION_GPU_BUFFER = 4, - DATA_LOCATION_ML_BUFFER = 5 + DATA_LOCATION_ML_TENSOR = 5 }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); @@ -237,7 +237,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && data_location != DATA_LOCATION_GPU_BUFFER && - data_location != DATA_LOCATION_ML_BUFFER) { + data_location != DATA_LOCATION_ML_TENSOR) { std::ostringstream ostr; ostr << "Invalid data location: " << data_location; CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); @@ -270,7 +270,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* case DATA_LOCATION_GPU_BUFFER: RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; - case DATA_LOCATION_ML_BUFFER: + case DATA_LOCATION_ML_TENSOR: RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; default: @@ -426,16 +426,16 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, output_location != DATA_LOCATION_CPU && output_location != DATA_LOCATION_CPU_PINNED && output_location != DATA_LOCATION_GPU_BUFFER && - output_location != DATA_LOCATION_ML_BUFFER) { + output_location != DATA_LOCATION_ML_TENSOR) { std::ostringstream ostr; ostr << "Invalid data location (" << output_location << ") for output: \"" << name << "\"."; return CheckStatus(Ort::GetApi().CreateStatus(ORT_INVALID_ARGUMENT, ostr.str().c_str())); } OrtMemoryInfo* memory_info = nullptr; - if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_BUFFER) { + if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_TENSOR) { RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); - } else if (output_location == DATA_LOCATION_ML_BUFFER) { + } else if (output_location == DATA_LOCATION_ML_TENSOR) { RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } else { RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 4aef3e361e9dc..f5aa3394213cf 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -208,15 +208,15 @@ Module['jsepInit'] = (name, params) => { // change the name. [Module.jsepBackend, - Module.jsepReserveBufferId, - Module.jsepReleaseBufferId, - Module['jsepEnsureBuffer'], - Module.jsepUploadBuffer, - Module['jsepDownloadBuffer'], + Module.jsepReserveTensorId, + Module.jsepReleaseTensorId, + Module['jsepEnsureTensor'], + Module.jsepUploadTensor, + Module['jsepDownloadTensor'], ] = params; // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. - Module['jsepReleaseBufferId'] = Module.jsepReleaseBufferId; + Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; @@ -229,11 +229,11 @@ Module['jsepInit'] = (name, params) => { Module['jsepOnReleaseSession'] = sessionId => { backend['onReleaseSession'](sessionId); }; - Module['jsepCreateMLBufferDownloader'] = (bufferId, type) => { - return backend['createMLBufferDownloader'](bufferId, type); + Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { + return backend['createMLTensorDownloader'](tensorId, type); } - Module['jsepRegisterMLBuffer'] = (buffer, dataType, dimensions) => { - return backend['registerMLBuffer'](buffer, dataType, dimensions); + Module['jsepRegisterMLTensor'] = (tensor, dataType, dimensions) => { + return backend['registerMLTensor'](tensor, dataType, dimensions); } } }; From 5e7f912c828c028803041495693e4ff319bd6390 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Mon, 9 Sep 2024 22:22:33 -0700 Subject: [PATCH 07/17] Missed a few renames --- onnxruntime/core/framework/allocator.cc | 2 +- onnxruntime/core/providers/webnn/allocator.cc | 6 +++--- onnxruntime/core/providers/webnn/allocator.h | 4 ++-- onnxruntime/core/providers/webnn/builders/model.cc | 14 +++++++------- .../core/providers/webnn/builders/model_builder.cc | 2 ++ .../providers/webnn/webnn_execution_provider.cc | 6 +++--- onnxruntime/wasm/api.cc | 4 ++-- 7 files changed, 20 insertions(+), 18 deletions(-) diff --git a/onnxruntime/core/framework/allocator.cc b/onnxruntime/core/framework/allocator.cc index 7bd9f64e5603f..04f027b0fb903 100644 --- a/onnxruntime/core/framework/allocator.cc +++ b/onnxruntime/core/framework/allocator.cc @@ -142,7 +142,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateMemoryInfo, _In_ const char* name1, enum OrtA strcmp(name1, onnxruntime::DML) == 0 || strcmp(name1, onnxruntime::HIP) == 0 || strcmp(name1, onnxruntime::WEBGPU_BUFFER) == 0 || - strcmp(name1, onnxruntime::WEBNN_BUFFER) == 0) { + strcmp(name1, onnxruntime::WEBNN_TENSOR) == 0) { *out = new OrtMemoryInfo( name1, type, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, static_cast(id1)), id1, mem_type1); diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc index 6e83386eee7c0..91e20f51cfa08 100644 --- a/onnxruntime/core/providers/webnn/allocator.cc +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -8,7 +8,7 @@ namespace onnxruntime { namespace webnn { -void* WebNNBufferAllocator::Alloc(size_t size) { +void* WebNNTensorAllocator::Alloc(size_t size) { if (size == 0) { return nullptr; } @@ -23,7 +23,7 @@ void* WebNNBufferAllocator::Alloc(size_t size) { return p; } -void WebNNBufferAllocator::Free(void* p) { +void WebNNTensorAllocator::Free(void* p) { if (p == nullptr) { return; } @@ -33,7 +33,7 @@ void WebNNBufferAllocator::Free(void* p) { allocations_.erase(p); } -void WebNNBufferAllocator::GetStats(AllocatorStats* stats) { +void WebNNTensorAllocator::GetStats(AllocatorStats* stats) { *stats = stats_; } diff --git a/onnxruntime/core/providers/webnn/allocator.h b/onnxruntime/core/providers/webnn/allocator.h index 6d9fd2c0542e2..c06da909801cc 100644 --- a/onnxruntime/core/providers/webnn/allocator.h +++ b/onnxruntime/core/providers/webnn/allocator.h @@ -13,9 +13,9 @@ namespace onnxruntime { namespace webnn { -class WebNNBufferAllocator : public IAllocator { +class WebNNTensorAllocator : public IAllocator { public: - WebNNBufferAllocator() : IAllocator(OrtMemoryInfo(WEBNN_BUFFER, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) {} + WebNNTensorAllocator() : IAllocator(OrtMemoryInfo(WEBNN_TENSOR, OrtAllocatorType::OrtDeviceAllocator, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0), 0, OrtMemTypeDefault)) {} void* Alloc(size_t size) override; diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 499ab60a65e23..9642173f904fe 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -162,8 +162,8 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); - promises.call("push", buffer); + auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + promises.call("push", ml_tensor); } for (const auto& [_, tensor] : outputs) { emscripten::val shape = emscripten::val::array(); @@ -171,15 +171,15 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto buffer = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); - promises.call("push", buffer); + auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + promises.call("push", ml_tensor); } - auto buffers = emscripten::val::global("Promise").call("all", promises).await(); + auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); for (const auto& [name, _] : inputs) { - wnn_inputs_.set(name, buffers.call("shift")); + wnn_inputs_.set(name, ml_tensors.call("shift")); } for (const auto& [name, _] : outputs) { - wnn_outputs_.set(name, buffers.call("shift")); + wnn_outputs_.set(name, ml_tensors.call("shift")); } wnn_context_.call("dispatch", wnn_graph_, wnn_inputs_, wnn_outputs_); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index fd7e05d4a3a84..65b75ba5a87fa 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -215,6 +215,8 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); } + emscripten::val::module_property("reserveTensor")(name, desc, is_input); + if (is_input) { wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); input_names_.push_back(name); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index 7a4440919ae3e..f67b359b3c28b 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -325,7 +325,7 @@ class WebNNMemcpy : public OpKernel { explicit WebNNMemcpy(const OpKernelInfo& info) : OpKernel(info) {} Status Compute(OpKernelContext* context) const override { - auto jsepEnsureBuffer = emscripten::val::module_property("jsepEnsureBuffer"); + auto jsepEnsureTensor = emscripten::val::module_property("jsepEnsureTensor"); const auto* X = context->Input(0); ORT_ENFORCE(X != nullptr, "Memcpy: input tensor is null"); auto* Y = context->Output(0, X->Shape()); @@ -335,7 +335,7 @@ class WebNNMemcpy : public OpKernel { shape.call("push", SafeInt(dim).Ref()); } - jsepEnsureBuffer(reinterpret_cast(Y->MutableDataRaw()), + jsepEnsureTensor(reinterpret_cast(Y->MutableDataRaw()), Y->GetElementType(), shape, false) .await(); @@ -409,7 +409,7 @@ std::vector WebNNExecutionProvider::CreatePreferredAllocators() { return {}; } AllocatorCreationInfo customAllocatorCreationInfo([&](OrtDevice::DeviceId) { - return std::make_unique(); + return std::make_unique(); }, 0, false); return {CreateAllocator(customAllocatorCreationInfo)}; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 4bba0148a93ff..5173125cb8634 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -271,7 +271,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; case DATA_LOCATION_ML_TENSOR: - RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_NULLPTR_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); break; default: RETURN_NULLPTR_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); @@ -436,7 +436,7 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, if (output_location != DATA_LOCATION_GPU_BUFFER && output_location != DATA_LOCATION_ML_TENSOR) { RETURN_ERROR_CODE_IF_ERROR(CreateCpuMemoryInfo, OrtDeviceAllocator, OrtMemTypeDefault, &memory_info); } else if (output_location == DATA_LOCATION_ML_TENSOR) { - RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); + RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebNN_Tensor", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } else { RETURN_ERROR_CODE_IF_ERROR(CreateMemoryInfo, "WebGPU_Buffer", OrtDeviceAllocator, 0, OrtMemTypeDefault, &memory_info); } From a76acd4b913823e80de5cefdc3af868df10dadf3 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Mon, 9 Sep 2024 22:24:30 -0700 Subject: [PATCH 08/17] Missing change to allocator.h --- include/onnxruntime/core/framework/allocator.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/onnxruntime/core/framework/allocator.h b/include/onnxruntime/core/framework/allocator.h index 17d8d804d4ae3..a2566c004bbf1 100644 --- a/include/onnxruntime/core/framework/allocator.h +++ b/include/onnxruntime/core/framework/allocator.h @@ -51,7 +51,7 @@ constexpr const char* HIP_PINNED = "HipPinned"; constexpr const char* OpenVINO_CPU = "OpenVINO_CPU"; constexpr const char* OpenVINO_GPU = "OpenVINO_GPU"; constexpr const char* WEBGPU_BUFFER = "WebGPU_Buffer"; -constexpr const char* WEBNN_BUFFER = "WebNN_Buffer"; +constexpr const char* WEBNN_TENSOR = "WebNN_Tensor"; constexpr size_t kAllocAlignment = 256; From f68de675b8fbabae5ca0f16969f77b14ba7996b1 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Mon, 9 Sep 2024 23:12:55 -0700 Subject: [PATCH 09/17] Removing extra call to backend function --- onnxruntime/core/providers/webnn/builders/model_builder.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 65b75ba5a87fa..fd7e05d4a3a84 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -215,8 +215,6 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); } - emscripten::val::module_property("reserveTensor")(name, desc, is_input); - if (is_input) { wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); input_names_.push_back(name); From 5e713ba663b16034b5b57c175d0188b98c771529 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 10 Sep 2024 10:29:09 -0700 Subject: [PATCH 10/17] More renames --- js/common/lib/tensor-factory.ts | 6 +++--- js/common/lib/tensor-impl.ts | 2 +- js/web/lib/wasm/jsep/backend-webnn.ts | 4 ++-- js/web/lib/wasm/jsep/init.ts | 22 +++++++++++----------- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 4 ++-- js/web/lib/wasm/wasm-core-impl.ts | 12 ++++++------ js/web/lib/wasm/wasm-types.ts | 14 +++++++------- 7 files changed, 32 insertions(+), 32 deletions(-) diff --git a/js/common/lib/tensor-factory.ts b/js/common/lib/tensor-factory.ts index 784e0d1a609d4..f66684112623e 100644 --- a/js/common/lib/tensor-factory.ts +++ b/js/common/lib/tensor-factory.ts @@ -95,7 +95,7 @@ export interface MLTensorConstructorParameters( - buffer: Tensor.MLTensorType, + tensor: Tensor.MLTensorType, options: TensorFromMLTensorOptions, ): TypedTensor; diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index e798977d6a92f..0680a15365f42 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -88,7 +88,7 @@ export class Tensor implements TensorInterface { constructor(params: GpuBufferConstructorParameters); /** - * Construct a new tensor object from the WebNN buffer with the given type and dims. + * Construct a new tensor object from the WebNN MLTensor with the given type and dims. * * Tensor's location will be set to 'ml-tensor'. * diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 2caa840ed03b5..685f3dc019461 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -131,7 +131,7 @@ export class WebNNBackend { if (!wasm.shouldTransferToMLTensor) { throw new Error('Trying to upload to a MLTensor while shouldTransferToMLTensor is false'); } - LOG_DEBUG('verbose', () => `[WebNN] uploadBuffer {tensorId: ${tensorId}, data: ${data.byteLength}}`); + LOG_DEBUG('verbose', () => `[WebNN] uploadTensor {tensorId: ${tensorId}, data: ${data.byteLength}}`); this.tensorManager.upload(tensorId, data); } @@ -158,7 +158,7 @@ export class WebNNBackend { () => `[WebNN] registerMLTensor {tensor: ${tensor}, dataType: ${webnnDataType}, dimensions: ${ dimensions - }} -> {bufferId: ${id}}`, + }} -> {tensorId: ${id}}`, ); return id; } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b6964064e8d63..71a48d7b74f1f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -270,19 +270,19 @@ export const init = async ( const backend = new WebNNBackend(env); jsepInit('webnn', [ backend, - // jsepReserveBufferId + // jsepReserveTensorId () => backend.reserveTensorId(), - // jsepReleaseBufferId, - (bufferId: number) => backend.releaseTensorId(bufferId), - // jsepEnsureBuffer - async (bufferId: number, onnxDataType: number, dimensions: number[], copyOld) => - backend.ensureTensor(bufferId, onnxDataType, dimensions, copyOld), - // jsepUploadBuffer - (bufferId: number, data: Uint8Array) => { - backend.uploadTensor(bufferId, data); + // jsepReleaseTensorId, + (tensorId: number) => backend.releaseTensorId(tensorId), + // jsepEnsureTensor + async (tensorId: number, onnxDataType: number, dimensions: number[], copyOld) => + backend.ensureTensor(tensorId, onnxDataType, dimensions, copyOld), + // jsepUploadTensor + (tensorId: number, data: Uint8Array) => { + backend.uploadTensor(tensorId, data); }, - // jsepDownloadBuffer - async (bufferId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(bufferId, dstBuffer), + // jsepDownloadTensor + async (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => backend.downloadTensor(tensorId, dstBuffer), ]); } }; diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 386b850b7e221..d6f93357512e4 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -387,7 +387,7 @@ interface MLTensor { destroy(): void; } -type MLNamedBuffers = Record; +type MLNamedTensor = Record; type MLTensorUsageFlags = number; @@ -408,5 +408,5 @@ interface MLContext { srcElementSize?: number): void; readTensor(sourceTensor: MLTensor): Promise; readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise; - dispatch(graph: MLGraph, inputs: MLNamedBuffers, outputs: MLNamedBuffers): void; + dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void; } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index f912813b43458..296da79a0ce5b 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -783,19 +783,19 @@ export const run = async ( 'gpu-buffer', ]); } else if (preferredLocation === 'ml-tensor' && size > 0) { - const ensureBuffer = wasm.jsepEnsureTensor; - if (!ensureBuffer) { + const ensureTensor = wasm.jsepEnsureTensor; + if (!ensureTensor) { throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.'); } - const bufferSize = calculateTensorSizeInBytes(dataType, size); - if (bufferSize === undefined || !isMLTensorSupportedType(type)) { + const tensorSize = calculateTensorSizeInBytes(dataType, size); + if (tensorSize === undefined || !isMLTensorSupportedType(type)) { throw new Error(`Unsupported data type: ${type}`); } // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use - // ensureBuffer to get/create the MLTensor. In which case, we don't need to copy the data if a new buffer is + // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new buffer is // created. - const mlTensor = await ensureBuffer(dataOffset, dataType, dims, false); + const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 82049d04d9926..e85a87d567777 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -174,14 +174,14 @@ export declare namespace JSEP { */ jsepReserveTensorId: () => number; /** - * [exported from pre-jsep.js] Release a MLTensor ID from use and destroy buffer if no longer in use. + * [exported from pre-jsep.js] Release an MLTensor ID from use and destroys underlying MLTensor if no longer in use. * @param tensorId - specify the MLTensor ID. * @returns */ jsepReleaseTensorId: (tensorId: number) => void; /** - * [exported from pre-jsep.js] Ensure a MLTensor of a given type and shape has exists for a buffer ID. - * @param tensorId - specify the tensor ID. + * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. + * @param tensorId - specify the MLTensor ID. * @param onnxDataType - specify the data type. * @param dimensions - specify the dimensions. * @param copyOld - specify whether to copy the old tensor if a new tensor was created. @@ -194,7 +194,7 @@ export declare namespace JSEP { copyOld: boolean, ) => Promise; /** - * [exported from pre-jsep.js] Upload data to MLTensor. + * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType. * @param dimensions - specify the dimensions. @@ -202,13 +202,13 @@ export declare namespace JSEP { */ jsepUploadTensor: (tensorId: number, data: Uint8Array) => void; /** - * [exported from pre-jsep.js] Download data from MLTensor. + * [exported from pre-jsep.js] Download data from an MLTensor. * @param tensorId - specify the MLTensor ID. * @returns the downloaded data. */ jsepDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise; /** - * [exported from pre-jsep.js] Create a downloader function to download data from MLTensor. + * [exported from pre-jsep.js] Creates a downloader function to download data from an MLTensor. * @param tensorId - specify the MLTensor ID. * @param type - specify the data type. * @returns the downloader function. @@ -218,7 +218,7 @@ export declare namespace JSEP { type: Tensor.MLTensorDataTypes, ) => () => Promise; /** - * [exported from pre-jsep.js] Register MLTensor for a session. + * [exported from pre-jsep.js] Registers an external MLTensor to a session. * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. From a43caa4dbfbedb00775ba0ce25d13eaa91234075 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 10 Sep 2024 11:35:28 -0700 Subject: [PATCH 11/17] Missing renames --- js/common/lib/tensor-impl.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 4 ++-- js/web/lib/wasm/wasm-types.ts | 2 +- onnxruntime/core/providers/webnn/allocator.cc | 2 +- onnxruntime/core/providers/webnn/builders/model.cc | 2 +- onnxruntime/core/providers/webnn/data_transfer.cc | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/js/common/lib/tensor-impl.ts b/js/common/lib/tensor-impl.ts index 0680a15365f42..e7b3f3ff35cb1 100644 --- a/js/common/lib/tensor-impl.ts +++ b/js/common/lib/tensor-impl.ts @@ -456,7 +456,7 @@ export class Tensor implements TensorInterface { get mlTensor(): TensorMLTensorType { this.ensureValid(); if (!this.mlTensorData) { - throw new Error('The data is not stored as a WebNN buffer.'); + throw new Error('The data is not stored as a WebNN MLTensor.'); } return this.mlTensorData; } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 296da79a0ce5b..3027c5ab65871 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -793,8 +793,8 @@ export const run = async ( } // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use - // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new buffer is - // created. + // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor + // has been created. const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index e85a87d567777..73fa79182a4e1 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -222,7 +222,7 @@ export declare namespace JSEP { * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. - * @returns the MLTensor ID. + * @returns the MLTensor ID for the external MLTensor. */ jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; } diff --git a/onnxruntime/core/providers/webnn/allocator.cc b/onnxruntime/core/providers/webnn/allocator.cc index 91e20f51cfa08..9c5cd651e1f00 100644 --- a/onnxruntime/core/providers/webnn/allocator.cc +++ b/onnxruntime/core/providers/webnn/allocator.cc @@ -13,7 +13,7 @@ void* WebNNTensorAllocator::Alloc(size_t size) { return nullptr; } if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { - // We don't need to transfer the buffer to an MLTensor, so we don't need to allocate buffer id. + // We don't need to transfer the tensor to an MLTensor, so we don't need to allocate an MLTensor id. return nullptr; } void* p = EM_ASM_PTR({ return Module.jsepReserveTensorId(); }); diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 9642173f904fe..349dd6b71381d 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -200,7 +200,7 @@ void Model::SetOutputMap(InlinedHashMap&& output_map) { // Pre-allocate the input and output buffers for the WebNN graph. void Model::AllocateInputOutputBuffers() { - // We don't need to allocate JS array buffers if the WebNN API supports MLTensor. + // We don't need to allocate JS ArrayBuffers if the WebNN API supports MLTensor. if (use_dispatch_) { return; } diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index a84c389c2fdaf..d98407eb3bdfb 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -17,7 +17,7 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { - // We don't need to transfer the buffer to an MLTensor, so we don't need to copy the buffer. + // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the buffer. return Status::OK(); } From 9c87bd8418dc7f449a0fb4d4c16d6cf4112265b4 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 10 Sep 2024 11:43:17 -0700 Subject: [PATCH 12/17] Missed buffer reference in comment --- onnxruntime/core/providers/webnn/data_transfer.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/webnn/data_transfer.cc b/onnxruntime/core/providers/webnn/data_transfer.cc index d98407eb3bdfb..44e9bf9edf3d9 100644 --- a/onnxruntime/core/providers/webnn/data_transfer.cc +++ b/onnxruntime/core/providers/webnn/data_transfer.cc @@ -17,7 +17,7 @@ bool DataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_dev common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (!emscripten::val::module_property("shouldTransferToMLTensor").as()) { - // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the buffer. + // We don't need to transfer the tensor to an MLTensor, so we don't need to copy the data. return Status::OK(); } From 722802e214d4b35370703ccf708a3cd7eedf391b Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 17 Sep 2024 11:50:04 -0700 Subject: [PATCH 13/17] Rename READ_FROM/WRITE_TO to READ/WRITE --- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 2 +- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 4 ++-- js/web/test/test-runner.ts | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 1f6550fb578b8..c7ce3a17321ca 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -147,7 +147,7 @@ class TensorTracker { } LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, dimensions: ${dimensions}}`); // eslint-disable-next-line no-bitwise - const usage = MLTensorUsage.READ_FROM | MLTensorUsage.WRITE_TO; + const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; const tensor = await this.context.createTensor({ dataType, dimensions, usage }); this.tensorEntry = [tensor, dataType, dimensions]; this.tensorCache.push(this.tensorEntry); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index d6f93357512e4..de03ac814f107 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -393,8 +393,8 @@ type MLTensorUsageFlags = number; declare const MLTensorUsage: { readonly WEBGPU_INTEROP: MLTensorUsageFlags; - readonly READ_FROM: MLTensorUsageFlags; - readonly WRITE_TO: MLTensorUsageFlags; + readonly READ: MLTensorUsageFlags; + readonly WRITE: MLTensorUsageFlags; }; interface MLTensorDescriptor extends MLOperandDescriptor { diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 29183ebd83657..3d7eba62cd226 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -664,7 +664,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty const mlTensor = await mlContext.createTensor({ dataType, dimensions: dims as number[], - usage: MLTensorUsage.READ_FROM, + usage: MLTensorUsage.READ, }); return ort.Tensor.fromMLTensor(mlTensor, { @@ -686,7 +686,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso const mlTensor = await mlContext.createTensor({ dataType, dimensions: cpuTensor.dims as number[], - usage: MLTensorUsage.WRITE_TO, + usage: MLTensorUsage.WRITE, }); mlContext.writeTensor(mlTensor, cpuTensor.data); return ort.Tensor.fromMLTensor(mlTensor, { From 16c14fbf04e7bb86f2f84e7670c3e9d668dcc1a7 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 17 Sep 2024 14:09:35 -0700 Subject: [PATCH 14/17] PR feedback --- onnxruntime/core/providers/webnn/builders/helper.cc | 2 +- onnxruntime/core/providers/webnn/builders/helper.h | 2 +- onnxruntime/core/providers/webnn/builders/model.cc | 1 - onnxruntime/wasm/pre-jsep.js | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index a0ee0156b3cd2..b90c7d76a6507 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -12,7 +12,7 @@ namespace onnxruntime { namespace webnn { -WebnnDeviceType DeviceTypeFromString(const std::string& device_type) { +WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type) { if (device_type == "gpu") { return WebnnDeviceType::GPU; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 854944728061f..ccb661ec11973 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -31,7 +31,7 @@ enum class WebnnDeviceType { NPU, }; -WebnnDeviceType DeviceTypeFromString(const std::string& device_type); +WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type); // Collects all the initializer tensors in the subGraph and its ancestor graphs. InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer); diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 349dd6b71381d..fcfdb146bff34 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -28,7 +28,6 @@ Status Model::Predict(const InlinedHashMap& inputs, const InlinedHashMap& outputs) { if (use_dispatch_) { return Dispatch(inputs, outputs); - } else { return Compute(inputs, outputs); } diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index f5aa3394213cf..3d3870801e212 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -202,7 +202,7 @@ Module['jsepInit'] = (name, params) => { Module.jsepUploadExternalBuffer = (dataId, buffer) => { backend['upload'](dataId, buffer); }; - } else if(name === 'webnn') { + } else if (name === 'webnn') { // Functions called from EM_ASM need to be assigned in a way that can be minified. // Functions called via emscripten::val::module_property need to be assigned by name so that the minifier doesn't // change the name. From bbbb0b15f2067541127629caeed4eb19469467fc Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 17 Sep 2024 22:55:20 -0700 Subject: [PATCH 15/17] Rename dimensions to shape --- js/web/lib/wasm/jsep/init.ts | 4 +- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 42 ++++++++++---------- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 4 +- js/web/lib/wasm/wasm-types.ts | 7 ++-- js/web/test/test-runner.ts | 4 +- onnxruntime/wasm/pre-jsep.js | 4 +- 6 files changed, 32 insertions(+), 33 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 71a48d7b74f1f..7bce5ff9390e8 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -275,8 +275,8 @@ export const init = async ( // jsepReleaseTensorId, (tensorId: number) => backend.releaseTensorId(tensorId), // jsepEnsureTensor - async (tensorId: number, onnxDataType: number, dimensions: number[], copyOld) => - backend.ensureTensor(tensorId, onnxDataType, dimensions, copyOld), + async (tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(tensorId, onnxDataType, shape, copyOld), // jsepUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index c7ce3a17321ca..1a6e12c670b27 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -29,7 +29,7 @@ export interface TensorManager { ensureTensor( tensorId: TensorId, dataType: MLOperandDataType, - dimensions: readonly number[], + shape: readonly number[], copyOld: boolean, ): Promise; /** @@ -48,7 +48,7 @@ export interface TensorManager { /** * Register an externally created MLTensor with a given MLContext and return a TensorId. */ - registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, dimensions: number[]): TensorId; + registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; } let tensorGuid = 1; @@ -60,8 +60,8 @@ export type MLTensorEntry = [MLTensor, MLOperandDataType, readonly number[]]; * TensorTracker tracks the MLTensor and pending upload data. * * We need to track the MLTensor and pending upload data because we delay the creation of MLTensor until - * we know the data type and dimensions. This is because future implementations of WebNN will only support creating - * MLTensors with dataTypes and dimensions. + * we know the data type and shape. This is because future implementations of WebNN will only support creating + * MLTensors with dataTypes and shape. */ class TensorTracker { private tensorEntry?: MLTensorEntry; @@ -103,12 +103,12 @@ class TensorTracker { } public trySelectTensor(context: MLContext, tryMLTensor: MLTensor): boolean { - for (const [mlTensor, dataType, dimensions] of this.tensorCache) { + for (const [mlTensor, dataType, shape] of this.tensorCache) { if (tryMLTensor === mlTensor) { if (this.context !== context) { throw new Error('MLTensor cannot be registered with a different MLContext.'); } - this.tensorEntry = [mlTensor, dataType, dimensions]; + this.tensorEntry = [mlTensor, dataType, shape]; return true; } } @@ -117,18 +117,18 @@ class TensorTracker { public async ensureTensor( dataType: MLOperandDataType, - dimensions: readonly number[], + shape: readonly number[], copyOld: boolean, ): Promise { if (this.tensorEntry) { - const [mlTensor, existingDataType, existingDimensions] = this.tensorEntry; - if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + const [mlTensor, existingDataType, existingShape] = this.tensorEntry; + if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { return mlTensor; } } - for (const [mlTensor, existingDataType, existingDimensions] of this.tensorCache) { - if (existingDataType === dataType && existingDimensions.every((v, i) => v === dimensions[i])) { + for (const [mlTensor, existingDataType, existingShape] of this.tensorCache) { + if (existingDataType === dataType && existingShape.every((v, i) => v === shape[i])) { if (copyOld && this.tensorEntry) { // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. LOG_DEBUG( @@ -136,20 +136,20 @@ class TensorTracker { () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${ dataType - }, dimensions: ${dimensions}}`, + }, shape: ${shape}}`, ); const data = await this.context.readTensor(this.tensorEntry[0]); this.context.writeTensor(mlTensor, data); } - this.tensorEntry = [mlTensor, existingDataType, existingDimensions]; + this.tensorEntry = [mlTensor, existingDataType, existingShape]; return mlTensor; } } - LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, dimensions: ${dimensions}}`); + LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); // eslint-disable-next-line no-bitwise const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; - const tensor = await this.context.createTensor({ dataType, dimensions, usage }); - this.tensorEntry = [tensor, dataType, dimensions]; + const tensor = await this.context.createTensor({ dataType, shape, usage }); + this.tensorEntry = [tensor, dataType, shape]; this.tensorCache.push(this.tensorEntry); if (this.activeUpload) { @@ -225,7 +225,7 @@ class TensorManagerImpl implements TensorManager { public async ensureTensor( tensorId: TensorId, dataType: MLOperandDataType, - dimensions: number[], + shape: number[], copyOld: boolean, ): Promise { LOG_DEBUG( @@ -233,7 +233,7 @@ class TensorManagerImpl implements TensorManager { () => `[WebNN] TensorManager.ensureTensor {tensorId: ${tensorId}, dataType: ${ dataType - }, dimensions: ${dimensions}, copyOld: ${copyOld}}`, + }, shape: ${shape}, copyOld: ${copyOld}}`, ); const tensor = this.tensorsById.get(tensorId); if (!tensor) { @@ -244,7 +244,7 @@ class TensorManagerImpl implements TensorManager { this.tensorIdsByContext.set(this.backend.currentContext, new Set()); } this.tensorIdsByContext.get(this.backend.currentContext)?.add(tensorId); - return tensor.ensureTensor(dataType, dimensions, copyOld); + return tensor.ensureTensor(dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { @@ -277,7 +277,7 @@ class TensorManagerImpl implements TensorManager { mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, - dimensions: readonly number[], + shape: readonly number[], ): TensorId { for (const [tensorId, tensorTracker] of this.tensorsById) { if (tensorTracker.trySelectTensor(mlContext, mlTensor)) { @@ -285,7 +285,7 @@ class TensorManagerImpl implements TensorManager { } } const tensorId = createNewTensorId(); - this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, dimensions])); + this.tensorsById.set(tensorId, new TensorTracker(mlContext, [mlTensor, dataType, shape])); let tensors = this.tensorIdsByContext.get(mlContext); if (!tensors) { tensors = new Set(); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index de03ac814f107..7075a88abedcf 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -32,7 +32,7 @@ type MLInputOperandLayout = 'nchw'|'nhwc'; type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|'int8'|'uint8'; interface MLOperandDescriptor { dataType: MLOperandDataType; - dimensions?: readonly number[]; + shape?: readonly number[]; } interface MLOperand { dataType(): MLOperandDataType; @@ -405,7 +405,7 @@ interface MLContext { createTensor(descriptor: MLTensorDescriptor): Promise; writeTensor( destinationTensor: MLTensor, sourceData: ArrayBufferView|ArrayBuffer, sourceElementOffset?: number, - srcElementSize?: number): void; + sourceElementSize?: number): void; readTensor(sourceTensor: MLTensor): Promise; readTensor(sourceTensor: MLTensor, destinationData: ArrayBufferView|ArrayBuffer): Promise; dispatch(graph: MLGraph, inputs: MLNamedTensor, outputs: MLNamedTensor): void; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 977a7d683b9a0..132738b803875 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -33,7 +33,7 @@ export declare namespace JSEP { type EnsureTensorFunction = ( tensorId: number, dataType: DataType, - dimensions: readonly number[], + shape: readonly number[], copyOld: boolean, ) => Promise; type UploadTensorFunction = (tensorId: number, data: Uint8Array) => void; @@ -183,21 +183,20 @@ export declare namespace JSEP { * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. * @param tensorId - specify the MLTensor ID. * @param onnxDataType - specify the data type. - * @param dimensions - specify the dimensions. + * @param shape - specify the dimensions (WebNN shape) of the tensor. * @param copyOld - specify whether to copy the old tensor if a new tensor was created. * @returns the MLTensor associated with the tensor ID. */ jsepEnsureTensor: ( tensorId: number, dataType: DataType, - dimensions: number[], + shape: number[], copyOld: boolean, ) => Promise; /** * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. * @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType. - * @param dimensions - specify the dimensions. * @returns */ jsepUploadTensor: (tensorId: number, data: Uint8Array) => void; diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index 3d7eba62cd226..d7c4bb90e239c 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -663,7 +663,7 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty const mlTensor = await mlContext.createTensor({ dataType, - dimensions: dims as number[], + shape: dims as number[], usage: MLTensorUsage.READ, }); @@ -685,7 +685,7 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso const dataType = cpuTensor.type === 'bool' ? 'uint8' : cpuTensor.type; const mlTensor = await mlContext.createTensor({ dataType, - dimensions: cpuTensor.dims as number[], + shape: cpuTensor.dims as number[], usage: MLTensorUsage.WRITE, }); mlContext.writeTensor(mlTensor, cpuTensor.data); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 3d3870801e212..68332d07a9782 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -232,8 +232,8 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { return backend['createMLTensorDownloader'](tensorId, type); } - Module['jsepRegisterMLTensor'] = (tensor, dataType, dimensions) => { - return backend['registerMLTensor'](tensor, dataType, dimensions); + Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { + return backend['registerMLTensor'](tensor, dataType, shape); } } }; From 5b3ceec86e382ccb671a820c5e6bd3c29ae6d338 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 17 Sep 2024 22:58:03 -0700 Subject: [PATCH 16/17] Format --- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 5 +---- js/web/lib/wasm/wasm-types.ts | 7 +------ 2 files changed, 2 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 1a6e12c670b27..e1c089e2d0fd3 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -133,10 +133,7 @@ class TensorTracker { // WebNN does not support copyTensorToTensor, so we need to read and write the tensors. LOG_DEBUG( 'verbose', - () => - `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${ - dataType - }, shape: ${shape}}`, + () => `[WebNN] Slowdown may occur, having to copy existing tensor {dataType: ${dataType}, shape: ${shape}}`, ); const data = await this.context.readTensor(this.tensorEntry[0]); this.context.writeTensor(mlTensor, data); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 132738b803875..3e08fe97f559d 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -187,12 +187,7 @@ export declare namespace JSEP { * @param copyOld - specify whether to copy the old tensor if a new tensor was created. * @returns the MLTensor associated with the tensor ID. */ - jsepEnsureTensor: ( - tensorId: number, - dataType: DataType, - shape: number[], - copyOld: boolean, - ) => Promise; + jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise; /** * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. From 7bc892df7605a5b590d80672d6582754f9b620e6 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Tue, 17 Sep 2024 23:38:51 -0700 Subject: [PATCH 17/17] Assign to both shape and dimensions when creating MLTensors --- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 8 +++++++- js/web/lib/wasm/jsep/webnn/webnn.d.ts | 2 ++ js/web/test/test-runner.ts | 4 ++++ 3 files changed, 13 insertions(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index e1c089e2d0fd3..9475de019ed1d 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -145,7 +145,13 @@ class TensorTracker { LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); // eslint-disable-next-line no-bitwise const usage = MLTensorUsage.READ | MLTensorUsage.WRITE; - const tensor = await this.context.createTensor({ dataType, shape, usage }); + const tensor = await this.context.createTensor({ + dataType, + shape, + // Assign both shape and dimensions while transitioning to new API. + dimensions: shape, + usage, + }); this.tensorEntry = [tensor, dataType, shape]; this.tensorCache.push(this.tensorEntry); diff --git a/js/web/lib/wasm/jsep/webnn/webnn.d.ts b/js/web/lib/wasm/jsep/webnn/webnn.d.ts index 7075a88abedcf..5cb0f4e74c3df 100644 --- a/js/web/lib/wasm/jsep/webnn/webnn.d.ts +++ b/js/web/lib/wasm/jsep/webnn/webnn.d.ts @@ -33,6 +33,8 @@ type MLOperandDataType = 'float32'|'float16'|'int32'|'uint32'|'int64'|'uint64'|' interface MLOperandDescriptor { dataType: MLOperandDataType; shape?: readonly number[]; + /** @deprecated Use shape instead of dimensions */ + dimensions?: readonly number[]; } interface MLOperand { dataType(): MLOperandDataType; diff --git a/js/web/test/test-runner.ts b/js/web/test/test-runner.ts index d7c4bb90e239c..2176a776a0192 100644 --- a/js/web/test/test-runner.ts +++ b/js/web/test/test-runner.ts @@ -664,6 +664,8 @@ async function createMLTensorForOutput(mlContext: MLContext, type: ort.Tensor.Ty const mlTensor = await mlContext.createTensor({ dataType, shape: dims as number[], + // Assign both shape and dimensions while transitioning to new API. + dimensions: dims as number[], usage: MLTensorUsage.READ, }); @@ -686,6 +688,8 @@ async function createMLTensorForInput(mlContext: MLContext, cpuTensor: ort.Tenso const mlTensor = await mlContext.createTensor({ dataType, shape: cpuTensor.dims as number[], + // Assign both shape and dimensions while transitioning to new API. + dimensions: cpuTensor.dims as number[], usage: MLTensorUsage.WRITE, }); mlContext.writeTensor(mlTensor, cpuTensor.data);