From 60cd3db4bc049a04f9b2b3fb479795468bf31872 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 9 Jul 2024 09:42:44 -0700 Subject: [PATCH 01/45] Initial changes to support wasm64. --- cmake/CMakeLists.txt | 3 +- cmake/adjust_global_compile_flags.cmake | 15 +++- cmake/onnxruntime_webassembly.cmake | 74 +++++++++++++++++-- js/web/lib/index.ts | 8 +- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 +- js/web/lib/wasm/jsep/init.ts | 43 ++++++----- .../lib/wasm/jsep/webgpu/gpu-data-manager.ts | 7 +- js/web/lib/wasm/jsep/webgpu/types.ts | 1 + js/web/lib/wasm/wasm-types.ts | 5 ++ .../core/framework/tensorprotoutils.cc | 2 +- onnxruntime/core/graph/model.cc | 5 +- onnxruntime/core/providers/js/js_kernel.h | 27 +++++++ .../core/providers/js/operators/conv.h | 12 +-- .../providers/js/operators/conv_transpose.h | 16 ++-- .../core/providers/js/operators/gather.cc | 18 +---- .../core/providers/js/operators/gemm.h | 4 +- onnxruntime/core/providers/js/operators/pad.h | 2 +- .../core/providers/js/operators/reduce.h | 2 +- .../core/providers/js/operators/resize.h | 2 +- .../core/providers/js/operators/slice.h | 6 +- .../core/providers/js/operators/split.h | 2 +- .../core/providers/js/operators/transpose.h | 2 +- onnxruntime/wasm/api.cc | 4 + onnxruntime/wasm/api.h | 4 + tools/ci_build/build.py | 3 +- 25 files changed, 189 insertions(+), 81 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 5555fa692eae8..476745c39aee9 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -201,6 +201,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF) +option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF) # Enable bitcode for iOS option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF) @@ -241,7 +242,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) +option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6eb784a4063ed..2189f728ad5ad 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -52,8 +52,13 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) - string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") - string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions") + string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions") + else() + string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + endif() endif() # Build WebAssembly with multi-threads support. @@ -61,6 +66,12 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") string(APPEND CMAKE_CXX_FLAGS " -pthread -Wno-pthreads-mem-growth") endif() + + # Build WebAssembly with 64bit support. + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") + endif() endif() if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 7a49e90c00bce..dbd62a7cca0f5 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -168,9 +168,11 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") - message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") + message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) + endif() endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -193,7 +195,7 @@ else() re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) @@ -215,10 +217,55 @@ else() set(EXPORTED_FUNCTIONS "_malloc,_free") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(ASYNCIFY 2) + set(MAXIMUM_MEMORY "17179869184") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s MEMORY64=1" + ) + string(APPEND CMAKE_C_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") + set(SMEMORY_FLAG "-sMEMORY64") + + target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + ) + else () + set(ASYNCIFY 1) + set(MAXIMUM_MEMORY "4294967296") + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + ) + endif () + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" - "SHELL:-s MAXIMUM_MEMORY=4294967296" + "SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}" "SHELL:-s EXIT_RUNTIME=0" "SHELL:-s ALLOW_MEMORY_GROWTH=1" "SHELL:-s MODULARIZE=1" @@ -231,6 +278,12 @@ else() --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" + "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p" + ) + endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) @@ -241,8 +294,11 @@ else() target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" - "SHELL:-s ASYNCIFY=1" + "SHELL:-s ASYNCIFY=${ASYNCIFY}" + #"SHELL:-s JSPI" + #"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() @@ -279,7 +335,11 @@ else() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions") + else() + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + endif() if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index 86c05b9a2fa15..ee4cc0067727b 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,11 +23,11 @@ if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { - registerBackend('webgpu', wasmBackend, 5); - registerBackend('webnn', wasmBackend, 5); + registerBackend('webgpu', wasmBackend, 1); + registerBackend('webnn', wasmBackend, 1); } - registerBackend('cpu', wasmBackend, 10); - registerBackend('wasm', wasmBackend, 10); + registerBackend('cpu', wasmBackend, 1); + registerBackend('wasm', wasmBackend, 1); } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index c701cf3a6df85..faa08ccca38ee 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -219,6 +219,7 @@ export class WebGpuBackend { maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX, maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY, maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ, + maxBindingsPerBindGroup: adapter.limits.maxBindingsPerBindGroup, }, requiredFeatures, }; @@ -449,7 +450,7 @@ export class WebGpuBackend { const isPersistent = validatedOutputIndices[i] === -2; const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : - createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + createKernelOutput(outputs[i].outputIndex || validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); outputTensorViews.push(tensorView); // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. if (tensorView.data === 0) { diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 242f7e939cda0..a410c77890354 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -3,8 +3,8 @@ import {Env} from 'onnxruntime-common'; -import type {OrtWasmModule} from '../wasm-types'; import {DataType, getTensorElementSize} from '../wasm-common'; +import type {OrtWasmModule} from '../wasm-types'; import {WebGpuBackend} from './backend-webgpu'; import {LOG_DEBUG} from './log'; @@ -68,24 +68,24 @@ class ComputeContextImpl implements ComputeContext { private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { this.adapterInfo = backend.adapterInfo; - const heapU32 = module.HEAPU32; + const heap = module.PTR_SIZE === 4 ? module.HEAPU32 : module.HEAPU64; // extract context data - let dataIndex = (contextDataOffset >>> 2); - this.opKernelContext = heapU32[dataIndex++]; - const inputCount = heapU32[dataIndex++]; - this.outputCount = heapU32[dataIndex++]; - this.customDataOffset = heapU32[dataIndex++]; - this.customDataSize = heapU32[dataIndex++]; + let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2); + this.opKernelContext = Number(heap[dataIndex++]); + const inputCount = Number(heap[dataIndex++]); + this.outputCount = Number(heap[dataIndex++]); + this.customDataOffset = Number(heap[dataIndex++]); + this.customDataSize = Number(heap[dataIndex++]); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = heapU32[dataIndex++]; - const data = heapU32[dataIndex++]; - const dim = heapU32[dataIndex++]; + const dataType = Number(heap[dataIndex++]); + const data = Number(heap[dataIndex++]); + const dim = Number(heap[dataIndex++]); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(heapU32[dataIndex++]); + dims.push(Number(heap[dataIndex++])); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } @@ -127,11 +127,11 @@ class ComputeContextImpl implements ComputeContext { output(index: number, dims: readonly number[]): number { const stack = this.module.stackSave(); try { - const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */); - let offset = data >> 2; - this.module.HEAPU32[offset++] = dims.length; + const ptrSize = this.module.PTR_SIZE; + const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); + this.module.setValue(data, dims.length, '*'); for (let i = 0; i < dims.length; i++) { - this.module.HEAPU32[offset++] = dims[i]; + this.module.setValue(data + ptrSize * (i + 1), dims[i], '*'); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { @@ -193,10 +193,15 @@ export const init = // jsepCopy(src, dst, size, isSourceGpu) (src: number, dst: number, size: number, isSourceGpu = false) => { if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`); backend.memcpy(src, dst); } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${ + Number(size)}`); const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); backend.upload(dst, data); } @@ -226,7 +231,7 @@ export const init = 'verbose', () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); - const context = new ComputeContextImpl(module, backend, contextDataOffset); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); return backend.computeKernel(kernel, context, errors); }, // jsepCaptureBegin diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index a5c0a088efa6e..aa731757651a9 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -112,7 +112,7 @@ const bucketArr: number[] = []; /** * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. */ -const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; +const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16; /** * calculate the buffer size so that it fits into buckets. @@ -342,7 +342,7 @@ class GpuDataManagerImpl implements GpuDataManager { } const gpuData = {id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer}; - this.storageCache.set(gpuData.id, {gpuData, originalSize: size}); + this.storageCache.set(gpuData.id, {gpuData, originalSize: Number(size)}); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); return gpuData; @@ -352,7 +352,8 @@ class GpuDataManagerImpl implements GpuDataManager { return this.storageCache.get(id)?.gpuData; } - release(id: GpuDataId): number { + release(idInput: GpuDataId): number { + const id = typeof idInput === 'bigint' ? Number(idInput) : idInput; const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('releasing data does not exist'); diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 2a584fc0a2218..6e906cc8497ec 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -31,6 +31,7 @@ export interface GpuData { export interface TensorInfo { dims: readonly number[]; dataType: number; + outputIndex?: number; } export interface ProgramUniform { diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70728c82e7753..8b29e24cb2143 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -209,10 +209,15 @@ export interface OrtTrainingAPIs { */ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, Partial { + HEAP64: BigInt64Array; + HEAPU64: BigUint64Array; + PTR_SIZE: number; // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; stackAlloc(size: number): number; + getValue(ptr: number, type: string): number; + setValue(ptr: number, value: number, type: string): void; UTF8ToString(offset: number, maxBytesToRead?: number): string; lengthBytesUTF8(str: string): number; diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 4ecd61962d797..b989be2c0fb7c 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1041,7 +1041,7 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo try { // Copy the file data (fileData,offset,length) into WebAssembly memory // (HEAPU8,buffer,length). - HEAPU8.set(fileData.subarray(offset, offset + length), buffer); + HEAPU8.set(fileData.subarray(Number(offset), Number(offset) + length), buffer); return 0; } catch { return 4; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index ee4d9f9154971..b90ac73ef1e34 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -556,9 +556,8 @@ static Status SaveModel(Model& model, const T& file_path) { const buffer_size = $1; const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); - bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); - if (typeof process == 'object' && typeof process.versions == 'object' && - typeof process.versions.node == 'string') { + bytes.set(HEAPU8.subarray(Number(buffer), Number(buffer) + buffer_size)); + if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') { // Node.js require('fs').writeFileSync(file_path, bytes); } else { diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 7324b0d69474c..e77ebb9d06559 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -110,16 +110,28 @@ class JsKernel : public OpKernel { temp_data_size += sizeof(size_t) * 3; } } +#ifdef WASM_MEMORY64 + uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); +#else uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); +#endif if (p_serialized_kernel_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context."); } +#ifdef WASM_MEMORY64 + p_serialized_kernel_context[0] = reinterpret_cast(context); + p_serialized_kernel_context[1] = static_cast(context->InputCount()); + p_serialized_kernel_context[2] = static_cast(context->OutputCount()); + p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); + p_serialized_kernel_context[4] = static_cast(custom_data_size); +#else p_serialized_kernel_context[0] = reinterpret_cast(context); p_serialized_kernel_context[1] = static_cast(context->InputCount()); p_serialized_kernel_context[2] = static_cast(context->OutputCount()); p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); p_serialized_kernel_context[4] = static_cast(custom_data_size); +#endif size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { const auto* input_ptr = context->Input(i); @@ -130,12 +142,21 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[index++] = 0; continue; } +#ifdef WASM_MEMORY64 + p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); + for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); + } +#else p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } +#endif } #ifndef NDEBUG @@ -199,9 +220,15 @@ class JsKernel : public OpKernel { return status; } +#ifdef WASM_MEMORY64 + intptr_t status_code = EM_ASM_INT( + { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, + this, reinterpret_cast(p_serialized_kernel_context)); +#else int status_code = EM_ASM_INT( { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); +#endif LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 32e8e1facafcd..a471044597b02 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -52,14 +52,14 @@ class ConvBase : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, - "dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "group" : $4, - "kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], - "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], - "strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], + "pads" : $7 ? Array.from(HEAP32.subarray($7, Number($8))) : [], + "strides" : $9 ? Array.from(HEAP32.subarray($9, Number($10))) : [], + "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : [] }), static_cast(conv_attrs_.auto_pad), JSEP_HEAP32_INDEX_START(dilations), diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index c51bf5ce9d4a6..1d3a5d75b68c4 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -48,8 +48,8 @@ class ConvTranspose : public JsKernel { "pads" : [ $5, $6 ], "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [], + "outputPadding" : $10 ? Array.from(HEAP32.subarray(Number($10), Number($11))) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray(Number($12), Number($13))) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -99,14 +99,14 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray(Number($2), Number(($2 >>> 0) + /* dialations_vec_size */ 2))), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)), + "kernelShape" : Array.from(HEAP32.subarray(Number($4), Number(($4 >>> 0) + /* kernel_shape_vec_size */ 2))), + "pads" : Array.from(HEAP32.subarray(Number($5), Number(($5 >>> 0) + /* pads_vec_size */ 4))), + "strides" : Array.from(HEAP32.subarray(Number($6), Number(($6 >>> 0) + /* strides_vec_size */ 2))), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), - "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [], + "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc index 485cd3da9b91b..e9c6f5c79294f 100644 --- a/onnxruntime/core/providers/js/operators/gather.cc +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -15,11 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -30,11 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -44,11 +36,7 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); diff --git a/onnxruntime/core/providers/js/operators/gemm.h b/onnxruntime/core/providers/js/operators/gemm.h index 74091526f8411..d7f8fa6289c1a 100644 --- a/onnxruntime/core/providers/js/operators/gemm.h +++ b/onnxruntime/core/providers/js/operators/gemm.h @@ -23,8 +23,8 @@ class Gemm : public JsKernel { "transA" : $3, "transB" : $4 }), - static_cast(alpha), - static_cast(beta), + static_cast(alpha), + static_cast(beta), static_cast(transA), static_cast(transB)); } diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index c18c7dd456dc2..f656462285bc4 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,7 +22,7 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(mode_), static_cast(value_), JSEP_HEAP32_INDEX_START(pads), diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 937f1f990dc67..5f6f3e2fd1a1f 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -24,7 +24,7 @@ namespace js { JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ "keepDims" : !!$1, \ "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \ + "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ }), \ static_cast(keepdims_), \ static_cast(noop_with_empty_axes_), \ diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 134eb4bf5a7f4..3e8ccf40753c8 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index daeffaa664741..f30e7bf01ec7b 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,9 +20,9 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}), + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [], + "ends" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : [], + "axes" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : []}), JSEP_HEAP32_INDEX_START(starts), JSEP_HEAP32_INDEX_END(starts), JSEP_HEAP32_INDEX_START(ends), diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 4fdbab00e739c..3f6cfcb8921f3 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,7 +49,7 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(axis_), static_cast(num_outputs_), JSEP_HEAP32_INDEX_START(split_sizes), diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index 7a945471c7701..f6b2b4faba850 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,7 +21,7 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [] }), JSEP_HEAP32_INDEX_START(perm), JSEP_HEAP32_INDEX_END(perm)); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..aeee3827dd9cd 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -27,7 +27,11 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); +#ifdef WASM_MEMORY64 +static_assert(sizeof(size_t) == 8, "size of size_t should be 8 in this build (wasm64)."); +#else static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); +#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 0730559c4375b..0d2a8034af448 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -183,7 +183,11 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ +#ifdef WASM_MEMORY64 +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); +#else int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length); +#endif /** * release the specified tensor. diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 98d9ba22b7190..1cdb7b4494e4b 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -398,7 +398,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with a specific GDK edition. Defaults to the latest installed.", ) parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") - + parser.add_argument("--enable_wasm_memory64", action="store_true", help="Enable WebAssembly 64bit support") platform_group = parser.add_mutually_exclusive_group() platform_group.add_argument("--ios", action="store_true", help="build for ios") platform_group.add_argument("--visionos", action="store_true", help="build for visionOS") @@ -1078,6 +1078,7 @@ def generate_build_tree( + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_MEMORY64=" + ("ON" if args.enable_wasm_memory64 else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"), From 95d1707bd75ad7aa5da5333a17f7b2259b3f0de9 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 10 Jul 2024 10:22:38 -0700 Subject: [PATCH 02/45] Lint --- js/web/lib/wasm/jsep/init.ts | 1 + .../ops/3rd-party/conv3d_naive_webgpu.ts | 12 +++-- .../core/providers/js/operators/reduce.h | 46 +++++++++---------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index a410c77890354..5f431c42bfc45 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -4,6 +4,7 @@ import {Env} from 'onnxruntime-common'; import {DataType, getTensorElementSize} from '../wasm-common'; + import type {OrtWasmModule} from '../wasm-types'; import {WebGpuBackend} from './backend-webgpu'; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index f428293add599..26e6ec0f46bc6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -298,13 +298,17 @@ export const createConv3DNaiveProgramInfo = let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : + getElementAt('uniforms.x_shape', 2, x.rank)}; let xShapeZ = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : + getElementAt('uniforms.x_shape', 3, x.rank)}; let xShapeW = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : + getElementAt('uniforms.x_shape', 4, x.rank)}; let xShapeU = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : + getElementAt('uniforms.x_shape', 1, x.rank)}; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 5f6f3e2fd1a1f..4ae558f9dfc00 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - JSEP_HEAP32_INDEX_START(axes), \ - JSEP_HEAP32_INDEX_END(axes)); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + JSEP_HEAP32_INDEX_START(axes), \ + JSEP_HEAP32_INDEX_END(axes)); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); From b8d5db7a9e115ae3519ed9eb0658896d94401572 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 10 Jul 2024 13:02:49 -0700 Subject: [PATCH 03/45] Add missing files. --- onnxruntime/wasm/js_post_js.js | 7 +++++++ onnxruntime/wasm/js_post_js_64.js | 7 +++++++ 2 files changed, 14 insertions(+) create mode 100644 onnxruntime/wasm/js_post_js.js create mode 100644 onnxruntime/wasm/js_post_js_64.js diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js new file mode 100644 index 0000000000000..b77d82fbd7d10 --- /dev/null +++ b/onnxruntime/wasm/js_post_js.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js new file mode 100644 index 0000000000000..b140df927ebbd --- /dev/null +++ b/onnxruntime/wasm/js_post_js_64.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 8; From 4726758cba75704b4a6cb9162079f98b7235f14d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 10 Jul 2024 11:24:16 -0700 Subject: [PATCH 04/45] Revert changes to conv3d_native_webgpu.ts --- .../jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts index 26e6ec0f46bc6..f428293add599 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/3rd-party/conv3d_naive_webgpu.ts @@ -298,17 +298,13 @@ export const createConv3DNaiveProgramInfo = let xRCorner = xFRCCorner.y; let xCCorner = xFRCCorner.z; let xShapeY = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : - getElementAt('uniforms.x_shape', 2, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 1, x.rank) : getElementAt('uniforms.x_shape', 2, x.rank)}; let xShapeZ = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : - getElementAt('uniforms.x_shape', 3, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 2, x.rank) : getElementAt('uniforms.x_shape', 3, x.rank)}; let xShapeW = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : - getElementAt('uniforms.x_shape', 4, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 3, x.rank) : getElementAt('uniforms.x_shape', 4, x.rank)}; let xShapeU = ${ - isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : - getElementAt('uniforms.x_shape', 1, x.rank)}; + isChannelsLast ? getElementAt('uniforms.x_shape', 4, x.rank) : getElementAt('uniforms.x_shape', 1, x.rank)}; let inputDepthNearestVec4 = (xShapeU / 4) * 4; let inputDepthVec4Remainder = xShapeU % 4; From 7001f395e5dab2e665983787aae8dfbfa0489a6d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 10 Jul 2024 19:18:53 -0700 Subject: [PATCH 05/45] Set ASYNCIFY=1 with MEMORY64. --- cmake/adjust_global_compile_flags.cmake | 4 ++-- cmake/onnxruntime_webassembly.cmake | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 2189f728ad5ad..ac98403c70071 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -53,8 +53,8 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions") - string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions") + # string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions") + # string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions") else() string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index dbd62a7cca0f5..ea3452b78273b 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -294,7 +294,7 @@ else() target_compile_definitions(onnxruntime_webassembly PRIVATE USE_JSEP=1) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" - "SHELL:-s ASYNCIFY=${ASYNCIFY}" + "SHELL:-s ASYNCIFY=1" #"SHELL:-s JSPI" #"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" @@ -336,7 +336,7 @@ else() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions") + # target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions") else() target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") endif() From 29fc6c0006c3f68b6903644576689e74a2fe464d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 11:05:15 -0700 Subject: [PATCH 06/45] Added Chrome Canary as a browser option. --- js/web/script/test-runner-cli-args.ts | 4 ++-- js/web/script/test-runner-cli.ts | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/js/web/script/test-runner-cli-args.ts b/js/web/script/test-runner-cli-args.ts index adcd940178e07..e9b058038e45d 100644 --- a/js/web/script/test-runner-cli-args.ts +++ b/js/web/script/test-runner-cli-args.ts @@ -131,7 +131,7 @@ Examples: export declare namespace TestRunnerCliArgs { type Mode = 'suite0'|'suite1'|'model'|'unittest'|'op'; type Backend = 'cpu'|'webgl'|'webgpu'|'wasm'|'onnxruntime'|'webnn'; - type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'; + type Environment = 'chrome'|'edge'|'firefox'|'electron'|'safari'|'node'|'bs'|'canary'; type BundleMode = 'dev'|'perf'; type IOBindingMode = 'none'|'gpu-tensor'|'gpu-location'; } @@ -384,7 +384,7 @@ export function parseTestRunnerCliArgs(cmdlineArgs: string[]): TestRunnerCliArgs // Option: -e=<...>, --env=<...> const envArg = args.env || args.e; const env = (typeof envArg !== 'string') ? 'chrome' : envArg; - if (['chrome', 'edge', 'firefox', 'electron', 'safari', 'node', 'bs'].indexOf(env) === -1) { + if (['chrome', 'edge', 'firefox', 'electron', 'safari', 'node', 'bs', 'canary'].indexOf(env) === -1) { throw new Error(`not supported env ${env}`); } diff --git a/js/web/script/test-runner-cli.ts b/js/web/script/test-runner-cli.ts index fbde81524ccec..bdd0bfac00724 100644 --- a/js/web/script/test-runner-cli.ts +++ b/js/web/script/test-runner-cli.ts @@ -673,6 +673,8 @@ async function main() { switch (env) { case 'chrome': return 'ChromeTest'; + case 'canary': + return 'ChromeCanaryTest'; case 'edge': return 'EdgeTest'; case 'firefox': From a6357e5da8baf55bc611a2e7c153dfc7608c43b4 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 11:06:36 -0700 Subject: [PATCH 07/45] Use wasm.PTR_SIZE instead of hardcoding to 4. --- js/web/lib/wasm/jsep/init.ts | 4 +- js/web/lib/wasm/wasm-core-impl.ts | 61 ++++++++++++++++--------------- js/web/lib/wasm/wasm-utils.ts | 9 +++-- 3 files changed, 39 insertions(+), 35 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 5f431c42bfc45..b0a0b6f826cb6 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -203,7 +203,7 @@ export const init = 'verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${ Number(size)}`); - const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number((src >>> 0) + size)); backend.upload(dst, data); } }, @@ -216,7 +216,7 @@ export const init = () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); await backend.download( - gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); + gpuDataId, () => module.HEAPU8.subarray(Number(dataOffset >>> 0), Number((dataOffset >>> 0) + size))); }, // jsepCreateKernel diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 9fc8786192c5c..60faeef2faa6e 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -183,12 +183,13 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] => const wasm = getInstance(); const stack = wasm.stackSave(); try { - const dataOffset = wasm.stackAlloc(8); - const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); + const ptrSize = wasm.PTR_SIZE; + const dataOffset = wasm.stackAlloc(2 * ptrSize); + const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize); if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + return [wasm.getValue(dataOffset, '*'), wasm.getValue(dataOffset + ptrSize, '*')]; } finally { wasm.stackRestore(stack); } @@ -413,6 +414,7 @@ export const prepareInputOutputTensor = } const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const dataType = tensor[0]; const dims = tensor[1]; @@ -448,7 +450,7 @@ export const prepareInputOutputTensor = dataByteLength = 4 * data.length; rawData = wasm._malloc(dataByteLength); allocs.push(rawData); - let dataIndex = rawData / 4; + let dataIndex = rawData / ptrSize; for (let i = 0; i < data.length; i++) { if (typeof data[i] !== 'string') { throw new TypeError(`tensor data at index ${i} is not a string`); @@ -464,10 +466,9 @@ export const prepareInputOutputTensor = } const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); + const dimsOffset = wasm.stackAlloc(ptrSize * dims.length); try { - let dimIndex = dimsOffset / 4; - dims.forEach(d => wasm.HEAP32[dimIndex++] = d); + dims.forEach((d, index) => wasm.setValue(dimsOffset + (index * ptrSize), d, '*')); const tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, dataLocationStringToEnum(location)); @@ -487,6 +488,7 @@ export const run = async( sessionId: number, inputIndices: number[], inputTensors: TensorMetadata[], outputIndices: number[], outputTensors: Array, options: InferenceSession.RunOptions): Promise => { const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); @@ -509,10 +511,10 @@ export const run = async( const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); + const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize); + const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize); + const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize); + const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); @@ -530,17 +532,17 @@ export const run = async( enableGraphCapture); } - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; + let inputValuesIndex = inputValuesOffset / ptrSize; + let inputNamesIndex = inputNamesOffset / ptrSize; + let outputValuesIndex = outputValuesOffset / ptrSize; + let outputNamesIndex = outputNamesOffset / ptrSize; for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + wasm.HEAPU64[inputValuesIndex++] = BigInt(inputTensorHandles[i]); + wasm.HEAPU64[inputNamesIndex++] = BigInt(inputNamesUTF8Encoded[inputIndices[i]]); } for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + wasm.HEAPU64[outputValuesIndex++] = BigInt(outputTensorHandles[i]); + wasm.HEAPU64[outputNamesIndex++] = BigInt(outputNamesUTF8Encoded[outputIndices[i]]); } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { @@ -603,7 +605,7 @@ export const run = async( const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const tensor = wasm.getValue(outputValuesOffset + i * ptrSize, '*'); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -612,24 +614,25 @@ export const run = async( const beforeGetTensorDataStack = wasm.stackSave(); // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const tensorDataOffset = wasm.stackAlloc(4 * ptrSize); let keepOutputTensor = false; let type: Tensor.Type|undefined, dataOffset = 0; try { const errorCode = wasm._OrtGetTensorData( - tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); + tensor, tensorDataOffset, tensorDataOffset + ptrSize, tensorDataOffset + 2 * ptrSize, + tensorDataOffset + 3 * ptrSize); if (errorCode !== 0) { checkLastError(`Can't access output tensor data on index ${i}.`); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + + const dataType = wasm.getValue(tensorDataOffset, '*'); + dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); + const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*'); + const dimsLength = wasm.getValue(tensorDataOffset + ptrSize * 3, '*'); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + dims.push(wasm.getValue(dimsOffset + i * ptrSize, '*')); } wasm._OrtFree(dimsOffset); @@ -643,7 +646,7 @@ export const run = async( throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; - let dataIndex = dataOffset / 4; + let dataIndex = dataOffset / ptrSize; for (let i = 0; i < size; i++) { const offset = wasm.HEAPU32[dataIndex++]; const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts index 37762b353f575..703ccb08addf5 100644 --- a/js/web/lib/wasm/wasm-utils.ts +++ b/js/web/lib/wasm/wasm-utils.ts @@ -52,10 +52,11 @@ export const checkLastError = (message: string): void => { const stack = wasm.stackSave(); try { - const paramsOffset = wasm.stackAlloc(8); - wasm._OrtGetLastError(paramsOffset, paramsOffset + 4); - const errorCode = wasm.HEAP32[paramsOffset / 4]; - const errorMessagePointer = wasm.HEAPU32[paramsOffset / 4 + 1]; + const ptrSize = wasm.PTR_SIZE; + const paramsOffset = wasm.stackAlloc(2 * ptrSize); + wasm._OrtGetLastError(paramsOffset, paramsOffset + ptrSize); + const errorCode = wasm.getValue(paramsOffset, 'i32'); + const errorMessagePointer = wasm.getValue(paramsOffset + ptrSize, '*'); const errorMessage = errorMessagePointer ? wasm.UTF8ToString(errorMessagePointer) : ''; throw new Error(`${message} ERROR_CODE: ${errorCode}, ERROR_MESSAGE: ${errorMessage}`); } finally { From 37a6d15711554196afa1f97f7d8b05107d945b36 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 11:09:25 -0700 Subject: [PATCH 08/45] Modified SIGNATURE_CONVERSION for OrtCreateTensor. --- cmake/onnxruntime_webassembly.cmake | 3 ++- js/web/lib/wasm/jsep/init.ts | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index ea3452b78273b..f1f5e3ad4aa0e 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -281,7 +281,7 @@ else() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" - "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p" + "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p" ) endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) @@ -297,6 +297,7 @@ else() "SHELL:-s ASYNCIFY=1" #"SHELL:-s JSPI" #"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1" + #"SHELL:-s WASM_BIGINT" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" ) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b0a0b6f826cb6..f4149396abdcb 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -203,7 +203,7 @@ export const init = 'verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${ Number(size)}`); - const data = module.HEAPU8.subarray(Number(src >>> 0), Number((src >>> 0) + size)); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); backend.upload(dst, data); } }, From acc7e5a59ec042842752a5047efcf68af8f5acd9 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 14:32:06 -0700 Subject: [PATCH 09/45] Fix data type. --- onnxruntime/wasm/api.cc | 2 +- onnxruntime/wasm/api.h | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index aeee3827dd9cd..b86805fcd0239 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -284,7 +284,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* } } -int OrtGetTensorData(OrtValue* tensor, int* data_type, void** data, size_t** dims, size_t* dims_length) { +int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length) { ONNXType tensor_type; RETURN_ERROR_CODE_IF_ERROR(GetValueType, tensor, &tensor_type); if (tensor_type != ONNX_TYPE_TENSOR) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 0d2a8034af448..f8b5dc49fb875 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -183,11 +183,7 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ -#ifdef WASM_MEMORY64 int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); -#else -int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length); -#endif /** * release the specified tensor. From 3b46edd9f15740eeb0017dd5d21994b06c0976be Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 17:19:34 -0700 Subject: [PATCH 10/45] Use etValue/etValue instead of directly accessing heap. --- js/web/lib/wasm/jsep/init.ts | 24 ++++++------ js/web/lib/wasm/wasm-core-impl.ts | 23 +++++------ js/web/lib/wasm/wasm-training-core-impl.ts | 45 +++++++++++----------- 3 files changed, 44 insertions(+), 48 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index f4149396abdcb..ee10ad52340fc 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -69,24 +69,24 @@ class ComputeContextImpl implements ComputeContext { private customDataSize = 0; constructor(private module: OrtWasmModule, private backend: WebGpuBackend, contextDataOffset: number) { this.adapterInfo = backend.adapterInfo; - const heap = module.PTR_SIZE === 4 ? module.HEAPU32 : module.HEAPU64; // extract context data + const ptrSize = module.PTR_SIZE; let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2); - this.opKernelContext = Number(heap[dataIndex++]); - const inputCount = Number(heap[dataIndex++]); - this.outputCount = Number(heap[dataIndex++]); - this.customDataOffset = Number(heap[dataIndex++]); - this.customDataSize = Number(heap[dataIndex++]); + this.opKernelContext = module.getValue(dataIndex++ * ptrSize, 'i32'); + const inputCount = module.getValue(dataIndex++ * ptrSize, 'i32'); + this.outputCount = module.getValue(dataIndex++ * ptrSize, 'i32'); + this.customDataOffset = module.getValue(dataIndex++ * ptrSize, 'i32'); + this.customDataSize = module.getValue(dataIndex++ * ptrSize, 'i32'); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = Number(heap[dataIndex++]); - const data = Number(heap[dataIndex++]); - const dim = Number(heap[dataIndex++]); + const dataType = module.getValue(dataIndex++ * ptrSize, 'i32'); + const data = module.getValue(dataIndex++ * ptrSize, '*'); + const dim = module.getValue(dataIndex++ * ptrSize, 'i32'); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(Number(heap[dataIndex++])); + dims.push(module.getValue(dataIndex++ * ptrSize, 'i32')); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } @@ -130,9 +130,9 @@ class ComputeContextImpl implements ComputeContext { try { const ptrSize = this.module.PTR_SIZE; const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); - this.module.setValue(data, dims.length, '*'); + this.module.setValue(data, dims.length, 'i32'); for (let i = 0; i < dims.length; i++) { - this.module.setValue(data + ptrSize * (i + 1), dims[i], '*'); + this.module.setValue(data + ptrSize * (i + 1), dims[i], 'i32'); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 60faeef2faa6e..33ecdc12f13cf 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -450,12 +450,11 @@ export const prepareInputOutputTensor = dataByteLength = 4 * data.length; rawData = wasm._malloc(dataByteLength); allocs.push(rawData); - let dataIndex = rawData / ptrSize; for (let i = 0; i < data.length; i++) { if (typeof data[i] !== 'string') { throw new TypeError(`tensor data at index ${i} is not a string`); } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } } else { dataByteLength = data.byteLength; @@ -468,7 +467,7 @@ export const prepareInputOutputTensor = const stack = wasm.stackSave(); const dimsOffset = wasm.stackAlloc(ptrSize * dims.length); try { - dims.forEach((d, index) => wasm.setValue(dimsOffset + (index * ptrSize), d, '*')); + dims.forEach((d, index) => wasm.setValue(dimsOffset + (index * ptrSize), d, 'i32')); const tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, dataLocationStringToEnum(location)); @@ -532,17 +531,13 @@ export const run = async( enableGraphCapture); } - let inputValuesIndex = inputValuesOffset / ptrSize; - let inputNamesIndex = inputNamesOffset / ptrSize; - let outputValuesIndex = outputValuesOffset / ptrSize; - let outputNamesIndex = outputNamesOffset / ptrSize; for (let i = 0; i < inputCount; i++) { - wasm.HEAPU64[inputValuesIndex++] = BigInt(inputTensorHandles[i]); - wasm.HEAPU64[inputNamesIndex++] = BigInt(inputNamesUTF8Encoded[inputIndices[i]]); + wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], 'i64'); + wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], 'i64'); } for (let i = 0; i < outputCount; i++) { - wasm.HEAPU64[outputValuesIndex++] = BigInt(outputTensorHandles[i]); - wasm.HEAPU64[outputNamesIndex++] = BigInt(outputNamesUTF8Encoded[outputIndices[i]]); + wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], 'i64'); + wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], 'i64'); } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { @@ -646,10 +641,10 @@ export const run = async( throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; - let dataIndex = dataOffset / ptrSize; for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + const offset = wasm.getValue(dataOffset + i * ptrSize, '*'); + const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*'); + const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c65178e2358d2..c6d70de983617 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -62,12 +62,13 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea const wasm = getInstance(); const stack = wasm.stackSave(); try { - const dataOffset = wasm.stackAlloc(8); + const ptrSize = wasm.PTR_SIZE; + const dataOffset = wasm.stackAlloc(2 * ptrSize); if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = - wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + ptrSize, isEvalModel); ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + return [wasm.getValue(dataOffset, 'i32'), wasm.getValue(dataOffset + ptrSize, 'i32')]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -170,10 +171,10 @@ const createAndAllocateTensors = // moves to heap const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; + const ptrSize = wasm.PTR_SIZE; + const valuesOffset = wasm.stackAlloc(count * ptrSize); for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + wasm.setValue(valuesOffset + i * ptrSize, tensorHandles[i], '*'); } return valuesOffset; @@ -191,10 +192,11 @@ const moveOutputToTensorMetadataArr = (outputValuesOffset: number, outputCount: number, outputTensorHandles: number[], outputTensors: Array) => { const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const tensor = wasm.getValue(outputValuesOffset + i * ptrSize, '*'); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -211,14 +213,13 @@ const moveOutputToTensorMetadataArr = tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const dataType = wasm.getValue(tensorDataOffset, '*'); + dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); + const dimsOffset = wasm.getValue(tensorDataOffset + 2 * ptrSize, '*'); + const dimsLength = wasm.getValue(tensorDataOffset + 3 * ptrSize, '*'); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + dims.push(wasm.getValue(dimsOffset + i * ptrSize, '*')); } wasm._OrtFree(dimsOffset); @@ -227,10 +228,10 @@ const moveOutputToTensorMetadataArr = if (type === 'string') { const stringData: string[] = []; - let dataIndex = dataOffset / 4; for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + const offset = wasm.getValue(dataOffset + i * ptrSize, '*'); + const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*'); + const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); @@ -396,14 +397,14 @@ export const runEvalStep = async( export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { const wasm = getInstance(); const stack = wasm.stackSave(); - + const ptrSize = wasm.PTR_SIZE; try { - const sizeOffset = wasm.stackAlloc(4); + const sizeOffset = wasm.stackAlloc(ptrSize); if (wasm._OrtTrainingGetParametersSize) { const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); ifErrCodeCheckLastError(errorCode, 'Can\'t get parameters size'); - return wasm.HEAP32[sizeOffset / 4]; + return wasm.getValue(sizeOffset, '*'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -432,7 +433,7 @@ export const getContiguousParameters = const dimsOffset = wasm.stackAlloc(4); const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; + wasm.setValue(dimsIndex, parametersSize, '*'); try { // wraps allocated array in a tensor @@ -488,8 +489,8 @@ export const loadParametersBuffer = wasm.HEAPU8.set(buffer, bufferOffset); // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsOffset = wasm.stackAlloc(wasm.PTR_SIZE); + wasm.setValue(dimsOffset, bufferCount, '*'); const dimsLength = 1; let tensor = 0; From edcaa6487dc47ec596dfa07a62d43c340adc2274 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 17:21:30 -0700 Subject: [PATCH 11/45] Use uintptr_t instead of uint32_t. --- onnxruntime/core/providers/js/js_export.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc index 2402bb33ce9d0..f99e90bcb13f6 100644 --- a/onnxruntime/core/providers/js/js_export.cc +++ b/onnxruntime/core/providers/js/js_export.cc @@ -6,8 +6,8 @@ #include "core/framework/op_kernel.h" const void* JsepOutput(void* context, int index, const void* data) { - const uint32_t* data_offset = reinterpret_cast(data); - uint32_t dim = *data_offset++; + const uintptr_t* data_offset = reinterpret_cast(data); + uintptr_t dim = *data_offset++; size_t dim_size = static_cast(dim); std::vector dims(dim_size); for (size_t i = 0; i < dim_size; i++) { From 05d0426024c6b4fe9780717a03f4b78d290b18f2 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 17:53:54 -0700 Subject: [PATCH 12/45] Removed WASM_MEMORY64 macro --- cmake/onnxruntime_webassembly.cmake | 4 ++-- onnxruntime/core/providers/js/js_kernel.h | 28 +---------------------- onnxruntime/wasm/api.cc | 5 ---- 3 files changed, 3 insertions(+), 34 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index f1f5e3ad4aa0e..e1b86f2367a51 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -223,8 +223,8 @@ else() target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s MEMORY64=1" ) - string(APPEND CMAKE_C_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") - string(APPEND CMAKE_CXX_FLAGS " -DWASM_MEMORY64 -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") set(SMEMORY_FLAG "-sMEMORY64") target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index e77ebb9d06559..aeaf40f957c70 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -110,28 +110,17 @@ class JsKernel : public OpKernel { temp_data_size += sizeof(size_t) * 3; } } -#ifdef WASM_MEMORY64 uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); -#else - uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); -#endif if (p_serialized_kernel_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context."); } -#ifdef WASM_MEMORY64 p_serialized_kernel_context[0] = reinterpret_cast(context); p_serialized_kernel_context[1] = static_cast(context->InputCount()); p_serialized_kernel_context[2] = static_cast(context->OutputCount()); p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); p_serialized_kernel_context[4] = static_cast(custom_data_size); -#else - p_serialized_kernel_context[0] = reinterpret_cast(context); - p_serialized_kernel_context[1] = static_cast(context->InputCount()); - p_serialized_kernel_context[2] = static_cast(context->OutputCount()); - p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); - p_serialized_kernel_context[4] = static_cast(custom_data_size); -#endif + size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { const auto* input_ptr = context->Input(i); @@ -142,21 +131,12 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[index++] = 0; continue; } -#ifdef WASM_MEMORY64 p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } -#else - p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); - p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); - for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); - } -#endif } #ifndef NDEBUG @@ -220,15 +200,9 @@ class JsKernel : public OpKernel { return status; } -#ifdef WASM_MEMORY64 intptr_t status_code = EM_ASM_INT( { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); -#else - int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, - this, reinterpret_cast(p_serialized_kernel_context)); -#endif LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index b86805fcd0239..aefbdab0215ea 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -27,11 +27,6 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); -#ifdef WASM_MEMORY64 -static_assert(sizeof(size_t) == 8, "size of size_t should be 8 in this build (wasm64)."); -#else -static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); -#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { From af8a685fa94149e93fe781d3a3c0094852fb1f90 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 12 Jul 2024 22:45:51 -0700 Subject: [PATCH 13/45] Clean up --- cmake/CMakeLists.txt | 2 +- cmake/onnxruntime_webassembly.cmake | 9 +-------- js/web/lib/index.ts | 8 ++++---- js/web/lib/wasm/jsep/init.ts | 18 +++++++++--------- js/web/lib/wasm/wasm-core-impl.ts | 8 ++++---- 5 files changed, 19 insertions(+), 26 deletions(-) diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 476745c39aee9..aa700418a0a94 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -242,7 +242,7 @@ option(onnxruntime_ENABLE_TRITON "Enable Triton" OFF) # composable kernel is managed automatically, unless user want to explicitly disable it, it should not be manually set option(onnxruntime_USE_COMPOSABLE_KERNEL "Enable composable kernel for ROCm EP" ON) -option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON) +cmake_dependent_option(onnxruntime_USE_COMPOSABLE_KERNEL_CK_TILE "Enable ck_tile for composable kernel" ON "onnxruntime_USE_COMPOSABLE_KERNEL" OFF) option(onnxruntime_USE_ROCBLAS_EXTENSION_API "Enable rocblas tuning for ROCm EP" OFF) option(onnxruntime_USE_TRITON_KERNEL "Enable triton compiled kernel" OFF) option(onnxruntime_BUILD_KERNEL_EXPLORER "Build Kernel Explorer for testing and profiling GPU kernels" OFF) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index e1b86f2367a51..35e0f219b7894 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -218,7 +218,6 @@ else() endif() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set(ASYNCIFY 2) set(MAXIMUM_MEMORY "17179869184") target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s MEMORY64=1" @@ -255,7 +254,6 @@ else() --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" ) else () - set(ASYNCIFY 1) set(MAXIMUM_MEMORY "4294967296") target_link_options(onnxruntime_webassembly PRIVATE --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" @@ -295,9 +293,6 @@ else() target_link_options(onnxruntime_webassembly PRIVATE "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" "SHELL:-s ASYNCIFY=1" - #"SHELL:-s JSPI" - #"SHELL:-s ASYNCIFY_IGNORE_INDIRECT=1" - #"SHELL:-s WASM_BIGINT" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" ) @@ -336,9 +331,7 @@ else() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - # target_link_options(onnxruntime_webassembly PRIVATE "-fwasm-exceptions") - else() + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") endif() diff --git a/js/web/lib/index.ts b/js/web/lib/index.ts index ee4cc0067727b..86c05b9a2fa15 100644 --- a/js/web/lib/index.ts +++ b/js/web/lib/index.ts @@ -23,11 +23,11 @@ if (!BUILD_DEFS.DISABLE_WASM) { const wasmBackend = BUILD_DEFS.DISABLE_TRAINING ? require('./backend-wasm-inference').wasmBackend : require('./backend-wasm-training').wasmBackend; if (!BUILD_DEFS.DISABLE_JSEP) { - registerBackend('webgpu', wasmBackend, 1); - registerBackend('webnn', wasmBackend, 1); + registerBackend('webgpu', wasmBackend, 5); + registerBackend('webnn', wasmBackend, 5); } - registerBackend('cpu', wasmBackend, 1); - registerBackend('wasm', wasmBackend, 1); + registerBackend('cpu', wasmBackend, 10); + registerBackend('wasm', wasmBackend, 10); } Object.defineProperty(env.versions, 'web', {value: version, enumerable: true}); diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index ee10ad52340fc..b202f3d772f94 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -73,20 +73,20 @@ class ComputeContextImpl implements ComputeContext { // extract context data const ptrSize = module.PTR_SIZE; let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2); - this.opKernelContext = module.getValue(dataIndex++ * ptrSize, 'i32'); - const inputCount = module.getValue(dataIndex++ * ptrSize, 'i32'); - this.outputCount = module.getValue(dataIndex++ * ptrSize, 'i32'); - this.customDataOffset = module.getValue(dataIndex++ * ptrSize, 'i32'); - this.customDataSize = module.getValue(dataIndex++ * ptrSize, 'i32'); + this.opKernelContext = module.getValue(ptrSize * dataIndex++, 'i32'); + const inputCount = module.getValue(ptrSize * dataIndex++, 'i32'); + this.outputCount = module.getValue(ptrSize * dataIndex++, 'i32'); + this.customDataOffset = module.getValue(ptrSize * dataIndex++, 'i32'); + this.customDataSize = module.getValue(ptrSize * dataIndex++, 'i32'); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = module.getValue(dataIndex++ * ptrSize, 'i32'); - const data = module.getValue(dataIndex++ * ptrSize, '*'); - const dim = module.getValue(dataIndex++ * ptrSize, 'i32'); + const dataType = module.getValue(ptrSize * dataIndex++, 'i32'); + const data = module.getValue(ptrSize * dataIndex++, '*'); + const dim = module.getValue(ptrSize * dataIndex++, 'i32'); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(module.getValue(dataIndex++ * ptrSize, 'i32')); + dims.push(module.getValue(ptrSize * dataIndex++, 'i32')); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 33ecdc12f13cf..6a777c596537f 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -532,12 +532,12 @@ export const run = async( } for (let i = 0; i < inputCount; i++) { - wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], 'i64'); - wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], 'i64'); + wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*'); + wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*'); } for (let i = 0; i < outputCount; i++) { - wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], 'i64'); - wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], 'i64'); + wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*'); + wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*'); } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { From 757229d436b0e5b7f5bea8fb2c86d30ac8838bf0 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Sun, 14 Jul 2024 16:35:56 -0700 Subject: [PATCH 14/45] Fix OrtRun integer arguments type. --- cmake/onnxruntime_webassembly.cmake | 4 ++-- onnxruntime/wasm/api.cc | 4 ++-- onnxruntime/wasm/api.h | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 35e0f219b7894..556f27b17f69b 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -271,7 +271,7 @@ else() "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" - "SHELL:-s WASM_BIGINT=1" + "SHELL:-s WASM_BIGINT" ${WASM_API_EXCEPTION_CATCHING} --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" @@ -279,7 +279,7 @@ else() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" - "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_pppppppp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p" + "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_ppp_p_pp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p" ) endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index aefbdab0215ea..566cd10aa287f 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -470,8 +470,8 @@ int OrtRunWithBinding(OrtSession* session, } int OrtRun(OrtSession* session, - const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, - const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, + const char** input_names, const ort_tensor_handle_t* inputs, int input_count, + const char** output_names, int output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index f8b5dc49fb875..ffba89c4fe4f9 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -280,9 +280,9 @@ int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, const char** input_names, const ort_tensor_handle_t* inputs, - size_t input_count, + int input_count, const char** output_names, - size_t output_count, + int output_count, ort_tensor_handle_t* outputs, ort_run_options_handle_t run_options); From 343f81234af776520c7f3a12692ba601b6258a99 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 16 Jul 2024 12:08:29 -0700 Subject: [PATCH 15/45] Added Number type conversions. --- js/web/lib/wasm/jsep/init.ts | 11 ++++++----- js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts | 2 +- onnxruntime/core/providers/js/operators/conv.h | 4 ++-- .../core/providers/js/operators/conv_transpose.h | 8 ++++---- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index b202f3d772f94..d237bb850c054 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -197,14 +197,14 @@ export const init = LOG_DEBUG( 'verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`); - backend.memcpy(src, dst); + backend.memcpy(Number(src), Number(dst)); } else { LOG_DEBUG( 'verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${ Number(size)}`); const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); - backend.upload(dst, data); + backend.upload(Number(dst), data); } }, @@ -216,12 +216,13 @@ export const init = () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`); await backend.download( - gpuDataId, () => module.HEAPU8.subarray(Number(dataOffset >>> 0), Number((dataOffset >>> 0) + size))); + Number(gpuDataId), + () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0)); }, // jsepCreateKernel (kernelType: string, kernelId: number, attribute: unknown) => backend.createKernel( - kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), + kernelType, Number(kernelId), attribute, module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId)))), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -233,7 +234,7 @@ export const init = () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${ contextDataOffset}`); const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); - return backend.computeKernel(kernel, context, errors); + return backend.computeKernel(Number(kernel), context, errors); }, // jsepCaptureBegin () => backend.captureBegin(), diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index aa731757651a9..bbe65786aa41b 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -369,7 +369,7 @@ class GpuDataManagerImpl implements GpuDataManager { } async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { - const cachedData = this.storageCache.get(id); + const cachedData = this.storageCache.get(Number(id)); if (!cachedData) { throw new Error('data does not exist'); } diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index a471044597b02..f73b7ebc850e7 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -55,8 +55,8 @@ class ConvBase : public JsKernel { "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "group" : $4, "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], - "pads" : $7 ? Array.from(HEAP32.subarray($7, Number($8))) : [], - "strides" : $9 ? Array.from(HEAP32.subarray($9, Number($10))) : [], + "pads" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], + "strides" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]), "activation" : UTF8ToString($13), "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : [] diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index 1d3a5d75b68c4..5ff52e8fda4fa 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -99,11 +99,11 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray(Number($2), Number(($2 >>> 0) + /* dialations_vec_size */ 2))), + "dilations" : Array.from(HEAP32.subarray(Number($2), (Number($2) >>> 0) + /* dialations_vec_size */ 2)), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray(Number($4), Number(($4 >>> 0) + /* kernel_shape_vec_size */ 2))), - "pads" : Array.from(HEAP32.subarray(Number($5), Number(($5 >>> 0) + /* pads_vec_size */ 4))), - "strides" : Array.from(HEAP32.subarray(Number($6), Number(($6 >>> 0) + /* strides_vec_size */ 2))), + "kernelShape" : Array.from(HEAP32.subarray(Number($4), (Number($4) >>> 0) + /* kernel_shape_vec_size */ 2)), + "pads" : Array.from(HEAP32.subarray(Number($5), (Number($5) >>> 0) + /* pads_vec_size */ 4)), + "strides" : Array.from(HEAP32.subarray(Number($6), (Number($6) >>> 0) + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [], From 7862a8ff53c3239c94b71193bb5c28ef93021c4a Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 16 Jul 2024 16:25:41 -0700 Subject: [PATCH 16/45] Number type conversion. --- onnxruntime/core/providers/js/data_transfer.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index ebea041b80128..3809df2c82e4c 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -6,7 +6,7 @@ #include "core/providers/js/data_transfer.h" EM_ASYNC_JS(void, jsepDownload, (const void* src_data, void* dst_data, size_t bytes), { - await Module.jsepCopyAsync(src_data, dst_data, bytes); + await Module.jsepCopyAsync(Number(src_data), Number(dst_data), Number(bytes)); }); namespace onnxruntime { @@ -30,10 +30,10 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::GPU) { // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2), true); }, src_data, dst_data, bytes); } else { // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU From 0b4b04087894af8b4854a46b34d7f55be5681c8c Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 16 Jul 2024 23:10:29 -0700 Subject: [PATCH 17/45] Added ASYNCIFY_IMPORT, and signature convertions. --- cmake/onnxruntime_webassembly.cmake | 5 ++++- onnxruntime/core/providers/js/data_transfer.cc | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 556f27b17f69b..d924b8aa7ac20 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -219,6 +219,7 @@ else() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) set(MAXIMUM_MEMORY "17179869184") + set(ASYNCIFY 2) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s MEMORY64=1" ) @@ -255,6 +256,7 @@ else() ) else () set(MAXIMUM_MEMORY "4294967296") + set(ASYNCIFY 1) target_link_options(onnxruntime_webassembly PRIVATE --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" ) @@ -279,7 +281,7 @@ else() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" - "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_ppp_p_pp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p" + "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_ppp_p_pp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p,jsepCopy:_pp_,jsepCopyAsync:_pp_,jsepDownload:_pp_" ) endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) @@ -295,6 +297,7 @@ else() "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync,jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index 3809df2c82e4c..e18bad836a223 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -37,7 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopyAsync(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); } } From 786b58cdeb7eb841a56493ebc56939c2c3d411d4 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 16 Jul 2024 23:15:35 -0700 Subject: [PATCH 18/45] Removed unused settings. --- cmake/onnxruntime_webassembly.cmake | 2 -- 1 file changed, 2 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index d924b8aa7ac20..90817ca48ec0a 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -219,7 +219,6 @@ else() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) set(MAXIMUM_MEMORY "17179869184") - set(ASYNCIFY 2) target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s MEMORY64=1" ) @@ -256,7 +255,6 @@ else() ) else () set(MAXIMUM_MEMORY "4294967296") - set(ASYNCIFY 1) target_link_options(onnxruntime_webassembly PRIVATE --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" ) From 4523acc3c8045aa8656600d0932134e421bffa9d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 17 Jul 2024 15:43:17 -0700 Subject: [PATCH 19/45] Added missing function in SIGNATURE_CONVERSIONS. --- cmake/onnxruntime_webassembly.cmake | 34 +++++++++++++++++-- .../core/providers/js/data_transfer.cc | 2 +- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 90817ca48ec0a..c90703cb0d2e8 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -271,15 +271,45 @@ else() "SHELL:-s VERBOSE=0" "SHELL:-s FILESYSTEM=0" "SHELL:-s INCOMING_MODULE_JS_API=[preRun,locateFile,arguments,onExit,wasmMemory,buffer,instantiateWasm,mainScriptUrlOrBlob]" - "SHELL:-s WASM_BIGINT" + "SHELL:-s WASM_BIGINT=1" ${WASM_API_EXCEPTION_CATCHING} --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(SIGNATURE_CONVERSIONS "\ +OrtRun:_ppp_p_pp,\ +OrtGetTensorData:_ppppp,\ +OrtCreateTensor:p_pppp_,\ +OrtCreateSession:pppp,\ +OrtReleaseSession:_p,\ +OrtGetInputOutputCount:pppp,\ +OrtCreateSessionOptions:pp__p_ppppp,\ +OrtAddSessionConfigEntry:pppp,\ +OrtReleaseSessionOptions:_p,\ +OrtAppendExecutionProvider:ppp,\ +OrtAddSessionConfigEntry:pppp,\ +OrtGetInputName:ppp,\ +OrtGetOutputName:ppp,\ +OrtCreateRunOptions:ppp_p,\ +OrtReleaseRunOptions:pp,\ +OrtReleaseTensor:_p,\ +OrtFree:_p,\ +OrtCreateBinding:_p,\ +OrtBindInput:_ppp,\ +OrtBindOutput:_ppp,\ +OrtClearBoundOutputs:_p,\ +OrtReleaseBinding:_p,\ +OrtGetLastError:_pp,\ +JsepOutput:pp_p,\ +JsepGetNodeN:pp,\ +JsepOutput:pp_p,\ +jsepCopy:_pp_,\ +jsepCopyAsync:_pp_,\ +jsepDownload:_pp_") target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" - "SHELL:-s SIGNATURE_CONVERSIONS=OrtRun:_ppp_p_pp,OrtGetTensorData:_ppppp,OrtCreateTensor:p_pppp_,OrtCreateSession:pppp,OrtReleaseSession:_p,OrtGetInputOutputCount:pppp,OrtCreateSessionOptions:pp__p_ppppp,OrtAddSessionConfigEntry:pppp,OrtReleaseSessionOptions:_p,OrtAppendExecutionProvider:ppp,OrtAddSessionConfigEntry:pppp,OrtGetInputName:ppp,OrtGetOutputName:ppp,OrtCreateRunOptions:ppp_p,OrtReleaseRunOptions:pp,OrtReleaseTensor:_p,OrtFree:_p,OrtGetLastError:_pp,JsepOutput:pp_p,JsepGetNodeName:pp,JsepOutput:pp_p,jsepCopy:_pp_,jsepCopyAsync:_pp_,jsepDownload:_pp_" + "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" ) endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index e18bad836a223..3809df2c82e4c 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -37,7 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - EM_ASM({ Module.jsepCopyAsync(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); + jsepDownload(src_data, dst_data, bytes); } } From 4d563c35f9ee21a51a5b8b625afa9cb251c83e3d Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 17 Jul 2024 18:26:42 -0700 Subject: [PATCH 20/45] clean-up --- cmake/adjust_global_compile_flags.cmake | 5 +---- cmake/onnxruntime_webassembly.cmake | 4 ++-- js/web/lib/wasm/jsep/backend-webgpu.ts | 3 +-- js/web/lib/wasm/jsep/webgpu/types.ts | 1 - js/web/lib/wasm/wasm-types.ts | 2 -- 5 files changed, 4 insertions(+), 11 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index ac98403c70071..e12ac0ae4605d 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -52,10 +52,7 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - # string(APPEND CMAKE_C_FLAGS " -fwasm-exceptions") - # string(APPEND CMAKE_CXX_FLAGS " -fwasm-exceptions") - else() + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c90703cb0d2e8..c42f8e0ce6a56 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -302,7 +302,7 @@ OrtClearBoundOutputs:_p,\ OrtReleaseBinding:_p,\ OrtGetLastError:_pp,\ JsepOutput:pp_p,\ -JsepGetNodeN:pp,\ +JsepGetNodeName:pp,\ JsepOutput:pp_p,\ jsepCopy:_pp_,\ jsepCopyAsync:_pp_,\ @@ -325,7 +325,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" - "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync,jsepDownload']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','Module.jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index faa08ccca38ee..c701cf3a6df85 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -219,7 +219,6 @@ export class WebGpuBackend { maxComputeWorkgroupSizeX: adapter.limits.maxComputeWorkgroupSizeX, maxComputeWorkgroupSizeY: adapter.limits.maxComputeWorkgroupSizeY, maxComputeWorkgroupSizeZ: adapter.limits.maxComputeWorkgroupSizeZ, - maxBindingsPerBindGroup: adapter.limits.maxBindingsPerBindGroup, }, requiredFeatures, }; @@ -450,7 +449,7 @@ export class WebGpuBackend { const isPersistent = validatedOutputIndices[i] === -2; const tensorView = (isTemporary || isPersistent) ? createIntermediateOutput(outputs[i].dataType, outputs[i].dims) : - createKernelOutput(outputs[i].outputIndex || validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); + createKernelOutput(validatedOutputIndices[i], outputs[i].dataType, outputs[i].dims); outputTensorViews.push(tensorView); // if tensor view data is 0, it means the output is zero-sized tensor, and there is no GPU data for it. if (tensorView.data === 0) { diff --git a/js/web/lib/wasm/jsep/webgpu/types.ts b/js/web/lib/wasm/jsep/webgpu/types.ts index 6e906cc8497ec..2a584fc0a2218 100644 --- a/js/web/lib/wasm/jsep/webgpu/types.ts +++ b/js/web/lib/wasm/jsep/webgpu/types.ts @@ -31,7 +31,6 @@ export interface GpuData { export interface TensorInfo { dims: readonly number[]; dataType: number; - outputIndex?: number; } export interface ProgramUniform { diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 8b29e24cb2143..9a4500822d457 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -209,8 +209,6 @@ export interface OrtTrainingAPIs { */ export interface OrtWasmModule extends EmscriptenModule, OrtInferenceAPIs, Partial, Partial { - HEAP64: BigInt64Array; - HEAPU64: BigUint64Array; PTR_SIZE: number; // #region emscripten functions stackSave(): number; From 5f504c5355e9ac4f20c6d1cea5534575999f5c19 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 18 Jul 2024 13:47:24 -0700 Subject: [PATCH 21/45] Miscellaneous edits. --- js/web/lib/wasm/jsep/init.ts | 2 +- onnxruntime/core/framework/tensorprotoutils.cc | 8 ++++---- onnxruntime/core/graph/model.cc | 9 +++++---- onnxruntime/core/providers/js/js_kernel.h | 2 +- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index d237bb850c054..11fa9dd62b8ec 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -217,7 +217,7 @@ export const init = await backend.download( Number(gpuDataId), - () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0)); + () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset) >>> 0 + Number(size))); }, // jsepCreateKernel diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index b989be2c0fb7c..305235e457f6a 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1030,9 +1030,9 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo if (!fileData) { return 2; // File not found in preloaded files. } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const buffer = $3 >>> 0; + const offset = Number($1 >>> 0); + const length = Number($2 >>> 0); + const buffer = Number($3 >>> 0); if (offset + length > fileData.byteLength) { return 3; // Out of bounds. @@ -1041,7 +1041,7 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo try { // Copy the file data (fileData,offset,length) into WebAssembly memory // (HEAPU8,buffer,length). - HEAPU8.set(fileData.subarray(Number(offset), Number(offset) + length), buffer); + HEAPU8.set(fileData.subarray(offset, offset + length), buffer); return 0; } catch { return 4; diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index b90ac73ef1e34..280c9f77dea79 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -552,12 +552,13 @@ static Status SaveModel(Model& model, const T& file_path) { model_proto.SerializeToArray(buffer, buffer_size); EM_ASM(({ - const buffer = $0; - const buffer_size = $1; + const buffer = Number($0); + const buffer_size = Number($1); const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); - bytes.set(HEAPU8.subarray(Number(buffer), Number(buffer) + buffer_size)); - if (typeof process == 'object' && typeof process.versions == 'object' && typeof process.versions.node == 'string') { + bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); + if (typeof process == 'object' && typeof process.versions == 'object' && + typeof process.versions.node == 'string') { // Node.js require('fs').writeFileSync(file_path, bytes); } else { diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index aeaf40f957c70..25809811acd0c 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -200,7 +200,7 @@ class JsKernel : public OpKernel { return status; } - intptr_t status_code = EM_ASM_INT( + int status_code = EM_ASM_INT( { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); From 5e07c97005a411d242caa9e64f41771da4f8bae0 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 19 Jul 2024 14:49:17 -0700 Subject: [PATCH 22/45] Use Number cast to jsepRunKernel --- cmake/onnxruntime_webassembly.cmake | 53 ++++++++++++++++++++--- onnxruntime/core/providers/js/js_kernel.h | 4 +- 2 files changed, 49 insertions(+), 8 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index c42f8e0ce6a56..62afaec9fde07 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -243,13 +243,55 @@ else() target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - + target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_link_options(onnxruntime_webassembly PRIVATE --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" ) @@ -277,8 +319,7 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set(SIGNATURE_CONVERSIONS "\ -OrtRun:_ppp_p_pp,\ + set(SIGNATURE_CONVERSIONS "OrtRun:_ppp_p_pp,\ OrtGetTensorData:_ppppp,\ OrtCreateTensor:p_pppp_,\ OrtCreateSession:pppp,\ diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 25809811acd0c..5ed3b7f3e8131 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -200,8 +200,8 @@ class JsKernel : public OpKernel { return status; } - int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, + intptr_t status_code = EM_ASM_INT( + { return Module.jsepRunKernel(Number($0), Number($1), Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" From c8b7d205421531594c6df3273e3a00870fa41a13 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 25 Jul 2024 18:04:12 -0700 Subject: [PATCH 23/45] Use uint32_t instead of size_t. --- onnxruntime/wasm/api.cc | 72 +++++++++++++++++++------------------ onnxruntime/wasm/api.h | 80 ++++++++++++++++++++--------------------- 2 files changed, 78 insertions(+), 74 deletions(-) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 566cd10aa287f..272851df6a4c3 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -97,15 +97,15 @@ void OrtGetLastError(int* error_code, const char** error_message) { *error_message = g_last_error_message.empty() ? nullptr : g_last_error_message.c_str(); } -OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, +OrtSessionOptions* OrtCreateSessionOptions(uint32_t graph_optimization_level, bool enable_cpu_mem_arena, bool enable_mem_pattern, - size_t execution_mode, + uint32_t execution_mode, bool enable_profiling, const char* /*profile_file_prefix*/, const char* log_id, - size_t log_severity_level, - size_t log_verbosity_level, + uint32_t log_severity_level, + uint32_t log_verbosity_level, const char* optimized_model_filepath) { OrtSessionOptions* session_options = nullptr; RETURN_NULLPTR_IF_ERROR(CreateSessionOptions, &session_options); @@ -179,7 +179,7 @@ void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); } -OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { +OrtSession* OrtCreateSession(void* data, uint32_t data_length, OrtSessionOptions* session_options) { #if defined(__EMSCRIPTEN_PTHREADS__) RETURN_NULLPTR_IF_ERROR(DisablePerSessionThreads, session_options); #else @@ -198,13 +198,17 @@ void OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); } -int OrtGetInputOutputCount(OrtSession* session, size_t* input_count, size_t* output_count) { - RETURN_ERROR_CODE_IF_ERROR(SessionGetInputCount, session, input_count); - RETURN_ERROR_CODE_IF_ERROR(SessionGetOutputCount, session, output_count); +int OrtGetInputOutputCount(OrtSession* session, uint32_t* input_count, uint32_t* output_count) { + size_t input_count_tmp = 0; + size_t output_count_tmp = 0; + RETURN_ERROR_CODE_IF_ERROR(SessionGetInputCount, session, &input_count_tmp); + RETURN_ERROR_CODE_IF_ERROR(SessionGetOutputCount, session, &output_count_tmp); + *input_count = static_cast(input_count_tmp); + *output_count = static_cast(output_count_tmp); return ORT_OK; } -char* OrtGetInputName(OrtSession* session, size_t index) { +char* OrtGetInputName(OrtSession* session, uint32_t index) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -214,7 +218,7 @@ char* OrtGetInputName(OrtSession* session, size_t index) { : nullptr; } -char* OrtGetOutputName(OrtSession* session, size_t index) { +char* OrtGetOutputName(OrtSession* session, uint32_t index) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -231,7 +235,7 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { +OrtValue* OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location) { if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && data_location != DATA_LOCATION_GPU_BUFFER) { @@ -279,7 +283,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* } } -int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length) { +int OrtGetTensorData(OrtValue* tensor, uint32_t* data_type, void** data, uint32_t** dims, uint32_t* dims_length) { ONNXType tensor_type; RETURN_ERROR_CODE_IF_ERROR(GetValueType, tensor, &tensor_type); if (tensor_type != ONNX_TYPE_TENSOR) { @@ -297,8 +301,8 @@ int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** OrtAllocator* allocator = nullptr; RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); - size_t* p_dims = reinterpret_cast(allocator->Alloc(allocator, sizeof(size_t) * dims_len)); - REGISTER_AUTO_RELEASE_BUFFER(size_t, p_dims, allocator); + uint32_t* p_dims = reinterpret_cast(allocator->Alloc(allocator, sizeof(uint32_t) * dims_len)); + REGISTER_AUTO_RELEASE_BUFFER(uint32_t, p_dims, allocator); ONNXTensorElementDataType type; RETURN_ERROR_CODE_IF_ERROR(GetTensorElementType, info, &type); @@ -306,7 +310,7 @@ int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** std::vector shape(dims_len, 0); RETURN_ERROR_CODE_IF_ERROR(GetDimensions, info, shape.data(), shape.size()); for (size_t i = 0; i < dims_len; i++) { - p_dims[i] = static_cast(shape[i]); + p_dims[i] = static_cast(shape[i]); } if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { @@ -359,8 +363,8 @@ void OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); } -OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, - size_t log_verbosity_level, +OrtRunOptions* OrtCreateRunOptions(uint32_t log_severity_level, + uint32_t log_verbosity_level, bool terminate, const char* tag) { OrtRunOptions* run_options = nullptr; @@ -444,7 +448,7 @@ void OrtReleaseBinding(OrtIoBinding* io_binding) { int OrtRunWithBinding(OrtSession* session, OrtIoBinding* io_binding, - size_t output_count, + uint32_t output_count, OrtValue** outputs, OrtRunOptions* run_options) { RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); @@ -470,8 +474,8 @@ int OrtRunWithBinding(OrtSession* session, } int OrtRun(OrtSession* session, - const char** input_names, const ort_tensor_handle_t* inputs, int input_count, - const char** output_names, int output_count, ort_tensor_handle_t* outputs, + const char** input_names, const ort_tensor_handle_t* inputs, uint32_t input_count, + const char** output_names, uint32_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } @@ -501,7 +505,7 @@ char* OrtEndProfiling(ort_session_handle_t session) { } while (false) ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, - size_t checkpoint_size) { + uint32_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; return (CHECK_TRAINING_STATUS(LoadCheckpointFromBuffer, checkpoint_data_buffer, checkpoint_size, &checkpoint_state) == ORT_OK) @@ -516,11 +520,11 @@ void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_h ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(const ort_session_options_handle_t options, ort_training_checkpoint_handle_t training_checkpoint_state_handle, void* train_model, - size_t train_size, + uint32_t train_size, void* eval_model, - size_t eval_size, + uint32_t eval_size, void* optimizer_model, - size_t optimizer_size) { + uint32_t optimizer_size) { OrtTrainingSession* training_session = nullptr; return (CHECK_TRAINING_STATUS(CreateTrainingSessionFromBuffer, g_env, options, training_checkpoint_state_handle, train_model, train_size, @@ -536,9 +540,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingLazyResetGrad(ort_training_session_handle_t int EMSCRIPTEN_KEEPALIVE OrtTrainingRunTrainStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - size_t input_count, + uint32_t input_count, ort_tensor_handle_t* outputs, - size_t output_count, + uint32_t output_count, ort_run_options_handle_t options) { return CHECK_TRAINING_STATUS(TrainStep, training_handle, options, input_count, inputs, output_count, outputs); } @@ -550,37 +554,37 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingOptimizerStep(ort_training_session_handle_t int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - size_t input_count, + uint32_t input_count, ort_tensor_handle_t* outputs, - size_t output_count, + uint32_t output_count, ort_run_options_handle_t options) { return CHECK_TRAINING_STATUS(EvalStep, training_handle, options, input_count, inputs, output_count, outputs); } int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handle_t training_handle, - size_t* param_size, + uint32_t* param_size, bool trainable_only) { return CHECK_TRAINING_STATUS(GetParametersSize, training_handle, param_size, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - size_t parameter_count, + uint32_t parameter_count, bool trainable_only) { return CHECK_TRAINING_STATUS(CopyParametersToBuffer, training_handle, parameters_buffer, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - size_t parameter_count, + uint32_t parameter_count, bool trainable_only) { return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, - size_t* input_count, - size_t* output_count, + uint32_t* input_count, + uint32_t* output_count, bool isEvalModel) { if (isEvalModel) { RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); @@ -594,7 +598,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_sessio } char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, - size_t index, + uint32_t index, bool isInput, bool isEvalModel) { OrtAllocator* allocator = nullptr; diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index ffba89c4fe4f9..49fba6bf93071 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -68,16 +68,16 @@ void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_me * @param optimized_model_filepath filepath of the optimized model to dump. * @returns a session option handle. Caller must release it after use by calling OrtReleaseSessionOptions(). */ -ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t graph_optimization_level, - bool enable_cpu_mem_arena, - bool enable_mem_pattern, - size_t execution_mode, - bool enable_profiling, - const char* profile_file_prefix, - const char* log_id, - size_t log_severity_level, - size_t log_verbosity_level, - const char* optimized_model_filepath); +ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(uint32_t graph_optimization_level, + bool enable_cpu_mem_arena, + bool enable_mem_pattern, + uint32_t execution_mode, + bool enable_profiling, + const char* profile_file_prefix, + const char* log_id, + uint32_t log_severity_level, + uint32_t log_verbosity_level, + const char* optimized_model_filepath); /** * append an execution provider for a session. @@ -118,7 +118,7 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t * @returns an ORT session handle. Caller must release it after use by calling OrtReleaseSession(). */ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, - size_t data_length, + uint32_t data_length, ort_session_options_handle_t session_options); /** @@ -129,13 +129,13 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); /** * get model's input count and output count. * @param session handle of the specified session - * @param input_count [out] a pointer to a size_t variable to accept input_count. - * @param output_count [out] a pointer to a size_t variable to accept output_count. + * @param input_count [out] a pointer to a uint32_t variable to accept input_count. + * @param output_count [out] a pointer to a uint32_t variable to accept output_count. * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtGetInputOutputCount(ort_session_handle_t session, - size_t* input_count, - size_t* output_count); + uint32_t* input_count, + uint32_t* output_count); /** * get the model's input name. @@ -144,7 +144,7 @@ int EMSCRIPTEN_KEEPALIVE OrtGetInputOutputCount(ort_session_handle_t session, * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by * calling OrtFree(). */ -char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, size_t index); +char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, uint32_t index); /** * get the model's output name. * @param session handle of the specified session @@ -152,7 +152,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, size_t * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by * calling OrtFree(). */ -char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t index); +char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, uint32_t index); /** * free the specified buffer. @@ -170,7 +170,7 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -183,7 +183,7 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ -int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, uint32_t* data_type, void** data, uint32_t** dims, uint32_t* dims_length); /** * release the specified tensor. @@ -198,8 +198,8 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); * @param tag tag for this run * @returns a run option handle. Caller must release it after use by calling OrtReleaseRunOptions(). */ -ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(size_t log_severity_level, - size_t log_verbosity_level, +ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(uint32_t log_severity_level, + uint32_t log_verbosity_level, bool terminate, const char* tag); @@ -268,7 +268,7 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); */ int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, ort_io_binding_handle_t io_binding, - size_t output_count, + uint32_t output_count, ort_tensor_handle_t* outputs, ort_run_options_handle_t run_options); @@ -280,9 +280,9 @@ int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, const char** input_names, const ort_tensor_handle_t* inputs, - int input_count, + uint32_t input_count, const char** output_names, - int output_count, + uint32_t output_count, ort_tensor_handle_t* outputs, ort_run_options_handle_t run_options); @@ -304,7 +304,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session); * @param checkpoint_size size of the CheckpointState in bytes * @return ort_training_checkpoint_handle_t */ -ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size); +ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, uint32_t checkpoint_size); /** * @brief Release the specified ORT training checkpoint state. @@ -330,11 +330,11 @@ void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_h ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(ort_session_options_handle_t options, ort_training_checkpoint_handle_t training_checkpoint_state_handle, void* train_model, - size_t train_size, + uint32_t train_size, void* eval_model, - size_t eval_size, + uint32_t eval_size, void* optimizer_model, - size_t optimizer_size); + uint32_t optimizer_size); /** * Resets the gradients of all trainable parameters to zero for the specified TrainingSession @@ -355,9 +355,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingLazyResetGrad(ort_training_session_handle_t * @return int ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingRunTrainStep(ort_training_session_handle_t training_handle, - ort_tensor_handle_t* inputs, size_t input_count, + ort_tensor_handle_t* inputs, uint32_t input_count, ort_tensor_handle_t* outputs, - size_t output_count, + uint32_t output_count, ort_run_options_handle_t run_options = nullptr); /** @@ -381,9 +381,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingOptimizerStep(ort_training_session_handle_t */ int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - size_t input_count, + uint32_t input_count, ort_tensor_handle_t* outputs, - size_t output_count, + uint32_t output_count, ort_run_options_handle_t options = nullptr); /** @@ -396,7 +396,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t train * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handle_t training_handle, - size_t* param_size, + uint32_t* param_size, bool trainable_only); /** @@ -414,7 +414,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handl */ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - size_t parameter_count, + uint32_t parameter_count, bool trainable_only); /** @@ -429,21 +429,21 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_ */ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - size_t parameter_count, + uint32_t parameter_count, bool trainable_only); /** * Gets the input count and output count of the training or eval model associated with the given training handle. * @param traning_handle handle of the traning session - * @param input_count [out] a pointer to a size_t variable to accept input_count - * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param input_count [out] a pointer to a uint32_t variable to accept input_count + * @param output_count [out] a pointer to a uint32_t variable to accept output_count * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, - size_t* input_count, - size_t* output_count, + uint32_t* input_count, + uint32_t* output_count, bool isEvalModel); /** @@ -457,7 +457,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_sessio * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, - size_t index, + uint32_t index, bool isInput, bool isEvalModel); From a31c5de3736639441783c288e532cb0c6eb6cfb2 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 25 Jul 2024 18:06:44 -0700 Subject: [PATCH 24/45] Revert unnecessary compiler flags. --- cmake/adjust_global_compile_flags.cmake | 6 ------ 1 file changed, 6 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index e12ac0ae4605d..ad5a4ff87c7b5 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -63,12 +63,6 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") string(APPEND CMAKE_CXX_FLAGS " -pthread -Wno-pthreads-mem-growth") endif() - - # Build WebAssembly with 64bit support. - if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") - string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") - endif() endif() if (onnxruntime_EXTERNAL_TRANSFORMER_SRC_PATH) From bfbb2d6de9818806b186f28a4aefe2c8680bee79 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 25 Jul 2024 18:07:41 -0700 Subject: [PATCH 25/45] Fixed SIGNATURE_CONVERSIONS --- cmake/onnxruntime_webassembly.cmake | 76 +++++++++++++++++------------ 1 file changed, 45 insertions(+), 31 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 62afaec9fde07..25de7d503cf81 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -231,7 +231,6 @@ else() target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) @@ -243,49 +242,64 @@ else() target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental) - target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental) @@ -320,25 +334,25 @@ else() ) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) set(SIGNATURE_CONVERSIONS "OrtRun:_ppp_p_pp,\ +OrtRunWithBinding:_pp_pp,\ OrtGetTensorData:_ppppp,\ -OrtCreateTensor:p_pppp_,\ -OrtCreateSession:pppp,\ +OrtCreateTensor:p_p_p__,\ +OrtCreateSession:pp_p,\ OrtReleaseSession:_p,\ -OrtGetInputOutputCount:pppp,\ -OrtCreateSessionOptions:pp__p_ppppp,\ -OrtAddSessionConfigEntry:pppp,\ +OrtGetInputOutputCount:_ppp,\ +OrtCreateSessionOptions:p_____pp__p,\ OrtReleaseSessionOptions:_p,\ OrtAppendExecutionProvider:ppp,\ -OrtAddSessionConfigEntry:pppp,\ -OrtGetInputName:ppp,\ -OrtGetOutputName:ppp,\ -OrtCreateRunOptions:ppp_p,\ -OrtReleaseRunOptions:pp,\ +OrtAddSessionConfigEntry:_pp,\ +OrtGetInputName:pp_,\ +OrtGetOutputName:pp_,\ +OrtCreateRunOptions:p___p,\ +OrtReleaseRunOptions:_p,\ OrtReleaseTensor:_p,\ OrtFree:_p,\ OrtCreateBinding:_p,\ OrtBindInput:_ppp,\ -OrtBindOutput:_ppp,\ +OrtBindOutput:_ppp_,\ OrtClearBoundOutputs:_p,\ OrtReleaseBinding:_p,\ OrtGetLastError:_pp,\ From 603013c02d4807758c3b56a50d44d84a0d1097e3 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Thu, 25 Jul 2024 18:08:14 -0700 Subject: [PATCH 26/45] minor change --- js/web/lib/wasm/jsep/init.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 11fa9dd62b8ec..d237bb850c054 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -217,7 +217,7 @@ export const init = await backend.download( Number(gpuDataId), - () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset) >>> 0 + Number(size))); + () => module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0)); }, // jsepCreateKernel From ab919147cb82a3934e8dd9fdc7dd61ab23c7c655 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 26 Jul 2024 10:03:39 -0700 Subject: [PATCH 27/45] lint --- onnxruntime/wasm/api.cc | 2 +- onnxruntime/wasm/api.h | 18 +++++++++--------- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 272851df6a4c3..fb985dba82bfd 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -235,7 +235,7 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location) { +OrtValue* OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location) { if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && data_location != DATA_LOCATION_GPU_BUFFER) { diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 49fba6bf93071..9c790d98c3a74 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -69,15 +69,15 @@ void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_me * @returns a session option handle. Caller must release it after use by calling OrtReleaseSessionOptions(). */ ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(uint32_t graph_optimization_level, - bool enable_cpu_mem_arena, - bool enable_mem_pattern, - uint32_t execution_mode, - bool enable_profiling, - const char* profile_file_prefix, - const char* log_id, - uint32_t log_severity_level, - uint32_t log_verbosity_level, - const char* optimized_model_filepath); + bool enable_cpu_mem_arena, + bool enable_mem_pattern, + uint32_t execution_mode, + bool enable_profiling, + const char* profile_file_prefix, + const char* log_id, + uint32_t log_severity_level, + uint32_t log_verbosity_level, + const char* optimized_model_filepath); /** * append an execution provider for a session. From a4fad865ce766fa624430f4b7718df0bd7610dd5 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 26 Jul 2024 11:00:54 -0700 Subject: [PATCH 28/45] Revert changes to gemm.h --- onnxruntime/core/providers/js/operators/gemm.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/gemm.h b/onnxruntime/core/providers/js/operators/gemm.h index d7f8fa6289c1a..74091526f8411 100644 --- a/onnxruntime/core/providers/js/operators/gemm.h +++ b/onnxruntime/core/providers/js/operators/gemm.h @@ -23,8 +23,8 @@ class Gemm : public JsKernel { "transA" : $3, "transB" : $4 }), - static_cast(alpha), - static_cast(beta), + static_cast(alpha), + static_cast(beta), static_cast(transA), static_cast(transB)); } From 11bbf268848a5f15976d6242e7823fba68b07ecb Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 26 Jul 2024 11:29:33 -0700 Subject: [PATCH 29/45] Keep static assertion guarded by ifdef. --- cmake/adjust_global_compile_flags.cmake | 5 +++++ onnxruntime/wasm/api.cc | 3 +++ 2 files changed, 8 insertions(+) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index ad5a4ff87c7b5..232b18032baf6 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -58,6 +58,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -DORT_WASM64") + string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64") + endif() + # Build WebAssembly with multi-threads support. if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index fb985dba82bfd..831067151a241 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -27,6 +27,9 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); +#ifndef ORT_WASM64 +static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); +#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { From 939a7404a85740b653a88926fda767f2b1937651 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 11:36:47 -0700 Subject: [PATCH 30/45] Switch back to using size_t instead of uint32_t --- onnxruntime/wasm/api.cc | 72 +++++++++++++++++++---------------------- onnxruntime/wasm/api.h | 68 +++++++++++++++++++------------------- 2 files changed, 68 insertions(+), 72 deletions(-) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 831067151a241..72a7cbde03c01 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -100,15 +100,15 @@ void OrtGetLastError(int* error_code, const char** error_message) { *error_message = g_last_error_message.empty() ? nullptr : g_last_error_message.c_str(); } -OrtSessionOptions* OrtCreateSessionOptions(uint32_t graph_optimization_level, +OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, bool enable_cpu_mem_arena, bool enable_mem_pattern, - uint32_t execution_mode, + size_t execution_mode, bool enable_profiling, const char* /*profile_file_prefix*/, const char* log_id, - uint32_t log_severity_level, - uint32_t log_verbosity_level, + size_t log_severity_level, + size_t log_verbosity_level, const char* optimized_model_filepath) { OrtSessionOptions* session_options = nullptr; RETURN_NULLPTR_IF_ERROR(CreateSessionOptions, &session_options); @@ -182,7 +182,7 @@ void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); } -OrtSession* OrtCreateSession(void* data, uint32_t data_length, OrtSessionOptions* session_options) { +OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { #if defined(__EMSCRIPTEN_PTHREADS__) RETURN_NULLPTR_IF_ERROR(DisablePerSessionThreads, session_options); #else @@ -201,17 +201,13 @@ void OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); } -int OrtGetInputOutputCount(OrtSession* session, uint32_t* input_count, uint32_t* output_count) { - size_t input_count_tmp = 0; - size_t output_count_tmp = 0; - RETURN_ERROR_CODE_IF_ERROR(SessionGetInputCount, session, &input_count_tmp); - RETURN_ERROR_CODE_IF_ERROR(SessionGetOutputCount, session, &output_count_tmp); - *input_count = static_cast(input_count_tmp); - *output_count = static_cast(output_count_tmp); +int OrtGetInputOutputCount(OrtSession* session, size_t* input_count, size_t* output_count) { + RETURN_ERROR_CODE_IF_ERROR(SessionGetInputCount, session, input_count); + RETURN_ERROR_CODE_IF_ERROR(SessionGetOutputCount, session, output_count); return ORT_OK; } -char* OrtGetInputName(OrtSession* session, uint32_t index) { +char* OrtGetInputName(OrtSession* session, size_t index) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -221,7 +217,7 @@ char* OrtGetInputName(OrtSession* session, uint32_t index) { : nullptr; } -char* OrtGetOutputName(OrtSession* session, uint32_t index) { +char* OrtGetOutputName(OrtSession* session, size_t index) { OrtAllocator* allocator = nullptr; RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); @@ -238,7 +234,7 @@ void OrtFree(void* ptr) { } } -OrtValue* OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location) { +OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { if (data_location != DATA_LOCATION_CPU && data_location != DATA_LOCATION_CPU_PINNED && data_location != DATA_LOCATION_GPU_BUFFER) { @@ -286,7 +282,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint3 } } -int OrtGetTensorData(OrtValue* tensor, uint32_t* data_type, void** data, uint32_t** dims, uint32_t* dims_length) { +int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length) { ONNXType tensor_type; RETURN_ERROR_CODE_IF_ERROR(GetValueType, tensor, &tensor_type); if (tensor_type != ONNX_TYPE_TENSOR) { @@ -304,8 +300,8 @@ int OrtGetTensorData(OrtValue* tensor, uint32_t* data_type, void** data, uint32_ OrtAllocator* allocator = nullptr; RETURN_ERROR_CODE_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); - uint32_t* p_dims = reinterpret_cast(allocator->Alloc(allocator, sizeof(uint32_t) * dims_len)); - REGISTER_AUTO_RELEASE_BUFFER(uint32_t, p_dims, allocator); + size_t* p_dims = reinterpret_cast(allocator->Alloc(allocator, sizeof(size_t) * dims_len)); + REGISTER_AUTO_RELEASE_BUFFER(size_t, p_dims, allocator); ONNXTensorElementDataType type; RETURN_ERROR_CODE_IF_ERROR(GetTensorElementType, info, &type); @@ -313,7 +309,7 @@ int OrtGetTensorData(OrtValue* tensor, uint32_t* data_type, void** data, uint32_ std::vector shape(dims_len, 0); RETURN_ERROR_CODE_IF_ERROR(GetDimensions, info, shape.data(), shape.size()); for (size_t i = 0; i < dims_len; i++) { - p_dims[i] = static_cast(shape[i]); + p_dims[i] = static_cast(shape[i]); } if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING) { @@ -366,8 +362,8 @@ void OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); } -OrtRunOptions* OrtCreateRunOptions(uint32_t log_severity_level, - uint32_t log_verbosity_level, +OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, + size_t log_verbosity_level, bool terminate, const char* tag) { OrtRunOptions* run_options = nullptr; @@ -451,7 +447,7 @@ void OrtReleaseBinding(OrtIoBinding* io_binding) { int OrtRunWithBinding(OrtSession* session, OrtIoBinding* io_binding, - uint32_t output_count, + size_t output_count, OrtValue** outputs, OrtRunOptions* run_options) { RETURN_ERROR_CODE_IF_ERROR(RunWithBinding, session, run_options, io_binding); @@ -477,8 +473,8 @@ int OrtRunWithBinding(OrtSession* session, } int OrtRun(OrtSession* session, - const char** input_names, const ort_tensor_handle_t* inputs, uint32_t input_count, - const char** output_names, uint32_t output_count, ort_tensor_handle_t* outputs, + const char** input_names, const ort_tensor_handle_t* inputs, size_t input_count, + const char** output_names, size_t output_count, ort_tensor_handle_t* outputs, OrtRunOptions* run_options) { return CHECK_STATUS(Run, session, run_options, input_names, inputs, input_count, output_names, output_count, outputs); } @@ -508,7 +504,7 @@ char* OrtEndProfiling(ort_session_handle_t session) { } while (false) ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, - uint32_t checkpoint_size) { + size_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; return (CHECK_TRAINING_STATUS(LoadCheckpointFromBuffer, checkpoint_data_buffer, checkpoint_size, &checkpoint_state) == ORT_OK) @@ -523,11 +519,11 @@ void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_h ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(const ort_session_options_handle_t options, ort_training_checkpoint_handle_t training_checkpoint_state_handle, void* train_model, - uint32_t train_size, + size_t train_size, void* eval_model, - uint32_t eval_size, + size_t eval_size, void* optimizer_model, - uint32_t optimizer_size) { + size_t optimizer_size) { OrtTrainingSession* training_session = nullptr; return (CHECK_TRAINING_STATUS(CreateTrainingSessionFromBuffer, g_env, options, training_checkpoint_state_handle, train_model, train_size, @@ -543,9 +539,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingLazyResetGrad(ort_training_session_handle_t int EMSCRIPTEN_KEEPALIVE OrtTrainingRunTrainStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - uint32_t input_count, + size_t input_count, ort_tensor_handle_t* outputs, - uint32_t output_count, + size_t output_count, ort_run_options_handle_t options) { return CHECK_TRAINING_STATUS(TrainStep, training_handle, options, input_count, inputs, output_count, outputs); } @@ -557,37 +553,37 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingOptimizerStep(ort_training_session_handle_t int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - uint32_t input_count, + size_t input_count, ort_tensor_handle_t* outputs, - uint32_t output_count, + size_t output_count, ort_run_options_handle_t options) { return CHECK_TRAINING_STATUS(EvalStep, training_handle, options, input_count, inputs, output_count, outputs); } int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handle_t training_handle, - uint32_t* param_size, + size_t* param_size, bool trainable_only) { return CHECK_TRAINING_STATUS(GetParametersSize, training_handle, param_size, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - uint32_t parameter_count, + size_t parameter_count, bool trainable_only) { return CHECK_TRAINING_STATUS(CopyParametersToBuffer, training_handle, parameters_buffer, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - uint32_t parameter_count, + size_t parameter_count, bool trainable_only) { return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, - uint32_t* input_count, - uint32_t* output_count, + size_t* input_count, + size_t* output_count, bool isEvalModel) { if (isEvalModel) { RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); @@ -601,7 +597,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_sessio } char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, - uint32_t index, + size_t index, bool isInput, bool isEvalModel) { OrtAllocator* allocator = nullptr; diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9c790d98c3a74..f8b5dc49fb875 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -68,15 +68,15 @@ void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_me * @param optimized_model_filepath filepath of the optimized model to dump. * @returns a session option handle. Caller must release it after use by calling OrtReleaseSessionOptions(). */ -ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(uint32_t graph_optimization_level, +ort_session_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSessionOptions(size_t graph_optimization_level, bool enable_cpu_mem_arena, bool enable_mem_pattern, - uint32_t execution_mode, + size_t execution_mode, bool enable_profiling, const char* profile_file_prefix, const char* log_id, - uint32_t log_severity_level, - uint32_t log_verbosity_level, + size_t log_severity_level, + size_t log_verbosity_level, const char* optimized_model_filepath); /** @@ -118,7 +118,7 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t * @returns an ORT session handle. Caller must release it after use by calling OrtReleaseSession(). */ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, - uint32_t data_length, + size_t data_length, ort_session_options_handle_t session_options); /** @@ -129,13 +129,13 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); /** * get model's input count and output count. * @param session handle of the specified session - * @param input_count [out] a pointer to a uint32_t variable to accept input_count. - * @param output_count [out] a pointer to a uint32_t variable to accept output_count. + * @param input_count [out] a pointer to a size_t variable to accept input_count. + * @param output_count [out] a pointer to a size_t variable to accept output_count. * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtGetInputOutputCount(ort_session_handle_t session, - uint32_t* input_count, - uint32_t* output_count); + size_t* input_count, + size_t* output_count); /** * get the model's input name. @@ -144,7 +144,7 @@ int EMSCRIPTEN_KEEPALIVE OrtGetInputOutputCount(ort_session_handle_t session, * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by * calling OrtFree(). */ -char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, uint32_t index); +char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, size_t index); /** * get the model's output name. * @param session handle of the specified session @@ -152,7 +152,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtGetInputName(ort_session_handle_t session, uint32_ * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by * calling OrtFree(). */ -char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, uint32_t index); +char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t index); /** * free the specified buffer. @@ -170,7 +170,7 @@ void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); * @param data_location specify the memory location of the tensor data. 0 for CPU, 1 for GPU buffer. * @returns a tensor handle. Caller must release it after use by calling OrtReleaseTensor(). */ -ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, uint32_t data_length, uint32_t* dims, uint32_t dims_length, int data_location); +ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location); /** * get type, shape info and data of the specified tensor. @@ -183,7 +183,7 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ -int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, uint32_t* data_type, void** data, uint32_t** dims, uint32_t* dims_length); +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); /** * release the specified tensor. @@ -198,8 +198,8 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); * @param tag tag for this run * @returns a run option handle. Caller must release it after use by calling OrtReleaseRunOptions(). */ -ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(uint32_t log_severity_level, - uint32_t log_verbosity_level, +ort_run_options_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateRunOptions(size_t log_severity_level, + size_t log_verbosity_level, bool terminate, const char* tag); @@ -268,7 +268,7 @@ void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); */ int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, ort_io_binding_handle_t io_binding, - uint32_t output_count, + size_t output_count, ort_tensor_handle_t* outputs, ort_run_options_handle_t run_options); @@ -280,9 +280,9 @@ int EMSCRIPTEN_KEEPALIVE OrtRunWithBinding(ort_session_handle_t session, int EMSCRIPTEN_KEEPALIVE OrtRun(ort_session_handle_t session, const char** input_names, const ort_tensor_handle_t* inputs, - uint32_t input_count, + size_t input_count, const char** output_names, - uint32_t output_count, + size_t output_count, ort_tensor_handle_t* outputs, ort_run_options_handle_t run_options); @@ -304,7 +304,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtEndProfiling(ort_session_handle_t session); * @param checkpoint_size size of the CheckpointState in bytes * @return ort_training_checkpoint_handle_t */ -ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, uint32_t checkpoint_size); +ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size); /** * @brief Release the specified ORT training checkpoint state. @@ -330,11 +330,11 @@ void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_h ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(ort_session_options_handle_t options, ort_training_checkpoint_handle_t training_checkpoint_state_handle, void* train_model, - uint32_t train_size, + size_t train_size, void* eval_model, - uint32_t eval_size, + size_t eval_size, void* optimizer_model, - uint32_t optimizer_size); + size_t optimizer_size); /** * Resets the gradients of all trainable parameters to zero for the specified TrainingSession @@ -355,9 +355,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingLazyResetGrad(ort_training_session_handle_t * @return int ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingRunTrainStep(ort_training_session_handle_t training_handle, - ort_tensor_handle_t* inputs, uint32_t input_count, + ort_tensor_handle_t* inputs, size_t input_count, ort_tensor_handle_t* outputs, - uint32_t output_count, + size_t output_count, ort_run_options_handle_t run_options = nullptr); /** @@ -381,9 +381,9 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingOptimizerStep(ort_training_session_handle_t */ int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t training_handle, ort_tensor_handle_t* inputs, - uint32_t input_count, + size_t input_count, ort_tensor_handle_t* outputs, - uint32_t output_count, + size_t output_count, ort_run_options_handle_t options = nullptr); /** @@ -396,7 +396,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingEvalStep(ort_training_session_handle_t train * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handle_t training_handle, - uint32_t* param_size, + size_t* param_size, bool trainable_only); /** @@ -414,7 +414,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetParametersSize(ort_training_session_handl */ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - uint32_t parameter_count, + size_t parameter_count, bool trainable_only); /** @@ -429,21 +429,21 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersToBuffer(ort_training_session_ */ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_session_handle_t training_handle, ort_tensor_handle_t parameters_buffer, - uint32_t parameter_count, + size_t parameter_count, bool trainable_only); /** * Gets the input count and output count of the training or eval model associated with the given training handle. * @param traning_handle handle of the traning session - * @param input_count [out] a pointer to a uint32_t variable to accept input_count - * @param output_count [out] a pointer to a uint32_t variable to accept output_count + * @param input_count [out] a pointer to a size_t variable to accept input_count + * @param output_count [out] a pointer to a size_t variable to accept output_count * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output * count of the eval model. * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. */ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, - uint32_t* input_count, - uint32_t* output_count, + size_t* input_count, + size_t* output_count, bool isEvalModel); /** @@ -457,7 +457,7 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_sessio * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by */ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, - uint32_t index, + size_t index, bool isInput, bool isEvalModel); From b3aaea933b7f962e7128e93eda911c8db5849542 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 11:42:54 -0700 Subject: [PATCH 31/45] Specify setValue/getValue type argument 'i32' or 'i64' based on wasm32 or wasm64' --- js/web/lib/wasm/jsep/init.ts | 24 ++++++++++++---------- js/web/lib/wasm/jsep/util.ts | 2 +- js/web/lib/wasm/jsep/webgpu/ops/common.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 15 +++++++------- js/web/lib/wasm/wasm-training-core-impl.ts | 11 +++++----- js/web/lib/wasm/wasm-utils.ts | 2 +- 6 files changed, 30 insertions(+), 26 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index d237bb850c054..0846fc711748c 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -72,21 +72,22 @@ class ComputeContextImpl implements ComputeContext { // extract context data const ptrSize = module.PTR_SIZE; - let dataIndex = module.PTR_SIZE === 8 ? (contextDataOffset / 2 ** 3) : (contextDataOffset >> 2); - this.opKernelContext = module.getValue(ptrSize * dataIndex++, 'i32'); - const inputCount = module.getValue(ptrSize * dataIndex++, 'i32'); - this.outputCount = module.getValue(ptrSize * dataIndex++, 'i32'); - this.customDataOffset = module.getValue(ptrSize * dataIndex++, 'i32'); - this.customDataSize = module.getValue(ptrSize * dataIndex++, 'i32'); + let dataIndex = module.PTR_SIZE === 4 ? (contextDataOffset >> 2) : (contextDataOffset / 2 ** 3); + const type = ptrSize === 4 ? 'i32' : 'i64'; + this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type)); + const inputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*')); + this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type)); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = module.getValue(ptrSize * dataIndex++, 'i32'); + const dataType = module.getValue(ptrSize * dataIndex++, type); const data = module.getValue(ptrSize * dataIndex++, '*'); - const dim = module.getValue(ptrSize * dataIndex++, 'i32'); + const dim = module.getValue(ptrSize * dataIndex++, type); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(module.getValue(ptrSize * dataIndex++, 'i32')); + dims.push(module.getValue(ptrSize * dataIndex++, type)); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } @@ -129,10 +130,11 @@ class ComputeContextImpl implements ComputeContext { const stack = this.module.stackSave(); try { const ptrSize = this.module.PTR_SIZE; + const type = ptrSize === 4 ? 'i32' : 'i64'; const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); - this.module.setValue(data, dims.length, 'i32'); + this.module.setValue(data, dims.length, type); for (let i = 0; i < dims.length; i++) { - this.module.setValue(data + ptrSize * (i + 1), dims[i], 'i32'); + this.module.setValue(data + ptrSize * (i + 1), dims[i], type); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 9a1d5463f7843..cd581a1239127 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -162,7 +162,7 @@ export class ShapeUtil { // eslint-disable-next-line max-len 'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.'); } - size *= dims[i]; + size *= Number(dims[i]); } return size; } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index ec2831a3cca04..9fbd78c86db29 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -219,7 +219,7 @@ const getWgslMappedType = (type: number, components: 1|2|3|4): string|[string, s } // return type is [ storage type, runtime type ] or a single string for both - switch (type) { + switch (Number(type)) { case DataType.float16: return components > 1 ? `vec${components}` : 'f16'; case DataType.float: diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6a777c596537f..a6c6f72def66c 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -189,7 +189,8 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] => if (errorCode !== 0) { checkLastError('Can\'t get session input/output count.'); } - return [wasm.getValue(dataOffset, '*'), wasm.getValue(dataOffset + ptrSize, '*')]; + const type = ptrSize === 4 ? 'i32' : 'i64'; + return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))]; } finally { wasm.stackRestore(stack); } @@ -467,7 +468,7 @@ export const prepareInputOutputTensor = const stack = wasm.stackSave(); const dimsOffset = wasm.stackAlloc(ptrSize * dims.length); try { - dims.forEach((d, index) => wasm.setValue(dimsOffset + (index * ptrSize), d, 'i32')); + dims.forEach((d, index) => wasm.setValue(dimsOffset + (index * ptrSize), d, ptrSize === 4 ? 'i32' : 'i64')); const tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, dataLocationStringToEnum(location)); @@ -600,7 +601,7 @@ export const run = async( const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.getValue(outputValuesOffset + i * ptrSize, '*'); + const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -620,14 +621,14 @@ export const run = async( if (errorCode !== 0) { checkLastError(`Can't access output tensor data on index ${i}.`); } - - const dataType = wasm.getValue(tensorDataOffset, '*'); + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + const dataType = Number(wasm.getValue(tensorDataOffset, valueType)); dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*'); - const dimsLength = wasm.getValue(tensorDataOffset + ptrSize * 3, '*'); + const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType)); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.getValue(dimsOffset + i * ptrSize, '*')); + dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType))); } wasm._OrtFree(dimsOffset); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index c6d70de983617..ff8f78fc52129 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -68,7 +68,8 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea const errorCode = wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + ptrSize, isEvalModel); ifErrCodeCheckLastError(errorCode, 'Can\'t get session input/output count.'); - return [wasm.getValue(dataOffset, 'i32'), wasm.getValue(dataOffset + ptrSize, 'i32')]; + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + return [Number(wasm.getValue(dataOffset, valueType)), Number(wasm.getValue(dataOffset + ptrSize, valueType))]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -212,14 +213,14 @@ const moveOutputToTensorMetadataArr = const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, tensorDataOffset + 4, tensorDataOffset + 8, tensorDataOffset + 12); ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - const dataType = wasm.getValue(tensorDataOffset, '*'); + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + const dataType = Number(wasm.getValue(tensorDataOffset, valueType)); dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); const dimsOffset = wasm.getValue(tensorDataOffset + 2 * ptrSize, '*'); - const dimsLength = wasm.getValue(tensorDataOffset + 3 * ptrSize, '*'); + const dimsLength = Number(wasm.getValue(tensorDataOffset + 3 * ptrSize, valueType)); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.getValue(dimsOffset + i * ptrSize, '*')); + Number(dims.push(wasm.getValue(dimsOffset + i * ptrSize, valueType))); } wasm._OrtFree(dimsOffset); diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts index 703ccb08addf5..bf65fa5efa75f 100644 --- a/js/web/lib/wasm/wasm-utils.ts +++ b/js/web/lib/wasm/wasm-utils.ts @@ -55,7 +55,7 @@ export const checkLastError = (message: string): void => { const ptrSize = wasm.PTR_SIZE; const paramsOffset = wasm.stackAlloc(2 * ptrSize); wasm._OrtGetLastError(paramsOffset, paramsOffset + ptrSize); - const errorCode = wasm.getValue(paramsOffset, 'i32'); + const errorCode = Number(wasm.getValue(paramsOffset, ptrSize === 4 ? 'i32' : 'i64')); const errorMessagePointer = wasm.getValue(paramsOffset + ptrSize, '*'); const errorMessage = errorMessagePointer ? wasm.UTF8ToString(errorMessagePointer) : ''; throw new Error(`${message} ERROR_CODE: ${errorCode}, ERROR_MESSAGE: ${errorMessage}`); From 07d6e3ea9fba0715db9758bc17b6c36beea3c5ab Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 11:44:48 -0700 Subject: [PATCH 32/45] Enable exception catching with wasm64 --- cmake/adjust_global_compile_flags.cmake | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 232b18032baf6..251f5d6bd62c2 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -52,10 +52,8 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING) - if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") - string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") - endif() + string(APPEND CMAKE_C_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") + string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) From b8bf40d641eab995903fd9229546043aee15dc4f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 15:59:17 -0700 Subject: [PATCH 33/45] Convert dims to Number --- js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts | 28 +++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index a094fffe239c4..6b25d6c934033 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -119,9 +119,11 @@ const createBinaryOpProgramShader = const createBinaryOpProgramInfo = (name: string, cacheKey: string, a: TensorView, b: TensorView, funcCall: BinaryFunctionCall, additionalImplementation?: string, outputDataType: number = a.dataType): ProgramInfo => { - const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); - let outputShape = a.dims; - let outputSize = ShapeUtil.size(a.dims); + const aDims = a.dims.map((x) => Number(x) ?? 1); + const bDims = b.dims.map((x) => Number(x) ?? 1); + const isBroadcast = !ShapeUtil.areEqual(aDims, bDims); + let outputShape = aDims; + let outputSize = ShapeUtil.size(aDims); let vectorize = false; let sharedDimensionDivisibleBy4 = false; @@ -129,16 +131,16 @@ const createBinaryOpProgramInfo = // TODO: deal with zero-sized tensors (eg. dims=[1,0]) const cacheKeyAux = [isBroadcast]; if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + const calculatedShape = BroadcastUtil.calcShape(aDims, bDims, false); if (!calculatedShape) { throw new Error('Can\'t perform binary op on the given tensors'); } - outputShape = calculatedShape; + outputShape = calculatedShape.slice(); outputSize = ShapeUtil.size(outputShape); - const isAOneElement = ShapeUtil.size(a.dims) === 1; - const isBOneElement = ShapeUtil.size(b.dims) === 1; - const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; - const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; + const isAOneElement = ShapeUtil.size(aDims) === 1; + const isBOneElement = ShapeUtil.size(bDims) === 1; + const aLastDimDivisibleBy4 = aDims.length > 0 && aDims[aDims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = bDims.length > 0 && bDims[bDims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); cacheKeyAux.push(aLastDimDivisibleBy4); @@ -146,8 +148,8 @@ const createBinaryOpProgramInfo = // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { - const dimA = a.dims[a.dims.length - i] ?? 1; - const dimB = b.dims[b.dims.length - i] ?? 1; + const dimA = aDims[aDims.length - i]; + const dimB = bDims[bDims.length - i]; if (dimA === dimB) { sharedDimension *= dimA; } else { @@ -173,14 +175,14 @@ const createBinaryOpProgramInfo = inputDependencies: ['rank', 'rank'], }, getShaderSource: (shaderHelper) => createBinaryOpProgramShader( - shaderHelper, a.dims, b.dims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, + shaderHelper, aDims, bDims, outputShape, vectorize, isBroadcast, sharedDimensionDivisibleBy4, funcCall, a.dataType, b.dataType, outputDataType, additionalImplementation), getRunData: () => ({ outputs: [{dims: outputShape, dataType: outputDataType}], dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */)}, programUniforms: [ {type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4)}, - ...createTensorShapeVariables(a.dims, b.dims, outputShape) + ...createTensorShapeVariables(aDims, bDims, outputShape) ], }), }; From 33e6f54d58daa42d984eef61e6310129057dfbc3 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 16:04:03 -0700 Subject: [PATCH 34/45] Make Ort api functions return --- onnxruntime/wasm/api.cc | 30 ++++++++++++++++++++---------- onnxruntime/wasm/api.h | 20 ++++++++++---------- 2 files changed, 30 insertions(+), 20 deletions(-) diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 72a7cbde03c01..347d0a6785894 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -95,9 +95,10 @@ int OrtInit(int num_threads, int logging_level) { #endif } -void OrtGetLastError(int* error_code, const char** error_message) { +int OrtGetLastError(int* error_code, const char** error_message) { *error_code = g_last_error_code; *error_message = g_last_error_message.empty() ? nullptr : g_last_error_message.c_str(); + return ORT_OK; } OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, @@ -178,8 +179,9 @@ int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, return CHECK_STATUS(AddSessionConfigEntry, session_options, config_key, config_value); } -void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { +int OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); + return ORT_OK; } OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { @@ -197,8 +199,9 @@ OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* : nullptr; } -void OrtReleaseSession(OrtSession* session) { +int OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); + return ORT_OK; } int OrtGetInputOutputCount(OrtSession* session, size_t* input_count, size_t* output_count) { @@ -227,11 +230,12 @@ char* OrtGetOutputName(OrtSession* session, size_t index) { : nullptr; } -void OrtFree(void* ptr) { +int OrtFree(void* ptr) { OrtAllocator* allocator = nullptr; if (CHECK_STATUS(GetAllocatorWithDefaultOptions, &allocator) == ORT_OK) { allocator->Free(allocator, ptr); } + return ORT_OK; } OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { @@ -358,8 +362,9 @@ int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** return ORT_OK; } -void OrtReleaseTensor(OrtValue* tensor) { +int OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); + return ORT_OK; } OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, @@ -394,8 +399,9 @@ int OrtAddRunConfigEntry(OrtRunOptions* run_options, return CHECK_STATUS(AddRunConfigEntry, run_options, config_key, config_value); } -void OrtReleaseRunOptions(OrtRunOptions* run_options) { +int OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); + return ORT_OK; } OrtIoBinding* OrtCreateBinding(OrtSession* session) { @@ -437,12 +443,14 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, } } -void OrtClearBoundOutputs(OrtIoBinding* io_binding) { +int OrtClearBoundOutputs(OrtIoBinding* io_binding) { Ort::GetApi().ClearBoundOutputs(io_binding); + return ORT_OK; } -void OrtReleaseBinding(OrtIoBinding* io_binding) { +int OrtReleaseBinding(OrtIoBinding* io_binding) { Ort::GetApi().ReleaseIoBinding(io_binding); + return ORT_OK; } int OrtRunWithBinding(OrtSession* session, @@ -512,8 +520,9 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( : nullptr; } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { Ort::GetTrainingApi().ReleaseCheckpointState(training_checkpoint_state_handle); + return ORT_OK; } ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(const ort_session_options_handle_t options, @@ -632,8 +641,9 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi } } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); + return ORT_OK; } #endif diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index f8b5dc49fb875..f44c515d98f6b 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -50,7 +50,7 @@ int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level); * @param error_code [out] a pointer to accept the error code. * @param error_message [out] a pointer to accept the error message. The message buffer is only available before any ORT API is called. */ -void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); +int EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); /** * create an instance of ORT session options. @@ -109,7 +109,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t s /** * release the specified ORT session options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); /** * create an instance of ORT session. @@ -124,7 +124,7 @@ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, /** * release the specified ORT session. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); /** * get model's input count and output count. @@ -158,7 +158,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t * free the specified buffer. * @param ptr a pointer to the buffer. */ -void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); +int EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); /** * create an instance of ORT tensor. @@ -188,7 +188,7 @@ int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* da /** * release the specified tensor. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); +int EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); /** * create an instance of ORT run options. @@ -218,7 +218,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio /** * release the specified ORT run options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); /** * create an instance of ORT IO binding. @@ -252,12 +252,12 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, /** * clear all bound outputs. */ -void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); /** * release the specified ORT IO binding. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); /** * inference the model. @@ -311,7 +311,7 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( * * @param training_checkpoint_state_handle handle for the CheckpointState */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); /** * Creates an instance of a training session that can be used to begin or resume training from a given checkpoint state @@ -466,7 +466,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi * * @param training_session_handle handle of the training session */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); #endif }; From b6604c6256c23af92a20ea024e5016f0d1b2a078 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 29 Jul 2024 16:14:39 -0700 Subject: [PATCH 35/45] Modified SIGNATURE_CONVERSIONS. --- cmake/onnxruntime_webassembly.cmake | 28 +++++++++++------------ onnxruntime/core/providers/js/js_kernel.h | 2 +- onnxruntime/wasm/api.cc | 2 +- 3 files changed, 15 insertions(+), 17 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 25de7d503cf81..833746214367e 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -168,11 +168,9 @@ else() "${ONNXRUNTIME_ROOT}/wasm/api.cc" "${ONNXRUNTIME_ROOT}/core/session/onnxruntime_c_api.cc" ) - if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") - message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") - set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) - endif() + set (WASM_API_EXCEPTION_CATCHING "-s DISABLE_EXCEPTION_CATCHING=0") + message(STATUS "onnxruntime_ENABLE_WEBASSEMBLY_EXCEPTION_CATCHING_ON_API set") + set_source_files_properties(${onnxruntime_webassembly_src_exc} PROPERTIES COMPILE_FLAGS ${WASM_API_EXCEPTION_CATCHING}) endif() target_link_libraries(onnxruntime_webassembly PRIVATE @@ -333,20 +331,20 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) - set(SIGNATURE_CONVERSIONS "OrtRun:_ppp_p_pp,\ -OrtRunWithBinding:_pp_pp,\ + set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\ +OrtRunWithBinding:_ppppp,\ OrtGetTensorData:_ppppp,\ -OrtCreateTensor:p_p_p__,\ -OrtCreateSession:pp_p,\ +OrtCreateTensor:p_pppp_,\ +OrtCreateSession:pppp,\ OrtReleaseSession:_p,\ OrtGetInputOutputCount:_ppp,\ -OrtCreateSessionOptions:p_____pp__p,\ +OrtCreateSessionOptions:pp__p_ppppp,\ OrtReleaseSessionOptions:_p,\ -OrtAppendExecutionProvider:ppp,\ -OrtAddSessionConfigEntry:_pp,\ -OrtGetInputName:pp_,\ -OrtGetOutputName:pp_,\ -OrtCreateRunOptions:p___p,\ +OrtAppendExecutionProvider:_pp,\ +OrtAddSessionConfigEntry:_ppp,\ +OrtGetInputName:ppp,\ +OrtGetOutputName:ppp,\ +OrtCreateRunOptions:ppp_p,\ OrtReleaseRunOptions:_p,\ OrtReleaseTensor:_p,\ OrtFree:_p,\ diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 5ed3b7f3e8131..68d89c96d96f7 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -202,7 +202,7 @@ class JsKernel : public OpKernel { intptr_t status_code = EM_ASM_INT( { return Module.jsepRunKernel(Number($0), Number($1), Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, - this, reinterpret_cast(p_serialized_kernel_context)); + this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 347d0a6785894..c2dc1eb474816 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -356,7 +356,7 @@ int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** *data = p_tensor_raw_data; } - *data_type = static_cast(type); + *data_type = static_cast(type); *dims_length = dims_len; *dims = UNREGISTER_AUTO_RELEASE(p_dims); return ORT_OK; From 3bf0347a33d7434de11234e850d5e5891227f050 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 30 Jul 2024 10:51:58 -0700 Subject: [PATCH 36/45] Fixed ORT api functions return type consistently --- cmake/onnxruntime_webassembly.cmake | 2 +- js/web/lib/wasm/wasm-types.ts | 20 ++++++++++---------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 833746214367e..e6db20f6cc44b 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -378,7 +378,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" - "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','Module.jsepDownload']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 9a4500822d457..6183650115856 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -123,27 +123,27 @@ export declare namespace JSEP { export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; - _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; + _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): number; _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; - _OrtReleaseSession(sessionHandle: number): void; + _OrtReleaseSession(sessionHandle: number): number; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; _OrtGetOutputName(sessionHandle: number, index: number): number; - _OrtFree(stringHandle: number): void; + _OrtFree(stringHandle: number): number; _OrtCreateTensor( dataType: number, dataOffset: number, dataLength: number, dimsOffset: number, dimsLength: number, dataLocation: number): number; _OrtGetTensorData(tensorHandle: number, dataType: number, dataOffset: number, dimsOffset: number, dimsLength: number): number; - _OrtReleaseTensor(tensorHandle: number): void; + _OrtReleaseTensor(tensorHandle: number): number; _OrtCreateBinding(sessionHandle: number): number; _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; - _OrtClearBoundOutputs(ioBindingHandle: number): void; - _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtClearBoundOutputs(ioBindingHandle: number): number; + _OrtReleaseBinding(ioBindingHandle: number): number; _OrtRunWithBinding( sessionHandle: number, ioBindingHandle: number, outputCount: number, outputsOffset: number, runOptionsHandle: number): Promise; @@ -158,11 +158,11 @@ export interface OrtInferenceAPIs { _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; + _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; _OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number; _OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseRunOptions(runOptionsHandle: number): void; + _OrtReleaseRunOptions(runOptionsHandle: number): number; _OrtEndProfiling(sessionHandle: number): number; } @@ -170,7 +170,7 @@ export interface OrtInferenceAPIs { export interface OrtTrainingAPIs { _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; + _OrtTrainingReleaseCheckpoint(checkpointHandle: number): number; _OrtTrainingCreateSession( sessionOptionsHandle: number, checkpointHandle: number, trainOffset: number, trainLength: number, @@ -201,7 +201,7 @@ export interface OrtTrainingAPIs { _OrtTrainingGetModelInputOutputName(trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; - _OrtTrainingReleaseSession(trainingHandle: number): void; + _OrtTrainingReleaseSession(trainingHandle: number): number; } /** From 8b1be685ac29c0279a748955a0ed679134ee0d83 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Tue, 30 Jul 2024 14:35:53 -0700 Subject: [PATCH 37/45] Skip jsepDownload --- cmake/onnxruntime_webassembly.cmake | 2 +- onnxruntime/core/providers/js/data_transfer.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index e6db20f6cc44b..497147ddb988c 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -378,7 +378,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" - "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopyAsync']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index 3809df2c82e4c..22e60ad967c27 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -37,7 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - jsepDownload(src_data, dst_data, bytes); + EM_ASM({Module.jsepCopyAsync(Number($0), Number($1), Number($2));}, src_data, dst_data, bytes); } } From ea944ff93c7ea7204a4b29a4f4f25073521c12c6 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 31 Jul 2024 15:13:04 -0700 Subject: [PATCH 38/45] Remove unnecessary SIGNATURE_CONVERSIONS. --- cmake/onnxruntime_webassembly.cmake | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index bffbaf5707b6d..75f4a207be864 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -356,10 +356,7 @@ OrtReleaseBinding:_p,\ OrtGetLastError:_pp,\ JsepOutput:pp_p,\ JsepGetNodeName:pp,\ -JsepOutput:pp_p,\ -jsepCopy:_pp_,\ -jsepCopyAsync:_pp_,\ -jsepDownload:_pp_") +JsepOutput:pp_p") target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" From 523527e9b08529dd3b78ed3de4042e67a6c62593 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 31 Jul 2024 15:16:11 -0700 Subject: [PATCH 39/45] Check return value --- js/web/lib/wasm/session-options.ts | 4 ++- js/web/lib/wasm/wasm-core-impl.ts | 37 ++++++++++++++++------ js/web/lib/wasm/wasm-training-core-impl.ts | 25 +++++++++++---- 3 files changed, 48 insertions(+), 18 deletions(-) diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index f289fc20bba40..d5c7f6d5df6a0 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -189,7 +189,9 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n return [sessionOptionsHandle, allocs]; } catch (e) { if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError('Can\'t release session options.'); + } } allocs.forEach(alloc => wasm._free(alloc)); throw e; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index a6c6f72def66c..1c5915a1e78ad 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -364,17 +364,23 @@ export const createSession = async( outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); if (ioBindingHandle !== 0) { - wasm._OrtReleaseBinding(ioBindingHandle); + if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) { + checkLastError('Can\'t release IO binding.'); + } } if (sessionHandle !== 0) { - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError('Can\'t release session.'); + } } throw e; } finally { wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError('Can\'t release session options.'); + } } allocs.forEach(alloc => wasm._free(alloc)); @@ -393,16 +399,22 @@ export const releaseSession = (sessionId: number): void => { if (ioBindingState) { if (enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if ( wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError('Can\'t clear bound outputs.'); + } + } + if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) { + checkLastError('Can\'t release IO binding.'); } - wasm._OrtReleaseBinding(ioBindingState.handle); } wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError('Can\'t release session.'); + } activeSessions.delete(sessionId); }; @@ -630,8 +642,9 @@ export const run = async( for (let i = 0; i < dimsLength; i++) { dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType))); } - wasm._OrtFree(dimsOffset); - + if (wasm._OrtFree(dimsOffset) !== 0) { + checkLastError('Can\'t free memory for tensor dims.'); + } const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); @@ -671,7 +684,9 @@ export const run = async( gpuBuffer, download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Can\'t release tensor.'); + } } }, 'gpu-buffer' @@ -696,7 +711,9 @@ export const run = async( } if (ioBindingState && !enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError('Can\'t clear bound outputs.'); + } activeSessions.set( sessionId, [sessionHandle, inputNamesUTF8Encoded, outputNamesUTF8Encoded, ioBindingState, enableGraphCapture, false]); diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index ff8f78fc52129..5d39398c80501 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -54,7 +54,9 @@ export const createCheckpointHandle = (checkpointData: SerializableInternalBuffe throw e; } finally { // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); + if (wasm._OrtFree(checkpointData[0]) !== 0) { + checkLastError('Error occurred when trying to free the checkpoint buffer'); + } } }; @@ -141,7 +143,9 @@ export const createTrainingSessionHandle = wasm._free(optimizerModelData[0]); if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError('Error occurred when trying to release the session options'); + } } allocs.forEach(alloc => wasm._free(alloc)); } @@ -222,8 +226,9 @@ const moveOutputToTensorMetadataArr = for (let i = 0; i < dimsLength; i++) { Number(dims.push(wasm.getValue(dimsOffset + i * ptrSize, valueType))); } - wasm._OrtFree(dimsOffset); - + if (wasm._OrtFree(dimsOffset) !== 0) { + checkLastError('Error occurred when trying to free the dims buffer'); + } const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); @@ -248,7 +253,9 @@ const moveOutputToTensorMetadataArr = if (type === 'string' && dataOffset) { wasm._free(dataOffset); } - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } } @@ -467,7 +474,9 @@ export const getContiguousParameters = } } finally { if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); + if ( wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } wasm._free(paramsOffset); wasm._free(dimsOffset); @@ -509,7 +518,9 @@ export const loadParametersBuffer = } } finally { if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } wasm.stackRestore(stack); wasm._free(bufferOffset); From cd5ed5c6443092136459afb4e4a8a45baca8f992 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 31 Jul 2024 20:29:23 -0700 Subject: [PATCH 40/45] Lint/format --- js/web/lib/wasm/wasm-core-impl.ts | 2 +- js/web/lib/wasm/wasm-training-core-impl.ts | 2 +- onnxruntime/core/providers/js/data_transfer.cc | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 1c5915a1e78ad..2edefda0d8e0d 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -399,7 +399,7 @@ export const releaseSession = (sessionId: number): void => { if (ioBindingState) { if (enableGraphCapture) { - if ( wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { checkLastError('Can\'t clear bound outputs.'); } } diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 5d39398c80501..e18e021ff9d1f 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -474,7 +474,7 @@ export const getContiguousParameters = } } finally { if (tensor !== 0) { - if ( wasm._OrtReleaseTensor(tensor) !== 0) { + if (wasm._OrtReleaseTensor(tensor) !== 0) { checkLastError('Error occurred when trying to release the tensor'); } } diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index 22e60ad967c27..e18bad836a223 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -37,7 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - EM_ASM({Module.jsepCopyAsync(Number($0), Number($1), Number($2));}, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopyAsync(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); } } From 80f3f497a22299a3e6b6f43add8d84febae8a897 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 7 Aug 2024 15:38:47 -0700 Subject: [PATCH 41/45] Revert "Remove unnecessary SIGNATURE_CONVERSIONS." This reverts commit ea944ff93c7ea7204a4b29a4f4f25073521c12c6. --- cmake/onnxruntime_webassembly.cmake | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 75f4a207be864..bffbaf5707b6d 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -356,7 +356,10 @@ OrtReleaseBinding:_p,\ OrtGetLastError:_pp,\ JsepOutput:pp_p,\ JsepGetNodeName:pp,\ -JsepOutput:pp_p") +JsepOutput:pp_p,\ +jsepCopy:_pp_,\ +jsepCopyAsync:_pp_,\ +jsepDownload:_pp_") target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" From 26d5dda4b560528bd8584cb77e2aafe8586a3e23 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Mon, 12 Aug 2024 14:19:01 -0700 Subject: [PATCH 42/45] Revert "Skip jsepDownload" This reverts commit 8b1be685ac29c0279a748955a0ed679134ee0d83. --- cmake/onnxruntime_webassembly.cmake | 2 +- onnxruntime/core/providers/js/data_transfer.cc | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index bffbaf5707b6d..efb1420173656 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -378,7 +378,7 @@ jsepDownload:_pp_") "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" - "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopyAsync']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index e18bad836a223..3809df2c82e4c 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -37,7 +37,7 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU - EM_ASM({ Module.jsepCopyAsync(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); + jsepDownload(src_data, dst_data, bytes); } } From e8bc234e1914b21fe20949bd40e9c2284aecf457 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Wed, 14 Aug 2024 05:59:06 -0700 Subject: [PATCH 43/45] User Number convertion. --- js/web/lib/wasm/jsep/init.ts | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 0846fc711748c..9e7ade7b8a41d 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -82,12 +82,12 @@ class ComputeContextImpl implements ComputeContext { const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = module.getValue(ptrSize * dataIndex++, type); - const data = module.getValue(ptrSize * dataIndex++, '*'); - const dim = module.getValue(ptrSize * dataIndex++, type); + const dataType = Number(module.getValue(ptrSize * dataIndex++, type)); + const data = Number(module.getValue(ptrSize * dataIndex++, '*')); + const dim = Number(module.getValue(ptrSize * dataIndex++, type)); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(module.getValue(ptrSize * dataIndex++, type)); + dims.push(Number(module.getValue(ptrSize * dataIndex++, type))); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } From bcfb312bd7f3e1a2871e30d0c26df1729aa0c050 Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 16 Aug 2024 10:49:14 -0700 Subject: [PATCH 44/45] Added Number conversion --- .../core/providers/js/operators/pool.h | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 66bcde86020b6..32556eeaeefe4 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -3,22 +3,22 @@ #pragma once -#include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/pool_base.h" +#include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ - "format" : $13 ? "NHWC" : "NCHW", \ - "auto_pad" : $1, \ - "ceil_mode" : $2, \ - "count_include_pad" : $3, \ - "storage_order" : $4, \ - "dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \ - "kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \ - "pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \ - "strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \ +#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ + "format" : $13 ? "NHWC" : "NCHW", \ + "auto_pad" : $1, \ + "ceil_mode" : $2, \ + "count_include_pad" : $3, \ + "storage_order" : $4, \ + "dilations" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], \ + "kernel_shape" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], \ + "pads" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], \ + "strides" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [] \ }) #define POOL_ATTRIBUTES_PARAM_LIST \ From 8a7ecdb397084ff1e7f8bd0e2e3bd42a469a9b5f Mon Sep 17 00:00:00 2001 From: Satya Jandhyala Date: Fri, 16 Aug 2024 10:49:54 -0700 Subject: [PATCH 45/45] Formatting and other minor changes. --- js/web/lib/wasm/jsep/init.ts | 2 +- js/web/lib/wasm/session-options.ts | 2 +- js/web/lib/wasm/wasm-core-impl.ts | 92 ++++++++++++++++-------------- 3 files changed, 50 insertions(+), 46 deletions(-) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index e0f13631bfca7..58e6696c807d5 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -80,7 +80,7 @@ class ComputeContextImpl implements ComputeContext { // extract context data const ptrSize = module.PTR_SIZE; - let dataIndex = module.PTR_SIZE === 4 ? contextDataOffset >> 2 : contextDataOffset / 2 ** 3; + let dataIndex = contextDataOffset / module.PTR_SIZE; const type = ptrSize === 4 ? 'i32' : 'i64'; this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type)); const inputCount = Number(module.getValue(ptrSize * dataIndex++, type)); diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index 1446cad0f6e14..17e564247863d 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -201,7 +201,7 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n } catch (e) { if (sessionOptionsHandle !== 0) { if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { - checkLastError('Can\'t release session options.'); + checkLastError("Can't release session options."); } } allocs.forEach((alloc) => wasm._free(alloc)); diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 82789bcaafc47..592f725fa7ca1 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -401,13 +401,13 @@ export const createSession = async ( if (ioBindingHandle !== 0) { if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) { - checkLastError('Can\'t release IO binding.'); + checkLastError("Can't release IO binding."); } } if (sessionHandle !== 0) { if (wasm._OrtReleaseSession(sessionHandle) !== 0) { - checkLastError('Can\'t release session.'); + checkLastError("Can't release session."); } } throw e; @@ -415,7 +415,7 @@ export const createSession = async ( wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { - checkLastError('Can\'t release session options.'); + checkLastError("Can't release session options."); } } allocs.forEach((alloc) => wasm._free(alloc)); @@ -436,20 +436,20 @@ export const releaseSession = (sessionId: number): void => { if (ioBindingState) { if (enableGraphCapture) { if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { - checkLastError('Can\'t clear bound outputs.'); + checkLastError("Can't clear bound outputs."); } } if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) { - checkLastError('Can\'t release IO binding.'); + checkLastError("Can't release IO binding."); } } wasm.jsepOnReleaseSession?.(sessionId); - inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); - if (wasm._OrtReleaseSession(sessionHandle) !=== 0) { - checkLastError('Can\'t release session.'); + inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError("Can't release session."); } activeSessions.delete(sessionId); }; @@ -500,41 +500,45 @@ export const prepareInputOutputTensor = ( } else { const data = tensor[2]; - if (Array.isArray(data)) { - // string tensor - dataByteLength = ptrSize * data.length; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - for (let i = 0; i < data.length; i++) { - if (typeof data[i] !== 'string') { - throw new TypeError(`tensor data at index ${i} is not a string`); - } - wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); - } - } else { - dataByteLength = data.byteLength; - rawData = wasm._malloc(dataByteLength); - allocs.push(rawData); - wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + if (Array.isArray(data)) { + // string tensor + dataByteLength = ptrSize * data.length; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + for (let i = 0; i < data.length; i++) { + if (typeof data[i] !== 'string') { + throw new TypeError(`tensor data at index ${i} is not a string`); } + wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } + } else { + dataByteLength = data.byteLength; + rawData = wasm._malloc(dataByteLength); + allocs.push(rawData); + wasm.HEAPU8.set(new Uint8Array(data.buffer, data.byteOffset, dataByteLength), rawData); + } + } - const stack = wasm.stackSave(); - const dimsOffset = wasm.stackAlloc(4 * dims.length); - try { - let dimIndex = dimsOffset / 4; - dims.forEach((d) => wasm.HEAP32[dimIndex++] = d); - const tensor = wasm._OrtCreateTensor( - tensorDataTypeStringToEnum(dataType), rawData, dataByteLength, dimsOffset, dims.length, - dataLocationStringToEnum(location)); - if (tensor === 0) { - checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); - } - tensorHandles.push(tensor); - } finally { - wasm.stackRestore(stack); - } - }; + const stack = wasm.stackSave(); + const dimsOffset = wasm.stackAlloc(4 * dims.length); + try { + dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64')); + const tensor = wasm._OrtCreateTensor( + tensorDataTypeStringToEnum(dataType), + rawData, + dataByteLength, + dimsOffset, + dims.length, + dataLocationStringToEnum(location), + ); + if (tensor === 0) { + checkLastError(`Can't create tensor for input/output. session=${sessionId}, index=${index}.`); + } + tensorHandles.push(tensor); + } finally { + wasm.stackRestore(stack); + } +}; /** * perform inference run @@ -732,7 +736,7 @@ export const run = async ( dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType))); } if (wasm._OrtFree(dimsOffset) !== 0) { - checkLastError('Can\'t free memory for tensor dims.'); + checkLastError("Can't free memory for tensor dims."); } const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); @@ -776,7 +780,7 @@ export const run = async ( download: wasm.jsepCreateDownloader!(gpuBuffer, size * elementSize, type), dispose: () => { if (wasm._OrtReleaseTensor(tensor) !== 0) { - checkLastError('Can\'t release tensor.'); + checkLastError("Can't release tensor."); } }, }, @@ -804,7 +808,7 @@ export const run = async ( if (ioBindingState && !enableGraphCapture) { if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { - checkLastError('Can\'t clear bound outputs.'); + checkLastError("Can't clear bound outputs."); } activeSessions.set(sessionId, [ sessionHandle,