diff --git a/js/web/lib/wasm/jsep/backend-webnn.ts b/js/web/lib/wasm/jsep/backend-webnn.ts index b302354c46eeb..2b9a9208e2e53 100644 --- a/js/web/lib/wasm/jsep/backend-webnn.ts +++ b/js/web/lib/wasm/jsep/backend-webnn.ts @@ -75,6 +75,19 @@ export class WebNNBackend { * Current session id. */ private activeSessionId?: number; + /** + * Maps from session id to list of graph inputs. + */ + private sessionGraphInputs: Map = new Map(); + /** + * Temporary graph inputs for the current session. + * These inputs will be registered when the session is created. + */ + private temporaryGraphInputs: string[] = []; + /** + * Temporary tensors for the current session. + */ + private temporarySessionTensorIds: Map = new Map(); constructor(env: Env) { configureLogger(env.logLevel!, !!env.debug); @@ -88,9 +101,24 @@ export class WebNNBackend { } public onRunStart(sessionId: number): void { + LOG_DEBUG('verbose', () => `[WebNN] onRunStart {sessionId: ${sessionId}}`); this.activeSessionId = sessionId; } + public onRunEnd(sessionId: number): void { + LOG_DEBUG('verbose', () => `[WebNN] onRunEnd {sessionId: ${sessionId}}`); + const tensorIds = this.temporarySessionTensorIds.get(sessionId); + if (!tensorIds) { + return; + } + for (const tensorId of tensorIds) { + LOG_DEBUG('verbose', () => `[WebNN] releasing temporary tensor {tensorId: ${tensorId}}`); + this.tensorManager.releaseTensorId(tensorId); + } + this.temporarySessionTensorIds.delete(sessionId); + this.activeSessionId = undefined; + } + public async createMLContext(optionsOrDevice?: MLContextOptions | GPUDevice): Promise { if (optionsOrDevice instanceof GPUDevice) { const mlContextIndex = this.mlContextCache.findIndex((entry) => entry.gpuDevice === optionsOrDevice); @@ -126,14 +154,6 @@ export class WebNNBackend { } } - public get currentContext(): MLContext { - const mlContext = this.getMLContext(this.currentSessionId); - if (!mlContext) { - throw new Error(`No MLContext found for session ${this.currentSessionId}`); - } - return mlContext; - } - public registerMLContext(sessionId: number, mlContext: MLContext): void { this.mlContextBySessionId.set(sessionId, mlContext); let sessionIds = this.sessionIdsByMLContext.get(mlContext); @@ -142,9 +162,15 @@ export class WebNNBackend { this.sessionIdsByMLContext.set(mlContext, sessionIds); } sessionIds.add(sessionId); + + if (this.temporaryGraphInputs.length > 0) { + this.sessionGraphInputs.set(sessionId, this.temporaryGraphInputs); + this.temporaryGraphInputs = []; + } } public onReleaseSession(sessionId: number): void { + this.sessionGraphInputs.delete(sessionId); const mlContext = this.mlContextBySessionId.get(sessionId)!; if (!mlContext) { // Current session is not a WebNN session. @@ -177,6 +203,7 @@ export class WebNNBackend { } public async ensureTensor( + sessionId: number | undefined, tensorId: TensorId, onnxDataType: DataType, dimensions: number[], @@ -186,7 +213,34 @@ export class WebNNBackend { if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - return this.tensorManager.ensureTensor(tensorId, webnnDataType, dimensions, copyOld); + return this.tensorManager.ensureTensor( + sessionId ?? this.currentSessionId, + tensorId, + webnnDataType, + dimensions, + copyOld, + ); + } + + public async createTemporaryTensor( + sessionId: number, + onnxDataType: DataType, + shape: readonly number[], + ): Promise { + LOG_DEBUG('verbose', () => `[WebNN] createTemporaryTensor {onnxDataType: ${onnxDataType}, shape: ${shape}}`); + const dataType = onnxDataTypeToWebnnDataType.get(onnxDataType); + if (!dataType) { + throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); + } + const tensorId = this.tensorManager.reserveTensorId(); + await this.tensorManager.ensureTensor(sessionId, tensorId, dataType, shape, false); + const tensorIds = this.temporarySessionTensorIds.get(sessionId); + if (!tensorIds) { + this.temporarySessionTensorIds.set(sessionId, [tensorId]); + } else { + tensorIds.push(tensorId); + } + return tensorId; } public uploadTensor(tensorId: TensorId, data: Uint8Array): void { @@ -209,13 +263,13 @@ export class WebNNBackend { }; } - public registerMLTensor(tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { + public registerMLTensor(sessionId: number, tensor: MLTensor, onnxDataType: DataType, dimensions: number[]): TensorId { const webnnDataType = onnxDataTypeToWebnnDataType.get(onnxDataType); if (!webnnDataType) { throw new Error(`Unsupported ONNX data type: ${onnxDataType}`); } - const id = this.tensorManager.registerTensor(this.currentContext, tensor, webnnDataType, dimensions); + const id = this.tensorManager.registerTensor(sessionId, tensor, webnnDataType, dimensions); LOG_DEBUG( 'verbose', () => @@ -291,6 +345,18 @@ export class WebNNBackend { return builder.constant(desc, bufferView); } + public registerGraphInput(inputName: string): void { + this.temporaryGraphInputs.push(inputName); + } + + public isGraphInput(sessionId: number, inputName: string): boolean { + const inputNames = this.sessionGraphInputs.get(sessionId); + if (!inputNames) { + return false; + } + return inputNames.includes(inputName); + } + public flush(): void { // Unlike the WebGPU backend, the WebNN backend does not need to flush any pending operations. } diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 48bd3ef2bc36f..b4071eae51c8f 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -287,8 +287,8 @@ export const init = async ( // jsepReleaseTensorId, (tensorId: number) => backend.releaseTensorId(tensorId), // jsepEnsureTensor - async (tensorId: number, onnxDataType: number, shape: number[], copyOld) => - backend.ensureTensor(tensorId, onnxDataType, shape, copyOld), + async (sessionId: number | undefined, tensorId: number, onnxDataType: number, shape: number[], copyOld) => + backend.ensureTensor(sessionId, tensorId, onnxDataType, shape, copyOld), // jsepUploadTensor (tensorId: number, data: Uint8Array) => { backend.uploadTensor(tensorId, data); diff --git a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts index 45b5b8b4fa932..ebdd5069aa089 100644 --- a/js/web/lib/wasm/jsep/webnn/tensor-manager.ts +++ b/js/web/lib/wasm/jsep/webnn/tensor-manager.ts @@ -27,6 +27,7 @@ export interface TensorManager { * Ensure a MLTensor is created for the TensorId. */ ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: readonly number[], @@ -46,9 +47,9 @@ export interface TensorManager { */ releaseTensorsForSession(session: number): void; /** - * Register an externally created MLTensor with a given MLContext and return a TensorId. + * Register an externally created MLTensor with a given session id and return a TensorId. */ - registerTensor(mlContext: MLContext, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; + registerTensor(sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: number[]): TensorId; } let tensorGuid = 1; @@ -177,11 +178,12 @@ class TensorIdTracker { } public async ensureTensor( - context: MLContext, + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], copyOld: boolean, ): Promise { + const context = this.tensorManager.getMLContext(sessionId); if (this.wrapper) { if (this.wrapper.canReuseTensor(context, dataType, shape)) { return this.wrapper.tensor; @@ -198,7 +200,7 @@ class TensorIdTracker { // eslint-disable-next-line no-bitwise const usage = typeof MLTensorUsage == 'undefined' ? undefined : MLTensorUsage.READ | MLTensorUsage.WRITE; - this.wrapper = await this.tensorManager.getCachedTensor(dataType, shape, usage, true, true); + this.wrapper = await this.tensorManager.getCachedTensor(sessionId, dataType, shape, usage, true, true); if (copyOld && this.activeUpload) { this.wrapper.write(this.activeUpload); @@ -256,6 +258,14 @@ class TensorManagerImpl implements TensorManager { constructor(private backend: WebNNBackend) {} + public getMLContext(sessionId: number): MLContext { + const context = this.backend.getMLContext(sessionId); + if (!context) { + throw new Error('MLContext not found for session.'); + } + return context; + } + public reserveTensorId(): TensorId { const tensorId = createNewTensorId(); this.tensorTrackersById.set(tensorId, new TensorIdTracker(this)); @@ -274,6 +284,7 @@ class TensorManagerImpl implements TensorManager { } public async ensureTensor( + sessionId: number, tensorId: TensorId, dataType: MLOperandDataType, shape: number[], @@ -290,7 +301,7 @@ class TensorManagerImpl implements TensorManager { if (!tensor) { throw new Error('Tensor not found.'); } - return tensor.ensureTensor(this.backend.currentContext, dataType, shape, copyOld); + return tensor.ensureTensor(sessionId, dataType, shape, copyOld); } public upload(tensorId: TensorId, data: Uint8Array): void { @@ -325,17 +336,18 @@ class TensorManagerImpl implements TensorManager { } public registerTensor( - mlContext: MLContext, + sessionId: number, mlTensor: MLTensor, dataType: MLOperandDataType, shape: readonly number[], ): TensorId { + const context = this.getMLContext(sessionId); const tensorId = createNewTensorId(); // Defaulting to READ | WRITE if usage is not provided. // eslint-disable-next-line no-bitwise const wrapper = new TensorWrapper({ - sessionId: this.backend.currentSessionId, - context: mlContext, + sessionId, + context, tensor: mlTensor, dataType, shape, @@ -349,14 +361,14 @@ class TensorManagerImpl implements TensorManager { * Get or create an MLTensor with the given data type and shape. */ public async getCachedTensor( + sessionId: number, dataType: MLOperandDataType, shape: readonly number[], usage: MLTensorUsageFlags | undefined, writable: boolean, readable: boolean, ): Promise { - const sessionId = this.backend.currentSessionId; - const context = this.backend.currentContext; + const context = this.getMLContext(sessionId); for (const [index, tensor] of this.freeTensors.entries()) { if (tensor.canReuseTensor(context, dataType, shape)) { LOG_DEBUG('verbose', () => `[WebNN] Reusing tensor {dataType: ${dataType}, shape: ${shape}}`); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index da8939cd0263a..4bccfa76fdda3 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -453,14 +453,14 @@ export const releaseSession = (sessionId: number): void => { activeSessions.delete(sessionId); }; -export const prepareInputOutputTensor = ( +export const prepareInputOutputTensor = async ( tensor: TensorMetadata | null, tensorHandles: number[], allocs: number[], sessionId: number, index: number, enableGraphCapture = false, -): void => { +): Promise => { if (!tensor) { tensorHandles.push(0); return; @@ -472,6 +472,7 @@ export const prepareInputOutputTensor = ( const dataType = tensor[0]; const dims = tensor[1]; const location = tensor[3]; + let actualLocation = location; let rawData: number; let dataByteLength: number; @@ -503,7 +504,7 @@ export const prepareInputOutputTensor = ( if (!registerMLTensor) { throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); } - rawData = registerMLTensor(mlTensor, tensorDataTypeStringToEnum(dataType), dims); + rawData = registerMLTensor(sessionId, mlTensor, tensorDataTypeStringToEnum(dataType), dims); } else { const data = tensor[2]; @@ -519,10 +520,35 @@ export const prepareInputOutputTensor = ( wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + const isGraphInput = wasm.jsepIsGraphInput; + if (dataType !== 'string' && isGraphInput) { + const tensorNameUTF8 = wasm._OrtGetInputName(sessionId, index); + const tensorName = wasm.UTF8ToString(tensorNameUTF8); + // Promote the tensor to 'ml-tensor' if it is a graph input. + if (isGraphInput(sessionId, tensorName)) { + const dataTypeEnum = tensorDataTypeStringToEnum(dataType); + dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!; + actualLocation = 'ml-tensor'; + const createTemporaryTensor = wasm.jsepCreateTemporaryTensor; + const uploadTensor = wasm.jsepUploadTensor; + if (!createTemporaryTensor || !uploadTensor) { + throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.'); + } + const tensorId = await createTemporaryTensor(sessionId, dataTypeEnum, dims as number[]); + uploadTensor(tensorId, new Uint8Array(data.buffer, data.byteOffset, data.byteLength)); + rawData = tensorId; + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } } } @@ -536,7 +562,7 @@ export const prepareInputOutputTensor = ( dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location), + dataLocationStringToEnum(actualLocation), ); if (tensor === 0) { checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); @@ -588,14 +614,11 @@ export const run = async ( const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { - // WebNN backend needs the active session to check MLTensors with the current context. - wasm.jsepOnRunStart?.(sessionHandle); - [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); // create input tensors for (let i = 0; i < inputCount; i++) { - prepareInputOutputTensor( + await prepareInputOutputTensor( inputTensors[i], inputTensorHandles, inputOutputAllocs, @@ -607,7 +630,7 @@ export const run = async ( // create output tensors for (let i = 0; i < outputCount; i++) { - prepareInputOutputTensor( + await prepareInputOutputTensor( outputTensors[i], outputTensorHandles, inputOutputAllocs, @@ -678,6 +701,8 @@ export const run = async ( ]); } + wasm.jsepOnRunStart?.(sessionHandle); + let errorCode: number; if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( @@ -806,7 +831,7 @@ export const run = async ( // If the graph has been partitioned, the output tensor may have not been created. For this reason, we use // ensureTensor to get/create the MLTensor. In which case, we don't need to copy the data if a new tensor // has been created. - const mlTensor = await ensureTensor(dataOffset, dataType, dims, false); + const mlTensor = await ensureTensor(sessionId, dataOffset, dataType, dims, false); // do not release the tensor right now. it will be released when user calls tensor.dispose(). keepOutputTensor = true; @@ -841,6 +866,7 @@ export const run = async ( if (!keepOutputTensor) { wasm._OrtReleaseTensor(tensor); } + wasm.jsepOnRunEnd?.(sessionHandle); } } diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index ebeac5dc9e587..b4871e145f4d7 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -31,6 +31,7 @@ export declare namespace JSEP { type ReserveTensorIdFunction = () => number; type ReleaseTensorIdFunction = (tensorId: number) => void; type EnsureTensorFunction = ( + sessionId: number | undefined, tensorId: number, dataType: DataType, shape: readonly number[], @@ -167,6 +168,13 @@ export declare namespace JSEP { */ shouldTransferToMLTensor: boolean; + /** + * [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after + * _OrtRun[WithBinding]() is called. + * @param sessionId - specify the session ID. + */ + jsepOnRunEnd: (sessionId: number) => void; + /** * [exported from pre-jsep.js] Register MLContext for a session. * @param sessionId - specify the session ID. @@ -187,13 +195,20 @@ export declare namespace JSEP { jsepReleaseTensorId: (tensorId: number) => void; /** * [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID. + * @param sessionId - specify the session ID or current active session ID if undefined. * @param tensorId - specify the MLTensor ID. * @param onnxDataType - specify the data type. * @param shape - specify the dimensions (WebNN shape) of the tensor. * @param copyOld - specify whether to copy the old tensor if a new tensor was created. * @returns the MLTensor associated with the tensor ID. */ - jsepEnsureTensor: (tensorId: number, dataType: DataType, shape: number[], copyOld: boolean) => Promise; + jsepEnsureTensor: ( + sessionId: number | undefined, + tensorId: number, + dataType: DataType, + shape: number[], + copyOld: boolean, + ) => Promise; /** * [exported from pre-jsep.js] Upload data to an MLTensor. * @param tensorId - specify the MLTensor ID. @@ -219,12 +234,18 @@ export declare namespace JSEP { ) => () => Promise; /** * [exported from pre-jsep.js] Registers an external MLTensor to a session. + * @param sessionId - specify the session ID. * @param tensor - specify the MLTensor. * @param dataType - specify the data type. * @param dimensions - specify the dimensions. * @returns the MLTensor ID for the external MLTensor. */ - jsepRegisterMLTensor: (tensor: MLTensor, onnxDataType: DataType, dimensions: readonly number[]) => number; + jsepRegisterMLTensor: ( + sessionId: number, + tensor: MLTensor, + onnxDataType: DataType, + dimensions: readonly number[], + ) => number; /** * [exported from pre-jsep.js] Create an MLContext from a GPUDevice or MLContextOptions. @@ -249,6 +270,27 @@ export declare namespace JSEP { builder: MLGraphBuilder, desc: MLOperandDescriptor, ): MLOperand; + + /** + * [exported from pre-jsep.js] Register a WebNN graph input. + * @param inputName - specify the input name. + */ + jsepRegisterGraphInput: (inputName: string) => void; + /** + * [exported from pre-jsep.js] Check if a graph input is a WebNN graph input. + * @param sessionId - specify the session ID. + * @param inputName - specify the input name. + * @returns whether the input is a WebNN graph input. + */ + jsepIsGraphInput: (sessionId: number, inputName: string) => boolean; + /** + * [exported from pre-jsep.js] Create a temporary MLTensor for a session. + * @param sessionId - specify the session ID. + * @param dataType - specify the data type. + * @param shape - specify the shape. + * @returns the MLTensor ID for the temporary MLTensor. + */ + jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise; } } diff --git a/onnxruntime/core/providers/webnn/builders/model.cc b/onnxruntime/core/providers/webnn/builders/model.cc index 231b65a4d1894..35964d85862e4 100644 --- a/onnxruntime/core/providers/webnn/builders/model.cc +++ b/onnxruntime/core/providers/webnn/builders/model.cc @@ -165,7 +165,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, true); promises.call("push", ml_tensor); } for (const auto& [_, tensor] : outputs) { @@ -174,7 +174,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap(dim); shape.call("push", dim_val); } - auto ml_tensor = jsepEnsureTensor(reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); + auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast(tensor.buffer), tensor.tensor_info.data_type, shape, false); promises.call("push", ml_tensor); } auto ml_tensors = emscripten::val::global("Promise").call("all", promises).await(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index e8f116d390199..4b7cab684ae81 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -252,6 +252,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i if (is_input) { wnn_operands_.insert(std::make_pair(name, wnn_builder_.call("input", name, desc))); + emscripten::val::module_property("jsepRegisterGraphInput")(name); input_names_.push_back(name); } else { output_names_.push_back(name); diff --git a/onnxruntime/wasm/pre-jsep.js b/onnxruntime/wasm/pre-jsep.js index 45e2475548df5..0c83e71a921cb 100644 --- a/onnxruntime/wasm/pre-jsep.js +++ b/onnxruntime/wasm/pre-jsep.js @@ -220,12 +220,14 @@ Module['jsepInit'] = (name, params) => { // This function is called from both JS and an EM_ASM block, it needs both a minifiable name and an explicit name. Module['jsepReleaseTensorId'] = Module.jsepReleaseTensorId; + Module['jsepUploadTensor'] = Module.jsepUploadTensor; // Functions called from JS also need to have explicit names. const backend = Module.jsepBackend; Module['jsepOnRunStart'] = sessionId => { return backend['onRunStart'](sessionId); }; + Module['jsepOnRunEnd'] = backend['onRunEnd'].bind(backend); Module['jsepRegisterMLContext'] = (sessionId, mlContext) => { backend['registerMLContext'](sessionId, mlContext); }; @@ -235,8 +237,8 @@ Module['jsepInit'] = (name, params) => { Module['jsepCreateMLTensorDownloader'] = (tensorId, type) => { return backend['createMLTensorDownloader'](tensorId, type); } - Module['jsepRegisterMLTensor'] = (tensor, dataType, shape) => { - return backend['registerMLTensor'](tensor, dataType, shape); + Module['jsepRegisterMLTensor'] = (sessionId, tensor, dataType, shape) => { + return backend['registerMLTensor'](sessionId, tensor, dataType, shape); }; Module['jsepCreateMLContext'] = (optionsOrGpuDevice) => { return backend['createMLContext'](optionsOrGpuDevice); @@ -245,5 +247,9 @@ Module['jsepInit'] = (name, params) => { return backend['registerMLConstant']( externalFilePath, dataOffset, dataLength, builder, desc, Module.MountedFiles); }; + Module['jsepRegisterGraphInput'] = backend['registerGraphInput'].bind(backend); + Module['jsepIsGraphInput'] = backend['isGraphInput'].bind(backend); + + Module['jsepCreateTemporaryTensor'] = backend['createTemporaryTensor'].bind(backend); } };