diff --git a/cmake/CMakeLists.txt b/cmake/CMakeLists.txt index 2e9a50e522171..65ae2ecadd60f 100644 --- a/cmake/CMakeLists.txt +++ b/cmake/CMakeLists.txt @@ -201,6 +201,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF) option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF) +option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF) # Enable bitcode for iOS option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF) diff --git a/cmake/adjust_global_compile_flags.cmake b/cmake/adjust_global_compile_flags.cmake index 6eb784a4063ed..251f5d6bd62c2 100644 --- a/cmake/adjust_global_compile_flags.cmake +++ b/cmake/adjust_global_compile_flags.cmake @@ -56,6 +56,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten") string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + string(APPEND CMAKE_C_FLAGS " -DORT_WASM64") + string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64") + endif() + # Build WebAssembly with multi-threads support. if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS) string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth") diff --git a/cmake/onnxruntime_webassembly.cmake b/cmake/onnxruntime_webassembly.cmake index 0686b66876d9f..efb1420173656 100644 --- a/cmake/onnxruntime_webassembly.cmake +++ b/cmake/onnxruntime_webassembly.cmake @@ -193,7 +193,7 @@ else() re2::re2 ) - set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'") + set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'") if (onnxruntime_USE_XNNPACK) target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK) @@ -215,10 +215,109 @@ else() set(EXPORTED_FUNCTIONS "_malloc,_free") endif() + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(MAXIMUM_MEMORY "17179869184") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s MEMORY64=1" + ) + string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental") + string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental") + set(SMEMORY_FLAG "-sMEMORY64") + + target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + # target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental) + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js" + ) + else () + set(MAXIMUM_MEMORY "4294967296") + target_link_options(onnxruntime_webassembly PRIVATE + --post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js" + ) + endif () + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]" "SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}" - "SHELL:-s MAXIMUM_MEMORY=4294967296" + "SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}" "SHELL:-s EXIT_RUNTIME=0" "SHELL:-s ALLOW_MEMORY_GROWTH=1" "SHELL:-s MODULARIZE=1" @@ -231,6 +330,41 @@ else() --no-entry "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\"" ) + if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\ +OrtRunWithBinding:_ppppp,\ +OrtGetTensorData:_ppppp,\ +OrtCreateTensor:p_pppp_,\ +OrtCreateSession:pppp,\ +OrtReleaseSession:_p,\ +OrtGetInputOutputCount:_ppp,\ +OrtCreateSessionOptions:pp__p_ppppp,\ +OrtReleaseSessionOptions:_p,\ +OrtAppendExecutionProvider:_pp,\ +OrtAddSessionConfigEntry:_ppp,\ +OrtGetInputName:ppp,\ +OrtGetOutputName:ppp,\ +OrtCreateRunOptions:ppp_p,\ +OrtReleaseRunOptions:_p,\ +OrtReleaseTensor:_p,\ +OrtFree:_p,\ +OrtCreateBinding:_p,\ +OrtBindInput:_ppp,\ +OrtBindOutput:_ppp_,\ +OrtClearBoundOutputs:_p,\ +OrtReleaseBinding:_p,\ +OrtGetLastError:_pp,\ +JsepOutput:pp_p,\ +JsepGetNodeName:pp,\ +JsepOutput:pp_p,\ +jsepCopy:_pp_,\ +jsepCopyAsync:_pp_,\ +jsepDownload:_pp_") + target_link_options(onnxruntime_webassembly PRIVATE + "SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0" + "SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'" + ) + endif () set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js) if (onnxruntime_USE_JSEP) @@ -243,6 +377,8 @@ else() "SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\"" "SHELL:-s ASYNCIFY=1" "SHELL:-s ASYNCIFY_STACK_SIZE=65536" + "SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']" + "SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']" ) set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js) endif() @@ -279,7 +415,9 @@ else() endif() # Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions. - target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64) + target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0") + endif() if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING) target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs) diff --git a/js/web/lib/wasm/jsep/init.ts b/js/web/lib/wasm/jsep/init.ts index 3f326881079f0..9909d262fffe4 100644 --- a/js/web/lib/wasm/jsep/init.ts +++ b/js/web/lib/wasm/jsep/init.ts @@ -77,24 +77,25 @@ class ComputeContextImpl implements ComputeContext { contextDataOffset: number, ) { this.adapterInfo = backend.adapterInfo; - const heapU32 = module.HEAPU32; // extract context data - let dataIndex = contextDataOffset >>> 2; - this.opKernelContext = heapU32[dataIndex++]; - const inputCount = heapU32[dataIndex++]; - this.outputCount = heapU32[dataIndex++]; - this.customDataOffset = heapU32[dataIndex++]; - this.customDataSize = heapU32[dataIndex++]; + const ptrSize = module.PTR_SIZE; + let dataIndex = contextDataOffset / module.PTR_SIZE; + const type = ptrSize === 4 ? 'i32' : 'i64'; + this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type)); + const inputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type)); + this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*')); + this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type)); const inputs: TensorView[] = []; for (let i = 0; i < inputCount; i++) { - const dataType = heapU32[dataIndex++]; - const data = heapU32[dataIndex++]; - const dim = heapU32[dataIndex++]; + const dataType = Number(module.getValue(ptrSize * dataIndex++, type)); + const data = Number(module.getValue(ptrSize * dataIndex++, '*')); + const dim = Number(module.getValue(ptrSize * dataIndex++, type)); const dims: number[] = []; for (let d = 0; d < dim; d++) { - dims.push(heapU32[dataIndex++]); + dims.push(Number(module.getValue(ptrSize * dataIndex++, type))); } inputs.push(new TensorViewImpl(module, dataType, data, dims)); } @@ -142,11 +143,12 @@ class ComputeContextImpl implements ComputeContext { output(index: number, dims: readonly number[]): number { const stack = this.module.stackSave(); try { - const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */); - let offset = data >> 2; - this.module.HEAPU32[offset++] = dims.length; + const ptrSize = this.module.PTR_SIZE; + const type = ptrSize === 4 ? 'i32' : 'i64'; + const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */); + this.module.setValue(data, dims.length, type); for (let i = 0; i < dims.length; i++) { - this.module.HEAPU32[offset++] = dims[i]; + this.module.setValue(data + ptrSize * (i + 1), dims[i], type); } return this.module._JsepOutput!(this.opKernelContext, index, data); } catch (e) { @@ -213,12 +215,19 @@ export const init = async ( // jsepCopy(src, dst, size, isSourceGpu) (src: number, dst: number, size: number, isSourceGpu = false) => { if (isSourceGpu) { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`); - backend.memcpy(src, dst); + LOG_DEBUG( + 'verbose', + () => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`, + ); + backend.memcpy(Number(src), Number(dst)); } else { - LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`); - const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size); - backend.upload(dst, data); + LOG_DEBUG( + 'verbose', + () => + `[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`, + ); + const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size)); + backend.upload(Number(dst), data); } }, @@ -229,12 +238,19 @@ export const init = async ( () => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`, ); - await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size)); + await backend.download(Number(gpuDataId), () => + module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0), + ); }, // jsepCreateKernel (kernelType: string, kernelId: number, attribute: unknown) => - backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))), + backend.createKernel( + kernelType, + Number(kernelId), + attribute, + module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))), + ), // jsepReleaseKernel (kernel: number) => backend.releaseKernel(kernel), @@ -246,8 +262,8 @@ export const init = async ( () => `[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`, ); - const context = new ComputeContextImpl(module, backend, contextDataOffset); - return backend.computeKernel(kernel, context, errors); + const context = new ComputeContextImpl(module, backend, Number(contextDataOffset)); + return backend.computeKernel(Number(kernel), context, errors); }, // jsepCaptureBegin () => backend.captureBegin(), diff --git a/js/web/lib/wasm/jsep/util.ts b/js/web/lib/wasm/jsep/util.ts index 5ae16d5625dc8..85aca96057df2 100644 --- a/js/web/lib/wasm/jsep/util.ts +++ b/js/web/lib/wasm/jsep/util.ts @@ -167,7 +167,7 @@ export class ShapeUtil { 'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.', ); } - size *= dims[i]; + size *= Number(dims[i]); } return size; } diff --git a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts index 8e18a28acc364..638c3c991bd95 100644 --- a/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts +++ b/js/web/lib/wasm/jsep/webgpu/gpu-data-manager.ts @@ -112,7 +112,7 @@ const bucketArr: number[] = []; /** * normalize the buffer size so that it fits the 128-bits (16 bytes) alignment. */ -const calcNormalizedBufferSize = (size: number) => Math.ceil(size / 16) * 16; +const calcNormalizedBufferSize = (size: number) => Math.ceil(Number(size) / 16) * 16; /** * calculate the buffer size so that it fits into buckets. @@ -295,9 +295,7 @@ class GpuDataManagerImpl implements GpuDataManager { LOG_DEBUG( 'verbose', () => - `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${ - id - }, buffer is the same, skip.`, + `[WebGPU] GpuDataManager.registerExternalBuffer(size=${originalSize}) => id=${id}, buffer is the same, skip.`, ); return id; } else if (this.backend.capturedCommandList.has(this.backend.currentSessionId!)) { @@ -358,7 +356,7 @@ class GpuDataManagerImpl implements GpuDataManager { } const gpuData = { id: createNewGpuDataId(), type: GpuDataType.default, buffer: gpuBuffer }; - this.storageCache.set(gpuData.id, { gpuData, originalSize: size }); + this.storageCache.set(gpuData.id, { gpuData, originalSize: Number(size) }); LOG_DEBUG('verbose', () => `[WebGPU] GpuDataManager.create(size=${size}) => id=${gpuData.id}`); return gpuData; @@ -368,7 +366,8 @@ class GpuDataManagerImpl implements GpuDataManager { return this.storageCache.get(id)?.gpuData; } - release(id: GpuDataId): number { + release(idInput: GpuDataId): number { + const id = typeof idInput === 'bigint' ? Number(idInput) : idInput; const cachedData = this.storageCache.get(id); if (!cachedData) { throw new Error('releasing data does not exist'); @@ -384,7 +383,7 @@ class GpuDataManagerImpl implements GpuDataManager { } async download(id: GpuDataId, getTargetBuffer: () => Uint8Array): Promise { - const cachedData = this.storageCache.get(id); + const cachedData = this.storageCache.get(Number(id)); if (!cachedData) { throw new Error('data does not exist'); } diff --git a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts index 53c2ca2fa47d6..c695a71568c97 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/binary-op.ts @@ -143,9 +143,11 @@ const createBinaryOpProgramInfo = ( additionalImplementation?: string, outputDataType: number = a.dataType, ): ProgramInfo => { - const isBroadcast = !ShapeUtil.areEqual(a.dims, b.dims); - let outputShape = a.dims; - let outputSize = ShapeUtil.size(a.dims); + const aDims = a.dims.map((x) => Number(x) ?? 1); + const bDims = b.dims.map((x) => Number(x) ?? 1); + const isBroadcast = !ShapeUtil.areEqual(aDims, bDims); + let outputShape = aDims; + let outputSize = ShapeUtil.size(aDims); let vectorize = false; let sharedDimensionDivisibleBy4 = false; @@ -153,16 +155,16 @@ const createBinaryOpProgramInfo = ( // TODO: deal with zero-sized tensors (eg. dims=[1,0]) const cacheKeyAux = [isBroadcast]; if (isBroadcast) { - const calculatedShape = BroadcastUtil.calcShape(a.dims, b.dims, false); + const calculatedShape = BroadcastUtil.calcShape(aDims, bDims, false); if (!calculatedShape) { throw new Error("Can't perform binary op on the given tensors"); } - outputShape = calculatedShape; + outputShape = calculatedShape.slice(); outputSize = ShapeUtil.size(outputShape); - const isAOneElement = ShapeUtil.size(a.dims) === 1; - const isBOneElement = ShapeUtil.size(b.dims) === 1; - const aLastDimDivisibleBy4 = a.dims.length > 0 && a.dims[a.dims.length - 1] % 4 === 0; - const bLastDimDivisibleBy4 = b.dims.length > 0 && b.dims[b.dims.length - 1] % 4 === 0; + const isAOneElement = ShapeUtil.size(aDims) === 1; + const isBOneElement = ShapeUtil.size(bDims) === 1; + const aLastDimDivisibleBy4 = aDims.length > 0 && aDims[aDims.length - 1] % 4 === 0; + const bLastDimDivisibleBy4 = bDims.length > 0 && bDims[bDims.length - 1] % 4 === 0; cacheKeyAux.push(isAOneElement); cacheKeyAux.push(isBOneElement); cacheKeyAux.push(aLastDimDivisibleBy4); @@ -170,8 +172,8 @@ const createBinaryOpProgramInfo = ( // check whether vectorize can be enabled let sharedDimension = 1; for (let i = 1; i < outputShape.length; i++) { - const dimA = a.dims[a.dims.length - i] ?? 1; - const dimB = b.dims[b.dims.length - i] ?? 1; + const dimA = aDims[aDims.length - i]; + const dimB = bDims[bDims.length - i]; if (dimA === dimB) { sharedDimension *= dimA; } else { @@ -199,8 +201,8 @@ const createBinaryOpProgramInfo = ( getShaderSource: (shaderHelper) => createBinaryOpProgramShader( shaderHelper, - a.dims, - b.dims, + aDims, + bDims, outputShape, vectorize, isBroadcast, @@ -216,7 +218,7 @@ const createBinaryOpProgramInfo = ( dispatchGroup: { x: Math.ceil(outputSize / 64 /* workgroup size */ / 4 /* component size */) }, programUniforms: [ { type: DataType.uint32, data: Math.ceil(ShapeUtil.size(outputShape) / 4) }, - ...createTensorShapeVariables(a.dims, b.dims, outputShape), + ...createTensorShapeVariables(aDims, bDims, outputShape), ], }), }; @@ -280,9 +282,7 @@ export const pow = (context: ComputeContext): void => { } else if (a < ${type}(0.0) && f32(b) != floor(f32(b))) { return ${type}(pow(f32(a), f32(b))); // NaN } - return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${ - roundStr - }(pow(f32(abs(a)), f32(b)))); + return select(sign(a), ${type}(1.0), round(f32(abs(b) % ${type}(2.0))) != 1.0) * ${type}(${roundStr}(pow(f32(abs(a)), f32(b)))); } fn pow_vector_custom(a : vec4<${type}>, b : vec4<${type}>) -> vec4<${type}> { // TODO: implement vectorized pow diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 7696f22d44abd..4780f04bc2351 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -219,7 +219,7 @@ const getWgslMappedType = (type: number, components: 1 | 2 | 3 | 4): string | [s } // return type is [ storage type, runtime type ] or a single string for both - switch (type) { + switch (Number(type)) { case DataType.float16: return components > 1 ? `vec${components}` : 'f16'; case DataType.float: diff --git a/js/web/lib/wasm/session-options.ts b/js/web/lib/wasm/session-options.ts index b2594267a595a..17e564247863d 100644 --- a/js/web/lib/wasm/session-options.ts +++ b/js/web/lib/wasm/session-options.ts @@ -200,7 +200,9 @@ export const setSessionOptions = (options?: InferenceSession.SessionOptions): [n return [sessionOptionsHandle, allocs]; } catch (e) { if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError("Can't release session options."); + } } allocs.forEach((alloc) => wasm._free(alloc)); throw e; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 6c4e28df62f23..bde852fd7c52d 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -207,12 +207,14 @@ const getSessionInputOutputCount = (sessionHandle: number): [number, number] => const wasm = getInstance(); const stack = wasm.stackSave(); try { - const dataOffset = wasm.stackAlloc(8); - const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + 4); + const ptrSize = wasm.PTR_SIZE; + const dataOffset = wasm.stackAlloc(2 * ptrSize); + const errorCode = wasm._OrtGetInputOutputCount(sessionHandle, dataOffset, dataOffset + ptrSize); if (errorCode !== 0) { checkLastError("Can't get session input/output count."); } - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + const type = ptrSize === 4 ? 'i32' : 'i64'; + return [Number(wasm.getValue(dataOffset, type)), Number(wasm.getValue(dataOffset + ptrSize, type))]; } finally { wasm.stackRestore(stack); } @@ -396,17 +398,23 @@ export const createSession = async ( outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); if (ioBindingHandle !== 0) { - wasm._OrtReleaseBinding(ioBindingHandle); + if (wasm._OrtReleaseBinding(ioBindingHandle) !== 0) { + checkLastError("Can't release IO binding."); + } } if (sessionHandle !== 0) { - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError("Can't release session."); + } } throw e; } finally { wasm._free(modelDataOffset); if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError("Can't release session options."); + } } allocs.forEach((alloc) => wasm._free(alloc)); @@ -425,16 +433,22 @@ export const releaseSession = (sessionId: number): void => { if (ioBindingState) { if (enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError("Can't clear bound outputs."); + } + } + if (wasm._OrtReleaseBinding(ioBindingState.handle) !== 0) { + checkLastError("Can't release IO binding."); } - wasm._OrtReleaseBinding(ioBindingState.handle); } wasm.jsepOnReleaseSession?.(sessionId); inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); outputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf)); - wasm._OrtReleaseSession(sessionHandle); + if (wasm._OrtReleaseSession(sessionHandle) !== 0) { + checkLastError("Can't release session."); + } activeSessions.delete(sessionId); }; @@ -452,6 +466,7 @@ export const prepareInputOutputTensor = ( } const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const dataType = tensor[0]; const dims = tensor[1]; @@ -484,15 +499,14 @@ export const prepareInputOutputTensor = ( if (Array.isArray(data)) { // string tensor - dataByteLength = 4 * data.length; + dataByteLength = ptrSize * data.length; rawData = wasm._malloc(dataByteLength); allocs.push(rawData); - let dataIndex = rawData / 4; for (let i = 0; i < data.length; i++) { if (typeof data[i] !== 'string') { throw new TypeError(`tensor data at index ${i} is not a string`); } - wasm.HEAPU32[dataIndex++] = allocWasmString(data[i], allocs); + wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*'); } } else { dataByteLength = data.byteLength; @@ -505,8 +519,7 @@ export const prepareInputOutputTensor = ( const stack = wasm.stackSave(); const dimsOffset = wasm.stackAlloc(4 * dims.length); try { - let dimIndex = dimsOffset / 4; - dims.forEach((d) => (wasm.HEAP32[dimIndex++] = d)); + dims.forEach((d, index) => wasm.setValue(dimsOffset + index * ptrSize, d, ptrSize === 4 ? 'i32' : 'i64')); const tensor = wasm._OrtCreateTensor( tensorDataTypeStringToEnum(dataType), rawData, @@ -536,6 +549,7 @@ export const run = async ( options: InferenceSession.RunOptions, ): Promise => { const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const session = activeSessions.get(sessionId); if (!session) { throw new Error(`cannot run inference. invalid session id: ${sessionId}`); @@ -558,10 +572,10 @@ export const run = async ( const inputOutputAllocs: number[] = []; const beforeRunStack = wasm.stackSave(); - const inputValuesOffset = wasm.stackAlloc(inputCount * 4); - const inputNamesOffset = wasm.stackAlloc(inputCount * 4); - const outputValuesOffset = wasm.stackAlloc(outputCount * 4); - const outputNamesOffset = wasm.stackAlloc(outputCount * 4); + const inputValuesOffset = wasm.stackAlloc(inputCount * ptrSize); + const inputNamesOffset = wasm.stackAlloc(inputCount * ptrSize); + const outputValuesOffset = wasm.stackAlloc(outputCount * ptrSize); + const outputNamesOffset = wasm.stackAlloc(outputCount * ptrSize); try { [runOptionsHandle, runOptionsAllocs] = setRunOptions(options); @@ -590,17 +604,13 @@ export const run = async ( ); } - let inputValuesIndex = inputValuesOffset / 4; - let inputNamesIndex = inputNamesOffset / 4; - let outputValuesIndex = outputValuesOffset / 4; - let outputNamesIndex = outputNamesOffset / 4; for (let i = 0; i < inputCount; i++) { - wasm.HEAPU32[inputValuesIndex++] = inputTensorHandles[i]; - wasm.HEAPU32[inputNamesIndex++] = inputNamesUTF8Encoded[inputIndices[i]]; + wasm.setValue(inputValuesOffset + i * ptrSize, inputTensorHandles[i], '*'); + wasm.setValue(inputNamesOffset + i * ptrSize, inputNamesUTF8Encoded[inputIndices[i]], '*'); } for (let i = 0; i < outputCount; i++) { - wasm.HEAPU32[outputValuesIndex++] = outputTensorHandles[i]; - wasm.HEAPU32[outputNamesIndex++] = outputNamesUTF8Encoded[outputIndices[i]]; + wasm.setValue(outputValuesOffset + i * ptrSize, outputTensorHandles[i], '*'); + wasm.setValue(outputNamesOffset + i * ptrSize, outputNamesUTF8Encoded[outputIndices[i]], '*'); } if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState && !inputOutputBound) { @@ -685,7 +695,7 @@ export const run = async ( const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const tensor = Number(wasm.getValue(outputValuesOffset + i * ptrSize, '*')); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -694,7 +704,7 @@ export const run = async ( const beforeGetTensorDataStack = wasm.stackSave(); // stack allocate 4 pointer value - const tensorDataOffset = wasm.stackAlloc(4 * 4); + const tensorDataOffset = wasm.stackAlloc(4 * ptrSize); let keepOutputTensor = false; let type: Tensor.Type | undefined, @@ -703,24 +713,26 @@ export const run = async ( const errorCode = wasm._OrtGetTensorData( tensor, tensorDataOffset, - tensorDataOffset + 4, - tensorDataOffset + 8, - tensorDataOffset + 12, + tensorDataOffset + ptrSize, + tensorDataOffset + 2 * ptrSize, + + tensorDataOffset + 3 * ptrSize, ); if (errorCode !== 0) { checkLastError(`Can't access output tensor data on index ${i}.`); } - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + const dataType = Number(wasm.getValue(tensorDataOffset, valueType)); + dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); + const dimsOffset = wasm.getValue(tensorDataOffset + ptrSize * 2, '*'); + const dimsLength = Number(wasm.getValue(tensorDataOffset + ptrSize * 3, valueType)); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + dims.push(Number(wasm.getValue(dimsOffset + i * ptrSize, valueType))); + } + if (wasm._OrtFree(dimsOffset) !== 0) { + checkLastError("Can't free memory for tensor dims."); } - wasm._OrtFree(dimsOffset); - const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); @@ -731,10 +743,10 @@ export const run = async ( throw new Error('String tensor is not supported on GPU.'); } const stringData: string[] = []; - let dataIndex = dataOffset / 4; for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + const offset = wasm.getValue(dataOffset + i * ptrSize, '*'); + const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*'); + const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); @@ -762,7 +774,9 @@ export const run = async ( gpuBuffer, download: wasm.jsepCreateDownloader!(gpuBuffer, bufferSize, type), dispose: () => { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError("Can't release tensor."); + } }, }, 'gpu-buffer', @@ -788,7 +802,9 @@ export const run = async ( } if (ioBindingState && !enableGraphCapture) { - wasm._OrtClearBoundOutputs(ioBindingState.handle); + if (wasm._OrtClearBoundOutputs(ioBindingState.handle) !== 0) { + checkLastError("Can't clear bound outputs."); + } activeSessions.set(sessionId, [ sessionHandle, inputNamesUTF8Encoded, diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts index 22cd6ec30732c..145e5702a3dd6 100644 --- a/js/web/lib/wasm/wasm-training-core-impl.ts +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -59,7 +59,9 @@ export const createCheckpointHandle = (checkpointData: SerializableInternalBuffe throw e; } finally { // free buffer from wasm heap - wasm._OrtFree(checkpointData[0]); + if (wasm._OrtFree(checkpointData[0]) !== 0) { + checkLastError('Error occurred when trying to free the checkpoint buffer'); + } } }; @@ -67,16 +69,18 @@ const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolea const wasm = getInstance(); const stack = wasm.stackSave(); try { - const dataOffset = wasm.stackAlloc(8); + const ptrSize = wasm.PTR_SIZE; + const dataOffset = wasm.stackAlloc(2 * ptrSize); if (wasm._OrtTrainingGetModelInputOutputCount) { const errorCode = wasm._OrtTrainingGetModelInputOutputCount( trainingSessionId, dataOffset, - dataOffset + 4, + dataOffset + ptrSize, isEvalModel, ); ifErrCodeCheckLastError(errorCode, "Can't get session input/output count."); - return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + return [Number(wasm.getValue(dataOffset, valueType)), Number(wasm.getValue(dataOffset + ptrSize, valueType))]; } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -163,7 +167,9 @@ export const createTrainingSessionHandle = ( wasm._free(optimizerModelData[0]); if (sessionOptionsHandle !== 0) { - wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + if (wasm._OrtReleaseSessionOptions(sessionOptionsHandle) !== 0) { + checkLastError('Error occurred when trying to release the session options'); + } } allocs.forEach((alloc) => wasm._free(alloc)); } @@ -198,10 +204,10 @@ const createAndAllocateTensors = ( // moves to heap const wasm = getInstance(); - const valuesOffset = wasm.stackAlloc(count * 4); - let valuesIndex = valuesOffset / 4; + const ptrSize = wasm.PTR_SIZE; + const valuesOffset = wasm.stackAlloc(count * ptrSize); for (let i = 0; i < count; i++) { - wasm.HEAPU32[valuesIndex++] = tensorHandles[i]; + wasm.setValue(valuesOffset + i * ptrSize, tensorHandles[i], '*'); } return valuesOffset; @@ -222,10 +228,11 @@ const moveOutputToTensorMetadataArr = ( outputTensors: Array, ) => { const wasm = getInstance(); + const ptrSize = wasm.PTR_SIZE; const output: TensorMetadata[] = []; for (let i = 0; i < outputCount; i++) { - const tensor = wasm.HEAPU32[outputValuesOffset / 4 + i]; + const tensor = wasm.getValue(outputValuesOffset + i * ptrSize, '*'); if (tensor === outputTensorHandles[i]) { // output tensor is pre-allocated. no need to copy data. output.push(outputTensors[i]!); @@ -247,27 +254,27 @@ const moveOutputToTensorMetadataArr = ( tensorDataOffset + 12, ); ifErrCodeCheckLastError(errorCode, `Can't access output tensor data on index ${i}.`); - - let tensorDataIndex = tensorDataOffset / 4; - const dataType = wasm.HEAPU32[tensorDataIndex++]; - dataOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsOffset = wasm.HEAPU32[tensorDataIndex++]; - const dimsLength = wasm.HEAPU32[tensorDataIndex++]; + const valueType = ptrSize === 4 ? 'i32' : 'i64'; + const dataType = Number(wasm.getValue(tensorDataOffset, valueType)); + dataOffset = wasm.getValue(tensorDataOffset + ptrSize, '*'); + const dimsOffset = wasm.getValue(tensorDataOffset + 2 * ptrSize, '*'); + const dimsLength = Number(wasm.getValue(tensorDataOffset + 3 * ptrSize, valueType)); const dims = []; for (let i = 0; i < dimsLength; i++) { - dims.push(wasm.HEAPU32[dimsOffset / 4 + i]); + Number(dims.push(wasm.getValue(dimsOffset + i * ptrSize, valueType))); + } + if (wasm._OrtFree(dimsOffset) !== 0) { + checkLastError('Error occurred when trying to free the dims buffer'); } - wasm._OrtFree(dimsOffset); - const size = dims.reduce((a, b) => a * b, 1); type = tensorDataTypeEnumToString(dataType); if (type === 'string') { const stringData: string[] = []; - let dataIndex = dataOffset / 4; for (let i = 0; i < size; i++) { - const offset = wasm.HEAPU32[dataIndex++]; - const maxBytesToRead = i === size - 1 ? undefined : wasm.HEAPU32[dataIndex] - offset; + const offset = wasm.getValue(dataOffset + i * ptrSize, '*'); + const nextOffset = wasm.getValue(dataOffset + (i + 1) * ptrSize, '*'); + const maxBytesToRead = i === size - 1 ? undefined : nextOffset - offset; stringData.push(wasm.UTF8ToString(offset, maxBytesToRead)); } output.push([type, dims, stringData, 'cpu']); @@ -284,7 +291,9 @@ const moveOutputToTensorMetadataArr = ( if (type === 'string' && dataOffset) { wasm._free(dataOffset); } - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } } @@ -482,14 +491,14 @@ export const runEvalStep = async ( export const getParametersSize = (trainingSessionId: number, trainableOnly: boolean): number => { const wasm = getInstance(); const stack = wasm.stackSave(); - + const ptrSize = wasm.PTR_SIZE; try { - const sizeOffset = wasm.stackAlloc(4); + const sizeOffset = wasm.stackAlloc(ptrSize); if (wasm._OrtTrainingGetParametersSize) { const errorCode = wasm._OrtTrainingGetParametersSize(trainingSessionId, sizeOffset, trainableOnly); ifErrCodeCheckLastError(errorCode, "Can't get parameters size"); - return wasm.HEAP32[sizeOffset / 4]; + return wasm.getValue(sizeOffset, '*'); } else { throw new Error(NO_TRAIN_FUNCS_MSG); } @@ -520,7 +529,7 @@ export const getContiguousParameters = async ( const dimsOffset = wasm.stackAlloc(4); const dimsIndex = dimsOffset / 4; - wasm.HEAP32[dimsIndex] = parametersSize; + wasm.setValue(dimsIndex, parametersSize, '*'); try { // wraps allocated array in a tensor @@ -561,7 +570,9 @@ export const getContiguousParameters = async ( } } finally { if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } wasm._free(paramsOffset); wasm._free(dimsOffset); @@ -587,8 +598,8 @@ export const loadParametersBuffer = async ( wasm.HEAPU8.set(buffer, bufferOffset); // allocates and handles moving dimensions information to WASM memory - const dimsOffset = wasm.stackAlloc(4); - wasm.HEAP32[dimsOffset / 4] = bufferCount; + const dimsOffset = wasm.stackAlloc(wasm.PTR_SIZE); + wasm.setValue(dimsOffset, bufferCount, '*'); const dimsLength = 1; let tensor = 0; @@ -611,7 +622,9 @@ export const loadParametersBuffer = async ( } } finally { if (tensor !== 0) { - wasm._OrtReleaseTensor(tensor); + if (wasm._OrtReleaseTensor(tensor) !== 0) { + checkLastError('Error occurred when trying to release the tensor'); + } } wasm.stackRestore(stack); wasm._free(bufferOffset); diff --git a/js/web/lib/wasm/wasm-types.ts b/js/web/lib/wasm/wasm-types.ts index 70b6cceab0eef..6c2ee06314eb0 100644 --- a/js/web/lib/wasm/wasm-types.ts +++ b/js/web/lib/wasm/wasm-types.ts @@ -140,15 +140,15 @@ export declare namespace JSEP { export interface OrtInferenceAPIs { _OrtInit(numThreads: number, loggingLevel: number): number; - _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): void; + _OrtGetLastError(errorCodeOffset: number, errorMessageOffset: number): number; _OrtCreateSession(dataOffset: number, dataLength: number, sessionOptionsHandle: number): Promise; - _OrtReleaseSession(sessionHandle: number): void; + _OrtReleaseSession(sessionHandle: number): number; _OrtGetInputOutputCount(sessionHandle: number, inputCountOffset: number, outputCountOffset: number): number; _OrtGetInputName(sessionHandle: number, index: number): number; _OrtGetOutputName(sessionHandle: number, index: number): number; - _OrtFree(stringHandle: number): void; + _OrtFree(stringHandle: number): number; _OrtCreateTensor( dataType: number, @@ -165,12 +165,12 @@ export interface OrtInferenceAPIs { dimsOffset: number, dimsLength: number, ): number; - _OrtReleaseTensor(tensorHandle: number): void; + _OrtReleaseTensor(tensorHandle: number): number; _OrtCreateBinding(sessionHandle: number): number; _OrtBindInput(bindingHandle: number, nameOffset: number, tensorHandle: number): Promise; _OrtBindOutput(bindingHandle: number, nameOffset: number, tensorHandle: number, location: number): number; - _OrtClearBoundOutputs(ioBindingHandle: number): void; - _OrtReleaseBinding(ioBindingHandle: number): void; + _OrtClearBoundOutputs(ioBindingHandle: number): number; + _OrtReleaseBinding(ioBindingHandle: number): number; _OrtRunWithBinding( sessionHandle: number, ioBindingHandle: number, @@ -204,11 +204,11 @@ export interface OrtInferenceAPIs { _OrtAppendExecutionProvider(sessionOptionsHandle: number, name: number): number; _OrtAddFreeDimensionOverride(sessionOptionsHandle: number, name: number, dim: number): number; _OrtAddSessionConfigEntry(sessionOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseSessionOptions(sessionOptionsHandle: number): void; + _OrtReleaseSessionOptions(sessionOptionsHandle: number): number; _OrtCreateRunOptions(logSeverityLevel: number, logVerbosityLevel: number, terminate: boolean, tag: number): number; _OrtAddRunConfigEntry(runOptionsHandle: number, configKey: number, configValue: number): number; - _OrtReleaseRunOptions(runOptionsHandle: number): void; + _OrtReleaseRunOptions(runOptionsHandle: number): number; _OrtEndProfiling(sessionHandle: number): number; } @@ -216,7 +216,7 @@ export interface OrtInferenceAPIs { export interface OrtTrainingAPIs { _OrtTrainingLoadCheckpoint(dataOffset: number, dataLength: number): number; - _OrtTrainingReleaseCheckpoint(checkpointHandle: number): void; + _OrtTrainingReleaseCheckpoint(checkpointHandle: number): number; _OrtTrainingCreateSession( sessionOptionsHandle: number, @@ -280,7 +280,7 @@ export interface OrtTrainingAPIs { isEvalModel: boolean, ): number; - _OrtTrainingReleaseSession(trainingHandle: number): void; + _OrtTrainingReleaseSession(trainingHandle: number): number; } /** @@ -291,10 +291,13 @@ export interface OrtWasmModule OrtInferenceAPIs, Partial, Partial { + PTR_SIZE: number; // #region emscripten functions stackSave(): number; stackRestore(stack: number): void; stackAlloc(size: number): number; + getValue(ptr: number, type: string): number; + setValue(ptr: number, value: number, type: string): void; UTF8ToString(offset: number, maxBytesToRead?: number): string; lengthBytesUTF8(str: string): number; diff --git a/js/web/lib/wasm/wasm-utils.ts b/js/web/lib/wasm/wasm-utils.ts index a820fd216ee03..9ce39c366dc77 100644 --- a/js/web/lib/wasm/wasm-utils.ts +++ b/js/web/lib/wasm/wasm-utils.ts @@ -55,10 +55,11 @@ export const checkLastError = (message: string): void => { const stack = wasm.stackSave(); try { - const paramsOffset = wasm.stackAlloc(8); - wasm._OrtGetLastError(paramsOffset, paramsOffset + 4); - const errorCode = wasm.HEAP32[paramsOffset / 4]; - const errorMessagePointer = wasm.HEAPU32[paramsOffset / 4 + 1]; + const ptrSize = wasm.PTR_SIZE; + const paramsOffset = wasm.stackAlloc(2 * ptrSize); + wasm._OrtGetLastError(paramsOffset, paramsOffset + ptrSize); + const errorCode = Number(wasm.getValue(paramsOffset, ptrSize === 4 ? 'i32' : 'i64')); + const errorMessagePointer = wasm.getValue(paramsOffset + ptrSize, '*'); const errorMessage = errorMessagePointer ? wasm.UTF8ToString(errorMessagePointer) : ''; throw new Error(`${message} ERROR_CODE: ${errorCode}, ERROR_MESSAGE: ${errorMessage}`); } finally { diff --git a/onnxruntime/core/framework/tensorprotoutils.cc b/onnxruntime/core/framework/tensorprotoutils.cc index 42f491825462c..6257e1b8bb955 100644 --- a/onnxruntime/core/framework/tensorprotoutils.cc +++ b/onnxruntime/core/framework/tensorprotoutils.cc @@ -1036,9 +1036,9 @@ Status GetExtDataFromTensorProto(const Env& env, const std::filesystem::path& mo if (!fileData) { return 2; // File not found in preloaded files. } - const offset = $1 >>> 0; - const length = $2 >>> 0; - const buffer = $3 >>> 0; + const offset = Number($1 >>> 0); + const length = Number($2 >>> 0); + const buffer = Number($3 >>> 0); if (offset + length > fileData.byteLength) { return 3; // Out of bounds. diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index d38c1ace7d7a8..f885f249bfd1e 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -552,8 +552,8 @@ static Status SaveModel(Model& model, const T& file_path) { model_proto.SerializeToArray(buffer, buffer_size); EM_ASM(({ - const buffer = $0; - const buffer_size = $1; + const buffer = Number($0); + const buffer_size = Number($1); const file_path = UTF8ToString($2); const bytes = new Uint8Array(buffer_size); bytes.set(HEAPU8.subarray(buffer, buffer + buffer_size)); diff --git a/onnxruntime/core/providers/js/data_transfer.cc b/onnxruntime/core/providers/js/data_transfer.cc index ebea041b80128..3809df2c82e4c 100644 --- a/onnxruntime/core/providers/js/data_transfer.cc +++ b/onnxruntime/core/providers/js/data_transfer.cc @@ -6,7 +6,7 @@ #include "core/providers/js/data_transfer.h" EM_ASYNC_JS(void, jsepDownload, (const void* src_data, void* dst_data, size_t bytes), { - await Module.jsepCopyAsync(src_data, dst_data, bytes); + await Module.jsepCopyAsync(Number(src_data), Number(dst_data), Number(bytes)); }); namespace onnxruntime { @@ -30,10 +30,10 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const { if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::GPU) { // copy from GPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2, true); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2), true); }, src_data, dst_data, bytes); } else { // copy from CPU to GPU - EM_ASM({ Module.jsepCopy($0, $1, $2); }, src_data, dst_data, bytes); + EM_ASM({ Module.jsepCopy(Number($0), Number($1), Number($2)); }, src_data, dst_data, bytes); } } else /* if (src_device.Type() == OrtDevice::GPU) */ { // copy from GPU to CPU diff --git a/onnxruntime/core/providers/js/js_export.cc b/onnxruntime/core/providers/js/js_export.cc index 2402bb33ce9d0..f99e90bcb13f6 100644 --- a/onnxruntime/core/providers/js/js_export.cc +++ b/onnxruntime/core/providers/js/js_export.cc @@ -6,8 +6,8 @@ #include "core/framework/op_kernel.h" const void* JsepOutput(void* context, int index, const void* data) { - const uint32_t* data_offset = reinterpret_cast(data); - uint32_t dim = *data_offset++; + const uintptr_t* data_offset = reinterpret_cast(data); + uintptr_t dim = *data_offset++; size_t dim_size = static_cast(dim); std::vector dims(dim_size); for (size_t i = 0; i < dim_size; i++) { diff --git a/onnxruntime/core/providers/js/js_kernel.h b/onnxruntime/core/providers/js/js_kernel.h index 7324b0d69474c..68d89c96d96f7 100644 --- a/onnxruntime/core/providers/js/js_kernel.h +++ b/onnxruntime/core/providers/js/js_kernel.h @@ -110,16 +110,17 @@ class JsKernel : public OpKernel { temp_data_size += sizeof(size_t) * 3; } } - uint32_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); + uintptr_t* p_serialized_kernel_context = reinterpret_cast(alloc->Alloc(temp_data_size)); if (p_serialized_kernel_context == nullptr) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to allocate memory for serialized kernel context."); } - p_serialized_kernel_context[0] = reinterpret_cast(context); - p_serialized_kernel_context[1] = static_cast(context->InputCount()); - p_serialized_kernel_context[2] = static_cast(context->OutputCount()); - p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); - p_serialized_kernel_context[4] = static_cast(custom_data_size); + p_serialized_kernel_context[0] = reinterpret_cast(context); + p_serialized_kernel_context[1] = static_cast(context->InputCount()); + p_serialized_kernel_context[2] = static_cast(context->OutputCount()); + p_serialized_kernel_context[3] = reinterpret_cast(custom_data_ptr); + p_serialized_kernel_context[4] = static_cast(custom_data_size); + size_t index = 5; for (int i = 0; i < context->InputCount(); i++) { const auto* input_ptr = context->Input(i); @@ -130,11 +131,11 @@ class JsKernel : public OpKernel { p_serialized_kernel_context[index++] = 0; continue; } - p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); - p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->GetElementType()); + p_serialized_kernel_context[index++] = reinterpret_cast(input_ptr->DataRaw()); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape().NumDimensions()); for (size_t d = 0; d < input_ptr->Shape().NumDimensions(); d++) { - p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); + p_serialized_kernel_context[index++] = static_cast(input_ptr->Shape()[d]); } } @@ -199,9 +200,9 @@ class JsKernel : public OpKernel { return status; } - int status_code = EM_ASM_INT( - { return Module.jsepRunKernel($0, $1, Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, - this, reinterpret_cast(p_serialized_kernel_context)); + intptr_t status_code = EM_ASM_INT( + { return Module.jsepRunKernel(Number($0), Number($1), Module.jsepSessionState.sessionHandle, Module.jsepSessionState.errors); }, + this, reinterpret_cast(p_serialized_kernel_context)); LOGS_DEFAULT(VERBOSE) << "outputs = " << context->OutputCount() << ". Y.data=" << (size_t)(context->Output(0)->DataRaw()) << "."; diff --git a/onnxruntime/core/providers/js/operators/conv.h b/onnxruntime/core/providers/js/operators/conv.h index 0357c2f02a7a2..b04df44954295 100644 --- a/onnxruntime/core/providers/js/operators/conv.h +++ b/onnxruntime/core/providers/js/operators/conv.h @@ -51,14 +51,14 @@ class ConvBase : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(Conv, ({ "format" : $11 ? "NHWC" : "NCHW", "auto_pad" : $1, - "dilations" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "dilations" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "group" : $4, - "kernel_shape" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], - "pads" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], - "strides" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "w_is_const" : () JS_ARROW(!!HEAP8[$12]), + "kernel_shape" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], + "pads" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], + "strides" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "w_is_const" : () JS_ARROW(!!HEAP8[Number($12)]), "activation" : UTF8ToString($13), - "activation_params" : $14 ? Array.from(HEAPF32.subarray($14, $15)) : [] + "activation_params" : $14 ? Array.from(HEAPF32.subarray(Number($14), Number($15))) : [] }), static_cast(conv_attrs_.auto_pad), JSEP_HEAP32_INDEX_START(dilations), diff --git a/onnxruntime/core/providers/js/operators/conv_transpose.h b/onnxruntime/core/providers/js/operators/conv_transpose.h index c51bf5ce9d4a6..5ff52e8fda4fa 100644 --- a/onnxruntime/core/providers/js/operators/conv_transpose.h +++ b/onnxruntime/core/providers/js/operators/conv_transpose.h @@ -48,8 +48,8 @@ class ConvTranspose : public JsKernel { "pads" : [ $5, $6 ], "strides" : [$7], "wIsConst" : () JS_ARROW(!!HEAP8[$9]), - "outputPadding" : $10 ? Array.from(HEAP32.subarray($10, $11)) : [], - "outputShape" : $12 ? Array.from(HEAP32.subarray($12, $13)) : [], + "outputPadding" : $10 ? Array.from(HEAP32.subarray(Number($10), Number($11))) : [], + "outputShape" : $12 ? Array.from(HEAP32.subarray(Number($12), Number($13))) : [], "activation" : UTF8ToString($14) }), static_cast(conv_transpose_attrs_.auto_pad), @@ -99,14 +99,14 @@ class ConvTranspose : public JsKernel { JSEP_INIT_KERNEL_ATTRIBUTE(ConvTranspose, ({ "format" : $7 ? "NHWC" : "NCHW", "autoPad" : $1, - "dilations" : Array.from(HEAP32.subarray($2, ($2 >>> 0) + /* dialations_vec_size */ 2)), + "dilations" : Array.from(HEAP32.subarray(Number($2), (Number($2) >>> 0) + /* dialations_vec_size */ 2)), "group" : $3, - "kernelShape" : Array.from(HEAP32.subarray($4, ($4 >>> 0) + /* kernel_shape_vec_size */ 2)), - "pads" : Array.from(HEAP32.subarray($5, ($5 >>> 0) + /* pads_vec_size */ 4)), - "strides" : Array.from(HEAP32.subarray($6, ($6 >>> 0) + /* strides_vec_size */ 2)), + "kernelShape" : Array.from(HEAP32.subarray(Number($4), (Number($4) >>> 0) + /* kernel_shape_vec_size */ 2)), + "pads" : Array.from(HEAP32.subarray(Number($5), (Number($5) >>> 0) + /* pads_vec_size */ 4)), + "strides" : Array.from(HEAP32.subarray(Number($6), (Number($6) >>> 0) + /* strides_vec_size */ 2)), "wIsConst" : () JS_ARROW(!!HEAP8[$8]), - "outputPadding" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], - "outputShape" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [], + "outputPadding" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], + "outputShape" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [], "activation" : UTF8ToString($13) }), static_cast(conv_transpose_attrs_.auto_pad), diff --git a/onnxruntime/core/providers/js/operators/gather.cc b/onnxruntime/core/providers/js/operators/gather.cc index 485cd3da9b91b..e9c6f5c79294f 100644 --- a/onnxruntime/core/providers/js/operators/gather.cc +++ b/onnxruntime/core/providers/js/operators/gather.cc @@ -15,11 +15,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 10, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -30,11 +26,7 @@ ONNX_OPERATOR_VERSIONED_KERNEL_EX( 12, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); @@ -44,11 +36,7 @@ ONNX_OPERATOR_KERNEL_EX( 13, kJsExecutionProvider, (*KernelDefBuilder::Create()) - .TypeConstraint("T", BuildKernelDefConstraintsFromTypeList>()) + .TypeConstraint("T", JsepSupportedDataTypes()) .TypeConstraint("Tind", BuildKernelDefConstraintsFromTypeList>()), Gather); diff --git a/onnxruntime/core/providers/js/operators/pad.h b/onnxruntime/core/providers/js/operators/pad.h index c18c7dd456dc2..f656462285bc4 100644 --- a/onnxruntime/core/providers/js/operators/pad.h +++ b/onnxruntime/core/providers/js/operators/pad.h @@ -22,7 +22,7 @@ class Pad : public JsKernel, public PadBase { JSEP_INIT_KERNEL_ATTRIBUTE(Pad, ({"mode" : $1, "value" : $2, - "pads" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "pads" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(mode_), static_cast(value_), JSEP_HEAP32_INDEX_START(pads), diff --git a/onnxruntime/core/providers/js/operators/pool.h b/onnxruntime/core/providers/js/operators/pool.h index 66bcde86020b6..32556eeaeefe4 100644 --- a/onnxruntime/core/providers/js/operators/pool.h +++ b/onnxruntime/core/providers/js/operators/pool.h @@ -3,22 +3,22 @@ #pragma once -#include "core/providers/js/js_kernel.h" #include "core/providers/cpu/nn/pool_base.h" +#include "core/providers/js/js_kernel.h" namespace onnxruntime { namespace js { -#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ - "format" : $13 ? "NHWC" : "NCHW", \ - "auto_pad" : $1, \ - "ceil_mode" : $2, \ - "count_include_pad" : $3, \ - "storage_order" : $4, \ - "dilations" : $5 ? Array.from(HEAP32.subarray($5, $6)) : [], \ - "kernel_shape" : $7 ? Array.from(HEAP32.subarray($7, $8)) : [], \ - "pads" : $9 ? Array.from(HEAP32.subarray($9, $10)) : [], \ - "strides" : $11 ? Array.from(HEAP32.subarray($11, $12)) : [] \ +#define POOL_ATTRIBUTES_JS_OBJ_MAPPING ({ \ + "format" : $13 ? "NHWC" : "NCHW", \ + "auto_pad" : $1, \ + "ceil_mode" : $2, \ + "count_include_pad" : $3, \ + "storage_order" : $4, \ + "dilations" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : [], \ + "kernel_shape" : $7 ? Array.from(HEAP32.subarray(Number($7), Number($8))) : [], \ + "pads" : $9 ? Array.from(HEAP32.subarray(Number($9), Number($10))) : [], \ + "strides" : $11 ? Array.from(HEAP32.subarray(Number($11), Number($12))) : [] \ }) #define POOL_ATTRIBUTES_PARAM_LIST \ diff --git a/onnxruntime/core/providers/js/operators/reduce.h b/onnxruntime/core/providers/js/operators/reduce.h index 937f1f990dc67..4ae558f9dfc00 100644 --- a/onnxruntime/core/providers/js/operators/reduce.h +++ b/onnxruntime/core/providers/js/operators/reduce.h @@ -8,29 +8,29 @@ namespace onnxruntime { namespace js { -#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ - template \ - class ReduceKernel : public JsKernel, public ReduceKernelBase { \ - public: \ - using ReduceKernelBase::axes_; \ - using ReduceKernelBase::noop_with_empty_axes_; \ - using ReduceKernelBase::keepdims_; \ - ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ - std::vector axes(axes_.size()); \ - if (axes_.size() > 0) { \ - std::transform(axes_.begin(), axes_.end(), axes.begin(), \ - [](int64_t axis) { return gsl::narrow_cast(axis); }); \ - } \ - JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ - "keepDims" : !!$1, \ - "noopWithEmptyAxes" : !!$2, \ - "axes" : $3 ? (Array.from(HEAP32.subarray($3, $4))) : [], \ - }), \ - static_cast(keepdims_), \ - static_cast(noop_with_empty_axes_), \ - JSEP_HEAP32_INDEX_START(axes), \ - JSEP_HEAP32_INDEX_END(axes)); \ - } \ +#define JSEP_DEFINE_REDUCE_KERNEL(ReduceKernel) \ + template \ + class ReduceKernel : public JsKernel, public ReduceKernelBase { \ + public: \ + using ReduceKernelBase::axes_; \ + using ReduceKernelBase::noop_with_empty_axes_; \ + using ReduceKernelBase::keepdims_; \ + ReduceKernel(const OpKernelInfo& info) : JsKernel(info), ReduceKernelBase(info) { \ + std::vector axes(axes_.size()); \ + if (axes_.size() > 0) { \ + std::transform(axes_.begin(), axes_.end(), axes.begin(), \ + [](int64_t axis) { return gsl::narrow_cast(axis); }); \ + } \ + JSEP_INIT_KERNEL_ATTRIBUTE(ReduceKernel, ({ \ + "keepDims" : !!$1, \ + "noopWithEmptyAxes" : !!$2, \ + "axes" : $3 ? (Array.from(HEAP32.subarray(Number($3), Number($4)))) : [], \ + }), \ + static_cast(keepdims_), \ + static_cast(noop_with_empty_axes_), \ + JSEP_HEAP32_INDEX_START(axes), \ + JSEP_HEAP32_INDEX_END(axes)); \ + } \ }; JSEP_DEFINE_REDUCE_KERNEL(ReduceMax); diff --git a/onnxruntime/core/providers/js/operators/resize.h b/onnxruntime/core/providers/js/operators/resize.h index 134eb4bf5a7f4..3e8ccf40753c8 100644 --- a/onnxruntime/core/providers/js/operators/resize.h +++ b/onnxruntime/core/providers/js/operators/resize.h @@ -23,7 +23,7 @@ class Resize : public JsKernel, public UpsampleBase { std::transform(axes_.begin(), axes_.end(), std::back_inserter(axes), [](auto& axis) { return gsl::narrow_cast(axis); }); JSEP_INIT_KERNEL_ATTRIBUTE(Resize, ({ "antialias" : $1, - "axes" : $2 ? Array.from(HEAP32.subarray($2, $3)) : [], + "axes" : $2 ? Array.from(HEAP32.subarray(Number($2), Number($3))) : [], "coordinateTransformMode" : UTF8ToString($4), "cubicCoeffA" : $5, "excludeOutside" : $6, diff --git a/onnxruntime/core/providers/js/operators/slice.h b/onnxruntime/core/providers/js/operators/slice.h index daeffaa664741..f30e7bf01ec7b 100644 --- a/onnxruntime/core/providers/js/operators/slice.h +++ b/onnxruntime/core/providers/js/operators/slice.h @@ -20,9 +20,9 @@ class Slice : public JsKernel, public SliceBase { std::vector starts(attr_starts.begin(), attr_starts.end()); std::vector ends(attr_ends.begin(), attr_ends.end()); - JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [], - "ends" : $3 ? Array.from(HEAP32.subarray($3, $4)) : [], - "axes" : $5 ? Array.from(HEAP32.subarray($5, $6)) : []}), + JSEP_INIT_KERNEL_ATTRIBUTE(Slice, ({"starts" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [], + "ends" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : [], + "axes" : $5 ? Array.from(HEAP32.subarray(Number($5), Number($6))) : []}), JSEP_HEAP32_INDEX_START(starts), JSEP_HEAP32_INDEX_END(starts), JSEP_HEAP32_INDEX_START(ends), diff --git a/onnxruntime/core/providers/js/operators/split.h b/onnxruntime/core/providers/js/operators/split.h index 4fdbab00e739c..3f6cfcb8921f3 100644 --- a/onnxruntime/core/providers/js/operators/split.h +++ b/onnxruntime/core/providers/js/operators/split.h @@ -49,7 +49,7 @@ class Split : public JsKernel, public SplitBase { JSEP_INIT_KERNEL_ATTRIBUTE(Split, ({"axis" : $1, "numOutputs" : $2, - "splitSizes" : $3 ? Array.from(HEAP32.subarray($3, $4)) : []}), + "splitSizes" : $3 ? Array.from(HEAP32.subarray(Number($3), Number($4))) : []}), static_cast(axis_), static_cast(num_outputs_), JSEP_HEAP32_INDEX_START(split_sizes), diff --git a/onnxruntime/core/providers/js/operators/transpose.h b/onnxruntime/core/providers/js/operators/transpose.h index 7a945471c7701..f6b2b4faba850 100644 --- a/onnxruntime/core/providers/js/operators/transpose.h +++ b/onnxruntime/core/providers/js/operators/transpose.h @@ -21,7 +21,7 @@ class Transpose final : public JsKernel, public TransposeBase { } } JSEP_INIT_KERNEL_ATTRIBUTE(Transpose, ({ - "perm" : $1 ? Array.from(HEAP32.subarray($1, $2)) : [] + "perm" : $1 ? Array.from(HEAP32.subarray(Number($1), Number($2))) : [] }), JSEP_HEAP32_INDEX_START(perm), JSEP_HEAP32_INDEX_END(perm)); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 0e58bb4f93f7f..c2dc1eb474816 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -27,7 +27,9 @@ enum DataLocation { }; static_assert(sizeof(const char*) == sizeof(size_t), "size of a pointer and a size_t value should be the same."); +#ifndef ORT_WASM64 static_assert(sizeof(size_t) == 4, "size of size_t should be 4 in this build (wasm32)."); +#endif OrtErrorCode CheckStatus(OrtStatusPtr status) { if (status) { @@ -93,9 +95,10 @@ int OrtInit(int num_threads, int logging_level) { #endif } -void OrtGetLastError(int* error_code, const char** error_message) { +int OrtGetLastError(int* error_code, const char** error_message) { *error_code = g_last_error_code; *error_message = g_last_error_message.empty() ? nullptr : g_last_error_message.c_str(); + return ORT_OK; } OrtSessionOptions* OrtCreateSessionOptions(size_t graph_optimization_level, @@ -176,8 +179,9 @@ int OrtAddSessionConfigEntry(OrtSessionOptions* session_options, return CHECK_STATUS(AddSessionConfigEntry, session_options, config_key, config_value); } -void OrtReleaseSessionOptions(OrtSessionOptions* session_options) { +int OrtReleaseSessionOptions(OrtSessionOptions* session_options) { Ort::GetApi().ReleaseSessionOptions(session_options); + return ORT_OK; } OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* session_options) { @@ -195,8 +199,9 @@ OrtSession* OrtCreateSession(void* data, size_t data_length, OrtSessionOptions* : nullptr; } -void OrtReleaseSession(OrtSession* session) { +int OrtReleaseSession(OrtSession* session) { Ort::GetApi().ReleaseSession(session); + return ORT_OK; } int OrtGetInputOutputCount(OrtSession* session, size_t* input_count, size_t* output_count) { @@ -225,11 +230,12 @@ char* OrtGetOutputName(OrtSession* session, size_t index) { : nullptr; } -void OrtFree(void* ptr) { +int OrtFree(void* ptr) { OrtAllocator* allocator = nullptr; if (CHECK_STATUS(GetAllocatorWithDefaultOptions, &allocator) == ORT_OK) { allocator->Free(allocator, ptr); } + return ORT_OK; } OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* dims, size_t dims_length, int data_location) { @@ -280,7 +286,7 @@ OrtValue* OrtCreateTensor(int data_type, void* data, size_t data_length, size_t* } } -int OrtGetTensorData(OrtValue* tensor, int* data_type, void** data, size_t** dims, size_t* dims_length) { +int OrtGetTensorData(OrtValue* tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length) { ONNXType tensor_type; RETURN_ERROR_CODE_IF_ERROR(GetValueType, tensor, &tensor_type); if (tensor_type != ONNX_TYPE_TENSOR) { @@ -350,14 +356,15 @@ int OrtGetTensorData(OrtValue* tensor, int* data_type, void** data, size_t** dim *data = p_tensor_raw_data; } - *data_type = static_cast(type); + *data_type = static_cast(type); *dims_length = dims_len; *dims = UNREGISTER_AUTO_RELEASE(p_dims); return ORT_OK; } -void OrtReleaseTensor(OrtValue* tensor) { +int OrtReleaseTensor(OrtValue* tensor) { Ort::GetApi().ReleaseValue(tensor); + return ORT_OK; } OrtRunOptions* OrtCreateRunOptions(size_t log_severity_level, @@ -392,8 +399,9 @@ int OrtAddRunConfigEntry(OrtRunOptions* run_options, return CHECK_STATUS(AddRunConfigEntry, run_options, config_key, config_value); } -void OrtReleaseRunOptions(OrtRunOptions* run_options) { +int OrtReleaseRunOptions(OrtRunOptions* run_options) { Ort::GetApi().ReleaseRunOptions(run_options); + return ORT_OK; } OrtIoBinding* OrtCreateBinding(OrtSession* session) { @@ -435,12 +443,14 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(OrtIoBinding* io_binding, } } -void OrtClearBoundOutputs(OrtIoBinding* io_binding) { +int OrtClearBoundOutputs(OrtIoBinding* io_binding) { Ort::GetApi().ClearBoundOutputs(io_binding); + return ORT_OK; } -void OrtReleaseBinding(OrtIoBinding* io_binding) { +int OrtReleaseBinding(OrtIoBinding* io_binding) { Ort::GetApi().ReleaseIoBinding(io_binding); + return ORT_OK; } int OrtRunWithBinding(OrtSession* session, @@ -510,8 +520,9 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( : nullptr; } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle) { Ort::GetTrainingApi().ReleaseCheckpointState(training_checkpoint_state_handle); + return ORT_OK; } ort_training_session_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingCreateSession(const ort_session_options_handle_t options, @@ -630,8 +641,9 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi } } -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); + return ORT_OK; } #endif diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 0730559c4375b..f44c515d98f6b 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -50,7 +50,7 @@ int EMSCRIPTEN_KEEPALIVE OrtInit(int num_threads, int logging_level); * @param error_code [out] a pointer to accept the error code. * @param error_message [out] a pointer to accept the error message. The message buffer is only available before any ORT API is called. */ -void EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); +int EMSCRIPTEN_KEEPALIVE OrtGetLastError(int* error_code, const char** error_message); /** * create an instance of ORT session options. @@ -109,7 +109,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddSessionConfigEntry(ort_session_options_handle_t s /** * release the specified ORT session options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSessionOptions(ort_session_options_handle_t session_options); /** * create an instance of ORT session. @@ -124,7 +124,7 @@ ort_session_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateSession(void* data, /** * release the specified ORT session. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); +int EMSCRIPTEN_KEEPALIVE OrtReleaseSession(ort_session_handle_t session); /** * get model's input count and output count. @@ -158,7 +158,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtGetOutputName(ort_session_handle_t session, size_t * free the specified buffer. * @param ptr a pointer to the buffer. */ -void EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); +int EMSCRIPTEN_KEEPALIVE OrtFree(void* ptr); /** * create an instance of ORT tensor. @@ -183,12 +183,12 @@ ort_tensor_handle_t EMSCRIPTEN_KEEPALIVE OrtCreateTensor(int data_type, void* da * 'dims' (for all types of tensor), 'data' (only for string tensor) * @returns ORT error code. If not zero, call OrtGetLastError() to get detailed error message. */ -int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, int* data_type, void** data, size_t** dims, size_t* dims_length); +int EMSCRIPTEN_KEEPALIVE OrtGetTensorData(ort_tensor_handle_t tensor, size_t* data_type, void** data, size_t** dims, size_t* dims_length); /** * release the specified tensor. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); +int EMSCRIPTEN_KEEPALIVE OrtReleaseTensor(ort_tensor_handle_t tensor); /** * create an instance of ORT run options. @@ -218,7 +218,7 @@ int EMSCRIPTEN_KEEPALIVE OrtAddRunConfigEntry(ort_run_options_handle_t run_optio /** * release the specified ORT run options. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); +int EMSCRIPTEN_KEEPALIVE OrtReleaseRunOptions(ort_run_options_handle_t run_options); /** * create an instance of ORT IO binding. @@ -252,12 +252,12 @@ int EMSCRIPTEN_KEEPALIVE OrtBindOutput(ort_io_binding_handle_t io_binding, /** * clear all bound outputs. */ -void EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtClearBoundOutputs(ort_io_binding_handle_t io_binding); /** * release the specified ORT IO binding. */ -void EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); +int EMSCRIPTEN_KEEPALIVE OrtReleaseBinding(ort_io_binding_handle_t io_binding); /** * inference the model. @@ -311,7 +311,7 @@ ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint( * * @param training_checkpoint_state_handle handle for the CheckpointState */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseCheckpoint(ort_training_checkpoint_handle_t training_checkpoint_state_handle); /** * Creates an instance of a training session that can be used to begin or resume training from a given checkpoint state @@ -466,7 +466,7 @@ char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_sessi * * @param training_session_handle handle of the training session */ -void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); +int EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_session_handle); #endif }; diff --git a/onnxruntime/wasm/js_post_js.js b/onnxruntime/wasm/js_post_js.js new file mode 100644 index 0000000000000..b77d82fbd7d10 --- /dev/null +++ b/onnxruntime/wasm/js_post_js.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 4; diff --git a/onnxruntime/wasm/js_post_js_64.js b/onnxruntime/wasm/js_post_js_64.js new file mode 100644 index 0000000000000..b140df927ebbd --- /dev/null +++ b/onnxruntime/wasm/js_post_js_64.js @@ -0,0 +1,7 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. + +// Licensed under the MIT License. + +'use strict'; + +Module["PTR_SIZE"] = 8; diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 6489babc562e8..1833a0a39dcf3 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -398,7 +398,7 @@ def convert_arg_line_to_args(self, arg_line): help="Build with a specific GDK edition. Defaults to the latest installed.", ) parser.add_argument("--gdk_platform", default="Scarlett", help="Sets the GDK target platform.") - + parser.add_argument("--enable_wasm_memory64", action="store_true", help="Enable WebAssembly 64bit support") platform_group = parser.add_mutually_exclusive_group() platform_group.add_argument("--ios", action="store_true", help="build for ios") platform_group.add_argument("--visionos", action="store_true", help="build for visionOS") @@ -1078,6 +1078,7 @@ def generate_build_tree( + ("ON" if args.enable_wasm_exception_throwing_override else "OFF"), "-Donnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER=" + ("ON" if args.wasm_run_tests_in_browser else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_THREADS=" + ("ON" if args.enable_wasm_threads else "OFF"), + "-Donnxruntime_ENABLE_WEBASSEMBLY_MEMORY64=" + ("ON" if args.enable_wasm_memory64 else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO=" + ("ON" if args.enable_wasm_debug_info else "OFF"), "-Donnxruntime_ENABLE_WEBASSEMBLY_PROFILING=" + ("ON" if args.enable_wasm_profiling else "OFF"), "-Donnxruntime_ENABLE_LAZY_TENSOR=" + ("ON" if args.enable_lazy_tensor else "OFF"),