From d66258fd5d0b047959358f39554bb34ab835d19e Mon Sep 17 00:00:00 2001 From: Enrico Galli Date: Fri, 13 Dec 2024 21:46:58 -0800 Subject: [PATCH] Pass sessionHandle/Id directly to function instead of using activeSession --- js/web/lib/wasm/jsep/backend-webnn.ts | 38 ++++++++++-------- js/web/lib/wasm/jsep/init.ts | 4 +- js/web/lib/wasm/jsep/webnn/tensor-manager.ts | 30 +++++++++----- js/web/lib/wasm/wasm-core-impl.ts | 13 +++---- js/web/lib/wasm/wasm-types.ts | 39 +++++++++++++------ .../core/providers/webnn/builders/model.cc | 4 +- 6 files changed, 80 insertions(+), 48 deletions(-) diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index 23e722f1cc7d7..2b9a9208e2e53 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -101,6 +101,7 @@ export class WebNNBackend { } public onRunStart(sessionId: number): void { + LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`); this.activeSessionId = sessionId; } @@ -115,6 +116,7 @@ export class WebNNBackend { this.tensorManager.releaseTensorId(tensorId); } this.temporarySessionTensorIds.delete(sessionId); + this.activeSessionId = undefined; } public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { @@ -152,14 +154,6 @@ export class WebNNBackend { } } - public get currentContext(): MLContext { - const mlContext = this.getMLContext(this.currentSessionId); - if (!mlContext) { - throw new Error(`No MLContext found for session ${this.currentSessionId}`); - } - return mlContext; - } - public registerMLContext(sessionId: number, mlContext: MLContext): void { this.mlContextBySessionId.set(sessionId, mlContext); let sessionIds = this.sessionIdsByMLContext.get(mlContext); @@ -209,6 +203,7 @@ export class WebNNBackend { } public async ensureTensor( + sessionId: number | undefined, tensorId: TensorId, onnxDataType: DataType, dimensions: number[], @@ -218,20 +213,30 @@ export class WebNNBackend { if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); + return this.tensorManager.ensureTensor( + sessionId ?? this.currentSessionId, + tensorId, + webnnDataType, + dimensions, + copyOld, + ); } - public async createTemporaryTensor(onnxDataType: DataType, shape: readonly number[]): Promise { + public async createTemporaryTensor( + sessionId: number, + onnxDataType: DataType, + shape: readonly number[], + ): Promise { LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`); const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!dataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } const tensorId = this.tensorManager.reserveTensorId(); - await this.tensorManager.ensureTensor(tensorId, dataType, shape, false); - const tensorIds = this.temporarySessionTensorIds.get(this.currentSessionId); + await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false); + const tensorIds = this.temporarySessionTensorIds.get(sessionId); if (!tensorIds) { - this.temporarySessionTensorIds.set(this.currentSessionId, [tensorId]); + this.temporarySessionTensorIds.set(sessionId, [tensorId]); } else { tensorIds.push(tensorId); } @@ -258,13 +263,13 @@ export class WebNNBackend { }; } - public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { + public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions); + const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions); LOG_DEBUG( 'verbose', () => @@ -344,8 +349,7 @@ export class WebNNBackend { this.temporaryGraphInputs.push(inputName); } - public isGraphInput(inputName: string): boolean { - const sessionId = this.currentSessionId; + public isGraphInput(sessionId: number, inputName: string): boolean { const inputNames = this.sessionGraphInputs.get(sessionId); if (!inputNames) { return false; diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 48bd3ef2bc36f..b4071eae51c8f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -287,8 +287,8 @@ export const init = async ( // jsepReleaseTensorId, (tensorId: number) => backend.releaseTensorId(tensorId), // jsepEnsureTensor - async (tensorId: number, onnxDataType: number, shape: number[], copyOld) => - backend.ensureTensor(tensorId, onnxDataType, shape, copyOld), + async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), // jsepUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 4932691bda65b..3bf8a5c334b58 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -27,6 +27,7 @@ export interface TensorManager { * Ensure a MLTensor is created for the TensorId. */ ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: readonly number[], @@ -46,9 +47,9 @@ export interface TensorManager { */ releaseTensorsForSession(session: number): void; /** - * Register an externally created MLTensor with a given MLContext and return a TensorId. + * Register an externally created MLTensor with a given session id and return a TensorId. */ - registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; + registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; } let tensorGuid = 1; @@ -176,6 +177,7 @@ class TensorIdTracker { } public async ensureTensor( + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, @@ -196,7 +198,7 @@ class TensorIdTracker { // eslint-disable-next-line no-bitwise const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; - this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); + this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true); if (copyOld && this.activeUpload) { this.wrapper.write(this.activeUpload); @@ -254,6 +256,14 @@ class TensorManagerImpl implements TensorManager { constructor(private backend: WebNNBackend) {} + public getMLContext(sessionId: number): MLContext { + const context = this.backend.getMLContext(sessionId); + if (!context) { + throw new Error('MLContext not found for session.'); + } + return context; + } + public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); @@ -272,6 +282,7 @@ class TensorManagerImpl implements TensorManager { } public async ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: number[], @@ -288,7 +299,7 @@ class TensorManagerImpl implements TensorManager { if (!tensor) { throw new Error('Tensor not found.'); } - return tensor.ensureTensor(dataType, shape, copyOld); + return tensor.ensureTensor(sessionId, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { @@ -323,17 +334,18 @@ class TensorManagerImpl implements TensorManager { } public registerTensor( - mlContext: MLContext, + sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: readonly number[], ): TensorId { + const context = this.getMLContext(sessionId); const tensorId = createNewTensorId(); // Defaulting to READ | WRITE if usage is not provided. // eslint-disable-next-line no-bitwise const wrapper = new TensorWrapper({ - sessionId: this.backend.currentSessionId, - context: mlContext, + sessionId, + context, tensor: mlTensor, dataType, shape, @@ -347,13 +359,13 @@ class TensorManagerImpl implements TensorManager { * Get or create an MLTensor with the given data type and shape. */ public async getCachedTensor( + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, ): Promise { - const sessionId = this.backend.currentSessionId; for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.sameTypeAndShape(dataType, shape)) { LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); @@ -362,7 +374,7 @@ class TensorManagerImpl implements TensorManager { return wrapper; } } - const context = this.backend.currentContext; + const context = this.getMLContext(sessionId); LOG_DEBUG('verbose', () => `[WebNN] MLContext.createTensor {dataType: ${dataType}, shape: ${shape}}`); const tensor = await context.createTensor({ dataType, diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index ff259b707aa65..4bccfa76fdda3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -504,7 +504,7 @@ export const prepareInputOutputTensor = async ( if (!registerMLTensor) { throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims); + rawData = registerMLTensor(sessionId, mlTensor, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -525,7 +525,7 @@ export const prepareInputOutputTensor = async ( const tensorNameUTF8 = wasm._OrtGetInputName(sessionId, index); const tensorName = wasm.UTF8ToString(tensorNameUTF8); // Promote the tensor to 'ml-tensor' if it is a graph input. - if (isGraphInput(tensorName)) { + if (isGraphInput(sessionId, tensorName)) { const dataTypeEnum = tensorDataTypeStringToEnum(dataType); dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!; actualLocation = 'ml-tensor'; @@ -534,7 +534,7 @@ export const prepareInputOutputTensor = async ( if (!createTemporaryTensor || !uploadTensor) { throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - const tensorId = await createTemporaryTensor(dataTypeEnum, dims as number[]); + const tensorId = await createTemporaryTensor(sessionId, dataTypeEnum, dims as number[]); uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); rawData = tensorId; } else { @@ -614,9 +614,6 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { - // WebNN backend needs the active session to check MLTensors with the current context. - wasm.jsepOnRunStart?.(sessionHandle); - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors @@ -704,6 +701,8 @@ export const run = async ( ]); } + wasm.jsepOnRunStart?.(sessionHandle); + let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -832,7 +831,7 @@ export const run = async ( // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor // has been created. - const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); + const mlTensor = await ensureTensor(sessionId, dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 42af5cbd3c91d..b4871e145f4d7 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -31,6 +31,7 @@ export declare namespace JSEP { type ReserveTensorIdFunction = () => number; type ReleaseTensorIdFunction = (tensorId: number) => void; type EnsureTensorFunction = ( + sessionId: number | undefined, tensorId: number, dataType: DataType, shape: readonly number[], @@ -141,12 +142,6 @@ export declare namespace JSEP { * @param sessionId - specify the session ID. */ jsepOnRunStart: (sessionId: number) => void; - /** - * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after - * _OrtRun[WithBinding]() is called. - * @param sessionId - specify the session ID. - */ - jsepOnRunEnd: (sessionId: number) => void; /** * [exported from pre-jsep.js] Create a session. This function will be called after _OrtCreateSession() is * called. @@ -173,6 +168,13 @@ export declare namespace JSEP { */ shouldTransferToMLTensor: boolean; + /** + * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunEnd: (sessionId: number) => void; + /** * [exported from pre-jsep.js] Register MLContext for a session. * @param sessionId - specify the session ID. @@ -193,13 +195,20 @@ export declare namespace JSEP { jsepReleaseTensorId: (tensorId: number) => void; /** * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. + * @param sessionId - specify the session ID or current active session ID if undefined. * @param tensorId - specify the MLTensor ID. * @param onnxDataType - specify the data type. * @param shape - specify the dimensions (WebNN shape) of the tensor. * @param copyOld - specify whether to copy the old tensor if a new tensor was created. * @returns the MLTensor associated with the tensor ID. */ - jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise; + jsepEnsureTensor: ( + sessionId: number | undefined, + tensorId: number, + dataType: DataType, + shape: number[], + copyOld: boolean, + ) => Promise; /** * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. @@ -225,12 +234,18 @@ export declare namespace JSEP { ) => () => Promise; /** * [exported from pre-jsep.js] Registers an external MLTensor to a session. + * @param sessionId - specify the session ID. * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. * @returns the MLTensor ID for the external MLTensor. */ - jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + jsepRegisterMLTensor: ( + sessionId: number, + tensor: MLTensor, + onnxDataType: DataType, + dimensions: readonly number[], + ) => number; /** * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. @@ -260,20 +275,22 @@ export declare namespace JSEP { * [exported from pre-jsep.js] Register a WebNN graph input. * @param inputName - specify the input name. */ - jsepRegisterGraphInput(inputName: string): void; + jsepRegisterGraphInput: (inputName: string) => void; /** * [exported from pre-jsep.js] Check if a graph input is a WebNN graph input. + * @param sessionId - specify the session ID. * @param inputName - specify the input name. * @returns whether the input is a WebNN graph input. */ - jsepIsGraphInput(inputName: string): boolean; + jsepIsGraphInput: (sessionId: number, inputName: string) => boolean; /** * [exported from pre-jsep.js] Create a temporary MLTensor for a session. + * @param sessionId - specify the session ID. * @param dataType - specify the data type. * @param shape - specify the shape. * @returns the MLTensor ID for the temporary MLTensor. */ - jsepCreateTemporaryTensor: (dataType: DataType, shape: readonly number[]) => Promise; + jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise; } } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 231b65a4d1894..35964d85862e4 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -165,7 +165,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); promises.call("push", ml_tensor); } for (const auto& [_, tensor] : outputs) { @@ -174,7 +174,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", ml_tensor); } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await();