From 2dbed77a611dbff769fea73236f15af4c0562731 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 21 Nov 2023 16:10:49 +0400 Subject: [PATCH] External weights load --- cmake/onnxruntime_webassembly.cmake | 4 ++-- js/common/lib/inference-session.ts | 4 ++++ js/web/lib/wasm/binding/ort-wasm.d.ts | 4 ++++ js/web/lib/wasm/wasm-core-impl.ts | 28 +++++++++++++++++++++++++-- onnxruntime/wasm/api.cc | 13 +++++++++++++ onnxruntime/wasm/api.h | 8 ++++++++ onnxruntime/wasm/js_internal_api.js | 6 ++++++ 7 files changed, 63 insertions(+), 4 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 9014089cb6112..01c27fbf182fe 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -193,7 +193,7 @@ else() re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','FS'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) @@ -224,7 +224,7 @@ else() "SHELL:-s MODULARIZE=1" "SHELL:-s EXPORT_ALL=0" "SHELL:-s VERBOSE=0" - "SHELL:-s FILESYSTEM=0" + "SHELL:-s FILESYSTEM=1" ${WASM_API_EXCEPTION_CATCHING} --no-entry ) diff --git a/js/common/lib/inference-session.ts b/js/common/lib/inference-session.ts index c7760692eed00..e7425a287abd4 100644 --- a/js/common/lib/inference-session.ts +++ b/js/common/lib/inference-session.ts @@ -226,6 +226,8 @@ export declare namespace InferenceSession { } export interface WebAssemblyExecutionProviderOption extends ExecutionProviderOption { readonly name: 'wasm'; + externalWeights?: ArrayBuffer; + externalWeightsFilename?: string; } export interface WebGLExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgl'; @@ -237,6 +239,8 @@ export declare namespace InferenceSession { export interface WebGpuExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webgpu'; preferredLayout?: 'NCHW'|'NHWC'; + externalWeights?: ArrayBuffer; + externalWeightsFilename?: string; } export interface WebNNExecutionProviderOption extends ExecutionProviderOption { readonly name: 'webnn'; diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 00431a4e86d5b..c3f9507913da5 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -17,6 +17,8 @@ export declare namespace JSEP { export interface OrtWasmModule extends EmscriptenModule { // #region emscripten functions + FS: {unlink(path: string): void; mkdir(path: string): void; chdir(path: string): void}; + stackSave(): number; stackRestore(stack: number): void; stackAlloc(size: number): number; @@ -24,6 +26,7 @@ export interface OrtWasmModule extends EmscriptenModule { UTF8ToString(offset: number, maxBytesToRead?: number): string; lengthBytesUTF8(str: string): number; stringToUTF8(str: string, offset: number, maxBytes: number): void; + createFileFromArrayBuffer(path: string, buffer: ArrayBuffer): void; // #endregion // #region ORT APIs @@ -32,6 +35,7 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): number; + _OrtCreateSessionFromFile(path: number, sessionOptionsHandle: number): number; _OrtReleaseSession(sessionHandle: number): void; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 3aacf8f4d90e0..83251be672535 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -9,6 +9,7 @@ import {setSessionOptions} from './session-options'; import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType, logLevelStringToEnum, tensorDataTypeEnumToString, tensorDataTypeStringToEnum, tensorTypeToTypedArrayConstructor} from './wasm-common'; import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +import WebAssemblyExecutionProviderOption = InferenceSession.WebAssemblyExecutionProviderOption let ortEnvInitialized = false; @@ -123,6 +124,21 @@ export const createSessionFinalize = (modelData: SerializableModeldata, options?: InferenceSession.SessionOptions): SerializableSessionMetadata => { const wasm = getInstance(); + const externalWeightsOption: WebAssemblyExecutionProviderOption|undefined = + options?.executionProviders?.find(e => (e as WebAssemblyExecutionProviderOption).externalWeights) as + WebAssemblyExecutionProviderOption; + let externalWeightsPath = ''; + if (externalWeightsOption) { + const modelDirectory = '/home/web_user/' + Math.random().toString(10); + wasm.FS.mkdir(modelDirectory); + const modelName = modelDirectory + '/model.onnx'; + + externalWeightsPath = `${modelDirectory}/${externalWeightsOption.externalWeightsFilename}`; + wasm.createFileFromArrayBuffer(externalWeightsPath, externalWeightsOption.externalWeights!); + wasm.createFileFromArrayBuffer(modelName, new Uint8Array(wasm.HEAPU8.buffer, modelData[0], modelData[1])); + wasm.FS.chdir(modelDirectory); + } + let sessionHandle = 0; let sessionOptionsHandle = 0; let ioBindingHandle = 0; @@ -132,8 +148,13 @@ export const createSessionFinalize = try { [sessionOptionsHandle, allocs] = setSessionOptions(options); - - sessionHandle = wasm._OrtCreateSession(modelData[0], modelData[1], sessionOptionsHandle); + // when external weights are passed, model must be created from a file + if (externalWeightsOption) { + const modelDirStringPtr = allocWasmString(modelName, allocs); + sessionHandle = wasm._OrtCreateSessionFromFile(modelDirStringPtr, sessionOptionsHandle); + } else { + sessionHandle = wasm._OrtCreateSession(modelData[0], modelData[1], sessionOptionsHandle); + } if (sessionHandle === 0) { checkLastError('Can\'t create a session.'); } @@ -206,6 +227,9 @@ export const createSessionFinalize = wasm._OrtReleaseSessionOptions(sessionOptionsHandle); } allocs.forEach(alloc => wasm._free(alloc)); + if (externalWeightsOption) { + wasm.FS.unlink(externalWeightsPath); + } } }; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..2afb5cc4f0b29 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -195,6 +195,19 @@ OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* : nullptr; } +OrtSession* OrtCreateSessionFromFile(char* path, OrtSessionOptions* session_options) { +#if defined(__EMSCRIPTEN_PTHREADS__) + RETURN_NULLPTR_IF_ERROR(DisablePerSessionThreads, session_options); +#else + // must disable thread pool when WebAssembly multi-threads support is disabled. + RETURN_NULLPTR_IF_ERROR(SetIntraOpNumThreads, session_options, 1); + RETURN_NULLPTR_IF_ERROR(SetSessionExecutionMode, session_options, ORT_SEQUENTIAL); +#endif + + OrtSession* session = nullptr; + return (CHECK_STATUS(CreateSession, g_env, path, session_options, &session) == ORT_OK) ? session : nullptr; +} + void OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 2cd1515d191c8..d0291eb13b6f7 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -121,6 +121,14 @@ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, size_t data_length, ort_session_options_handle_t session_options); +/** + * create an instance of ORT session. + * @param path a pointer to a string that contains the path to ONNX or ORT format model. + * @returns an ORT session handle. Caller must release it after use by calling OrtReleaseSession(). + */ +ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionFromFile(char* path, + ort_session_options_handle_t session_options); + /** * release the specified ORT session. */ diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 427ad6f6d14f3..acdda3e1cd9ca 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -167,3 +167,9 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea return backend['createDownloader'](gpuBuffer, size, type); }; }; + +Module["createFileFromArrayBuffer"] = (path, buffer) => { + const weightsFile = FS.create(path); + weightsFile.contents = new Uint8Array(buffer); + weightsFile.usedBytes = buffer.byteLength; +} \ No newline at end of file