diff --git a/.lintrunner.toml b/.lintrunner.toml
index c44a66200ad1b..4e5d077b08ff4 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -45,6 +45,7 @@ exclude_patterns = [
'cmake/external/**',
# ignore generated flatbuffers code
'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**',
+ 'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
@@ -76,6 +77,7 @@ exclude_patterns = [
'cmake/**',
'orttraining/*',
'onnxruntime/core/flatbuffers/**',
+ 'orttraining/orttraining/python/training/optim/_ds_code_store.py',
]
command = [
'python',
diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json
index 08ca90d7c3b7f..f9f2fbdab7b10 100644
--- a/cgmanifests/generated/cgmanifest.json
+++ b/cgmanifests/generated/cgmanifest.json
@@ -2,6 +2,36 @@
"$schema": "https://json.schemastore.org/component-detection-manifest.json",
"Version": 1,
"Registrations": [
+ {
+ "component": {
+ "type": "git",
+ "git": {
+ "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57",
+ "repositoryUrl": "https://github.com/emscripten-core/emsdk.git"
+ },
+ "comments": "git submodule at cmake/external/emsdk"
+ }
+ },
+ {
+ "component": {
+ "type": "git",
+ "git": {
+ "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db",
+ "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git"
+ },
+ "comments": "git submodule at cmake/external/libprotobuf-mutator"
+ }
+ },
+ {
+ "component": {
+ "type": "git",
+ "git": {
+ "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40",
+ "repositoryUrl": "https://github.com/onnx/onnx.git"
+ },
+ "comments": "git submodule at cmake/external/onnx"
+ }
+ },
{
"component": {
"type": "git",
@@ -166,7 +196,7 @@
"component": {
"type": "git",
"git": {
- "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f",
+ "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828",
"repositoryUrl": "https://github.com/onnx/onnx.git"
},
"comments": "onnx"
diff --git a/cmake/deps.txt b/cmake/deps.txt
index 7cf49f02333a4..26fd35075c4b9 100644
--- a/cmake/deps.txt
+++ b/cmake/deps.txt
@@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
-onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442
+onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d
protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa
diff --git a/cmake/external/onnx b/cmake/external/onnx
index e2525550194ce..6a20ba82b439e 160000
--- a/cmake/external/onnx
+++ b/cmake/external/onnx
@@ -1 +1 @@
-Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e
+Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index 9985f5c8bc516..a76664adff207 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -344,12 +344,18 @@ else()
set(mlas_platform_srcs
${mlas_platform_srcs}
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
+ ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
+ ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
+ ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
+ ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index fe3e577b4fc36..de1458c120016 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -54,6 +54,11 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
+ "quantization/dequantize_blockwise_bnb4.cuh"
+ "quantization/dequantize_blockwise_bnb4.cu"
+ "quantization/matmul_bnb4.cc"
+ "quantization/matmul_bnb4.cuh"
+ "quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 7e67ec6d0c94e..1a76c18a6a8e0 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -47,6 +47,7 @@ Do not modify directly.*
* com.microsoft.Inverse
* com.microsoft.Irfft
* com.microsoft.LongformerAttention
+ * com.microsoft.MatMulBnb4
* com.microsoft.MatMulFpQ4
* com.microsoft.MatMulInteger16
* com.microsoft.MatMulIntegerToFloat
@@ -90,6 +91,7 @@ Do not modify directly.*
* com.microsoft.RemovePadding
* com.microsoft.RestorePadding
* com.microsoft.Rfft
+ * com.microsoft.RotaryEmbedding
* com.microsoft.SampleOp
* com.microsoft.Sampling
* com.microsoft.SkipLayerNormalization
@@ -2503,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.MatMulBnb4**
+
+ MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences:
+ 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'.
+ 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'.
+ And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,..
+ 3. Input B's quantization constants or scales are specified by input 'absmax'.
+
+ Input B is stored as uint8_t with shape: [(N * K + 1) / 2].
+ Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size].
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- K : int (required)
+- size of each input feature
+- N : int (required)
+- size of each output feature
+- block_size : int (required)
+- number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+- quant_type : int (required)
+- quantization data type. 0 for FP4, 1 for NF4.
+
+
+#### Inputs
+
+
+- A : T1
+- The input tensor, not quantized
+- B : T2
+- 1-dimensional quantized data for weight
+- absmax : T1
+- quantization constants
+
+
+#### Outputs
+
+
+- Y : T1
+- tensor. The output tensor has the same rank as the input.
+
+
+#### Type Constraints
+
+
+- T1 : tensor(float), tensor(float16)
+- Constrain input and output types to float/half_float tensors.
+- T2 : tensor(uint8)
+- Constrain quantized weight types to uint8.
+
+
+
### **com.microsoft.MatMulFpQ4**
Matrix product with right hand matrix being pre-packed and quantized int4 data blob.
@@ -2834,7 +2892,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4796,6 +4854,54 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.RotaryEmbedding**
+
+ RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices
+ that are multiplied to query and key before the inner product of query and key is taken.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- interleaved : int
+- Rotate using interleaved pattern. Default value is 0 (False).
+- scale : float
+- Custom scale will be used if specified. Default value is 1.0
+
+
+#### Inputs
+
+
+- input : T
+- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+- position_ids : M
+- 1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+- cos_cache : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+- sin_cache : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+
+
+#### Outputs
+
+
+- output : T
+- 3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float tensors.
+- M : tensor(int64)
+- Constrain input and output types to integer tensors
+
+
+
### **com.microsoft.SampleOp**
Sample echo operator.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index e2d500006b05f..84249df92231b 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -25,6 +25,7 @@ Do not modify directly.*
|||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)|
|Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
+|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)|
|And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)|
|ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
|||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)|
@@ -156,8 +157,10 @@ Do not modify directly.*
|||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)|
|InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)|
-|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
-|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
+|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)|
+|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)|
+|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
|LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)|
|||[1, 12]|**T** = tensor(float)|
@@ -454,6 +457,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
@@ -477,9 +481,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)|
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
+|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
@@ -847,6 +853,7 @@ Do not modify directly.*
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
@@ -866,6 +873,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h
index 0fd04f28d44b7..dd607cbbc6952 100644
--- a/include/onnxruntime/core/framework/float8.h
+++ b/include/onnxruntime/core/framework/float8.h
@@ -208,9 +208,10 @@ struct Float8E4M3FNUZ {
val = static_cast((b & 0x80000000) >> 24); // sign
if ((b & 0x7fffffff) == 0x7f800000) { // infinity
if (saturate) {
+ // the highest available value
val |= 0x7F;
} else {
- // infinity
+ // NaN
val = 0x80;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
@@ -362,8 +363,10 @@ struct Float8E5M2 {
val = (b & 0x80000000) >> 24; // sign
if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf
if (saturate) {
+ // the highest available value
val |= 0x7B;
} else {
+ // the infinity
val |= 0x7C;
}
} else if ((b & 0x7F800000) == 0x7F800000) { // NaN
diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h
index f153e88909b8d..462d410e13769 100644
--- a/include/onnxruntime/core/graph/graph.h
+++ b/include/onnxruntime/core/graph/graph.h
@@ -3,6 +3,7 @@
#pragma once
+#include
#include
#include
#include
@@ -83,10 +84,10 @@ class Node {
gsl::span output_args,
const NodeAttributes* attributes,
std::string_view domain) {
- Init(std::string{name}, std::string{op_type}, std::string{description},
- std::vector{input_args.begin(), input_args.end()},
- std::vector{output_args.begin(), output_args.end()},
- attributes, std::string{domain});
+ Init(name, op_type, description,
+ input_args,
+ output_args,
+ attributes, domain);
}
#endif
@@ -563,13 +564,13 @@ class Node {
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node);
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
- void Init(const std::string& name,
- const std::string& op_type,
- const std::string& description,
- const std::vector& input_args,
- const std::vector& output_args,
+ void Init(std::string_view name,
+ std::string_view op_type,
+ std::string_view description,
+ gsl::span input_args,
+ gsl::span output_args,
const NodeAttributes* attributes,
- const std::string& domain);
+ std::string_view domain);
#endif
#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
@@ -1141,8 +1142,22 @@ class Graph {
*/
Status InlineFunction(Node& node);
+ /**
+ Directly insert the nodes in the function proto provided into the graph.
+ The function converts Constant nodes into the initializers in the graph.
+ It then creates a node in the graph for each of the function nodes.
+ All of the names are expected to be specialized, and, therefore unique.
+ See function_utils::Specialize().
+
+ The Graph needs to be Resolve()d after this call.
+ @param func_to_inline
+ @returns Status indicating success or providing an error message.
+ */
+
+ Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline);
+
/** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will
- be used as a GraphProto attribute in another Node..
+ be used as a GraphProto attribute in another Node.
e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to
define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs
when the Graph is resolved.
@@ -1391,6 +1406,13 @@ class Graph {
Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto,
const ArgNameToTypeMap& name_to_type);
+ /** Helper that converts and adds constant node proto to an initializer in the graph.
+ @param constant_node_proto Constant node to convert
+ @param new_name use the new name for the initializer.
+ */
+ Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto,
+ std::optional new_name);
+
#endif
Version IrVersion() const noexcept {
diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h
index 13c176dad3cc5..646f33ed952a4 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_context.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_context.h
@@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext {
cudaStream_t cuda_stream = {};
cudnnHandle_t cudnn_handle = {};
cublasHandle_t cublas_handle = {};
+ OrtAllocator* deferred_cpu_allocator = {};
void Init(const OrtKernelContext& kernel_ctx) override {
const auto& ort_api = Ort::GetApi();
@@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext {
ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
}
cublas_handle = reinterpret_cast(resource);
+
+ resource = {};
+ status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource);
+ if (status) {
+ ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+ deferred_cpu_allocator = reinterpret_cast(resource);
+ }
+
+ void* AllocDeferredCpuMem(size_t size) const {
+ if (0 == size) {
+ return {};
+ }
+ const auto& ort_api = Ort::GetApi();
+ void* mem = {};
+ auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem);
+ if (status) {
+ ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+ return mem;
+ }
+
+ void FreeDeferredCpuMem(void* mem) const {
+ if (mem) {
+ const auto& ort_api = Ort::GetApi();
+ auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem);
+ if (status) {
+ ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION);
+ }
+ }
}
};
diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h
index e46fc5b4219dd..8c3ed46ade6a1 100644
--- a/include/onnxruntime/core/providers/cuda/cuda_resource.h
+++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h
@@ -3,10 +3,11 @@
#include "core/providers/resource.h"
-#define ORT_CUDA_RESOUCE_VERSION 1
+#define ORT_CUDA_RESOUCE_VERSION 2
enum CudaResource : int {
cuda_stream_t = cuda_resource_offset,
cudnn_handle_t,
- cublas_handle_t
+ cublas_handle_t,
+ deferred_cpu_allocator_t,
};
\ No newline at end of file
diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h
index dd4ffb835d51c..cf3ddc3f125f9 100644
--- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h
+++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h
@@ -128,7 +128,7 @@ struct OrtDmlApi {
/**
* SessionOptionsAppendExecutionProvider_DML2
* Creates a DirectML Execution Provider given the supplied device options that contain a performance preference
- * (high power, low power, or defult) and a device filter (None, GPU, or NPU).
+ * (high power, low power, or default) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);
};
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index 37545f41b43dd..831def24e4f5e 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab
// GeluApproximation has side effects which may change the inference results. It is disabled by default due to this.
static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation";
+// This setting controls whether to enable AheadOfTime function inlining.
+// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model
+// as possible with the help of enabled execution providers.
+// This can reduce the number of function calls and improve performance because it is done before
+// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available,
+// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time.
+// "0": enable; "1": disable.
+// Its default value is "0".
+static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining";
+
#ifdef ENABLE_TRAINING
// Specifies a list of op types for memory footprint reduction.
// The value should be a ","-delimited list of pair of
diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts
index f06d06bda035f..47e67879e66ce 100644
--- a/js/common/lib/training-session-impl.ts
+++ b/js/common/lib/training-session-impl.ts
@@ -1,11 +1,14 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
+import {resolveBackend} from './backend-impl.js';
import {TrainingSessionHandler} from './backend.js';
import {InferenceSession as InferenceSession} from './inference-session.js';
import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js';
type SessionOptions = InferenceSession.SessionOptions;
+const noBackendErrMsg: string = 'Training backend could not be resolved. ' +
+ 'Make sure you\'re using the correct configuration & WebAssembly files.';
export class TrainingSession implements TrainingSessionInterface {
private constructor(handler: TrainingSessionHandler) {
@@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface {
return this.handler.outputNames;
}
- static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions):
+ static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions):
Promise {
- throw new Error('Method not implemented');
+ const evalModel: string|Uint8Array = trainingOptions.evalModel || '';
+ const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || '';
+ const options: SessionOptions = sessionOptions || {};
+
+ // get backend hints
+ const eps = options.executionProviders || [];
+ const backendHints = eps.map(i => typeof i === 'string' ? i : i.name);
+ const backend = await resolveBackend(backendHints);
+ if (backend.createTrainingSessionHandler) {
+ const handler = await backend.createTrainingSessionHandler(
+ trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options);
+ return new TrainingSession(handler);
+ } else {
+ throw new Error(noBackendErrMsg);
+ }
}
async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise {
diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md
index de84134ddbb3f..7c129b66bfa3d 100644
--- a/js/web/docs/webgl-operators.md
+++ b/js/web/docs/webgl-operators.md
@@ -12,6 +12,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Acos](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acos) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7) |
| [Acosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acosh) | |
| [Add](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-14) |
+| [AffineGrid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AffineGrid) | |
| [And](https://github.com/onnx/onnx/blob/main/docs/Operators.md#And) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#And-7) |
| [ArgMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMax) | |
| [ArgMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMin) | |
@@ -67,6 +68,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) |
| [GatherElements](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) | |
| [GatherND](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND) | |
+| [Gelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gelu) | |
| [Gemm](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-7), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13) |
| [GlobalAveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool) | [1+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalAveragePool-1) |
| [GlobalLpPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalLpPool) | |
@@ -82,6 +84,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | |
| [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) |
| [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | |
+| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | |
| [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) |
| [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | |
| [IsNaN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsNaN) | |
@@ -137,12 +140,13 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [ReduceL2](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2) | |
| [ReduceLogSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-18) |
| [ReduceLogSumExp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSumExp) | |
-| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18) |
+| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-20) |
| [ReduceMean](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-18) |
-| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18) |
+| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-20) |
| [ReduceProd](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceProd) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-18) |
| [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) |
| [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) |
+| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | |
| [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) |
| [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) |
| [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) |
@@ -179,7 +183,9 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat
| [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | |
| [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) |
| [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) |
+| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | |
| [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | |
+| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | |
| [Sub](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14) |
| [Sum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) | [6-7](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-6), [8-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-8), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-13) |
| [Tan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7) |
diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts
index af5b575c87a7f..98e40807aa29c 100644
--- a/js/web/lib/backend-wasm-training.ts
+++ b/js/web/lib/backend-wasm-training.ts
@@ -4,13 +4,17 @@
import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common';
import {OnnxruntimeWebAssemblyBackend} from './backend-wasm';
+import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training';
class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend {
async createTrainingSessionHandler(
- _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array,
- _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array,
- _options: InferenceSession.SessionOptions): Promise {
- throw new Error('Method not implemented yet.');
+ checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
+ evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
+ options: InferenceSession.SessionOptions): Promise {
+ const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler();
+ await handler.createTrainingSession(
+ checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options);
+ return Promise.resolve(handler);
}
}
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index b7b2ff4537095..00431a4e86d5b 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule {
_OrtTrainingCopyParametersFromBuffer?
(trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number;
+ _OrtTrainingGetModelInputOutputCount?
+ (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number;
+ _OrtTrainingGetModelInputOutputName?
+ (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number;
+
_OrtTrainingReleaseSession?(trainingHandle: number): void;
// #endregion
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
index 0a64d1ad1792a..1d3fc78fe368a 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts
@@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly
}
return dims;
};
+
+// TODO: remove this limitation once >4D dims are supported by uniform.
+export const enableShapesUniforms = (rank: number): boolean => rank <= 4;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
index d241b8b92a669..e880afe09a5d8 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts
@@ -232,7 +232,7 @@ const convTranspose2d =
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
- createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm),
+ createTransposeProgramInfo(inputs[1], weightTransposePerm),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
index b323a36cee5c8..c7ea0cffe51c3 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts
@@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
if (isChannelsLast) {
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
- createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
+ createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
@@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut
// STEP.1: transpose weight
const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ??
context.compute(
- createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute),
+ createTransposeProgramInfo(inputs[1], weightTransposeAttribute),
{inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0];
if (attributes.wIsConst && !context.kernelCustomData.wT) {
context.kernelCustomData.wT = transposedWeight;
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
index 05f02b07c4d89..1538644412afd 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts
@@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => {
if (!inputs || inputs.length !== 1) {
throw new Error('Pool ops requires 1 input.');
}
- if (inputs[0].dims.length !== 4) {
- throw new Error('Pool ops supports 2-D inputs only for now.');
+ if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) {
+ throw new Error('Pool ops supports 1-D or 2-D inputs only for now.');
}
};
const getAdjustedPoolAttributesAndOutputShape = (
input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => {
const isChannelsLast = attributes.format === 'NHWC';
- const inputShapeAsChannelFirst =
- isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice();
+ const inputShapeAsChannelFirst = input.dims.slice();
+ if (isChannelsLast) {
+ inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position.
+ }
const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations');
const kernelShape = attributes.kernelShape.slice();
const strides = attributes.strides.slice();
@@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = (
@@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) {
- pad++;
- continue;
- }
- let x_val = x[${x.indicesToOffset('xIndices')}];
- ${op1}
- }`;
+ for (var i: u32 = 0u; i < ${kw}u; i++) {
+ xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
+ if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) {
+ pad++;
+ continue;
+ }
+ let x_val = x[${x.indicesToOffset('xIndices')}];
+ ${op1}
+ }`;
} else {
codeW = `
- for (var i: u32 = 0u; i < ${kw}u; i++) {
- xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
- let x_val = x[${x.indicesToOffset('xIndices')}];
- ${op1}
- }`;
+ for (var i: u32 = 0u; i < ${kw}u; i++) {
+ xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i;
+ let x_val = x[${x.indicesToOffset('xIndices')}];
+ ${op1}
+ }`;
}
if (attributes.kernelShape.length === 2) {
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
index fe556a7fd8552..c4d43e9f466f5 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts
@@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util';
import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key';
import {ComputeContext, ProgramInfo} from '../types';
-import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
+import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common';
export interface TransposeAttributes extends AttributeWithCacheKey {
readonly perm: number[];
@@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou
return reverseFunc.join('\n');
};
-export const createTransposeProgramInfo =
- (inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => {
- const perm = getAdjustedPerm(inputRank, permAttr);
- const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank);
- const input = inputVariable('a', inputDataType, inputRank);
+export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => {
+ const inputDataType = inputTensor.dataType;
+ const inputRank = inputTensor.dims.length;
+ const perm = getAdjustedPerm(inputRank, permAttr);
+ const useShapesUniforms = enableShapesUniforms(inputRank);
+ const outputShape = getOutputShape(inputTensor.dims, perm);
+ const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape;
+ const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims;
+ const output = outputVariable('output', inputDataType, outShapeOrRank);
+ const input = inputVariable('a', inputDataType, inShapeOrRank);
- const getShaderSource = (shaderHelper: ShaderHelper) => `
+ const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)}
${permFunctionBody(perm, inputRank, input, output)}
@@ -54,30 +59,32 @@ export const createTransposeProgramInfo =
${output.setByOffset('global_idx', input.getByIndices('aIndices'))}
}`;
+ return {
+ name: 'Transpose',
+ shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']},
+ getRunData: (inputs) => {
+ const outputSize = ShapeUtil.size(outputShape);
return {
- name: 'Transpose',
- shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']},
- getRunData: (inputs) => {
- const outputShape = getOutputShape(inputs[0].dims, perm);
- const outputSize = ShapeUtil.size(outputShape);
- return {
- outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
- dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
- programUniforms: [
+ outputs: [{dims: outputShape, dataType: inputs[0].dataType}],
+ dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)},
+ programUniforms: useShapesUniforms ?
+ [
{type: 'uint32', data: outputSize},
...createTensorShapeVariables(inputs[0].dims),
...createTensorShapeVariables(outputShape),
+ ] :
+ [
+ {type: 'uint32', data: outputSize},
],
- };
- },
- getShaderSource,
};
- };
+ },
+ getShaderSource,
+ };
+};
export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => {
validateInputs(context.inputs);
- context.compute(
- createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm));
+ context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm));
};
export const parseTransposeAttributes = (attributes: Record): TransposeAttributes =>
diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts
index 7aa866773bcb1..efeb086256cf3 100644
--- a/js/web/lib/wasm/proxy-messages.ts
+++ b/js/web/lib/wasm/proxy-messages.ts
@@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError {
in ?: number;
}
+interface MessageIsOrtEnvInitialized extends MessageError {
+ type: 'is-ort-env-initialized';
+ out?: boolean;
+}
+
export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize|
- MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling;
+ MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized;
diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts
index fe8bd9b11b191..1f4595818e5c0 100644
--- a/js/web/lib/wasm/proxy-worker/main.ts
+++ b/js/web/lib/wasm/proxy-worker/main.ts
@@ -4,7 +4,7 @@
///
import {OrtWasmMessage} from '../proxy-messages';
-import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl';
+import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl';
import {initializeWebAssembly} from '../wasm-factory';
self.onmessage = (ev: MessageEvent): void => {
@@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => {
postMessage({type: 'end-profiling', err} as OrtWasmMessage);
}
break;
+ case 'is-ort-env-initialized':
+ try {
+ const ortEnvInitialized = isOrtEnvInitialized();
+ postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage);
+ } catch (err) {
+ postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage);
+ }
+ break;
default:
}
};
diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts
index a3e4a1ef1fc75..069a1fa452dbc 100644
--- a/js/web/lib/wasm/proxy-wrapper.ts
+++ b/js/web/lib/wasm/proxy-wrapper.ts
@@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = [];
const runCallbacks: Array> = [];
const endProfilingCallbacks: Array> = [];
+const isOrtEnvInitializedCallbacks: Array> = [];
const ensureWorker = (): void => {
if (initializing || !initialized || aborted || !proxyWorker) {
@@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => {
endProfilingCallbacks.shift()![0]();
}
break;
+ case 'is-ort-env-initialized':
+ if (ev.data.err) {
+ isOrtEnvInitializedCallbacks.shift()![1](ev.data.err);
+ } else {
+ isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!);
+ }
+ break;
default:
}
};
@@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => {
core.endProfiling(sessionId);
}
};
+
+export const isOrtEnvInitialized = async(): Promise => {
+ if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) {
+ ensureWorker();
+ return new Promise((resolve, reject) => {
+ isOrtEnvInitializedCallbacks.push([resolve, reject]);
+ const message: OrtWasmMessage = {type: 'is-ort-env-initialized'};
+ proxyWorker!.postMessage(message);
+ });
+ } else {
+ return core.isOrtEnvInitialized();
+ }
+};
diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts
new file mode 100644
index 0000000000000..83d133b9a5157
--- /dev/null
+++ b/js/web/lib/wasm/session-handler-for-training.ts
@@ -0,0 +1,73 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common';
+
+import {SerializableModeldata} from './proxy-messages';
+import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl';
+import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl';
+
+export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler {
+ async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise {
+ throw new Error('Method not implemented.');
+ }
+ async getContiguousParameters(_trainableOnly: boolean): Promise {
+ throw new Error('Method not implemented.');
+ }
+ private sessionId: number;
+ private checkpointId: number;
+
+ inputNames: string[];
+ outputNames: string[];
+
+ inputEncodedNames: number[];
+ outputEncodedNames: number[];
+
+ async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise {
+ let buffer: Uint8Array;
+ if (typeof uriOrBuffer === 'string') {
+ const response = await fetch(uriOrBuffer);
+ const arrayBuffer = await response.arrayBuffer();
+ buffer = new Uint8Array(arrayBuffer);
+ } else {
+ buffer = uriOrBuffer;
+ }
+ return createSessionAllocate(buffer);
+ }
+
+ async createTrainingSession(
+ checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array,
+ evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array,
+ options: InferenceSession.SessionOptions) {
+ if (!isOrtEnvInitialized()) {
+ await initRuntime(env);
+ }
+ const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer);
+ const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer);
+ // 0 is supposed to be the nullptr
+ let evalModelData: SerializableModeldata = [0, 0];
+ let optimizerModelData: SerializableModeldata = [0, 0];
+
+ if (evalModelUriOrBuffer !== '') {
+ evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer);
+ }
+ if (optimizerModelUriOrBuffer !== '') {
+ optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer);
+ }
+
+ this.checkpointId = createCheckpointHandle(checkpointData);
+ [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] =
+ createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options);
+ }
+
+ async dispose(): Promise {
+ return releaseTrainingSessionAndCheckpoint(
+ this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames);
+ }
+
+ async runTrainStep(
+ _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType,
+ _options: InferenceSession.RunOptions): Promise {
+ throw new Error('Method not implemented yet.');
+ }
+}
diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts
index d1760e37c93f7..a5017a920f38b 100644
--- a/js/web/lib/wasm/session-handler.ts
+++ b/js/web/lib/wasm/session-handler.ts
@@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises';
import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common';
import {SerializableModeldata, TensorMetadata} from './proxy-messages';
-import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper';
+import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper';
import {isGpuBufferSupportedType} from './wasm-common';
-let runtimeInitialized: boolean;
let runtimeInitializationPromise: Promise|undefined;
const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => {
@@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan
}
async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise {
- if (!runtimeInitialized) {
+ if (!(await isOrtEnvInitialized())) {
if (!runtimeInitializationPromise) {
runtimeInitializationPromise = initializeRuntime(env);
}
await runtimeInitializationPromise;
runtimeInitializationPromise = undefined;
- runtimeInitialized = true;
}
if (typeof pathOrBuffer === 'string') {
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 5b49a1d4202e3..947242945c665 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType
import {getInstance} from './wasm-factory';
import {allocWasmString, checkLastError} from './wasm-utils';
+let ortEnvInitialized = false;
+
/**
* get the input/output count of the session.
* @param sessionHandle the handle representing the session. should be non-zero.
@@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise => {
const initJsep = require('./jsep/init').init;
await initJsep(getInstance(), env);
}
+
+ ortEnvInitialized = true;
};
/**
@@ -93,6 +97,8 @@ type SessionMetadata = [
const activeSessions = new Map();
+export const isOrtEnvInitialized = (): boolean => ortEnvInitialized;
+
/**
* allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession.
* @returns a 2-elements tuple - the pointer and size of the allocated buffer
diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts
new file mode 100644
index 0000000000000..4830b5d2b5e80
--- /dev/null
+++ b/js/web/lib/wasm/wasm-training-core-impl.ts
@@ -0,0 +1,162 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+import {InferenceSession} from 'onnxruntime-common';
+
+import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages';
+import {setSessionOptions} from './session-options';
+import {getInstance} from './wasm-factory';
+import {checkLastError} from './wasm-utils';
+
+const NO_TRAIN_FUNCS_MSG =
+ 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' +
+ 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' +
+ 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.';
+
+export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => {
+ const wasm = getInstance();
+
+ const [checkpointDataOffset, checkpointDataLength] = checkpointData;
+ let checkpointHandle = 0;
+
+ try {
+ if (wasm._OrtTrainingLoadCheckpoint) {
+ checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength);
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+
+ if (checkpointHandle === 0) {
+ checkLastError('Error occurred when trying to create a CheckpointState.');
+ }
+ return checkpointHandle;
+ } catch (e) {
+ if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) {
+ wasm._OrtTrainingReleaseCheckpoint(checkpointHandle);
+ }
+ throw e;
+ } finally {
+ // free buffer from wasm heap
+ wasm._OrtFree(checkpointData[0]);
+ }
+};
+
+const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => {
+ const wasm = getInstance();
+ const stack = wasm.stackSave();
+ try {
+ const dataOffset = wasm.stackAlloc(8);
+ if (wasm._OrtTrainingGetModelInputOutputCount) {
+ const errorCode =
+ wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel);
+ if (errorCode !== 0) {
+ checkLastError('Can\'t get session input/output count.');
+ }
+ return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]];
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+ } finally {
+ wasm.stackRestore(stack);
+ }
+};
+
+const getModelInputOutputNamesLoop =
+ (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => {
+ const names = [];
+ const wasm = getInstance();
+
+ const namesUTF8Encoded = [];
+
+ for (let i = 0; i < count; i++) {
+ if (wasm._OrtTrainingGetModelInputOutputName) {
+ const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel);
+ if (name === 0) {
+ checkLastError('Can\'t get input or output name');
+ }
+
+ namesUTF8Encoded.push(name);
+ names.push(wasm.UTF8ToString(name));
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+ }
+ return [names, namesUTF8Encoded];
+ };
+
+const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => {
+ const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false);
+
+ const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false);
+ const [outputNames, outputNamesUTF8Encoded] =
+ getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false);
+
+ return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded];
+};
+
+export const createTrainingSessionHandle =
+ (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata,
+ optimizerModelData: SerializableModeldata,
+ options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => {
+ const wasm = getInstance();
+
+ let trainingSessionHandle = 0;
+ let sessionOptionsHandle = 0;
+ let allocs: number[] = [];
+ let inputNamesUTF8Encoded: number[] = [];
+ let outputNamesUTF8Encoded: number[] = [];
+
+ let inputNames: string[] = [];
+ let outputNames: string[] = [];
+
+ try {
+ [sessionOptionsHandle, allocs] = setSessionOptions(options);
+ if (wasm._OrtTrainingCreateSession) {
+ trainingSessionHandle = wasm._OrtTrainingCreateSession(
+ sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0],
+ evalModelData[1], optimizerModelData[0], optimizerModelData[1]);
+ } else {
+ throw new Error(NO_TRAIN_FUNCS_MSG);
+ }
+
+ if (trainingSessionHandle === 0) {
+ checkLastError('Error occurred when trying to create a TrainingSession.');
+ }
+
+ [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] =
+ getTrainingModelInputOutputNames(trainingSessionHandle);
+ return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded];
+
+ } catch (e) {
+ if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) {
+ wasm._OrtTrainingReleaseSession(trainingSessionHandle);
+ }
+ throw e;
+ } finally {
+ wasm._free(trainModelData[0]);
+ wasm._free(evalModelData[0]);
+ wasm._free(optimizerModelData[0]);
+
+ if (sessionOptionsHandle !== 0) {
+ wasm._OrtReleaseSessionOptions(sessionOptionsHandle);
+ }
+ allocs.forEach(alloc => wasm._free(alloc));
+ inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+ outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+ }
+ };
+
+export const releaseTrainingSessionAndCheckpoint =
+ (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]):
+ void => {
+ const wasm = getInstance();
+ inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+ outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf));
+
+ if (wasm._OrtTrainingReleaseSession) {
+ wasm._OrtTrainingReleaseSession(sessionId);
+ }
+ if (wasm._OrtTrainingReleaseCheckpoint) {
+ wasm._OrtTrainingReleaseCheckpoint(checkpointId);
+ }
+ };
diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc
index 285d14018e74d..e1edfa7e41513 100644
--- a/js/web/test/data/ops/transpose.jsonc
+++ b/js/web/test/data/ops/transpose.jsonc
@@ -166,5 +166,29 @@
]
}
]
+ },
+ {
+ "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]",
+ "operator": "Transpose",
+ "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }],
+ "cases": [
+ {
+ "name": "T[3, 1, 2, 1, 4]",
+ "inputs": [
+ {
+ "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24],
+ "dims": [3, 1, 2, 1, 4],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24],
+ "dims": [4, 1, 1, 3, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 0b55cb7804c61..694c40bf3eda6 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -16,7 +16,6 @@
#include
#include
-#include
using onnxruntime::concurrency::ThreadPool;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 73b83057bdbe9..00e82c9844b3d 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -206,6 +206,7 @@ Status CheckInputs(const T* query,
}
}
+ int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
@@ -216,13 +217,21 @@ Status CheckInputs(const T* query,
} else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
- } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) {
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(kv_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
+ } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) &&
+ mask_dims[1] == static_cast(sequence_length) &&
+ mask_dims[2] == static_cast(total_sequence_length)) {
+ mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}
if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
+ "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}
@@ -257,7 +266,6 @@ Status CheckInputs(const T* query,
}
}
- int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
new file mode 100644
index 0000000000000..4a266af789250
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -0,0 +1,115 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "contrib_ops/cpu/bert/rotary_embedding.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+
+#include "core/platform/threadpool.h"
+
+using onnxruntime::concurrency::ThreadPool;
+using namespace onnxruntime::contrib::rotary_embedding_helper;
+
+namespace onnxruntime {
+namespace contrib {
+
+// These ops are internal-only, so register outside of onnx
+ONNX_OPERATOR_TYPED_KERNEL_EX(
+ RotaryEmbedding,
+ kMSDomain,
+ 1,
+ float,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T", DataTypeImpl::GetTensorType())
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()),
+ RotaryEmbedding);
+
+template
+RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
+ scale = info.GetAttrOrDefault("scale", 1.0);
+ interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+}
+
+template
+Status RotaryEmbedding::Compute(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* position_ids = context->Input(1);
+ const Tensor* cos_cache = context->Input(2);
+ const Tensor* sin_cache = context->Input(3);
+
+ RotaryParameters parameters = {};
+ ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ ¶meters));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ if (parameters.sequence_length > parameters.max_sequence_length) {
+ // Launch update_cos_sin_cache kernel with scale
+ ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
+ }
+
+ const T* input_src = input->Data();
+ const int64_t* pos_ids_data = position_ids->Data();
+ const T* cos_cache_data = cos_cache->Data();
+ const T* sin_cache_data = sin_cache->Data();
+ T* output_dest = output->MutableData();
+
+ const int batch_size = parameters.batch_size;
+ const int sequence_length = parameters.sequence_length;
+ const int num_heads = parameters.num_heads;
+ const int head_size = parameters.head_size;
+ const int position_ids_format = parameters.position_ids_format;
+ const int half_head_size = head_size / 2;
+
+ AllocatorPtr allocator;
+ ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
+ auto* tp = context->GetOperatorThreadPool();
+
+ const int loop_len = batch_size * sequence_length * num_heads;
+ const double cost = static_cast(head_size);
+ ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
+ for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
+ const int b = static_cast((ptr / num_heads) / sequence_length);
+ const int s = static_cast((ptr / num_heads) % sequence_length);
+ const int n = static_cast(ptr % num_heads);
+
+ const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
+ const int data_offset = block_offset * head_size;
+
+ const T* input_data = input_src + data_offset;
+ T* output_data = output_dest + data_offset;
+
+ // Cache is (M, H/2)
+ const int position_id = (position_ids_format == 0)
+ ? static_cast(pos_ids_data[0]) + s
+ : static_cast(pos_ids_data[b * sequence_length + s]);
+ const int cache_offset = position_id * half_head_size;
+ const T* cos_data = cos_cache_data + cache_offset;
+ const T* sin_data = sin_cache_data + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ for (int i = 0; i < head_size; i++) {
+ if (interleaved) {
+ cache_idx = (i / 2) % half_head_size;
+ sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
+ j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
+ } else {
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
+ j = (i + half_head_size) % head_size;
+ }
+ output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+ }
+ }
+ });
+
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
new file mode 100644
index 0000000000000..be834a66cdc69
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -0,0 +1,23 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+class RotaryEmbedding final : public OpKernel {
+ public:
+ RotaryEmbedding(const OpKernelInfo& info);
+ Status Compute(OpKernelContext* context) const override;
+
+ protected:
+ float scale;
+ bool interleaved;
+};
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
new file mode 100644
index 0000000000000..cf8080800e072
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -0,0 +1,121 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/common.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace rotary_embedding_helper {
+
+// Parameters deduced from node attributes and inputs/outputs.
+struct RotaryParameters {
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size used by cos/sin cache * 2
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+};
+
+template
+Status CheckInputs(const T* input,
+ const T* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ void* parameters) {
+ // input : (batch_size, sequence_length, hidden_size)
+ // position ids : (1) or (batch_size, sequence_length)
+ // cos cache : (max_sequence_length, head_size / 2)
+ // sin cache : (max_sequence_length, head_size / 2)
+
+ // Check input
+ const auto& input_dims = input->Shape().GetDims();
+ if (input_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ",
+ input_dims.size());
+ }
+ // Check position_ids
+ const auto& position_ids_dims = position_ids->Shape().GetDims();
+ if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ",
+ "dimensions, got ", position_ids_dims.size());
+ }
+ // Check cos_cache and sin_cache
+ const auto& cos_cache_dims = cos_cache->Shape().GetDims();
+ if (cos_cache_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ",
+ cos_cache_dims.size());
+ }
+ const auto& sin_cache_dims = sin_cache->Shape().GetDims();
+ if (sin_cache_dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ",
+ sin_cache_dims.size());
+ }
+ if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ",
+ "the same shape");
+ }
+
+ // Get attributes from inputs
+ int batch_size = static_cast(input_dims[0]);
+ int sequence_length = static_cast(input_dims[1]);
+ int hidden_size = static_cast(input_dims[2]);
+ int max_sequence_length = static_cast(cos_cache_dims[0]);
+ int head_size = static_cast(cos_cache_dims[1]) * 2;
+ int num_heads = hidden_size / head_size;
+ int position_ids_format = -1;
+
+ // Check position_ids input shapes
+ if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) {
+ if (batch_size != static_cast(position_ids_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ",
+ "batch_size, got ", position_ids_dims[0]);
+ }
+ if (sequence_length != static_cast(position_ids_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ",
+ "sequence_length, got ", position_ids_dims[1]);
+ }
+ position_ids_format = 1;
+ } else {
+ position_ids_format = 0;
+ }
+ // Check cos_cache input shapes
+ if (max_sequence_length != static_cast(cos_cache_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
+ "max_sequence_length, got ", cos_cache_dims[0]);
+ }
+ if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
+ "head_size / 2, got ", cos_cache_dims[1]);
+ }
+ // Check sin_cache input shapes
+ if (max_sequence_length != static_cast(sin_cache_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
+ "max_sequence_length, got ", sin_cache_dims[0]);
+ }
+ if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
+ "head_size / 2, got ", sin_cache_dims[1]);
+ }
+
+ // Set rotary parameters
+ if (parameters != nullptr) {
+ RotaryParameters* output_parameters = reinterpret_cast(parameters);
+ output_parameters->batch_size = batch_size;
+ output_parameters->sequence_length = sequence_length;
+ output_parameters->hidden_size = hidden_size;
+ output_parameters->head_size = head_size;
+ output_parameters->num_heads = num_heads;
+ output_parameters->max_sequence_length = max_sequence_length;
+ output_parameters->position_ids_format = position_ids_format;
+ }
+
+ return Status::OK();
+}
+
+} // namespace rotary_embedding_helper
+} // namespace contrib
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index b4c51ab290eb7..f9d9b13f0fedc 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
@@ -29,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
#ifndef ORT_MINIMAL_BUILD
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4);
#endif
@@ -124,6 +126,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu);
@@ -253,6 +257,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -266,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo, // backward compatibility
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifndef ORT_MINIMAL_BUILD
BuildKernelCreateInfo,
#endif
@@ -299,6 +305,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
new file mode 100644
index 0000000000000..cb8e97a592d8c
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h
@@ -0,0 +1,202 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include
+#include
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+#if defined(_MSC_VER)
+#define FORCEINLINE __forceinline
+#else
+#define FORCEINLINE __attribute__((always_inline)) inline
+#endif
+
+typedef enum Bnb_DataType_t {
+ FP4 = 0,
+ NF4 = 1,
+} Bnb_DataType_t;
+
+FORCEINLINE uint8_t QuantizeOneFP4(float x) {
+ // FP4 with bias of 3
+ // first bit is a sign
+ // subnormals
+ // 0b000 = 0
+ // 0b001 = 0.0625
+ // 0b110 = 2
+ // 0b111 = 3
+ // 0b100 = 4
+ // 0b101 = 6
+ // 0b010 = 8
+ // 0b011 = 12
+
+ // we do a binary search
+ // the pivots are divided by 12 (the FP4 absmax)
+ // since we assum input data is in [-1.0, 1.0]
+
+ // !be careful here, its easy to make a mistake
+ // that is difficult to noice if you add an extra
+ // zero somewhere!
+
+ uint8_t sign = x < 0 ? 0b1000 : 0b0000;
+ x = fabsf(x);
+ if (x > 0.29166667f) {
+ if (x > 0.583333f) {
+ if (x > 0.8333333f) {
+ return 0b0011 + sign;
+ } else {
+ return 0b0010 + sign;
+ }
+ } else if (x > 0.4166667f) {
+ return 0b101 + sign;
+ } else {
+ return 0b100 + sign;
+ }
+ } else if (x > 0.0859375f) {
+ if (x > 0.20833333f) {
+ return 0b0111 + sign;
+ } else {
+ return 0b0110 + sign;
+ }
+ } else if (x > 0.00260417f) {
+ return 0b0001 + sign;
+ } else {
+ return 0b0000 + sign;
+ }
+}
+
+FORCEINLINE uint8_t QuantizeOneNF4(float x) {
+ if (x > 0.03979014977812767f) {
+ if (x > 0.3893125355243683f) { // 1
+ if (x > 0.6427869200706482f) { // 11
+ if (x > 0.8614784181118011f) { // 111
+ return 0b1111;
+ } else {
+ return 0b1110;
+ }
+ } else if (x > 0.5016634166240692f) { // 110
+ return 0b1101;
+ } else {
+ return 0b1100;
+ }
+ } else if (x > 0.2035212516784668f) { // 10
+ if (x > 0.2920137718319893f) { // 101
+ return 0b1011;
+ } else {
+ return 0b1010;
+ }
+ } else if (x > 0.1202552504837513f) { // 100
+ return 0b1001;
+ } else {
+ return 0b1000;
+ }
+ } else if (x > -0.33967943489551544f) { // 0
+ if (x > -0.13791173323988914f) { // 01
+ if (x > -0.045525018125772476f) { // 011
+ return 0b0111;
+ } else {
+ return 0b0110;
+ }
+ } else if (x > -0.23460740596055984f) { // 010
+ return 0b0101;
+ } else {
+ return 0b0100;
+ }
+ } else if (x > -0.6106329262256622f) { // 00
+ if (x > -0.4599952697753906f) { // 001
+ return 0b0011;
+ } else {
+ return 0b0010;
+ }
+ } else if (x > -0.8480964004993439f) { // 000
+ return 0b0001;
+ } else {
+ return 0b0000;
+ }
+}
+
+template
+FORCEINLINE uint8_t QuantizeOneBnb4(float x) {
+ if constexpr (DATA_TYPE == FP4)
+ return QuantizeOneFP4(x);
+ else
+ return QuantizeOneNF4(x);
+}
+
+template
+FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) {
+ float local_absmax = 0.0f;
+
+ int32_t block_len = std::min(block_size, numel - block_idx * block_size);
+ int32_t src_offset = block_idx * block_size;
+ int32_t dst_offset = block_idx * block_size / 2;
+
+ for (int32_t idx = 0; idx < block_len; idx++) {
+ const float v = static_cast(src[src_offset + idx]);
+ local_absmax = fmaxf(local_absmax, fabsf(v));
+ }
+
+ absmax_block = static_cast(local_absmax);
+ const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f;
+
+ for (int32_t idx = 0; idx < block_len; idx += 2) {
+ const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax;
+ const uint8_t vi0 = QuantizeOneBnb4(v0);
+
+ const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0;
+ const uint8_t vi1 = QuantizeOneBnb4(v1);
+
+ dst[dst_offset + idx / 2] = (vi0 << 4) | vi1;
+ }
+}
+
+static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f,
+ 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f,
+ -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f,
+ -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f};
+
+static float nf4_qaunt_map[16] = {-1.0f,
+ -0.6961928009986877f,
+ -0.5250730514526367f,
+ -0.39491748809814453f,
+ -0.28444138169288635f,
+ -0.18477343022823334f,
+ -0.09105003625154495f,
+ 0.0f,
+ 0.07958029955625534f,
+ 0.16093020141124725f,
+ 0.24611230194568634f,
+ 0.33791524171829224f,
+ 0.44070982933044434f,
+ 0.5626170039176941f,
+ 0.7229568362236023f,
+ 1.0f};
+
+template
+FORCEINLINE T DequantizeOneBnb4(uint8_t x) {
+ if constexpr (DATA_TYPE == FP4)
+ return static_cast(fp4_qaunt_map[x]);
+ else
+ return static_cast(nf4_qaunt_map[x]);
+}
+
+template
+FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) {
+ int32_t block_len = std::min(block_size, numel - block_idx * block_size);
+ int32_t src_offset = block_idx * block_size / 2;
+ int32_t dst_offset = block_idx * block_size;
+
+ for (int32_t idx = 0; idx < block_len; idx += 2) {
+ const uint8_t val = src[src_offset + idx / 2];
+
+ dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block;
+ if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block;
+ }
+}
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h
new file mode 100644
index 0000000000000..5ddb77e5b5ee3
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h
@@ -0,0 +1,143 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "blockwise_quant_block_bnb4.h"
+
+#include
+
+#include "core/common/safeint.h"
+#include "core/framework/float16.h"
+#include "core/platform/threadpool.h"
+#include
+
+namespace onnxruntime {
+namespace contrib {
+
+template
+void QuantizeBlockwiseBnb4(
+ uint8_t* dst, // shape: [(N * K + 1) / 2]
+ const T* src, // shape: [N, K]
+ T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ int32_t numel = N * K;
+ int32_t total_block_count = (numel + block_size - 1) / block_size;
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ total_block_count,
+ [&](ptrdiff_t block_idx) {
+ QuantizeBlockBnb4(
+ src,
+ dst,
+ absmax[block_idx],
+ static_cast(block_idx),
+ numel);
+ },
+ 0);
+}
+
+#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \
+ if (quant_type == FP4) \
+ QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \
+ else \
+ QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool);
+
+template
+void QuantizeBlockwiseBnb4(
+ uint8_t* dst, // shape: [(N * K + 1) / 2]
+ const T* src, // shape: [N, K]
+ T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t block_size,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ if (block_size == 16) {
+ QuantizeBlockwiseBn4DataTyped(16, quant_type);
+ } else if (block_size == 32) {
+ QuantizeBlockwiseBn4DataTyped(32, quant_type);
+ } else if (block_size == 64) {
+ QuantizeBlockwiseBn4DataTyped(64, quant_type);
+ } else if (block_size == 128) {
+ QuantizeBlockwiseBn4DataTyped(128, quant_type);
+ } else if (block_size == 256) {
+ QuantizeBlockwiseBn4DataTyped(256, quant_type);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef QuantizeBlockwiseBn4DataTyped
+
+template
+void DequantizeBlockwiseBnb4(
+ T* dst, // shape: [N, K]
+ const uint8_t* src, // shape: [(N * K + 1) / 2)]
+ const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ int32_t numel = N * K;
+ int32_t total_block_count = (numel + block_size - 1) / block_size;
+
+ concurrency::ThreadPool::TryBatchParallelFor(
+ thread_pool,
+ total_block_count,
+ [&](ptrdiff_t block_idx) {
+ DequantizeBlockBnb4(
+ src,
+ dst,
+ absmax[block_idx],
+ static_cast(block_idx),
+ numel);
+ },
+ 0);
+}
+
+#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \
+ if (quant_type == FP4) \
+ DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \
+ else \
+ DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool);
+
+template
+void DequantizeBlockwiseBnb4(
+ T* dst, // shape: [N, K]
+ const uint8_t* src, // shape: [(N * K + 1) / 2)]
+ const T* absmax, // shape: [(N * K + block_size - 1) / block_size]
+ int32_t block_size,
+ int32_t quant_type,
+ int32_t N,
+ int32_t K,
+ onnxruntime::concurrency::ThreadPool* thread_pool) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ if (block_size == 16) {
+ DequantizeBlockwiseBn4DataTyped(16, quant_type);
+ } else if (block_size == 32) {
+ DequantizeBlockwiseBn4DataTyped(32, quant_type);
+ } else if (block_size == 64) {
+ DequantizeBlockwiseBn4DataTyped(64, quant_type);
+ } else if (block_size == 128) {
+ DequantizeBlockwiseBn4DataTyped(128, quant_type);
+ } else if (block_size == 256) {
+ DequantizeBlockwiseBn4DataTyped(256, quant_type);
+ } else {
+ ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported.");
+ }
+}
+
+#undef DequantizeBlockwiseBn4DataTyped
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
new file mode 100644
index 0000000000000..2f3ede49c3650
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc
@@ -0,0 +1,109 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/common/safeint.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/math/matmul_helper.h"
+#include "core/providers/common.h"
+#include "dequantize_blockwise_bnb4.h"
+#include "core/mlas/inc/mlas.h"
+
+namespace onnxruntime {
+namespace contrib {
+
+class MatMulBnb4 final : public OpKernel {
+ public:
+ MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) {
+ ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_));
+ ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_));
+ ORT_ENFORCE(
+ quant_type_ == FP4 || quant_type_ == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+ }
+
+ Status Compute(OpKernelContext* context) const override;
+
+ private:
+ int64_t K_;
+ int64_t N_;
+ int64_t block_size_;
+ int64_t quant_type_;
+};
+
+Status MatMulBnb4::Compute(OpKernelContext* ctx) const {
+ concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
+
+ const Tensor* a = ctx->Input(0);
+ const Tensor* b_quant = ctx->Input(1);
+ const Tensor* absmax = ctx->Input(2);
+
+ const float* a_data = a->Data();
+ const uint8_t* b_quant_data = b_quant->Data();
+ const float* absmax_data = absmax->Data();
+
+ AllocatorPtr allocator;
+ auto status = ctx->GetTempSpaceAllocator(&allocator);
+ ORT_RETURN_IF_ERROR(status);
+ auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_);
+ DequantizeBlockwiseBnb4(
+ tmp_b_data_ptr.get(),
+ b_quant_data,
+ absmax_data,
+ static_cast(block_size_),
+ static_cast(quant_type_),
+ static_cast(N_),
+ static_cast(K_),
+ thread_pool);
+
+ constexpr bool transa = false;
+ constexpr bool transb = true;
+ TensorShape b_shape({N_, K_});
+ MatMulComputeHelper helper;
+ ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb));
+
+ Tensor* y = ctx->Output(0, helper.OutputShape());
+
+ // Bail out early if the output is going to be empty
+ if (y->Shape().Size() == 0) return Status::OK();
+
+ auto* y_data = y->MutableData();
+
+ const size_t max_len = helper.OutputOffsets().size();
+ const size_t M = static_cast(helper.M());
+ const size_t N = static_cast(helper.N());
+ const size_t K = static_cast(helper.K());
+ const size_t lda = helper.Lda(transa);
+ const size_t ldb = helper.Ldb(transb);
+
+ // TODO: implement with native kernel
+ std::vector data(max_len);
+ for (size_t i = 0; i < max_len; i++) {
+ data[i].BIsPacked = false;
+ data[i].A = a_data + helper.LeftOffsets()[i];
+ data[i].lda = lda;
+ data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i];
+ data[i].ldb = ldb;
+ data[i].C = y_data + helper.OutputOffsets()[i];
+ data[i].ldc = N;
+ data[i].alpha = 1.f;
+ data[i].beta = 0.0f;
+ }
+ MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool);
+
+ return Status::OK();
+}
+
+ONNX_OPERATOR_KERNEL_EX(
+ MatMulBnb4,
+ kMSDomain,
+ 1,
+ kCpuExecutionProvider,
+ KernelDefBuilder()
+ .TypeConstraint("T1", DataTypeImpl::GetTensorType())
+ .TypeConstraint("T2", DataTypeImpl::GetTensorType()),
+ MatMulBnb4);
+
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
index e86a12d9fb873..4e103c2556a7a 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc
@@ -20,20 +20,29 @@ namespace contrib {
kCpuExecutionProvider, \
KernelDefBuilder() \
.TypeConstraint("T", DataTypeImpl::GetTensorType()), \
- SkipLayerNorm);
+ SkipLayerNorm); \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ SkipSimplifiedLayerNormalization, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ KernelDefBuilder() \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ SkipLayerNorm);
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(double)
-template
-SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
+template
+SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info)
: OpKernel(op_kernel_info) {
ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK());
ORT_ENFORCE(epsilon_ >= 0);
}
-template
-Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
+template
+Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
const Tensor* input = p_ctx->Input(0);
const Tensor* skip = p_ctx->Input(1);
const Tensor* gamma = p_ctx->Input(2);
@@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const {
}
mean = mean / hidden_size;
- mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
+ if (simplified) {
+ mean_square = sqrt(mean_square / hidden_size + epsilon_);
+ } else {
+ mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_);
+ }
for (int64_t h = 0; h < hidden_size; h++) {
- if (nullptr == beta_data) {
+ if (simplified) {
+ p_output[h] = p_output[h] / mean_square * gamma_data[h];
+ } else if (nullptr == beta_data) {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h];
} else {
p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h];
diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
index 7723541cb6b18..69edf4609e340 100644
--- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
+++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h
@@ -10,7 +10,7 @@
namespace onnxruntime {
namespace contrib {
-template
+template
class SkipLayerNorm final : public OpKernel {
public:
SkipLayerNorm(const OpKernelInfo& op_kernel_info);
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
new file mode 100644
index 0000000000000..b4b5dac1fbe19
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -0,0 +1,84 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"
+#include "contrib_ops/cuda/bert/rotary_embedding.h"
+#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
+
+using namespace onnxruntime::cuda;
+using namespace ::onnxruntime::common;
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::contrib::rotary_embedding_helper;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ RotaryEmbedding, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .TypeConstraint("M", DataTypeImpl::GetTensorType()), \
+ RotaryEmbedding);
+
+REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(MLFloat16)
+
+template
+RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
+ scale = info.GetAttrOrDefault("scale", 1.0);
+ interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+}
+
+template
+Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
+ const Tensor* input = context->Input(0);
+ const Tensor* position_ids = context->Input(1);
+ const Tensor* cos_cache = context->Input(2);
+ const Tensor* sin_cache = context->Input(3);
+
+ RotaryParameters parameters = {};
+ ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ ¶meters));
+
+ Tensor* output = context->Output(0, input->Shape());
+
+ if (parameters.sequence_length > parameters.max_sequence_length) {
+ // Launch update_cos_sin_cache kernel with scale
+ ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
+ }
+
+ // Launch rotary embedding kernel
+ typedef typename ToCudaType::MappedType CudaT;
+ auto& device_prop = GetDeviceProp();
+ return LaunchRotaryEmbeddingKernel(
+ Stream(context),
+ reinterpret_cast(output->template MutableData()),
+ reinterpret_cast(input->template Data()),
+ position_ids->Data(),
+ reinterpret_cast(cos_cache->template Data()),
+ reinterpret_cast(sin_cache->template Data()),
+ parameters.batch_size,
+ parameters.sequence_length,
+ parameters.num_heads,
+ parameters.head_size,
+ parameters.max_sequence_length,
+ parameters.position_ids_format,
+ interleaved,
+ device_prop.maxThreadsPerBlock);
+
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
new file mode 100644
index 0000000000000..6dab2ad56749e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
@@ -0,0 +1,27 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+using namespace onnxruntime::cuda;
+
+template
+class RotaryEmbedding final : public CudaKernel {
+ public:
+ RotaryEmbedding(const OpKernelInfo& info);
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ protected:
+ float scale;
+ bool interleaved;
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
new file mode 100644
index 0000000000000..c54e72dcfce13
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -0,0 +1,141 @@
+/*
+Copyright (c) Microsoft Corporation.
+Licensed under the MIT License.
+*/
+
+/*
+Kernel implementation for rotary embeddings.
+*/
+
+#include
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "contrib_ops/cuda/bert/rotary_embedding_impl.h"
+
+using namespace onnxruntime::cuda;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
+ const T* input, // BxSxNxH
+ const T* cos_cache, // Mx(H/2)
+ const T* sin_cache, // Mx(H/2)
+ const int64_t* position_ids, // (1) or BxS
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int position_ids_format,
+ const bool interleaved) {
+ // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
+ // Use .x in innermost loop to access global memory efficiently
+
+ const int b = blockIdx.z;
+ const int s = blockIdx.y;
+ const int n = blockIdx.x;
+
+ const int i = threadIdx.x;
+
+ const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
+ const int data_offset = block_offset * head_size;
+
+ const T* input_data = input + data_offset;
+ T* output_data = output + data_offset;
+
+ // Cache is (M, H/2)
+ const int half_head_size = head_size / 2;
+ const int position_id = (position_ids_format == 0) ? \
+ static_cast(position_ids[0]) + s \
+ : static_cast(position_ids[b * sequence_length + s]);
+ const int cache_offset = position_id * half_head_size;
+ const T* cos_data = cos_cache + cache_offset;
+ const T* sin_data = sin_cache + cache_offset;
+
+ int cache_idx = 0;
+ T sign = 0;
+ int j = 0;
+ if (interleaved) {
+ cache_idx = (i / 2) % half_head_size;
+ sign = (i % 2 == 0) ? -1 : 1;
+ j = (i % 2 == 0) ? i+1 : i-1; // i - sign
+ } else {
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? -1 : 1;
+ j = (i + half_head_size) % head_size;
+ }
+ output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
+}
+
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block) {
+
+ constexpr int smem_size = 0;
+ const dim3 grid(num_heads, sequence_length, batch_size);
+ const dim3 block(head_size, 1, 1);
+
+ // Note: Current implementation assumes head_size <= max_threads_per_block
+ // because head_size is currently large for LLaMA-2. For smaller head_size
+ // and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
+ // instead. This will require kernel changes to support.
+
+ assert(head_size <= max_threads_per_block);
+ RotaryEmbeddingBSNH<<>>(
+ output, input, cos_cache, sin_cache, position_ids,
+ sequence_length, num_heads, head_size, position_ids_format, interleaved
+ );
+
+ return CUDA_CALL(cudaGetLastError());
+}
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ float* output,
+ const float* input,
+ const int64_t* position_ids,
+ const float* cos_cache,
+ const float* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ half* output,
+ const half* input,
+ const int64_t* position_ids,
+ const half* cos_cache,
+ const half* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
new file mode 100644
index 0000000000000..29ff48a8ad0fb
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -0,0 +1,31 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ T* output,
+ const T* input,
+ const int64_t* position_ids,
+ const T* cos_cache,
+ const T* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block);
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 52ff285539360..e762a80cb0e2f 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -91,6 +91,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -116,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping);
@@ -250,6 +254,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
@@ -275,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
new file mode 100644
index 0000000000000..e58723f0b31e1
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu
@@ -0,0 +1,129 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include "core/providers/cuda/cuda_common.h"
+#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h"
+#include "dequantize_blockwise_bnb4.cuh"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+template
+Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) {
+ ORT_ENFORCE(
+ quant_type == FP4 || quant_type == NF4,
+ "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported.");
+
+ T host_quant_map[16];
+ switch (quant_type) {
+ case FP4:
+ for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]);
+ break;
+ case NF4:
+ for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]);
+ break;
+ }
+ CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream));
+
+ return Status::OK();
+}
+
+template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream);
+
+template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream);
+
+template
+__global__ void kDequantizeBlockwise(
+ const T* quant_map,
+ T* output,
+ const uint8_t* quant_data,
+ const T* absmax,
+ const int block_size,
+ const int n) {
+ const int n_load = (gridDim.x * TILE_SIZE);
+ int valid_items_load = 0;
+ int valid_items_store = 0;
+ const int base_idx = (blockIdx.x * TILE_SIZE);
+
+ T vals[NUM_PER_TH * 2];
+ uint8_t qvals[NUM_PER_TH];
+ T local_abs_max = T(0.0f);
+
+ typedef cub::BlockLoad