From 18224f35af7225e5d75a4df76fa3c661adde5b68 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Wed, 4 Dec 2024 16:23:57 -0800 Subject: [PATCH 1/4] [WebNN EP] Automatically move input CPU tensors to ml-tensor ### Description If it would improve performance, this patch moves the CPU to ml-tensor before sending the to the ONNXRuntime WebNN EP. ### Motivation and Context We are currently performing 2 extra copies on input tensors located in the CPU when using the WebNN EP (JS -(copy)-> wasm heap -(copy)-> JS -> WebNN API). This patch removes these extra copies. --- js/web/lib/wasm/jsep/backend-webnn.ts | 62 +++++++++++++++++++ js/web/lib/wasm/wasm-core-impl.ts | 45 +++++++++++--- js/web/lib/wasm/wasm-types.ts | 25 ++++++++ .../providers/webnn/builders/model_builder.cc | 1 + onnxruntime/wasm/pre-jsep.js | 6 ++ 5 files changed, 130 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index b302354c46eeb..f9365251d1905 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -75,6 +75,19 @@ export class WebNNBackend { * Current session id. */ private activeSessionId?: number; + /** + * Maps from session id to list of graph inputs. + */ + private sessionGraphInputs: Map = new Map(); + /** + * Temporary graph inputs for the current session. + * These inputs will be registered when the session is created. + */ + private temporaryGraphInputs: string[] = []; + /** + * Temporary tensors for the current session. + */ + private temporarySessionTensors: Map = new Map(); constructor(env: Env) { configureLogger(env.logLevel!, !!env.debug); @@ -91,6 +104,19 @@ export class WebNNBackend { this.activeSessionId = sessionId; } + public onRunEnd(sessionId: number): void { + LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`); + const tensors = this.temporarySessionTensors.get(sessionId); + if (!tensors) { + return; + } + for (const tensor of tensors) { + LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensor}}`); + this.tensorManager.releaseTensorId(tensor); + } + this.temporarySessionTensors.delete(sessionId); + } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { if (optionsOrDevice instanceof GPUDevice) { const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); @@ -142,9 +168,15 @@ export class WebNNBackend { this.sessionIdsByMLContext.set(mlContext, sessionIds); } sessionIds.add(sessionId); + + if (this.temporaryGraphInputs.length > 0) { + this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs); + this.temporaryGraphInputs = []; + } } public onReleaseSession(sessionId: number): void { + this.sessionGraphInputs.delete(sessionId); const mlContext = this.mlContextBySessionId.get(sessionId)!; if (!mlContext) { // Current session is not a WebNN session. @@ -189,6 +221,23 @@ export class WebNNBackend { return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); } + public async createTemporaryTensor(onnxDataType: DataType, shape: readonly number[]): Promise { + LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`); + const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType); + if (!dataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + const tensorId = this.tensorManager.reserveTensorId(); + await this.tensorManager.ensureTensor(tensorId, dataType, shape, false); + const tensors = this.temporarySessionTensors.get(this.currentSessionId); + if (!tensors) { + this.temporarySessionTensors.set(this.currentSessionId, [tensorId]); + } else { + tensors.push(tensorId); + } + return tensorId; + } + public uploadTensor(tensorId: TensorId, data: Uint8Array): void { const wasm = getInstance(); if (!wasm.shouldTransferToMLTensor) { @@ -291,6 +340,19 @@ export class WebNNBackend { return builder.constant(desc, bufferView); } + public registerGraphInput(inputName: string): void { + this.temporaryGraphInputs.push(inputName); + } + + public isGraphInput(inputName: string): boolean { + const sessionId = this.currentSessionId; + const inputNames = this.sessionGraphInputs.get(sessionId); + if (!inputNames) { + return false; + } + return inputNames.includes(inputName); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index da8939cd0263a..ff259b707aa65 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -453,14 +453,14 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -export const prepareInputOutputTensor = ( +export const prepareInputOutputTensor = async ( tensor: TensorMetadata | null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, enableGraphCapture = false, -): void => { +): Promise => { if (!tensor) { tensorHandles.push(0); return; @@ -472,6 +472,7 @@ export const prepareInputOutputTensor = ( const dataType = tensor[0]; const dims = tensor[1]; const location = tensor[3]; + let actualLocation = location; let rawData: number; let dataByteLength: number; @@ -519,10 +520,35 @@ export const prepareInputOutputTensor = ( wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + const isGraphInput = wasm.jsepIsGraphInput; + if (dataType !== 'string' && isGraphInput) { + const tensorNameUTF8 = wasm._OrtGetInputName(sessionId, index); + const tensorName = wasm.UTF8ToString(tensorNameUTF8); + // Promote the tensor to 'ml-tensor' if it is a graph input. + if (isGraphInput(tensorName)) { + const dataTypeEnum = tensorDataTypeStringToEnum(dataType); + dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!; + actualLocation = 'ml-tensor'; + const createTemporaryTensor = wasm.jsepCreateTemporaryTensor; + const uploadTensor = wasm.jsepUploadTensor; + if (!createTemporaryTensor || !uploadTensor) { + throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); + } + const tensorId = await createTemporaryTensor(dataTypeEnum, dims as number[]); + uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); + rawData = tensorId; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } } } @@ -536,7 +562,7 @@ export const prepareInputOutputTensor = ( dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location), + dataLocationStringToEnum(actualLocation), ); if (tensor === 0) { checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); @@ -595,7 +621,7 @@ export const run = async ( // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor( + await prepareInputOutputTensor( inputTensors[i], inputTensorHandles, inputOutputAllocs, @@ -607,7 +633,7 @@ export const run = async ( // create output tensors for (let i = 0; i < outputCount; i++) { - prepareInputOutputTensor( + await prepareInputOutputTensor( outputTensors[i], outputTensorHandles, inputOutputAllocs, @@ -841,6 +867,7 @@ export const run = async ( if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } + wasm.jsepOnRunEnd?.(sessionHandle); } } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index ebeac5dc9e587..42af5cbd3c91d 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -141,6 +141,12 @@ export declare namespace JSEP { * @param sessionId - specify the session ID. */ jsepOnRunStart: (sessionId: number) => void; + /** + * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunEnd: (sessionId: number) => void; /** * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is * called. @@ -249,6 +255,25 @@ export declare namespace JSEP { builder: MLGraphBuilder, desc: MLOperandDescriptor, ): MLOperand; + + /** + * [exported from pre-jsep.js] Register a WebNN graph input. + * @param inputName - specify the input name. + */ + jsepRegisterGraphInput(inputName: string): void; + /** + * [exported from pre-jsep.js] Check if a graph input is a WebNN graph input. + * @param inputName - specify the input name. + * @returns whether the input is a WebNN graph input. + */ + jsepIsGraphInput(inputName: string): boolean; + /** + * [exported from pre-jsep.js] Create a temporary MLTensor for a session. + * @param dataType - specify the data type. + * @param shape - specify the shape. + * @returns the MLTensor ID for the temporary MLTensor. + */ + jsepCreateTemporaryTensor: (dataType: DataType, shape: readonly number[]) => Promise; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index e8f116d390199..4b7cab684ae81 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -252,6 +252,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (is_input) { wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); + emscripten::val::module_property("jsepRegisterGraphInput")(name); input_names_.push_back(name); } else { output_names_.push_back(name); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 45e2475548df5..b3270fdb02f2a 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -220,12 +220,14 @@ Module['jsepInit'] = (name, params) => { // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; + Module['jsepUploadTensor'] = Module.jsepUploadTensor; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; Module['jsepOnRunStart'] = sessionId => { return backend['onRunStart'](sessionId); }; + Module['jsepOnRunEnd'] = backend['onRunEnd'].bind(backend); Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { backend['registerMLContext'](sessionId, mlContext); }; @@ -245,5 +247,9 @@ Module['jsepInit'] = (name, params) => { return backend['registerMLConstant']( externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); }; + Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); + Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); + + Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); } }; From be01b60c37096d301d37862487ac9101e9f5cd26 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Wed, 11 Dec 2024 11:16:46 -0800 Subject: [PATCH 2/4] PR feedback --- js/web/lib/wasm/jsep/backend-webnn.ts | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index f9365251d1905..f3e961f80cdb5 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -87,7 +87,7 @@ export class WebNNBackend { /** * Temporary tensors for the current session. */ - private temporarySessionTensors: Map = new Map(); + private temporarySessionTensorIds: Map = new Map(); constructor(env: Env) { configureLogger(env.logLevel!, !!env.debug); @@ -106,15 +106,15 @@ export class WebNNBackend { public onRunEnd(sessionId: number): void { LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`); - const tensors = this.temporarySessionTensors.get(sessionId); - if (!tensors) { + const tensorIds = this.temporarySessionTensorIds.get(sessionId); + if (!tensorIds) { return; } - for (const tensor of tensors) { - LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensor}}`); - this.tensorManager.releaseTensorId(tensor); + for (const tensorId of tensorIds) { + LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`); + this.tensorManager.releaseTensorId(tensorId); } - this.temporarySessionTensors.delete(sessionId); + this.temporarySessionTensorIds.delete(sessionId); } public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { @@ -229,9 +229,9 @@ export class WebNNBackend { } const tensorId = this.tensorManager.reserveTensorId(); await this.tensorManager.ensureTensor(tensorId, dataType, shape, false); - const tensors = this.temporarySessionTensors.get(this.currentSessionId); + const tensors = this.temporarySessionTensorIds.get(this.currentSessionId); if (!tensors) { - this.temporarySessionTensors.set(this.currentSessionId, [tensorId]); + this.temporarySessionTensorIds.set(this.currentSessionId, [tensorId]); } else { tensors.push(tensorId); } From 5e3295fdef70e4613b6a9c58b473343ef98f7da8 Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Wed, 11 Dec 2024 18:27:01 -0800 Subject: [PATCH 3/4] More renames from tensor(s) to tensorId(s) --- js/web/lib/wasm/jsep/backend-webnn.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index f3e961f80cdb5..23e722f1cc7d7 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -229,11 +229,11 @@ export class WebNNBackend { } const tensorId = this.tensorManager.reserveTensorId(); await this.tensorManager.ensureTensor(tensorId, dataType, shape, false); - const tensors = this.temporarySessionTensorIds.get(this.currentSessionId); - if (!tensors) { + const tensorIds = this.temporarySessionTensorIds.get(this.currentSessionId); + if (!tensorIds) { this.temporarySessionTensorIds.set(this.currentSessionId, [tensorId]); } else { - tensors.push(tensorId); + tensorIds.push(tensorId); } return tensorId; } From d66258fd5d0b047959358f39554bb34ab835d19e Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Fri, 13 Dec 2024 21:46:58 -0800 Subject: [PATCH 4/4] Pass sessionHandle/Id directly to function instead of using activeSession --- js/web/lib/wasm/jsep/backend-webnn.ts | 38 ++++++++++-------- js/web/lib/wasm/jsep/init.ts | 4 +- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 30 +++++++++----- js/web/lib/wasm/wasm-core-impl.ts | 13 +++---- js/web/lib/wasm/wasm-types.ts | 39 +++++++++++++------ .../core/providers/webnn/builders/model.cc | 4 +- 6 files changed, 80 insertions(+), 48 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 23e722f1cc7d7..2b9a9208e2e53 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -101,6 +101,7 @@ export class WebNNBackend { } public onRunStart(sessionId: number): void { + LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`); this.activeSessionId = sessionId; } @@ -115,6 +116,7 @@ export class WebNNBackend { this.tensorManager.releaseTensorId(tensorId); } this.temporarySessionTensorIds.delete(sessionId); + this.activeSessionId = undefined; } public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { @@ -152,14 +154,6 @@ export class WebNNBackend { } } - public get currentContext(): MLContext { - const mlContext = this.getMLContext(this.currentSessionId); - if (!mlContext) { - throw new Error(`No MLContext found for session ${this.currentSessionId}`); - } - return mlContext; - } - public registerMLContext(sessionId: number, mlContext: MLContext): void { this.mlContextBySessionId.set(sessionId, mlContext); let sessionIds = this.sessionIdsByMLContext.get(mlContext); @@ -209,6 +203,7 @@ export class WebNNBackend { } public async ensureTensor( + sessionId: number | undefined, tensorId: TensorId, onnxDataType: DataType, dimensions: number[], @@ -218,20 +213,30 @@ export class WebNNBackend { if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); + return this.tensorManager.ensureTensor( + sessionId ?? this.currentSessionId, + tensorId, + webnnDataType, + dimensions, + copyOld, + ); } - public async createTemporaryTensor(onnxDataType: DataType, shape: readonly number[]): Promise { + public async createTemporaryTensor( + sessionId: number, + onnxDataType: DataType, + shape: readonly number[], + ): Promise { LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`); const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!dataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } const tensorId = this.tensorManager.reserveTensorId(); - await this.tensorManager.ensureTensor(tensorId, dataType, shape, false); - const tensorIds = this.temporarySessionTensorIds.get(this.currentSessionId); + await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false); + const tensorIds = this.temporarySessionTensorIds.get(sessionId); if (!tensorIds) { - this.temporarySessionTensorIds.set(this.currentSessionId, [tensorId]); + this.temporarySessionTensorIds.set(sessionId, [tensorId]); } else { tensorIds.push(tensorId); } @@ -258,13 +263,13 @@ export class WebNNBackend { }; } - public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { + public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions); + const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions); LOG_DEBUG( 'verbose', () => @@ -344,8 +349,7 @@ export class WebNNBackend { this.temporaryGraphInputs.push(inputName); } - public isGraphInput(inputName: string): boolean { - const sessionId = this.currentSessionId; + public isGraphInput(sessionId: number, inputName: string): boolean { const inputNames = this.sessionGraphInputs.get(sessionId); if (!inputNames) { return false; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 48bd3ef2bc36f..b4071eae51c8f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -287,8 +287,8 @@ export const init = async ( // jsepReleaseTensorId, (tensorId: number) => backend.releaseTensorId(tensorId), // jsepEnsureTensor - async (tensorId: number, onnxDataType: number, shape: number[], copyOld) => - backend.ensureTensor(tensorId, onnxDataType, shape, copyOld), + async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), // jsepUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 4932691bda65b..3bf8a5c334b58 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -27,6 +27,7 @@ export interface TensorManager { * Ensure a MLTensor is created for the TensorId. */ ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: readonly number[], @@ -46,9 +47,9 @@ export interface TensorManager { */ releaseTensorsForSession(session: number): void; /** - * Register an externally created MLTensor with a given MLContext and return a TensorId. + * Register an externally created MLTensor with a given session id and return a TensorId. */ - registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; + registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; } let tensorGuid = 1; @@ -176,6 +177,7 @@ class TensorIdTracker { } public async ensureTensor( + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, @@ -196,7 +198,7 @@ class TensorIdTracker { // eslint-disable-next-line no-bitwise const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; - this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); + this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true); if (copyOld && this.activeUpload) { this.wrapper.write(this.activeUpload); @@ -254,6 +256,14 @@ class TensorManagerImpl implements TensorManager { constructor(private backend: WebNNBackend) {} + public getMLContext(sessionId: number): MLContext { + const context = this.backend.getMLContext(sessionId); + if (!context) { + throw new Error('MLContext not found for session.'); + } + return context; + } + public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); @@ -272,6 +282,7 @@ class TensorManagerImpl implements TensorManager { } public async ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: number[], @@ -288,7 +299,7 @@ class TensorManagerImpl implements TensorManager { if (!tensor) { throw new Error('Tensor not found.'); } - return tensor.ensureTensor(dataType, shape, copyOld); + return tensor.ensureTensor(sessionId, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { @@ -323,17 +334,18 @@ class TensorManagerImpl implements TensorManager { } public registerTensor( - mlContext: MLContext, + sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: readonly number[], ): TensorId { + const context = this.getMLContext(sessionId); const tensorId = createNewTensorId(); // Defaulting to READ | WRITE if usage is not provided. // eslint-disable-next-line no-bitwise const wrapper = new TensorWrapper({ - sessionId: this.backend.currentSessionId, - context: mlContext, + sessionId, + context, tensor: mlTensor, dataType, shape, @@ -347,13 +359,13 @@ class TensorManagerImpl implements TensorManager { * Get or create an MLTensor with the given data type and shape. */ public async getCachedTensor( + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, ): Promise { - const sessionId = this.backend.currentSessionId; for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.sameTypeAndShape(dataType, shape)) { LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); @@ -362,7 +374,7 @@ class TensorManagerImpl implements TensorManager { return wrapper; } } - const context = this.backend.currentContext; + const context = this.getMLContext(sessionId); LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); const tensor = await context.createTensor({ dataType, diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index ff259b707aa65..4bccfa76fdda3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -504,7 +504,7 @@ export const prepareInputOutputTensor = async ( if (!registerMLTensor) { throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims); + rawData = registerMLTensor(sessionId, mlTensor, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -525,7 +525,7 @@ export const prepareInputOutputTensor = async ( const tensorNameUTF8 = wasm._OrtGetInputName(sessionId, index); const tensorName = wasm.UTF8ToString(tensorNameUTF8); // Promote the tensor to 'ml-tensor' if it is a graph input. - if (isGraphInput(tensorName)) { + if (isGraphInput(sessionId, tensorName)) { const dataTypeEnum = tensorDataTypeStringToEnum(dataType); dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!; actualLocation = 'ml-tensor'; @@ -534,7 +534,7 @@ export const prepareInputOutputTensor = async ( if (!createTemporaryTensor || !uploadTensor) { throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - const tensorId = await createTemporaryTensor(dataTypeEnum, dims as number[]); + const tensorId = await createTemporaryTensor(sessionId, dataTypeEnum, dims as number[]); uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); rawData = tensorId; } else { @@ -614,9 +614,6 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { - // WebNN backend needs the active session to check MLTensors with the current context. - wasm.jsepOnRunStart?.(sessionHandle); - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -704,6 +701,8 @@ export const run = async ( ]); } + wasm.jsepOnRunStart?.(sessionHandle); + let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -832,7 +831,7 @@ export const run = async ( // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor // has been created. - const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); + const mlTensor = await ensureTensor(sessionId, dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 42af5cbd3c91d..b4871e145f4d7 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -31,6 +31,7 @@ export declare namespace JSEP { type ReserveTensorIdFunction = () => number; type ReleaseTensorIdFunction = (tensorId: number) => void; type EnsureTensorFunction = ( + sessionId: number | undefined, tensorId: number, dataType: DataType, shape: readonly number[], @@ -141,12 +142,6 @@ export declare namespace JSEP { * @param sessionId - specify the session ID. */ jsepOnRunStart: (sessionId: number) => void; - /** - * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after - * _OrtRun[WithBinding]() is called. - * @param sessionId - specify the session ID. - */ - jsepOnRunEnd: (sessionId: number) => void; /** * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is * called. @@ -173,6 +168,13 @@ export declare namespace JSEP { */ shouldTransferToMLTensor: boolean; + /** + * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunEnd: (sessionId: number) => void; + /** * [exported from pre-jsep.js] Register MLContext for a session. * @param sessionId - specify the session ID. @@ -193,13 +195,20 @@ export declare namespace JSEP { jsepReleaseTensorId: (tensorId: number) => void; /** * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. + * @param sessionId - specify the session ID or current active session ID if undefined. * @param tensorId - specify the MLTensor ID. * @param onnxDataType - specify the data type. * @param shape - specify the dimensions (WebNN shape) of the tensor. * @param copyOld - specify whether to copy the old tensor if a new tensor was created. * @returns the MLTensor associated with the tensor ID. */ - jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise; + jsepEnsureTensor: ( + sessionId: number | undefined, + tensorId: number, + dataType: DataType, + shape: number[], + copyOld: boolean, + ) => Promise; /** * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. @@ -225,12 +234,18 @@ export declare namespace JSEP { ) => () => Promise; /** * [exported from pre-jsep.js] Registers an external MLTensor to a session. + * @param sessionId - specify the session ID. * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. * @returns the MLTensor ID for the external MLTensor. */ - jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + jsepRegisterMLTensor: ( + sessionId: number, + tensor: MLTensor, + onnxDataType: DataType, + dimensions: readonly number[], + ) => number; /** * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. @@ -260,20 +275,22 @@ export declare namespace JSEP { * [exported from pre-jsep.js] Register a WebNN graph input. * @param inputName - specify the input name. */ - jsepRegisterGraphInput(inputName: string): void; + jsepRegisterGraphInput: (inputName: string) => void; /** * [exported from pre-jsep.js] Check if a graph input is a WebNN graph input. + * @param sessionId - specify the session ID. * @param inputName - specify the input name. * @returns whether the input is a WebNN graph input. */ - jsepIsGraphInput(inputName: string): boolean; + jsepIsGraphInput: (sessionId: number, inputName: string) => boolean; /** * [exported from pre-jsep.js] Create a temporary MLTensor for a session. + * @param sessionId - specify the session ID. * @param dataType - specify the data type. * @param shape - specify the shape. * @returns the MLTensor ID for the temporary MLTensor. */ - jsepCreateTemporaryTensor: (dataType: DataType, shape: readonly number[]) => Promise; + jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise; } } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 231b65a4d1894..35964d85862e4 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -165,7 +165,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); promises.call("push", ml_tensor); } for (const auto& [_, tensor] : outputs) { @@ -174,7 +174,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", ml_tensor); } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await();