Skip to content

Commit

Permalink
[JS/WebGPU] Support WASM64 (microsoft#21836)
Browse files Browse the repository at this point in the history
### Description
Support wasm64



### Motivation and Context
Overcome memory limitations

---------

Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
2 people authored and ankitm3k committed Dec 11, 2024
1 parent d431773 commit 991e8fb
Show file tree
Hide file tree
Showing 32 changed files with 435 additions and 237 deletions.
1 change: 1 addition & 0 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ option(onnxruntime_WEBASSEMBLY_RUN_TESTS_IN_BROWSER "Enable this option to run t
option(onnxruntime_ENABLE_WEBASSEMBLY_DEBUG_INFO "Enable this option to turn on DWARF format debug info" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_PROFILING "Enable this option to turn on WebAssembly profiling and preserve function names" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_OUTPUT_OPTIMIZED_MODEL "Enable this option to allow WebAssembly to output optimized model" OFF)
option(onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64 "Enable this option to allow WebAssembly to use 64bit memory" OFF)

# Enable bitcode for iOS
option(onnxruntime_ENABLE_BITCODE "Enable bitcode for iOS only" OFF)
Expand Down
5 changes: 5 additions & 0 deletions cmake/adjust_global_compile_flags.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ if (CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
string(APPEND CMAKE_CXX_FLAGS " -s DISABLE_EXCEPTION_CATCHING=0")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
string(APPEND CMAKE_C_FLAGS " -DORT_WASM64")
string(APPEND CMAKE_CXX_FLAGS " -DORT_WASM64")
endif()

# Build WebAssembly with multi-threads support.
if (onnxruntime_ENABLE_WEBASSEMBLY_THREADS)
string(APPEND CMAKE_C_FLAGS " -pthread -Wno-pthreads-mem-growth")
Expand Down
153 changes: 146 additions & 7 deletions cmake/onnxruntime_webassembly.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ target_compile_options(onnx PRIVATE -Wno-unused-parameter -Wno-unused-variable)

if (onnxruntime_BUILD_WEBASSEMBLY_STATIC_LIB)
bundle_static_library(onnxruntime_webassembly

${PROTOBUF_LIB}
onnx
onnx_proto
Expand Down Expand Up @@ -175,7 +174,6 @@ else()
endif()

target_link_libraries(onnxruntime_webassembly PRIVATE

${PROTOBUF_LIB}
onnx
onnx_proto
Expand All @@ -194,9 +192,7 @@ else()
onnxruntime_util
re2::re2
)

set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8'")

set(EXPORTED_RUNTIME_METHODS "'stackAlloc','stackRestore','stackSave','UTF8ToString','stringToUTF8','lengthBytesUTF8','getValue','setValue'")
if (onnxruntime_USE_XNNPACK)
target_link_libraries(onnxruntime_webassembly PRIVATE XNNPACK)
string(APPEND EXPORTED_RUNTIME_METHODS ",'addFunction'")
Expand All @@ -217,10 +213,114 @@ else()
set(EXPORTED_FUNCTIONS "_malloc,_free")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set(MAXIMUM_MEMORY "17179869184")
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s MEMORY64=1"
)
string(APPEND CMAKE_C_FLAGS " -sMEMORY64 -Wno-experimental")
string(APPEND CMAKE_CXX_FLAGS " -sMEMORY64 -Wno-experimental")
set(SMEMORY_FLAG "-sMEMORY64")

target_compile_options(onnx PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_common PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_session PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_framework PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(nsync_cpp PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnx_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
# target_compile_options(protoc PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(libprotobuf-lite PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_providers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_mlas PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_optimizer PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_graph PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_flatbuffers PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(onnxruntime_util PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(re2 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_private_handle_accessor PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_commandlineflag PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_commandlineflag_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_marshalling PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_reflection PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_config PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_flags_program_name PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cord PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_info PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cord_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_functions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_cordz_handle PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_cord_state PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc32c PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_crc_cpu_detect PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_raw_hash_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_hashtablez_sampler PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_exponential_biased PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_conditions PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_check_op PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_message PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_format PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_str_format_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_log_sink_set PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_sink PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_entry PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_globals PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_city PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_low_level_hash PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_bad_variant_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_vlog_config_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_synchronization PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_kernel_timeout_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_time_zone PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_civil_time PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_graphcycles_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_bad_optional_access PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_fnmatch PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_examine_stack PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_symbolize PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_malloc_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_demangle_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_demangle_rust PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_decode_rust_punycode PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_utf8_for_code_point PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_stacktrace PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_debugging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_proto PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strerror PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_internal_nullguard PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strings PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_strings_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_int128 PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_string_view PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_base PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_spinlock_wait PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_throw_delegate PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_raw_logging_internal PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(absl_log_severity PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
if (onnxruntime_USE_EXTENSIONS)
target_compile_options(ortcustomops PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(ocos_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
target_compile_options(noexcep_operators PRIVATE ${SMEMORY_FLAG} -Wno-experimental)
endif()
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js_64.js"
)
else ()
set(MAXIMUM_MEMORY "4294967296")
target_link_options(onnxruntime_webassembly PRIVATE
--post-js "${ONNXRUNTIME_ROOT}/wasm/js_post_js.js"
)
endif ()

target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s EXPORTED_RUNTIME_METHODS=[${EXPORTED_RUNTIME_METHODS}]"
"SHELL:-s EXPORTED_FUNCTIONS=${EXPORTED_FUNCTIONS}"
"SHELL:-s MAXIMUM_MEMORY=4294967296"
"SHELL:-s MAXIMUM_MEMORY=${MAXIMUM_MEMORY}"
"SHELL:-s EXIT_RUNTIME=0"
"SHELL:-s ALLOW_MEMORY_GROWTH=1"
"SHELL:-s MODULARIZE=1"
Expand All @@ -233,6 +333,41 @@ else()
--no-entry
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre.js\""
)
if (onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
set(SIGNATURE_CONVERSIONS "OrtRun:_pppppppp,\
OrtRunWithBinding:_ppppp,\
OrtGetTensorData:_ppppp,\
OrtCreateTensor:p_pppp_,\
OrtCreateSession:pppp,\
OrtReleaseSession:_p,\
OrtGetInputOutputCount:_ppp,\
OrtCreateSessionOptions:pp__p_ppppp,\
OrtReleaseSessionOptions:_p,\
OrtAppendExecutionProvider:_pp,\
OrtAddSessionConfigEntry:_ppp,\
OrtGetInputName:ppp,\
OrtGetOutputName:ppp,\
OrtCreateRunOptions:ppp_p,\
OrtReleaseRunOptions:_p,\
OrtReleaseTensor:_p,\
OrtFree:_p,\
OrtCreateBinding:_p,\
OrtBindInput:_ppp,\
OrtBindOutput:_ppp_,\
OrtClearBoundOutputs:_p,\
OrtReleaseBinding:_p,\
OrtGetLastError:_pp,\
JsepOutput:pp_p,\
JsepGetNodeName:pp,\
JsepOutput:pp_p,\
jsepCopy:_pp_,\
jsepCopyAsync:_pp_,\
jsepDownload:_pp_")
target_link_options(onnxruntime_webassembly PRIVATE
"SHELL:-s ERROR_ON_UNDEFINED_SYMBOLS=0"
"SHELL:-s SIGNATURE_CONVERSIONS='${SIGNATURE_CONVERSIONS}'"
)
endif ()
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre.js)

if (onnxruntime_USE_JSEP)
Expand All @@ -245,6 +380,8 @@ else()
"SHELL:--pre-js \"${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js\""
"SHELL:-s ASYNCIFY=1"
"SHELL:-s ASYNCIFY_STACK_SIZE=65536"
"SHELL:-s ASYNCIFY_EXPORTS=['OrtRun']"
"SHELL:-s ASYNCIFY_IMPORTS=['Module.jsepCopy','Module.jsepCopyAsync','jsepDownload']"
)
set_target_properties(onnxruntime_webassembly PROPERTIES LINK_DEPENDS ${ONNXRUNTIME_ROOT}/wasm/pre-jsep.js)
endif()
Expand Down Expand Up @@ -281,7 +418,9 @@ else()
endif()

# Set link flag to enable exceptions support, this will override default disabling exception throwing behavior when disable exceptions.
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
if (NOT onnxruntime_ENABLE_WEBASSEMBLY_MEMORY64)
target_link_options(onnxruntime_webassembly PRIVATE "SHELL:-s DISABLE_EXCEPTION_THROWING=0")
endif()

if (onnxruntime_ENABLE_WEBASSEMBLY_PROFILING)
target_link_options(onnxruntime_webassembly PRIVATE --profiling --profiling-funcs)
Expand Down
66 changes: 41 additions & 25 deletions js/web/lib/wasm/jsep/init.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,24 +87,25 @@ class ComputeContextImpl implements ComputeContext {
contextDataOffset: number,
) {
this.adapterInfo = backend.adapterInfo;
const heapU32 = module.HEAPU32;

// extract context data
let dataIndex = contextDataOffset >>> 2;
this.opKernelContext = heapU32[dataIndex++];
const inputCount = heapU32[dataIndex++];
this.outputCount = heapU32[dataIndex++];
this.customDataOffset = heapU32[dataIndex++];
this.customDataSize = heapU32[dataIndex++];
const ptrSize = module.PTR_SIZE;
let dataIndex = contextDataOffset / module.PTR_SIZE;
const type = ptrSize === 4 ? 'i32' : 'i64';
this.opKernelContext = Number(module.getValue(ptrSize * dataIndex++, type));
const inputCount = Number(module.getValue(ptrSize * dataIndex++, type));
this.outputCount = Number(module.getValue(ptrSize * dataIndex++, type));
this.customDataOffset = Number(module.getValue(ptrSize * dataIndex++, '*'));
this.customDataSize = Number(module.getValue(ptrSize * dataIndex++, type));

const inputs: TensorView[] = [];
for (let i = 0; i < inputCount; i++) {
const dataType = heapU32[dataIndex++];
const data = heapU32[dataIndex++];
const dim = heapU32[dataIndex++];
const dataType = Number(module.getValue(ptrSize * dataIndex++, type));
const data = Number(module.getValue(ptrSize * dataIndex++, '*'));
const dim = Number(module.getValue(ptrSize * dataIndex++, type));
const dims: number[] = [];
for (let d = 0; d < dim; d++) {
dims.push(heapU32[dataIndex++]);
dims.push(Number(module.getValue(ptrSize * dataIndex++, type)));
}
inputs.push(new TensorViewImpl(module, dataType, data, dims));
}
Expand Down Expand Up @@ -152,11 +153,12 @@ class ComputeContextImpl implements ComputeContext {
output(index: number, dims: readonly number[]): number {
const stack = this.module.stackSave();
try {
const data = this.module.stackAlloc((1 + dims.length) * 4 /* sizeof(size_t) */);
let offset = data >> 2;
this.module.HEAPU32[offset++] = dims.length;
const ptrSize = this.module.PTR_SIZE;
const type = ptrSize === 4 ? 'i32' : 'i64';
const data = this.module.stackAlloc((1 + dims.length) * ptrSize /* sizeof(size_t) */);
this.module.setValue(data, dims.length, type);
for (let i = 0; i < dims.length; i++) {
this.module.HEAPU32[offset++] = dims[i];
this.module.setValue(data + ptrSize * (i + 1), dims[i], type);
}
return this.module._JsepOutput!(this.opKernelContext, index, data);
} catch (e) {
Expand Down Expand Up @@ -215,20 +217,27 @@ export const init = async (
backend,

// jsepAlloc()
(size: number) => backend.alloc(size),
(size: number) => backend.alloc(Number(size)),

// jsepFree()
(ptr: number) => backend.free(ptr),

// jsepCopy(src, dst, size, isSourceGpu)
(src: number, dst: number, size: number, isSourceGpu = false) => {
if (isSourceGpu) {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyGpuToGpu: src=${src}, dst=${dst}, size=${size}`);
backend.memcpy(src, dst);
LOG_DEBUG(
'verbose',
() => `[WebGPU] jsepCopyGpuToGpu: src=${Number(src)}, dst=${Number(dst)}, size=${Number(size)}`,
);
backend.memcpy(Number(src), Number(dst));
} else {
LOG_DEBUG('verbose', () => `[WebGPU] jsepCopyCpuToGpu: dataOffset=${src}, gpuDataId=${dst}, size=${size}`);
const data = module.HEAPU8.subarray(src >>> 0, (src >>> 0) + size);
backend.upload(dst, data);
LOG_DEBUG(
'verbose',
() =>
`[WebGPU] jsepCopyCpuToGpu: dataOffset=${Number(src)}, gpuDataId=${Number(dst)}, size=${Number(size)}`,
);
const data = module.HEAPU8.subarray(Number(src >>> 0), Number(src >>> 0) + Number(size));
backend.upload(Number(dst), data);
}
},

Expand All @@ -239,12 +248,19 @@ export const init = async (
() => `[WebGPU] jsepCopyGpuToCpu: gpuDataId=${gpuDataId}, dataOffset=${dataOffset}, size=${size}`,
);

await backend.download(gpuDataId, () => module.HEAPU8.subarray(dataOffset >>> 0, (dataOffset >>> 0) + size));
await backend.download(Number(gpuDataId), () =>
module.HEAPU8.subarray(Number(dataOffset) >>> 0, Number(dataOffset + size) >>> 0),
);
},

// jsepCreateKernel
(kernelType: string, kernelId: number, attribute: unknown) =>
backend.createKernel(kernelType, kernelId, attribute, module.UTF8ToString(module._JsepGetNodeName!(kernelId))),
backend.createKernel(
kernelType,
Number(kernelId),
attribute,
module.UTF8ToString(module._JsepGetNodeName!(Number(kernelId))),
),

// jsepReleaseKernel
(kernel: number) => backend.releaseKernel(kernel),
Expand All @@ -256,8 +272,8 @@ export const init = async (
() =>
`[WebGPU] jsepRun: sessionHandle=${sessionHandle}, kernel=${kernel}, contextDataOffset=${contextDataOffset}`,
);
const context = new ComputeContextImpl(module, backend, contextDataOffset);
return backend.computeKernel(kernel, context, errors);
const context = new ComputeContextImpl(module, backend, Number(contextDataOffset));
return backend.computeKernel(Number(kernel), context, errors);
},
// jsepCaptureBegin
() => backend.captureBegin(),
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ export class ShapeUtil {
'cannot get valid size from specified dimension range. Most likely the range contains negative values in them.',
);
}
size *= dims[i];
size *= Number(dims[i]);
}
return size;
}
Expand Down
Loading

0 comments on commit 991e8fb

Please sign in to comment.