From 3ea20e9e2e56243ae476662e52aee5f53b8f44da Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 9 Jul 2024 09:42:44 -0700 Subject: [PATCH] Initial changes to support wasm64. --- cmake/CMakeLists.txt | 3 +- cmake/adjust_global_compile_flags.cmake | 15 +++- cmake/onnxruntime_webassembly.cmake | 74 +++++++++++++++++-- js/web/lib/index.ts | 8 +- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 +- js/web/lib/wasm/jsep/init.ts | 43 ++++++----- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 7 +- js/web/lib/wasm/jsep/webgpu/types.ts | 1 + js/web/lib/wasm/wasm-types.ts | 5 ++ .../core/framework/tensorprotoutils.cc | 2 +- onnxruntime/core/graph/model.cc | 2 +- onnxruntime/core/providers/js/js_kernel.h | 27 +++++++ .../core/providers/js/operators/conv.h | 12 +-- .../providers/js/operators/conv_transpose.h | 16 ++-- .../core/providers/js/operators/gather.cc | 18 +---- .../core/providers/js/operators/gemm.h | 4 +- onnxruntime/core/providers/js/operators/pad.h | 2 +- .../core/providers/js/operators/reduce.h | 2 +- .../core/providers/js/operators/resize.h | 2 +- .../core/providers/js/operators/slice.h | 6 +- .../core/providers/js/operators/split.h | 2 +- .../core/providers/js/operators/transpose.h | 2 +- onnxruntime/wasm/api.cc | 4 + onnxruntime/wasm/api.h | 4 + tools/ci_build/build.py | 3 +- 25 files changed, 188 insertions(+), 79 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index a9b0dfb30cc4e..c61d720d8321c 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -201,6 +201,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF) +option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF) # Enable bitcode for iOS option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF) @@ -241,7 +242,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) +option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6eb784a4063ed..2189f728ad5ad 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -52,8 +52,13 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) - string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") - string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions") + string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions") + else() + string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + endif() endif() # Build WebAssembly with multi-threads support. @@ -61,6 +66,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") string(APPEND CMAKE_CXX_FLAGS " -pthread -Wno-pthreads-mem-growth") endif() + + # Build WebAssembly with 64bit support. + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") + endif() endif() if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7a49e90c00bce..dbd62a7cca0f5 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -168,9 +168,11 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") - message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") + message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + endif() endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -193,7 +195,7 @@ else() re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) @@ -215,10 +217,55 @@ else() set(EXPORTED_FUNCTIONS "_malloc,_free") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(ASYNCIFY 2) + set(MAXIMUM_MEMORY "17179869184") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s MEMORY64=1" + ) + string(APPEND CMAKE_C_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") + set(SMEMORY_FLAG "-sMEMORY64") + + target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + ) + else () + set(ASYNCIFY 1) + set(MAXIMUM_MEMORY "4294967296") + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + ) + endif () + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" - "SHELL:-s MAXIMUM_MEMORY=4294967296" + "SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}" "SHELL:-s EXIT_RUNTIME=0" "SHELL:-s ALLOW_MEMORY_GROWTH=1" "SHELL:-s MODULARIZE=1" @@ -231,6 +278,12 @@ else() --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" + "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p" + ) + endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) @@ -241,8 +294,11 @@ else() target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" - "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY=${ASYNCIFY}" + #"SHELL:-s JSPI" + #"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() @@ -279,7 +335,11 @@ else() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions") + else() + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + endif() if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 86c05b9a2fa15..ee4cc0067727b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,11 +23,11 @@ if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { - registerBackend('webgpu', wasmBackend, 5); - registerBackend('webnn', wasmBackend, 5); + registerBackend('webgpu', wasmBackend, 1); + registerBackend('webnn', wasmBackend, 1); } - registerBackend('cpu', wasmBackend, 10); - registerBackend('wasm', wasmBackend, 10); + registerBackend('cpu', wasmBackend, 1); + registerBackend('wasm', wasmBackend, 1); } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index c701cf3a6df85..faa08ccca38ee 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -219,6 +219,7 @@ export class WebGpuBackend { maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX, maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY, maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ, + maxBindingsPerBindGroup: adapter.limits.maxBindingsPerBindGroup, }, requiredFeatures, }; @@ -449,7 +450,7 @@ export class WebGpuBackend { const isPersistent = validatedOutputIndices[i] === -2; const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : - createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + createKernelOutput(outputs[i].outputIndex || validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); outputTensorViews.push(tensorView); // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. if (tensorView.data === 0) { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 242f7e939cda0..a410c77890354 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,8 +3,8 @@ import {Env} from 'onnxruntime-common'; -import type {OrtWasmModule} from '../wasm-types'; import {DataType, getTensorElementSize} from '../wasm-common'; +import type {OrtWasmModule} from '../wasm-types'; import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; @@ -68,24 +68,24 @@ class ComputeContextImpl implements ComputeContext { private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { this.adapterInfo = backend.adapterInfo; - const heapU32 = module.HEAPU32; + const heap = module.PTR_SIZE === 4 ? module.HEAPU32 : module.HEAPU64; // extract context data - let dataIndex = (contextDataOffset >>> 2); - this.opKernelContext = heapU32[dataIndex++]; - const inputCount = heapU32[dataIndex++]; - this.outputCount = heapU32[dataIndex++]; - this.customDataOffset = heapU32[dataIndex++]; - this.customDataSize = heapU32[dataIndex++]; + let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2); + this.opKernelContext = Number(heap[dataIndex++]); + const inputCount = Number(heap[dataIndex++]); + this.outputCount = Number(heap[dataIndex++]); + this.customDataOffset = Number(heap[dataIndex++]); + this.customDataSize = Number(heap[dataIndex++]); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = heapU32[dataIndex++]; - const data = heapU32[dataIndex++]; - const dim = heapU32[dataIndex++]; + const dataType = Number(heap[dataIndex++]); + const data = Number(heap[dataIndex++]); + const dim = Number(heap[dataIndex++]); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(heapU32[dataIndex++]); + dims.push(Number(heap[dataIndex++])); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } @@ -127,11 +127,11 @@ class ComputeContextImpl implements ComputeContext { output(index: number, dims: readonly number[]): number { const stack = this.module.stackSave(); try { - const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */); - let offset = data >> 2; - this.module.HEAPU32[offset++] = dims.length; + const ptrSize = this.module.PTR_SIZE; + const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); + this.module.setValue(data, dims.length, '*'); for (let i = 0; i < dims.length; i++) { - this.module.HEAPU32[offset++] = dims[i]; + this.module.setValue(data + ptrSize * (i + 1), dims[i], '*'); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { @@ -193,10 +193,15 @@ export const init = // jsepCopy(src, dst, size, isSourceGpu) (src: number, dst: number, size: number, isSourceGpu = false) => { if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`); backend.memcpy(src, dst); } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${ + Number(size)}`); const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); backend.upload(dst, data); } @@ -226,7 +231,7 @@ export const init = 'verbose', () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); - const context = new ComputeContextImpl(module, backend, contextDataOffset); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); return backend.computeKernel(kernel, context, errors); }, // jsepCaptureBegin diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index a5c0a088efa6e..aa731757651a9 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -112,7 +112,7 @@ const bucketArr: number[] = []; /** * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. */ -const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; +const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16; /** * calculate the buffer size so that it fits into buckets. @@ -342,7 +342,7 @@ class GpuDataManagerImpl implements GpuDataManager { } const gpuData = {id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer}; - this.storageCache.set(gpuData.id, {gpuData, originalSize: size}); + this.storageCache.set(gpuData.id, {gpuData, originalSize: Number(size)}); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); return gpuData; @@ -352,7 +352,8 @@ class GpuDataManagerImpl implements GpuDataManager { return this.storageCache.get(id)?.gpuData; } - release(id: GpuDataId): number { + release(idInput: GpuDataId): number { + const id = typeof idInput === 'bigint' ? Number(idInput) : idInput; const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('releasing data does not exist'); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 2a584fc0a2218..6e906cc8497ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -31,6 +31,7 @@ export interface GpuData { export interface TensorInfo { dims: readonly number[]; dataType: number; + outputIndex?: number; } export interface ProgramUniform { diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70728c82e7753..8b29e24cb2143 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -209,10 +209,15 @@ export interface OrtTrainingAPIs { */ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, Partial { + HEAP64: BigInt64Array; + HEAPU64: BigUint64Array; + PTR_SIZE: number; // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; stackAlloc(size: number): number; + getValue(ptr: number, type: string): number; + setValue(ptr: number, value: number, type: string): void; UTF8ToString(offset: number, maxBytesToRead?: number): string; lengthBytesUTF8(str: string): number; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 77323f268a27d..199c6646aa704 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -958,7 +958,7 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo try { // Copy the file data (fileData,offset,length) into WebAssembly memory // (HEAPU8,buffer,length). - HEAPU8.set(fileData.subarray(offset, offset + length), buffer); + HEAPU8.set(fileData.subarray(Number(offset), Number(offset) + length), buffer); return 0; } catch { return 4; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index e9d1b4e944edd..4878907ac0554 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -553,7 +553,7 @@ static Status SaveModel(Model& model, const T& file_path) { const buffer_size = $1; const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); - bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); + bytes.set(HEAPU8.subarray(Number(buffer), Number(buffer) + buffer_size)); if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') { // Node.js require('fs').writeFileSync(file_path, bytes); diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 7324b0d69474c..e77ebb9d06559 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -110,16 +110,28 @@ class JsKernel : public OpKernel { temp_data_size += sizeof(size_t) * 3; } } +#ifdef WASM_MEMORY64 + uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); +#else uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); +#endif if (p_serialized_kernel_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context."); } +#ifdef WASM_MEMORY64 + p_serialized_kernel_context[0] = reinterpret_cast(context); + p_serialized_kernel_context[1] = static_cast(context->InputCount()); + p_serialized_kernel_context[2] = static_cast(context->OutputCount()); + p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); + p_serialized_kernel_context[4] = static_cast(custom_data_size); +#else p_serialized_kernel_context[0] = reinterpret_cast(context); p_serialized_kernel_context[1] = static_cast(context->InputCount()); p_serialized_kernel_context[2] = static_cast(context->OutputCount()); p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); p_serialized_kernel_context[4] = static_cast(custom_data_size); +#endif size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { const auto* input_ptr = context->Input(i); @@ -130,12 +142,21 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[index++] = 0; continue; } +#ifdef WASM_MEMORY64 + p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); + for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); + } +#else p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } +#endif } #ifndef NDEBUG @@ -199,9 +220,15 @@ class JsKernel : public OpKernel { return status; } +#ifdef WASM_MEMORY64 + intptr_t status_code = EM_ASM_INT( + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, + this, reinterpret_cast(p_serialized_kernel_context)); +#else int status_code = EM_ASM_INT( { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); +#endif LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 32e8e1facafcd..a471044597b02 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -52,14 +52,14 @@ class ConvBase : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, - "dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "group" : $4, - "kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], - "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], - "strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], + "pads" : $7 ? Array.from(HEAP32.subarray($7, Number($8))) : [], + "strides" : $9 ? Array.from(HEAP32.subarray($9, Number($10))) : [], + "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : [] }), static_cast(conv_attrs_.auto_pad), JSEP_HEAP32_INDEX_START(dilations), diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 59dfbb0e94922..522a9cb2cfead 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -48,8 +48,8 @@ class ConvTranspose : public JsKernel { "pads" : [ $5, $6 ], "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [], + "outputPadding" : $10 ? Array.from(HEAP32.subarray(Number($10), Number($11))) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray(Number($12), Number($13))) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -99,14 +99,14 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray(Number($2), Number(($2 >>> 0) + /* dialations_vec_size */ 2))), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)), + "kernelShape" : Array.from(HEAP32.subarray(Number($4), Number(($4 >>> 0) + /* kernel_shape_vec_size */ 2))), + "pads" : Array.from(HEAP32.subarray(Number($5), Number(($5 >>> 0) + /* pads_vec_size */ 4))), + "strides" : Array.from(HEAP32.subarray(Number($6), Number(($6 >>> 0) + /* strides_vec_size */ 2))), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), - "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [], + "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc index 485cd3da9b91b..e9c6f5c79294f 100644 --- a/onnxruntime/core/providers/js/operators/gather.cc +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -15,11 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -30,11 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -44,11 +36,7 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); diff --git a/onnxruntime/core/providers/js/operators/gemm.h b/onnxruntime/core/providers/js/operators/gemm.h index 74091526f8411..d7f8fa6289c1a 100644 --- a/onnxruntime/core/providers/js/operators/gemm.h +++ b/onnxruntime/core/providers/js/operators/gemm.h @@ -23,8 +23,8 @@ class Gemm : public JsKernel { "transA" : $3, "transB" : $4 }), - static_cast(alpha), - static_cast(beta), + static_cast(alpha), + static_cast(beta), static_cast(transA), static_cast(transB)); } diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index c18c7dd456dc2..f656462285bc4 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,7 +22,7 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(mode_), static_cast(value_), JSEP_HEAP32_INDEX_START(pads), diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 937f1f990dc67..5f6f3e2fd1a1f 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -24,7 +24,7 @@ namespace js { JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ "keepDims" : !!$1, \ "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \ + "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ }), \ static_cast(keepdims_), \ static_cast(noop_with_empty_axes_), \ diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 134eb4bf5a7f4..3e8ccf40753c8 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index daeffaa664741..f30e7bf01ec7b 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,9 +20,9 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}), + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [], + "ends" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : [], + "axes" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : []}), JSEP_HEAP32_INDEX_START(starts), JSEP_HEAP32_INDEX_END(starts), JSEP_HEAP32_INDEX_START(ends), diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 4fdbab00e739c..3f6cfcb8921f3 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,7 +49,7 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(axis_), static_cast(num_outputs_), JSEP_HEAP32_INDEX_START(split_sizes), diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index f43dd814aa959..20586232ef1a9 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,7 +21,7 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [] }), JSEP_HEAP32_INDEX_START(perm), JSEP_HEAP32_INDEX_END(perm)); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..aeee3827dd9cd 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -27,7 +27,11 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); +#ifdef WASM_MEMORY64 +static_assert(sizeof(size_t) == 8, "size of size_t should be 8 in this build (wasm64)."); +#else static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); +#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 2cd1515d191c8..f6dfc57171768 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -183,7 +183,11 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ +#ifdef WASM_MEMORY64 +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); +#else int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length); +#endif /** * release the specified tensor. diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index ae4c9b27544ba..dfa3f3a83f439 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -401,7 +401,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with a specific GDK edition. Defaults to the latest installed.", ) parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") - + parser.add_argument("--enable_wasm_memory64", action="store_true", help="Enable WebAssembly 64bit support") platform_group = parser.add_mutually_exclusive_group() platform_group.add_argument("--ios", action="store_true", help="build for ios") platform_group.add_argument("--visionos", action="store_true", help="build for visionOS") @@ -1090,6 +1090,7 @@ def generate_build_tree( + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_MEMORY64=" + ("ON" if args.enable_wasm_memory64 else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"),