Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ORT 1.17.0 Release] Cherry pick 1st round #19243

Merged
merged 23 commits into from
Jan 27, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
357d385
more inputs support for LLM exporter (#19005)
wejoncy Jan 17, 2024
c1760bf
Fix untyped float values in quantization tool missing from PR #18043 …
xadupre Jan 17, 2024
de93132
[js/web] allow proxy to load model with 1GB <= size < 2GB (#19178)
fs-eire Jan 17, 2024
89dae2d
[js/web] show warning when numThreads is set but threads is not suppo…
fs-eire Jan 17, 2024
eb77b5b
Check the ep_cache_context and don't allow access outside the directo…
HectorSVC Jan 18, 2024
3d06a91
Update x64 template kernel library for 'sqnbitgemm' (#19016)
luoyu-intel Jan 18, 2024
2b86515
[js/web] upgrade dependency packages version (#19193)
fs-eire Jan 18, 2024
5eebd09
Update LLaMA attention fusions (#19200)
kunal-vaishnavi Jan 19, 2024
3c2065c
Fix issue that the generated context cache model inputs/outputs order…
HectorSVC Jan 19, 2024
9b30485
[TensorRT EP] Enhance EP context configs in session options and provi…
chilo-ms Jan 21, 2024
f9ad177
phi2 contrib ops changes (#19112)
wangyems Jan 22, 2024
6aa7f79
[QNN EP] Expose device-level session options (#19212)
adrianlizarraga Jan 22, 2024
041dfd5
[aarch64] Add Sbgemm kernel to accelerate fp32 tensor matmul with bfl…
snadampal Jan 22, 2024
95a746f
[QNN EP] Create Windows ARM64 nightly python package (#19128)
adrianlizarraga Jan 23, 2024
c1def2e
unet fusion for stable diffusion webui (#19227)
tianleiwu Jan 23, 2024
3c6adf3
Modified the condition to load the optimiser model (#18891)
heflinstephenraj Jan 23, 2024
d968811
Fix Fuzz Testing CI (#19228)
yf711 Jan 22, 2024
8bdea53
Fix AMD pipeline test failures (#19250)
wangyems Jan 24, 2024
0714f0c
remove old quantization tool file (#19247)
yufenglee Jan 24, 2024
aaa577d
[TensorRT EP] Avoid calling unavailable function with cpu python pack…
chilo-ms Jan 24, 2024
7d941f4
[TensorRT EP] Fix mem leak for TRT plugins custom ops (#19248)
chilo-ms Jan 25, 2024
5bb85b6
Remove USE_CUTLASS flag (#19271)
tianleiwu Jan 26, 2024
099cefb
Update abseil to a release tag and register neural_speed (#19255)
snnn Jan 24, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_JBLAS "Build MLAS with JBLAS support" ON)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" ON)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -909,6 +909,10 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE USE_CUTLASS)
endif()

if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()

set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
Expand Down Expand Up @@ -1193,14 +1197,10 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

set(USE_JBLAS FALSE)
if (onnxruntime_USE_JBLAS AND NOT onnxruntime_MINIMAL_BUILD)
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
add_compile_definitions(MLAS_JBLAS)
set(USE_JBLAS TRUE)
if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
include(neural_speed)
if (USE_NEURAL_SPEED)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
endif()
endif()

Expand Down
2 changes: 1 addition & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -54,4 +54,4 @@ tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.1.0.zip;757f90a795034a89d4f48a79d1f009f7a04c8dee
utf8_range;https://github.com/protocolbuffers/utf8_range/archive/72c943dea2b9240cd09efde15191e144bc7c7d38.zip;9925739c9debc0efa2adcb194d371a35b6a03156
extensions;https://github.com/microsoft/onnxruntime-extensions/archive/94142d8391c9791ec71c38336436319a2d4ac7a0.zip;4365ac5140338b4cb75a39944a4be276e3829b3c
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
composable_kernel;https://github.com/ROCmSoftwarePlatform/composable_kernel/archive/5356c4a943a35e74d7cdc69486afcb8703b9a59a.zip;522382c2af437e09124287e5879ab64af5b2e299
18 changes: 18 additions & 0 deletions cmake/external/neural_speed.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
if ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU" AND onnxruntime_target_platform STREQUAL "x86_64")
set(USE_NEURAL_SPEED TRUE)
elseif ("${CMAKE_CXX_COMPILER_ID}" STREQUAL "MSVC" AND onnxruntime_target_platform STREQUAL "x64")
set(USE_NEURAL_SPEED TRUE)
endif()

if(USE_NEURAL_SPEED)
FetchContent_Declare(
neural_speed
URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip
URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
)
set(BTLA_USE_OPENMP OFF)
FetchContent_MakeAvailable(neural_speed)
if(NOT neural_speed_POPULATED)
FetchContent_Populate(neural_speed)
endif()
endif()
17 changes: 4 additions & 13 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,6 @@ endif()

set(ONNXRUNTIME_MLAS_LIBS onnxruntime_mlas)

function(add_jblas)
add_subdirectory(${MLAS_SRC_DIR}/x86_64/jblas jblas)
target_link_libraries(onnxruntime_mlas PRIVATE jblas::jblas)
target_sources(onnxruntime_mlas PRIVATE
${MLAS_SRC_DIR}/jblas_gemm.cpp
)
set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR OFF)
endfunction()

#TODO: set MASM flags properly
function(setup_mlas_source_for_windows)

Expand Down Expand Up @@ -364,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down Expand Up @@ -622,10 +617,6 @@ else()
target_sources(onnxruntime_mlas PRIVATE ${mlas_platform_srcs})
endif()

if(USE_JBLAS)
add_jblas()
endif()

foreach(mlas_target ${ONNXRUNTIME_MLAS_LIBS})
target_include_directories(${mlas_target} PRIVATE ${MLAS_INC_DIR} ${MLAS_SRC_DIR})
onnxruntime_add_include_to_target(${mlas_target} ${GSL_TARGET})
Expand Down
15 changes: 15 additions & 0 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,15 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h"
)
if(NOT USE_NEURAL_SPEED)
list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs})
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs})
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
Expand Down Expand Up @@ -144,6 +153,12 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL)
target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical")
endif()

if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
if(USE_NEURAL_SPEED)
onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla)
endif()
endif()

if (MSVC)
target_compile_options(onnxruntime_providers PRIVATE "/bigobj")
# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
Expand Down
12 changes: 9 additions & 3 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Number of attention heads</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
<dd>Whether every token can only attend to previous tokens. Default value is 0.</dd>
</dl>

#### Inputs (1 - 8)
Expand Down Expand Up @@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>interleaved</tt> : int</dt>
<dd>Rotate using interleaved pattern. Default value is 0 (False).</dd>
<dt><tt>num_heads</tt> : int</dt>
<dd>Number of attention heads. Default value is 0. Must use with rotary_embedding_dim</dd>
<dt><tt>rotary_embedding_dim</tt> : int</dt>
<dd>Rotary embedding dimension. Default value is 0.</dd>
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1.0</dd>
</dl>
Expand All @@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>position_ids</tt> : M</dt>
<dd>1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)</dd>
<dt><tt>cos_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
<dt><tt>sin_cache</tt> : T</dt>
<dd>2D tensor with shape (max_sequence_length, head_size / 2).</dd>
<dd>2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)</dd>
</dl>

#### Outputs
Expand All @@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dt><tt>T</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>M</tt> : tensor(int64)</dt>
<dd>Constrain input and output types to integer tensors</dd>
Expand Down
2 changes: 1 addition & 1 deletion docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -868,7 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand Down
9 changes: 9 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,15 @@ class IExecutionProvider {
*/
virtual std::vector<AllocatorPtr> CreatePreferredAllocators() { return std::vector<AllocatorPtr>(); };

/**
* Get the array of pointers for EPContext nodes
* EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
* Default return an empty vector if not provided by the Execution Provider
*/
virtual const InlinedVector<const Node*> GetEpContextNodes() const {
return InlinedVector<const Node*>();
}

private:
const std::string type_;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
/// </summary>
struct OrtTensorRTProviderOptionsV2 {
OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator

int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
Expand Down Expand Up @@ -46,8 +48,26 @@
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
int trt_dump_ep_context_model{0}; // Dump EP context node model
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix

/*
* Please note that there are rules for using following context model related provider options:
*
* 1. In the case of dumping the context model and loading the context model,
* for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be
* the absolute path or relative path that is outside of context model directory.
* It means engine cache needs to be in the same directory or sub-directory of context model.
*
* 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory.
* For example:
* If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled,
* if "trt_ep_context_file_path" is "./context_model_dir",
* - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
* - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
*
*/
int trt_dump_ep_context_model{0}; // Dump EP context node model
const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.

Check warning on line 69 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L69

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:69:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data

Check warning on line 70 in include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h#L70

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h:70:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
8 changes: 8 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -3608,6 +3608,14 @@ struct OrtApi {
* - "1": Faster preparation time, less optimal graph.
* - "2": Longer preparation time, more optimal graph.
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
* "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
* "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
* - "0": Default (none).
* - "68"
* - "69"
* - "73"
* - "75"
* "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";

// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
// "1": enable.
Expand All @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";

// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
// Option values:
// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
// - "1": Gemm FastMath mode is enabled.
static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
2 changes: 1 addition & 1 deletion java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes
}
}
wchar_t* optimizerStr = NULL;
if (optimizerPath == NULL) {
if (optimizerPath != NULL) {
optimizerStr = copyAndPad(jniEnv, optimizerPath);
if (optimizerStr == NULL) {
// exception has been thrown in Java, go to cleanup and return null.
Expand Down
6 changes: 6 additions & 0 deletions js/web/lib/backend-wasm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ export const initializeFlags = (): void => {
}

if (typeof env.wasm.numThreads !== 'number' || !Number.isInteger(env.wasm.numThreads) || env.wasm.numThreads <= 0) {
// Web: when crossOriginIsolated is false, SharedArrayBuffer is not available so WebAssembly threads will not work.
// Node.js: onnxruntime-web does not support multi-threads in Node.js.
if ((typeof self !== 'undefined' && !self.crossOriginIsolated) ||
(typeof process !== 'undefined' && process.versions && process.versions.node)) {
env.wasm.numThreads = 1;
}
const numCpuLogicalCores = typeof navigator === 'undefined' ? cpus().length : navigator.hardwareConcurrency;
env.wasm.numThreads = Math.min(4, Math.ceil((numCpuLogicalCores || 1) / 2));
}
Expand Down
33 changes: 27 additions & 6 deletions js/web/lib/wasm/wasm-factory.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,34 @@ let initialized = false;
let initializing = false;
let aborted = false;

const isMultiThreadSupported = (): boolean => {
try {
// If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
if (typeof SharedArrayBuffer === 'undefined') {
return false;
const isMultiThreadSupported = (numThreads: number): boolean => {
// WebAssembly threads are set to 1 (single thread).
if (numThreads === 1) {
return false;
}

// If 'SharedArrayBuffer' is not available, WebAssembly threads will not work.
if (typeof SharedArrayBuffer === 'undefined') {
if (typeof self !== 'undefined' && !self.crossOriginIsolated) {
// eslint-disable-next-line no-console
console.warn(
'env.wasm.numThreads is set to ' + numThreads +
', but this will not work unless you enable crossOriginIsolated mode. ' +
'See https://web.dev/cross-origin-isolation-guide/ for more info.');
}
return false;
}

// onnxruntime-web does not support multi-threads in Node.js.
if (typeof process !== 'undefined' && process.versions && process.versions.node) {
// eslint-disable-next-line no-console
console.warn(
'env.wasm.numThreads is set to ' + numThreads +
', however, currently onnxruntime-web does not support multi-threads in Node.js. ' +
'Please consider using onnxruntime-node for performance critical scenarios.');
}

try {
// Test for transferability of SABs (for browsers. needed for Firefox)
// https://groups.google.com/forum/#!msg/mozilla.dev.platform/IHkBZlHETpA/dwsMNchWEQAJ
if (typeof MessageChannel !== 'undefined') {
Expand Down Expand Up @@ -106,7 +127,7 @@ export const initializeWebAssembly = async(flags: Env.WebAssemblyFlags): Promise
const numThreads = flags.numThreads!;
const simd = flags.simd!;

const useThreads = numThreads > 1 && isMultiThreadSupported();
const useThreads = isMultiThreadSupported(numThreads);
const useSimd = simd && isSimdSupported();

const wasmPaths = flags.wasmPaths;
Expand Down
Loading
Loading