diff --git a/.lintrunner.toml b/.lintrunner.toml index c44a66200ad1b..4e5d077b08ff4 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -45,6 +45,7 @@ exclude_patterns = [ 'cmake/external/**', # ignore generated flatbuffers code 'onnxruntime/core/flatbuffers/ort_flatbuffers_py/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', @@ -76,6 +77,7 @@ exclude_patterns = [ 'cmake/**', 'orttraining/*', 'onnxruntime/core/flatbuffers/**', + 'orttraining/orttraining/python/training/optim/_ds_code_store.py', ] command = [ 'python', diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index 08ca90d7c3b7f..f9f2fbdab7b10 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -2,6 +2,36 @@ "$schema": "https://json.schemastore.org/component-detection-manifest.json", "Version": 1, "Registrations": [ + { + "component": { + "type": "git", + "git": { + "commitHash": "a896e3d066448b3530dbcaa48869fafefd738f57", + "repositoryUrl": "https://github.com/emscripten-core/emsdk.git" + }, + "comments": "git submodule at cmake/external/emsdk" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "7a2ed51a6b682a83e345ff49fc4cfd7ca47550db", + "repositoryUrl": "https://github.com/google/libprotobuf-mutator.git" + }, + "comments": "git submodule at cmake/external/libprotobuf-mutator" + } + }, + { + "component": { + "type": "git", + "git": { + "commitHash": "0c296085f9f65f0f8ef7aec7b9eed55faf37dc40", + "repositoryUrl": "https://github.com/onnx/onnx.git" + }, + "comments": "git submodule at cmake/external/onnx" + } + }, { "component": { "type": "git", @@ -166,7 +196,7 @@ "component": { "type": "git", "git": { - "commitHash": "fdefbe85ed9c362b95b9b401cd19db068a76141f", + "commitHash": "6a20ba82b439ea1fd650da4d389e96b60a1dd828", "repositoryUrl": "https://github.com/onnx/onnx.git" }, "comments": "onnx" diff --git a/cmake/deps.txt b/cmake/deps.txt index 7cf49f02333a4..26fd35075c4b9 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -24,7 +24,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 -onnx;https://github.com/onnx/onnx/archive/14303de049144035dfd94ace5f7a3b44773b1aad.zip;250eab9690392b248d75b56e605fb49eca373442 +onnx;https://github.com/onnx/onnx/archive/6a20ba82b439ea1fd650da4d389e96b60a1dd828.zip;179a22ad4cd67109c60031ae4b6cf2f434d8bd7e #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/0462dc31ae78f48744b6141ae376df1f96d3f459.zip;5ff086361956cceb81ed17453a1fd8db2aa4328d protobuf;https://github.com/protocolbuffers/protobuf/archive/refs/tags/v21.12.zip;7cf2733949036c7d52fda017badcab093fe73bfa diff --git a/cmake/external/onnx b/cmake/external/onnx index e2525550194ce..6a20ba82b439e 160000 --- a/cmake/external/onnx +++ b/cmake/external/onnx @@ -1 +1 @@ -Subproject commit e2525550194ce3d8a2c4a3af451c9d9b3ae6650e +Subproject commit 6a20ba82b439ea1fd650da4d389e96b60a1dd828 diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index 9985f5c8bc516..a76664adff207 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -344,12 +344,18 @@ else() set(mlas_platform_srcs ${mlas_platform_srcs} ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S + ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S + ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp + ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index fe3e577b4fc36..de1458c120016 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -54,6 +54,11 @@ set(contrib_ops_excluded_files "quantization/attention_quantization_impl.cuh" "quantization/dequantize_blockwise.cuh" "quantization/dequantize_blockwise.cu" + "quantization/dequantize_blockwise_bnb4.cuh" + "quantization/dequantize_blockwise_bnb4.cu" + "quantization/matmul_bnb4.cc" + "quantization/matmul_bnb4.cuh" + "quantization/matmul_bnb4.cu" "quantization/matmul_nbits.cc" "quantization/matmul_nbits.cuh" "quantization/matmul_nbits.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 7e67ec6d0c94e..1a76c18a6a8e0 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -47,6 +47,7 @@ Do not modify directly.* * com.microsoft.Inverse * com.microsoft.Irfft * com.microsoft.LongformerAttention + * com.microsoft.MatMulBnb4 * com.microsoft.MatMulFpQ4 * com.microsoft.MatMulInteger16 * com.microsoft.MatMulIntegerToFloat @@ -90,6 +91,7 @@ Do not modify directly.* * com.microsoft.RemovePadding * com.microsoft.RestorePadding * com.microsoft.Rfft + * com.microsoft.RotaryEmbedding * com.microsoft.SampleOp * com.microsoft.Sampling * com.microsoft.SkipLayerNormalization @@ -2503,6 +2505,62 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.MatMulBnb4** + + MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + + Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. + Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
K : int (required)
+
size of each input feature
+
N : int (required)
+
size of each output feature
+
block_size : int (required)
+
number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.
+
quant_type : int (required)
+
quantization data type. 0 for FP4, 1 for NF4.
+
+ +#### Inputs + +
+
A : T1
+
The input tensor, not quantized
+
B : T2
+
1-dimensional quantized data for weight
+
absmax : T1
+
quantization constants
+
+ +#### Outputs + +
+
Y : T1
+
tensor. The output tensor has the same rank as the input.
+
+ +#### Type Constraints + +
+
T1 : tensor(float), tensor(float16)
+
Constrain input and output types to float/half_float tensors.
+
T2 : tensor(uint8)
+
Constrain quantized weight types to uint8.
+
+ + ### **com.microsoft.MatMulFpQ4** Matrix product with right hand matrix being pre-packed and quantized int4 data blob. @@ -2834,7 +2892,7 @@ This version of the operator has been available since version 1 of the 'com.micr
bias (optional) : T
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
-
Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)
+
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
relative_position_bias (optional) : T
relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
@@ -4796,6 +4854,54 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.RotaryEmbedding** + + RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices + that are multiplied to query and key before the inner product of query and key is taken. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
+
scale : float
+
Custom scale will be used if specified. Default value is 1.0
+
+ +#### Inputs + +
+
input : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
position_ids : M
+
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
+
cos_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
+ +#### Outputs + +
+
output : T
+
3D tensor with shape (batch_size, sequence_length, hidden_size)
+
+ +#### Type Constraints + +
+
T : tensor(float), tensor(float16)
+
Constrain input and output types to float tensors.
+
M : tensor(int64)
+
Constrain input and output types to integer tensors
+
+ + ### **com.microsoft.SampleOp** Sample echo operator. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e2d500006b05f..84249df92231b 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -25,6 +25,7 @@ Do not modify directly.* |||13|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int64)| |Affine|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|AffineGrid|*in* theta:**T1**
*in* size:**T2**
*out* grid:**T1**|20+|**T1** = tensor(double), tensor(float)
**T2** = tensor(int64)| |And|*in* A:**T**
*in* B:**T**
*out* C:**T1**|7+|**T** = tensor(bool)
**T1** = tensor(bool)| |ArgMax|*in* data:**T**
*out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| |||[11, 12]|**T** = tensor(double), tensor(float), tensor(int32), tensor(int8), tensor(uint8)| @@ -156,8 +157,10 @@ Do not modify directly.* |||[1, 10]|**B** = tensor(bool)
**V** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |ImageScaler|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(float)| |InstanceNormalization|*in* input:**T**
*in* scale:**T**
*in* B:**T**
*out* output:**T**|6+|**T** = tensor(float)| -|IsInf|*in* X:**T1**
*out* Y:**T2**|10+|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| -|IsNaN|*in* X:**T1**
*out* Y:**T2**|13+|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| +|IsInf|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[10, 19]|**T1** = tensor(double), tensor(float)
**T2** = tensor(bool)| +|IsNaN|*in* X:**T1**
*out* Y:**T2**|20+|**T1** = tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz)
**T2** = tensor(bool)| +|||[13, 19]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |||[9, 12]|**T1** = tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)| |LRN|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float)| |||[1, 12]|**T** = tensor(float)| @@ -454,6 +457,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)| |MatMulFpQ4|*in* A:**T1**
*in* B:**T2**
*in* B_shape:**T3**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(int64)| |MatMulInteger16|*in* A:**T1**
*in* B:**T2**
*out* Y:**T3**|1+|**T1** = tensor(int16)
**T2** = tensor(int16)
**T3** = tensor(int32)| |MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)| @@ -477,9 +481,11 @@ Do not modify directly.* |QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(int16), tensor(int8), tensor(uint16), tensor(uint8)| |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| +|SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| @@ -847,6 +853,7 @@ Do not modify directly.* |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| +|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T2**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)| |MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)| |NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)| @@ -866,6 +873,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipSimplifiedLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/include/onnxruntime/core/framework/float8.h b/include/onnxruntime/core/framework/float8.h index 0fd04f28d44b7..dd607cbbc6952 100644 --- a/include/onnxruntime/core/framework/float8.h +++ b/include/onnxruntime/core/framework/float8.h @@ -208,9 +208,10 @@ struct Float8E4M3FNUZ { val = static_cast((b & 0x80000000) >> 24); // sign if ((b & 0x7fffffff) == 0x7f800000) { // infinity if (saturate) { + // the highest available value val |= 0x7F; } else { - // infinity + // NaN val = 0x80; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN @@ -362,8 +363,10 @@ struct Float8E5M2 { val = (b & 0x80000000) >> 24; // sign if ((b & 0x7FFFFFFF) == 0x7F800000) { // inf if (saturate) { + // the highest available value val |= 0x7B; } else { + // the infinity val |= 0x7C; } } else if ((b & 0x7F800000) == 0x7F800000) { // NaN diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index f153e88909b8d..462d410e13769 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -3,6 +3,7 @@ #pragma once +#include #include #include #include @@ -83,10 +84,10 @@ class Node { gsl::span output_args, const NodeAttributes* attributes, std::string_view domain) { - Init(std::string{name}, std::string{op_type}, std::string{description}, - std::vector{input_args.begin(), input_args.end()}, - std::vector{output_args.begin(), output_args.end()}, - attributes, std::string{domain}); + Init(name, op_type, description, + input_args, + output_args, + attributes, domain); } #endif @@ -563,13 +564,13 @@ class Node { ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Node); #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - void Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, + void Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain); + std::string_view domain); #endif #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) @@ -1141,8 +1142,22 @@ class Graph { */ Status InlineFunction(Node& node); + /** + Directly insert the nodes in the function proto provided into the graph. + The function converts Constant nodes into the initializers in the graph. + It then creates a node in the graph for each of the function nodes. + All of the names are expected to be specialized, and, therefore unique. + See function_utils::Specialize(). + + The Graph needs to be Resolve()d after this call. + @param func_to_inline + @returns Status indicating success or providing an error message. + */ + + Status InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline); + /** Mark a NodeArg name as coming from the outer scope when programmatically constructing a Graph that will - be used as a GraphProto attribute in another Node.. + be used as a GraphProto attribute in another Node. e.g. when creating a Graph instance that will be used as a subgraph in a control flow operator, it is necessary to define placeholder NodeArgs for outer scope values. This prevents these values from becoming explicit graph inputs when the Graph is resolved. @@ -1391,6 +1406,13 @@ class Graph { Node& AddNode(const ONNX_NAMESPACE::NodeProto& node_proto, const ArgNameToTypeMap& name_to_type); + /** Helper that converts and adds constant node proto to an initializer in the graph. + @param constant_node_proto Constant node to convert + @param new_name use the new name for the initializer. + */ + Status AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& constant_node_proto, + std::optional new_name); + #endif Version IrVersion() const noexcept { diff --git a/include/onnxruntime/core/providers/cuda/cuda_context.h b/include/onnxruntime/core/providers/cuda/cuda_context.h index 13c176dad3cc5..646f33ed952a4 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_context.h +++ b/include/onnxruntime/core/providers/cuda/cuda_context.h @@ -19,6 +19,7 @@ struct CudaContext : public CustomOpContext { cudaStream_t cuda_stream = {}; cudnnHandle_t cudnn_handle = {}; cublasHandle_t cublas_handle = {}; + OrtAllocator* deferred_cpu_allocator = {}; void Init(const OrtKernelContext& kernel_ctx) override { const auto& ort_api = Ort::GetApi(); @@ -44,6 +45,36 @@ struct CudaContext : public CustomOpContext { ORT_CXX_API_THROW("failed to fetch cublas handle", OrtErrorCode::ORT_RUNTIME_EXCEPTION); } cublas_handle = reinterpret_cast(resource); + + resource = {}; + status = ort_api.KernelContext_GetResource(&kernel_ctx, ORT_CUDA_RESOUCE_VERSION, CudaResource::deferred_cpu_allocator_t, &resource); + if (status) { + ORT_CXX_API_THROW("failed to fetch deferred cpu allocator", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + deferred_cpu_allocator = reinterpret_cast(resource); + } + + void* AllocDeferredCpuMem(size_t size) const { + if (0 == size) { + return {}; + } + const auto& ort_api = Ort::GetApi(); + void* mem = {}; + auto status = ort_api.AllocatorAlloc(deferred_cpu_allocator, size, &mem); + if (status) { + ORT_CXX_API_THROW("failed to allocate deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + return mem; + } + + void FreeDeferredCpuMem(void* mem) const { + if (mem) { + const auto& ort_api = Ort::GetApi(); + auto status = ort_api.AllocatorFree(deferred_cpu_allocator, mem); + if (status) { + ORT_CXX_API_THROW("failed to free deferred cpu memory", OrtErrorCode::ORT_RUNTIME_EXCEPTION); + } + } } }; diff --git a/include/onnxruntime/core/providers/cuda/cuda_resource.h b/include/onnxruntime/core/providers/cuda/cuda_resource.h index e46fc5b4219dd..8c3ed46ade6a1 100644 --- a/include/onnxruntime/core/providers/cuda/cuda_resource.h +++ b/include/onnxruntime/core/providers/cuda/cuda_resource.h @@ -3,10 +3,11 @@ #include "core/providers/resource.h" -#define ORT_CUDA_RESOUCE_VERSION 1 +#define ORT_CUDA_RESOUCE_VERSION 2 enum CudaResource : int { cuda_stream_t = cuda_resource_offset, cudnn_handle_t, - cublas_handle_t + cublas_handle_t, + deferred_cpu_allocator_t, }; \ No newline at end of file diff --git a/include/onnxruntime/core/providers/dml/dml_provider_factory.h b/include/onnxruntime/core/providers/dml/dml_provider_factory.h index dd4ffb835d51c..cf3ddc3f125f9 100644 --- a/include/onnxruntime/core/providers/dml/dml_provider_factory.h +++ b/include/onnxruntime/core/providers/dml/dml_provider_factory.h @@ -128,7 +128,7 @@ struct OrtDmlApi { /** * SessionOptionsAppendExecutionProvider_DML2 * Creates a DirectML Execution Provider given the supplied device options that contain a performance preference - * (high power, low power, or defult) and a device filter (None, GPU, or NPU). + * (high power, low power, or default) and a device filter (None, GPU, or NPU). */ ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts); }; diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 37545f41b43dd..831def24e4f5e 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -67,6 +67,16 @@ static const char* const kOrtSessionOptionsEnableQuantQDQCleanup = "session.enab // GeluApproximation has side effects which may change the inference results. It is disabled by default due to this. static const char* const kOrtSessionOptionsEnableGeluApproximation = "optimization.enable_gelu_approximation"; +// This setting controls whether to enable AheadOfTime function inlining. +// AOT function inlining examines the graph and attempts to inline as many locally defined functions in the model +// as possible with the help of enabled execution providers. +// This can reduce the number of function calls and improve performance because it is done before +// Level1 optimizers and constant folding. However, under some circumstances, when the EPs are not available, +// one can disable the AOT inlining, produce an optimized model and postpone AOT until run time. +// "0": enable; "1": disable. +// Its default value is "0". +static const char* const kOrtSessionOptionsDisableAheadOfTimeFunctionInlining = "session.disable_aot_function_inlining"; + #ifdef ENABLE_TRAINING // Specifies a list of op types for memory footprint reduction. // The value should be a ","-delimited list of pair of diff --git a/js/common/lib/training-session-impl.ts b/js/common/lib/training-session-impl.ts index f06d06bda035f..47e67879e66ce 100644 --- a/js/common/lib/training-session-impl.ts +++ b/js/common/lib/training-session-impl.ts @@ -1,11 +1,14 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +import {resolveBackend} from './backend-impl.js'; import {TrainingSessionHandler} from './backend.js'; import {InferenceSession as InferenceSession} from './inference-session.js'; import {TrainingSession as TrainingSessionInterface, TrainingSessionCreateOptions} from './training-session.js'; type SessionOptions = InferenceSession.SessionOptions; +const noBackendErrMsg: string = 'Training backend could not be resolved. ' + + 'Make sure you\'re using the correct configuration & WebAssembly files.'; export class TrainingSession implements TrainingSessionInterface { private constructor(handler: TrainingSessionHandler) { @@ -20,9 +23,23 @@ export class TrainingSession implements TrainingSessionInterface { return this.handler.outputNames; } - static async create(_trainingOptions: TrainingSessionCreateOptions, _sessionOptions?: SessionOptions): + static async create(trainingOptions: TrainingSessionCreateOptions, sessionOptions?: SessionOptions): Promise { - throw new Error('Method not implemented'); + const evalModel: string|Uint8Array = trainingOptions.evalModel || ''; + const optimizerModel: string|Uint8Array = trainingOptions.optimizerModel || ''; + const options: SessionOptions = sessionOptions || {}; + + // get backend hints + const eps = options.executionProviders || []; + const backendHints = eps.map(i => typeof i === 'string' ? i : i.name); + const backend = await resolveBackend(backendHints); + if (backend.createTrainingSessionHandler) { + const handler = await backend.createTrainingSessionHandler( + trainingOptions.checkpointState, trainingOptions.trainModel, evalModel, optimizerModel, options); + return new TrainingSession(handler); + } else { + throw new Error(noBackendErrMsg); + } } async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { diff --git a/js/web/docs/webgl-operators.md b/js/web/docs/webgl-operators.md index de84134ddbb3f..7c129b66bfa3d 100644 --- a/js/web/docs/webgl-operators.md +++ b/js/web/docs/webgl-operators.md @@ -12,6 +12,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Acos](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acos) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Acos-7) | | [Acosh](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Acosh) | | | [Add](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Add) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Add-14) | +| [AffineGrid](https://github.com/onnx/onnx/blob/main/docs/Operators.md#AffineGrid) | | | [And](https://github.com/onnx/onnx/blob/main/docs/Operators.md#And) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#And-7) | | [ArgMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMax) | | | [ArgMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ArgMin) | | @@ -67,6 +68,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Gather](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gather) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gather-13) | | [GatherElements](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherElements) | | | [GatherND](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GatherND) | | +| [Gelu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gelu) | | | [Gemm](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Gemm) | [7-8](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-7), [9-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-9), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Gemm-13) | | [GlobalAveragePool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalAveragePool) | [1+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#GlobalAveragePool-1) | | [GlobalLpPool](https://github.com/onnx/onnx/blob/main/docs/Operators.md#GlobalLpPool) | | @@ -82,6 +84,7 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [Hardmax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Hardmax) | | | [Identity](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Identity) | [1-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-1), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-13), [14-15](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-14), [16-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-16), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Identity-19) | | [If](https://github.com/onnx/onnx/blob/main/docs/Operators.md#If) | | +| [ImageDecoder](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ImageDecoder) | | | [InstanceNormalization](https://github.com/onnx/onnx/blob/main/docs/Operators.md#InstanceNormalization) | [6+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#InstanceNormalization-6) | | [IsInf](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf) | | | [IsNaN](https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsNaN) | | @@ -137,12 +140,13 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [ReduceL2](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceL2) | | | [ReduceLogSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceLogSum-18) | | [ReduceLogSumExp](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceLogSumExp) | | -| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18) | +| [ReduceMax](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMax) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMax-20) | | [ReduceMean](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMean) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMean-18) | -| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18) | +| [ReduceMin](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceMin) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-1), [11](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-11), [12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-12), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-13), [18-19](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-18), [20+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceMin-20) | | [ReduceProd](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceProd) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceProd-18) | | [ReduceSum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSum) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSum-11) | | [ReduceSumSquare](https://github.com/onnx/onnx/blob/main/docs/Operators.md#ReduceSumSquare) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-13), [18+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#ReduceSumSquare-18) | +| [RegexFullMatch](https://github.com/onnx/onnx/blob/main/docs/Operators.md#RegexFullMatch) | | | [Relu](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Relu) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-6), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Relu-14) | | [Reshape](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Reshape) | [5-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-5), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-13), [14-18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-14), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Reshape-19) | | [Resize](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Resize) | [10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-10), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-11), [13-17](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-13), [18](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-18), [19+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Resize-19) | @@ -179,7 +183,9 @@ See [Compatibility](../README.md#Compatibility) for a list of the supported plat | [SplitToSequence](https://github.com/onnx/onnx/blob/main/docs/Operators.md#SplitToSequence) | | | [Sqrt](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sqrt) | [6-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-6), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sqrt-13) | | [Squeeze](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Squeeze) | [1-10](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-1), [11-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-11), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Squeeze-13) | +| [StringConcat](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringConcat) | | | [StringNormalizer](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringNormalizer) | | +| [StringSplit](https://github.com/onnx/onnx/blob/main/docs/Operators.md#StringSplit) | | | [Sub](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sub) | [7-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-7), [13](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-13), [14+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sub-14) | | [Sum](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Sum) | [6-7](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-6), [8-12](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-8), [13+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Sum-13) | | [Tan](https://github.com/onnx/onnx/blob/main/docs/Operators.md#Tan) | [7+](https://github.com/onnx/onnx/blob/main/docs/Changelog.md#Tan-7) | diff --git a/js/web/lib/backend-wasm-training.ts b/js/web/lib/backend-wasm-training.ts index af5b575c87a7f..98e40807aa29c 100644 --- a/js/web/lib/backend-wasm-training.ts +++ b/js/web/lib/backend-wasm-training.ts @@ -4,13 +4,17 @@ import {InferenceSession, TrainingSessionHandler} from 'onnxruntime-common'; import {OnnxruntimeWebAssemblyBackend} from './backend-wasm'; +import {OnnxruntimeWebAssemblyTrainingSessionHandler} from './wasm/session-handler-for-training'; class OnnxruntimeTrainingWebAssemblyBackend extends OnnxruntimeWebAssemblyBackend { async createTrainingSessionHandler( - _checkpointStateUriOrBuffer: string|Uint8Array, _trainModelUriOrBuffer: string|Uint8Array, - _evalModelUriOrBuffer: string|Uint8Array, _optimizerModelUriOrBuffer: string|Uint8Array, - _options: InferenceSession.SessionOptions): Promise { - throw new Error('Method not implemented yet.'); + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions): Promise { + const handler = new OnnxruntimeWebAssemblyTrainingSessionHandler(); + await handler.createTrainingSession( + checkpointStateUriOrBuffer, trainModelUriOrBuffer, evalModelUriOrBuffer, optimizerModelUriOrBuffer, options); + return Promise.resolve(handler); } } diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index b7b2ff4537095..00431a4e86d5b 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -102,6 +102,11 @@ export interface OrtWasmModule extends EmscriptenModule { _OrtTrainingCopyParametersFromBuffer? (trainingHandle: number, parametersBuffer: number, parameterCount: number, trainableOnly: boolean): number; + _OrtTrainingGetModelInputOutputCount? + (trainingHandle: number, inputCount: number, outputCount: number, isEvalModel: boolean): number; + _OrtTrainingGetModelInputOutputName? + (trainingHandle: number, index: number, isInput: boolean, isEvalModel: boolean): number; + _OrtTrainingReleaseSession?(trainingHandle: number): void; // #endregion diff --git a/js/web/lib/wasm/jsep/webgpu/ops/common.ts b/js/web/lib/wasm/jsep/webgpu/ops/common.ts index 0a64d1ad1792a..1d3fc78fe368a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/common.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/common.ts @@ -803,3 +803,6 @@ export const getBroadcastDims = (inShape: readonly number[], outShape: readonly } return dims; }; + +// TODO: remove this limitation once >4D dims are supported by uniform. +export const enableShapesUniforms = (rank: number): boolean => rank <= 4; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts index d241b8b92a669..e880afe09a5d8 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv-transpose.ts @@ -232,7 +232,7 @@ const convTranspose2d = // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposePerm), + createTransposeProgramInfo(inputs[1], weightTransposePerm), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts index b323a36cee5c8..c7ea0cffe51c3 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/conv.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/conv.ts @@ -168,7 +168,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut if (isChannelsLast) { const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; @@ -208,7 +208,7 @@ const conv2d = (context: ComputeContext, inputs: readonly TensorView[], attribut // STEP.1: transpose weight const transposedWeight = (context.kernelCustomData.wT as TensorView | undefined) ?? context.compute( - createTransposeProgramInfo(inputs[1].dataType, inputs[1].dims.length, weightTransposeAttribute), + createTransposeProgramInfo(inputs[1], weightTransposeAttribute), {inputs: [1], outputs: [attributes.wIsConst ? -2 : -1]})[0]; if (attributes.wIsConst && !context.kernelCustomData.wT) { context.kernelCustomData.wT = transposedWeight; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts index 05f02b07c4d89..1538644412afd 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/pool.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/pool.ts @@ -18,16 +18,18 @@ const validateInputs = (inputs: readonly TensorView[]): void => { if (!inputs || inputs.length !== 1) { throw new Error('Pool ops requires 1 input.'); } - if (inputs[0].dims.length !== 4) { - throw new Error('Pool ops supports 2-D inputs only for now.'); + if (inputs[0].dims.length !== 4 && inputs[0].dims.length !== 3) { + throw new Error('Pool ops supports 1-D or 2-D inputs only for now.'); } }; const getAdjustedPoolAttributesAndOutputShape = ( input: TensorView, attributes: AttributeType, isGlobalOperator: boolean): [AttributeType, number[]] => { const isChannelsLast = attributes.format === 'NHWC'; - const inputShapeAsChannelFirst = - isChannelsLast ? [input.dims[0], input.dims[3], input.dims[1], input.dims[2]] : input.dims.slice(); + const inputShapeAsChannelFirst = input.dims.slice(); + if (isChannelsLast) { + inputShapeAsChannelFirst.splice(1, 0, inputShapeAsChannelFirst.pop()!); // Move channel to the second position. + } const hasDilations = Object.hasOwnProperty.call(attributes, 'dilations'); const kernelShape = attributes.kernelShape.slice(); const strides = attributes.strides.slice(); @@ -44,15 +46,9 @@ const getAdjustedPoolAttributesAndOutputShape = ( @@ -76,22 +72,22 @@ const generatePoolingCode = = ${inputDims[dimIdxW]}) { - pad++; - continue; - } - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + if (xIndices[${dimIdxW}] < 0 || xIndices[${dimIdxW}] >= ${inputDims[dimIdxW]}) { + pad++; + continue; + } + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } else { codeW = ` - for (var i: u32 = 0u; i < ${kw}u; i++) { - xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; - let x_val = x[${x.indicesToOffset('xIndices')}]; - ${op1} - }`; + for (var i: u32 = 0u; i < ${kw}u; i++) { + xIndices[${dimIdxW}] = indices[${dimIdxW}] * ${sw} - ${pwStart} + i; + let x_val = x[${x.indicesToOffset('xIndices')}]; + ${op1} + }`; } if (attributes.kernelShape.length === 2) { diff --git a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts index fe556a7fd8552..c4d43e9f466f5 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/transpose.ts @@ -6,7 +6,7 @@ import {ShapeUtil} from '../../util'; import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; import {ComputeContext, ProgramInfo} from '../types'; -import {createTensorShapeVariables, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; +import {createTensorShapeVariables, enableShapesUniforms, IndicesHelper, inputVariable, outputVariable, ShaderHelper} from './common'; export interface TransposeAttributes extends AttributeWithCacheKey { readonly perm: number[]; @@ -35,13 +35,18 @@ const permFunctionBody = (perm: number[], rank: number, input: IndicesHelper, ou return reverseFunc.join('\n'); }; -export const createTransposeProgramInfo = - (inputDataType: number, inputRank: number, permAttr: number[]): ProgramInfo => { - const perm = getAdjustedPerm(inputRank, permAttr); - const output = outputVariable('output', inputDataType, (permAttr && permAttr.length) || inputRank); - const input = inputVariable('a', inputDataType, inputRank); +export const createTransposeProgramInfo = (inputTensor: TensorView, permAttr: number[]): ProgramInfo => { + const inputDataType = inputTensor.dataType; + const inputRank = inputTensor.dims.length; + const perm = getAdjustedPerm(inputRank, permAttr); + const useShapesUniforms = enableShapesUniforms(inputRank); + const outputShape = getOutputShape(inputTensor.dims, perm); + const outShapeOrRank = useShapesUniforms ? outputShape.length : outputShape; + const inShapeOrRank = useShapesUniforms ? inputRank : inputTensor.dims; + const output = outputVariable('output', inputDataType, outShapeOrRank); + const input = inputVariable('a', inputDataType, inShapeOrRank); - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = (shaderHelper: ShaderHelper) => ` ${shaderHelper.registerUniform('output_size', 'u32').declareVariables(input, output)} ${permFunctionBody(perm, inputRank, input, output)} @@ -54,30 +59,32 @@ export const createTransposeProgramInfo = ${output.setByOffset('global_idx', input.getByIndices('aIndices'))} }`; + return { + name: 'Transpose', + shaderCache: {hint: `${permAttr}`, inputDependencies: useShapesUniforms ? ['rank'] : ['dims']}, + getRunData: (inputs) => { + const outputSize = ShapeUtil.size(outputShape); return { - name: 'Transpose', - shaderCache: {hint: `${permAttr}`, inputDependencies: ['rank']}, - getRunData: (inputs) => { - const outputShape = getOutputShape(inputs[0].dims, perm); - const outputSize = ShapeUtil.size(outputShape); - return { - outputs: [{dims: outputShape, dataType: inputs[0].dataType}], - dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, - programUniforms: [ + outputs: [{dims: outputShape, dataType: inputs[0].dataType}], + dispatchGroup: {x: Math.ceil(outputSize / 64 /* workgroup size */)}, + programUniforms: useShapesUniforms ? + [ {type: 'uint32', data: outputSize}, ...createTensorShapeVariables(inputs[0].dims), ...createTensorShapeVariables(outputShape), + ] : + [ + {type: 'uint32', data: outputSize}, ], - }; - }, - getShaderSource, }; - }; + }, + getShaderSource, + }; +}; export const transpose = (context: ComputeContext, attributes: TransposeAttributes): void => { validateInputs(context.inputs); - context.compute( - createTransposeProgramInfo(context.inputs[0].dataType, context.inputs[0].dims.length, attributes.perm)); + context.compute(createTransposeProgramInfo(context.inputs[0], attributes.perm)); }; export const parseTransposeAttributes = (attributes: Record): TransposeAttributes => diff --git a/js/web/lib/wasm/proxy-messages.ts b/js/web/lib/wasm/proxy-messages.ts index 7aa866773bcb1..efeb086256cf3 100644 --- a/js/web/lib/wasm/proxy-messages.ts +++ b/js/web/lib/wasm/proxy-messages.ts @@ -73,5 +73,10 @@ interface MesssageEndProfiling extends MessageError { in ?: number; } +interface MessageIsOrtEnvInitialized extends MessageError { + type: 'is-ort-env-initialized'; + out?: boolean; +} + export type OrtWasmMessage = MessageInitWasm|MessageInitOrt|MessageCreateSessionAllocate|MessageCreateSessionFinalize| - MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling; + MessageCreateSession|MessageReleaseSession|MessageRun|MesssageEndProfiling|MessageIsOrtEnvInitialized; diff --git a/js/web/lib/wasm/proxy-worker/main.ts b/js/web/lib/wasm/proxy-worker/main.ts index fe8bd9b11b191..1f4595818e5c0 100644 --- a/js/web/lib/wasm/proxy-worker/main.ts +++ b/js/web/lib/wasm/proxy-worker/main.ts @@ -4,7 +4,7 @@ /// import {OrtWasmMessage} from '../proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, releaseSession, run} from '../wasm-core-impl'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, extractTransferableBuffers, initRuntime, isOrtEnvInitialized, releaseSession, run} from '../wasm-core-impl'; import {initializeWebAssembly} from '../wasm-factory'; self.onmessage = (ev: MessageEvent): void => { @@ -89,6 +89,14 @@ self.onmessage = (ev: MessageEvent): void => { postMessage({type: 'end-profiling', err} as OrtWasmMessage); } break; + case 'is-ort-env-initialized': + try { + const ortEnvInitialized = isOrtEnvInitialized(); + postMessage({type: 'is-ort-env-initialized', out: ortEnvInitialized} as OrtWasmMessage); + } catch (err) { + postMessage({type: 'is-ort-env-initialized', err} as OrtWasmMessage); + } + break; default: } }; diff --git a/js/web/lib/wasm/proxy-wrapper.ts b/js/web/lib/wasm/proxy-wrapper.ts index a3e4a1ef1fc75..069a1fa452dbc 100644 --- a/js/web/lib/wasm/proxy-wrapper.ts +++ b/js/web/lib/wasm/proxy-wrapper.ts @@ -24,6 +24,7 @@ const createSessionCallbacks: Array> = []; const runCallbacks: Array> = []; const endProfilingCallbacks: Array> = []; +const isOrtEnvInitializedCallbacks: Array> = []; const ensureWorker = (): void => { if (initializing || !initialized || aborted || !proxyWorker) { @@ -92,6 +93,13 @@ const onProxyWorkerMessage = (ev: MessageEvent): void => { endProfilingCallbacks.shift()![0](); } break; + case 'is-ort-env-initialized': + if (ev.data.err) { + isOrtEnvInitializedCallbacks.shift()![1](ev.data.err); + } else { + isOrtEnvInitializedCallbacks.shift()![0](ev.data.out!); + } + break; default: } }; @@ -251,3 +259,16 @@ export const endProfiling = async(sessionId: number): Promise => { core.endProfiling(sessionId); } }; + +export const isOrtEnvInitialized = async(): Promise => { + if (!BUILD_DEFS.DISABLE_WASM_PROXY && isProxy()) { + ensureWorker(); + return new Promise((resolve, reject) => { + isOrtEnvInitializedCallbacks.push([resolve, reject]); + const message: OrtWasmMessage = {type: 'is-ort-env-initialized'}; + proxyWorker!.postMessage(message); + }); + } else { + return core.isOrtEnvInitialized(); + } +}; diff --git a/js/web/lib/wasm/session-handler-for-training.ts b/js/web/lib/wasm/session-handler-for-training.ts new file mode 100644 index 0000000000000..83d133b9a5157 --- /dev/null +++ b/js/web/lib/wasm/session-handler-for-training.ts @@ -0,0 +1,73 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {env, InferenceSession, SessionHandler, TrainingSessionHandler} from 'onnxruntime-common'; + +import {SerializableModeldata} from './proxy-messages'; +import {createSessionAllocate, initRuntime, isOrtEnvInitialized} from './wasm-core-impl'; +import {createCheckpointHandle, createTrainingSessionHandle, releaseTrainingSessionAndCheckpoint} from './wasm-training-core-impl'; + +export class OnnxruntimeWebAssemblyTrainingSessionHandler implements TrainingSessionHandler { + async loadParametersBuffer(_array: Uint8Array, _trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + async getContiguousParameters(_trainableOnly: boolean): Promise { + throw new Error('Method not implemented.'); + } + private sessionId: number; + private checkpointId: number; + + inputNames: string[]; + outputNames: string[]; + + inputEncodedNames: number[]; + outputEncodedNames: number[]; + + async uriOrBufferToHeap(uriOrBuffer: string|Uint8Array): Promise { + let buffer: Uint8Array; + if (typeof uriOrBuffer === 'string') { + const response = await fetch(uriOrBuffer); + const arrayBuffer = await response.arrayBuffer(); + buffer = new Uint8Array(arrayBuffer); + } else { + buffer = uriOrBuffer; + } + return createSessionAllocate(buffer); + } + + async createTrainingSession( + checkpointStateUriOrBuffer: string|Uint8Array, trainModelUriOrBuffer: string|Uint8Array, + evalModelUriOrBuffer: string|Uint8Array, optimizerModelUriOrBuffer: string|Uint8Array, + options: InferenceSession.SessionOptions) { + if (!isOrtEnvInitialized()) { + await initRuntime(env); + } + const checkpointData: SerializableModeldata = await this.uriOrBufferToHeap(checkpointStateUriOrBuffer); + const trainModelData: SerializableModeldata = await this.uriOrBufferToHeap(trainModelUriOrBuffer); + // 0 is supposed to be the nullptr + let evalModelData: SerializableModeldata = [0, 0]; + let optimizerModelData: SerializableModeldata = [0, 0]; + + if (evalModelUriOrBuffer !== '') { + evalModelData = await this.uriOrBufferToHeap(evalModelUriOrBuffer); + } + if (optimizerModelUriOrBuffer !== '') { + optimizerModelData = await this.uriOrBufferToHeap(optimizerModelUriOrBuffer); + } + + this.checkpointId = createCheckpointHandle(checkpointData); + [[this.sessionId, this.inputNames, this.outputNames], this.inputEncodedNames, this.outputEncodedNames] = + createTrainingSessionHandle(this.checkpointId, trainModelData, evalModelData, optimizerModelData, options); + } + + async dispose(): Promise { + return releaseTrainingSessionAndCheckpoint( + this.checkpointId, this.sessionId, this.inputEncodedNames, this.outputEncodedNames); + } + + async runTrainStep( + _feeds: SessionHandler.FeedsType, _fetches: SessionHandler.FetchesType, + _options: InferenceSession.RunOptions): Promise { + throw new Error('Method not implemented yet.'); + } +} diff --git a/js/web/lib/wasm/session-handler.ts b/js/web/lib/wasm/session-handler.ts index d1760e37c93f7..a5017a920f38b 100644 --- a/js/web/lib/wasm/session-handler.ts +++ b/js/web/lib/wasm/session-handler.ts @@ -5,10 +5,9 @@ import {readFile} from 'node:fs/promises'; import {env, InferenceSession, InferenceSessionHandler, SessionHandler, Tensor} from 'onnxruntime-common'; import {SerializableModeldata, TensorMetadata} from './proxy-messages'; -import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, releaseSession, run} from './proxy-wrapper'; +import {createSession, createSessionAllocate, createSessionFinalize, endProfiling, initializeRuntime, isOrtEnvInitialized, releaseSession, run} from './proxy-wrapper'; import {isGpuBufferSupportedType} from './wasm-common'; -let runtimeInitialized: boolean; let runtimeInitializationPromise: Promise|undefined; const encodeTensorMetadata = (tensor: Tensor, getName: () => string): TensorMetadata => { @@ -57,13 +56,12 @@ export class OnnxruntimeWebAssemblySessionHandler implements InferenceSessionHan } async loadModel(pathOrBuffer: string|Uint8Array, options?: InferenceSession.SessionOptions): Promise { - if (!runtimeInitialized) { + if (!(await isOrtEnvInitialized())) { if (!runtimeInitializationPromise) { runtimeInitializationPromise = initializeRuntime(env); } await runtimeInitializationPromise; runtimeInitializationPromise = undefined; - runtimeInitialized = true; } if (typeof pathOrBuffer === 'string') { diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5b49a1d4202e3..947242945c665 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -10,6 +10,8 @@ import {dataLocationStringToEnum, getTensorElementSize, isGpuBufferSupportedType import {getInstance} from './wasm-factory'; import {allocWasmString, checkLastError} from './wasm-utils'; +let ortEnvInitialized = false; + /** * get the input/output count of the session. * @param sessionHandle the handle representing the session. should be non-zero. @@ -57,6 +59,8 @@ export const initRuntime = async(env: Env): Promise => { const initJsep = require('./jsep/init').init; await initJsep(getInstance(), env); } + + ortEnvInitialized = true; }; /** @@ -93,6 +97,8 @@ type SessionMetadata = [ const activeSessions = new Map(); +export const isOrtEnvInitialized = (): boolean => ortEnvInitialized; + /** * allocate the memory and memcpy the model bytes, preparing for creating an instance of InferenceSession. * @returns a 2-elements tuple - the pointer and size of the allocated buffer diff --git a/js/web/lib/wasm/wasm-training-core-impl.ts b/js/web/lib/wasm/wasm-training-core-impl.ts new file mode 100644 index 0000000000000..4830b5d2b5e80 --- /dev/null +++ b/js/web/lib/wasm/wasm-training-core-impl.ts @@ -0,0 +1,162 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import {InferenceSession} from 'onnxruntime-common'; + +import {SerializableModeldata, SerializableSessionMetadata} from './proxy-messages'; +import {setSessionOptions} from './session-options'; +import {getInstance} from './wasm-factory'; +import {checkLastError} from './wasm-utils'; + +const NO_TRAIN_FUNCS_MSG = + 'Built without training API\'s enabled. Use the onnxruntime-web/training import for training ' + + 'functionality, and make sure that all the correct artifacts are built & moved to the correct folder if ' + + 'using a custom build. Check https://onnxruntime.ai/docs/build/web.html for more information.'; + +export const createCheckpointHandle = (checkpointData: SerializableModeldata): number => { + const wasm = getInstance(); + + const [checkpointDataOffset, checkpointDataLength] = checkpointData; + let checkpointHandle = 0; + + try { + if (wasm._OrtTrainingLoadCheckpoint) { + checkpointHandle = wasm._OrtTrainingLoadCheckpoint(checkpointDataOffset, checkpointDataLength); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (checkpointHandle === 0) { + checkLastError('Error occurred when trying to create a CheckpointState.'); + } + return checkpointHandle; + } catch (e) { + if (wasm._OrtTrainingReleaseCheckpoint && checkpointHandle !== 0) { + wasm._OrtTrainingReleaseCheckpoint(checkpointHandle); + } + throw e; + } finally { + // free buffer from wasm heap + wasm._OrtFree(checkpointData[0]); + } +}; + +const getModelInputOutputCount = (trainingSessionId: number, isEvalModel: boolean): [number, number] => { + const wasm = getInstance(); + const stack = wasm.stackSave(); + try { + const dataOffset = wasm.stackAlloc(8); + if (wasm._OrtTrainingGetModelInputOutputCount) { + const errorCode = + wasm._OrtTrainingGetModelInputOutputCount(trainingSessionId, dataOffset, dataOffset + 4, isEvalModel); + if (errorCode !== 0) { + checkLastError('Can\'t get session input/output count.'); + } + return [wasm.HEAP32[dataOffset / 4], wasm.HEAP32[dataOffset / 4 + 1]]; + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } finally { + wasm.stackRestore(stack); + } +}; + +const getModelInputOutputNamesLoop = + (trainingSessionId: number, count: number, isInput: boolean, isEvalModel: boolean): [string[], number[]] => { + const names = []; + const wasm = getInstance(); + + const namesUTF8Encoded = []; + + for (let i = 0; i < count; i++) { + if (wasm._OrtTrainingGetModelInputOutputName) { + const name = wasm._OrtTrainingGetModelInputOutputName(trainingSessionId, i, isInput, isEvalModel); + if (name === 0) { + checkLastError('Can\'t get input or output name'); + } + + namesUTF8Encoded.push(name); + names.push(wasm.UTF8ToString(name)); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + } + return [names, namesUTF8Encoded]; + }; + +const getTrainingModelInputOutputNames = (trainingSessionId: number): [string[], number[], string[], number[]] => { + const [inputCount, outputCount] = getModelInputOutputCount(trainingSessionId, false); + + const [inputNames, inputNamesUTF8Encoded] = getModelInputOutputNamesLoop(trainingSessionId, inputCount, true, false); + const [outputNames, outputNamesUTF8Encoded] = + getModelInputOutputNamesLoop(trainingSessionId, outputCount, false, false); + + return [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded]; +}; + +export const createTrainingSessionHandle = + (checkpointHandle: number, trainModelData: SerializableModeldata, evalModelData: SerializableModeldata, + optimizerModelData: SerializableModeldata, + options: InferenceSession.SessionOptions): [SerializableSessionMetadata, number[], number[]] => { + const wasm = getInstance(); + + let trainingSessionHandle = 0; + let sessionOptionsHandle = 0; + let allocs: number[] = []; + let inputNamesUTF8Encoded: number[] = []; + let outputNamesUTF8Encoded: number[] = []; + + let inputNames: string[] = []; + let outputNames: string[] = []; + + try { + [sessionOptionsHandle, allocs] = setSessionOptions(options); + if (wasm._OrtTrainingCreateSession) { + trainingSessionHandle = wasm._OrtTrainingCreateSession( + sessionOptionsHandle, checkpointHandle, trainModelData[0], trainModelData[1], evalModelData[0], + evalModelData[1], optimizerModelData[0], optimizerModelData[1]); + } else { + throw new Error(NO_TRAIN_FUNCS_MSG); + } + + if (trainingSessionHandle === 0) { + checkLastError('Error occurred when trying to create a TrainingSession.'); + } + + [inputNames, inputNamesUTF8Encoded, outputNames, outputNamesUTF8Encoded] = + getTrainingModelInputOutputNames(trainingSessionHandle); + return [[trainingSessionHandle, inputNames, outputNames], inputNamesUTF8Encoded, outputNamesUTF8Encoded]; + + } catch (e) { + if (wasm._OrtTrainingReleaseSession && trainingSessionHandle !== 0) { + wasm._OrtTrainingReleaseSession(trainingSessionHandle); + } + throw e; + } finally { + wasm._free(trainModelData[0]); + wasm._free(evalModelData[0]); + wasm._free(optimizerModelData[0]); + + if (sessionOptionsHandle !== 0) { + wasm._OrtReleaseSessionOptions(sessionOptionsHandle); + } + allocs.forEach(alloc => wasm._free(alloc)); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + } + }; + +export const releaseTrainingSessionAndCheckpoint = + (checkpointId: number, sessionId: number, inputNamesUTF8Encoded: number[], outputNamesUTF8Encoded: number[]): + void => { + const wasm = getInstance(); + inputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + outputNamesUTF8Encoded.forEach(buf => wasm._OrtFree(buf)); + + if (wasm._OrtTrainingReleaseSession) { + wasm._OrtTrainingReleaseSession(sessionId); + } + if (wasm._OrtTrainingReleaseCheckpoint) { + wasm._OrtTrainingReleaseCheckpoint(checkpointId); + } + }; diff --git a/js/web/test/data/ops/transpose.jsonc b/js/web/test/data/ops/transpose.jsonc index 285d14018e74d..e1edfa7e41513 100644 --- a/js/web/test/data/ops/transpose.jsonc +++ b/js/web/test/data/ops/transpose.jsonc @@ -166,5 +166,29 @@ ] } ] + }, + { + "name": "Transpose 5D - perms:[4, 3, 1, 0, 2]", + "operator": "Transpose", + "attributes": [{ "name": "perm", "data": [4, 3, 1, 0, 2], "type": "ints" }], + "cases": [ + { + "name": "T[3, 1, 2, 1, 4]", + "inputs": [ + { + "data": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24], + "dims": [3, 1, 2, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [1, 5, 9, 13, 17, 21, 2, 6, 10, 14, 18, 22, 3, 7, 11, 15, 19, 23, 4, 8, 12, 16, 20, 24], + "dims": [4, 1, 1, 3, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index 0b55cb7804c61..694c40bf3eda6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -16,7 +16,6 @@ #include #include -#include using onnxruntime::concurrency::ThreadPool; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index 73b83057bdbe9..00e82c9844b3d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -206,6 +206,7 @@ Status CheckInputs(const T* query, } } + int total_sequence_length = past_sequence_length + kv_sequence_length; AttentionMaskType mask_type = AttentionMaskType::MASK_NONE; if (key_padding_mask != nullptr) { mask_type = AttentionMaskType::MASK_UNKNOWN; @@ -216,13 +217,21 @@ Status CheckInputs(const T* query, } else if (mask_dims[0] == static_cast(3) * static_cast(batch_size) + static_cast(2)) { mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START; } - } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && mask_dims[1] == static_cast(kv_sequence_length)) { + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(kv_sequence_length)) { + mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 2 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(total_sequence_length)) { mask_type = AttentionMaskType::MASK_2D_KEY_PADDING; + } else if (mask_dims.size() == 3 && mask_dims[0] == static_cast(batch_size) && + mask_dims[1] == static_cast(sequence_length) && + mask_dims[2] == static_cast(total_sequence_length)) { + mask_type = AttentionMaskType::MASK_3D_ATTENTION; } if (mask_type == AttentionMaskType::MASK_UNKNOWN) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)"); + "Input 'key_padding_mask' shape shall be 1D, 2D, or 3D"); } } @@ -257,7 +266,6 @@ Status CheckInputs(const T* query, } } - int total_sequence_length = past_sequence_length + kv_sequence_length; bool broadcast_res_pos_bias = false; if (relative_position_bias != nullptr) { const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims(); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..4a266af789250 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -0,0 +1,115 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "contrib_ops/cpu/bert/rotary_embedding.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" + +#include "core/platform/threadpool.h" + +using onnxruntime::concurrency::ThreadPool; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { + +// These ops are internal-only, so register outside of onnx +ONNX_OPERATOR_TYPED_KERNEL_EX( + RotaryEmbedding, + kMSDomain, + 1, + float, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T", DataTypeImpl::GetTensorType()) + .TypeConstraint("M", DataTypeImpl::GetTensorType()), + RotaryEmbedding); + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::Compute(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + const T* input_src = input->Data(); + const int64_t* pos_ids_data = position_ids->Data(); + const T* cos_cache_data = cos_cache->Data(); + const T* sin_cache_data = sin_cache->Data(); + T* output_dest = output->MutableData(); + + const int batch_size = parameters.batch_size; + const int sequence_length = parameters.sequence_length; + const int num_heads = parameters.num_heads; + const int head_size = parameters.head_size; + const int position_ids_format = parameters.position_ids_format; + const int half_head_size = head_size / 2; + + AllocatorPtr allocator; + ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); + + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); + ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { + for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input_src + data_offset; + T* output_data = output_dest + data_offset; + + // Cache is (M, H/2) + const int position_id = (position_ids_format == 0) + ? static_cast(pos_ids_data[0]) + s + : static_cast(pos_ids_data[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache_data + cache_offset; + const T* sin_data = sin_cache_data + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + for (int i = 0; i < head_size; i++) { + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); + j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; + } + } + }); + + return Status::OK(); +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h new file mode 100644 index 0000000000000..be834a66cdc69 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { +namespace contrib { + +template +class RotaryEmbedding final : public OpKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status Compute(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h new file mode 100644 index 0000000000000..cf8080800e072 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -0,0 +1,121 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/common.h" + +namespace onnxruntime { +namespace contrib { +namespace rotary_embedding_helper { + +// Parameters deduced from node attributes and inputs/outputs. +struct RotaryParameters { + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) +}; + +template +Status CheckInputs(const T* input, + const T* position_ids, + const T* cos_cache, + const T* sin_cache, + void* parameters) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + + // Check input + const auto& input_dims = input->Shape().GetDims(); + if (input_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'x' is expected to have 3 dimensions, got ", + input_dims.size()); + } + // Check position_ids + const auto& position_ids_dims = position_ids->Shape().GetDims(); + if (!onnxruntime::IsScalarOr1ElementVector(position_ids) && position_ids_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' is expected to have 0, 1, or 2 ", + "dimensions, got ", position_ids_dims.size()); + } + // Check cos_cache and sin_cache + const auto& cos_cache_dims = cos_cache->Shape().GetDims(); + if (cos_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' is expected to have 2 dimensions, got ", + cos_cache_dims.size()); + } + const auto& sin_cache_dims = sin_cache->Shape().GetDims(); + if (sin_cache_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' is expected to have 2 dimensions, got ", + sin_cache_dims.size()); + } + if (cos_cache_dims[0] != sin_cache_dims[0] || cos_cache_dims[1] != sin_cache_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'cos_cache' and 'sin_cache' are expected to have ", + "the same shape"); + } + + // Get attributes from inputs + int batch_size = static_cast(input_dims[0]); + int sequence_length = static_cast(input_dims[1]); + int hidden_size = static_cast(input_dims[2]); + int max_sequence_length = static_cast(cos_cache_dims[0]); + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; + int position_ids_format = -1; + + // Check position_ids input shapes + if (!onnxruntime::IsScalarOr1ElementVector(position_ids)) { + if (batch_size != static_cast(position_ids_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 0 should be of size ", + "batch_size, got ", position_ids_dims[0]); + } + if (sequence_length != static_cast(position_ids_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'position_ids' dimension 1 should be of size ", + "sequence_length, got ", position_ids_dims[1]); + } + position_ids_format = 1; + } else { + position_ids_format = 0; + } + // Check cos_cache input shapes + if (max_sequence_length != static_cast(cos_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", + "max_sequence_length, got ", cos_cache_dims[0]); + } + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); + } + + // Set rotary parameters + if (parameters != nullptr) { + RotaryParameters* output_parameters = reinterpret_cast(parameters); + output_parameters->batch_size = batch_size; + output_parameters->sequence_length = sequence_length; + output_parameters->hidden_size = hidden_size; + output_parameters->head_size = head_size; + output_parameters->num_heads = num_heads; + output_parameters->max_sequence_length = max_sequence_length; + output_parameters->position_ids_format = position_ids_format; + } + + return Status::OK(); +} + +} // namespace rotary_embedding_helper +} // namespace contrib +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index b4c51ab290eb7..f9d9b13f0fedc 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -20,6 +20,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MultiHeadAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); @@ -29,6 +30,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Gathe class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, TransposeMatMul); // backward compatibility class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, FusedMatMul); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulNBits); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4); #ifndef ORT_MINIMAL_BUILD class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulFpQ4); #endif @@ -124,6 +126,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 1, double, SimplifiedLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipLayerNormalization); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, SkipSimplifiedLayerNormalization); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, double, SkipSimplifiedLayerNormalization); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Trilu); @@ -253,6 +257,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -266,6 +271,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, // backward compatibility BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifndef ORT_MINIMAL_BUILD BuildKernelCreateInfo, #endif @@ -299,6 +305,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h new file mode 100644 index 0000000000000..cb8e97a592d8c --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h @@ -0,0 +1,202 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +#if defined(_MSC_VER) +#define FORCEINLINE __forceinline +#else +#define FORCEINLINE __attribute__((always_inline)) inline +#endif + +typedef enum Bnb_DataType_t { + FP4 = 0, + NF4 = 1, +} Bnb_DataType_t; + +FORCEINLINE uint8_t QuantizeOneFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + uint8_t sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) { + if (x > 0.583333f) { + if (x > 0.8333333f) { + return 0b0011 + sign; + } else { + return 0b0010 + sign; + } + } else if (x > 0.4166667f) { + return 0b101 + sign; + } else { + return 0b100 + sign; + } + } else if (x > 0.0859375f) { + if (x > 0.20833333f) { + return 0b0111 + sign; + } else { + return 0b0110 + sign; + } + } else if (x > 0.00260417f) { + return 0b0001 + sign; + } else { + return 0b0000 + sign; + } +} + +FORCEINLINE uint8_t QuantizeOneNF4(float x) { + if (x > 0.03979014977812767f) { + if (x > 0.3893125355243683f) { // 1 + if (x > 0.6427869200706482f) { // 11 + if (x > 0.8614784181118011f) { // 111 + return 0b1111; + } else { + return 0b1110; + } + } else if (x > 0.5016634166240692f) { // 110 + return 0b1101; + } else { + return 0b1100; + } + } else if (x > 0.2035212516784668f) { // 10 + if (x > 0.2920137718319893f) { // 101 + return 0b1011; + } else { + return 0b1010; + } + } else if (x > 0.1202552504837513f) { // 100 + return 0b1001; + } else { + return 0b1000; + } + } else if (x > -0.33967943489551544f) { // 0 + if (x > -0.13791173323988914f) { // 01 + if (x > -0.045525018125772476f) { // 011 + return 0b0111; + } else { + return 0b0110; + } + } else if (x > -0.23460740596055984f) { // 010 + return 0b0101; + } else { + return 0b0100; + } + } else if (x > -0.6106329262256622f) { // 00 + if (x > -0.4599952697753906f) { // 001 + return 0b0011; + } else { + return 0b0010; + } + } else if (x > -0.8480964004993439f) { // 000 + return 0b0001; + } else { + return 0b0000; + } +} + +template +FORCEINLINE uint8_t QuantizeOneBnb4(float x) { + if constexpr (DATA_TYPE == FP4) + return QuantizeOneFP4(x); + else + return QuantizeOneNF4(x); +} + +template +FORCEINLINE void QuantizeBlockBnb4(const T* src, uint8_t* dst, T& absmax_block, int32_t block_idx, int32_t numel) { + float local_absmax = 0.0f; + + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size; + int32_t dst_offset = block_idx * block_size / 2; + + for (int32_t idx = 0; idx < block_len; idx++) { + const float v = static_cast(src[src_offset + idx]); + local_absmax = fmaxf(local_absmax, fabsf(v)); + } + + absmax_block = static_cast(local_absmax); + const float reciprocal_absmax = local_absmax ? 1.0f / local_absmax : 0.0f; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const float v0 = static_cast(src[src_offset + idx]) * reciprocal_absmax; + const uint8_t vi0 = QuantizeOneBnb4(v0); + + const float v1 = (idx + 1 < block_len) ? static_cast(src[src_offset + idx + 1]) * reciprocal_absmax : 0; + const uint8_t vi1 = QuantizeOneBnb4(v1); + + dst[dst_offset + idx / 2] = (vi0 << 4) | vi1; + } +} + +static float fp4_qaunt_map[16] = {0.00000000f, 5.208333333e-03f, 0.66666667f, 1.00000000f, + 0.33333333f, 0.50000000f, 0.16666667f, 0.25000000f, + -0.00000000f, -5.208333333e-03f, -0.66666667f, -1.00000000f, + -0.33333333f, -0.50000000f, -0.16666667f, -0.25000000f}; + +static float nf4_qaunt_map[16] = {-1.0f, + -0.6961928009986877f, + -0.5250730514526367f, + -0.39491748809814453f, + -0.28444138169288635f, + -0.18477343022823334f, + -0.09105003625154495f, + 0.0f, + 0.07958029955625534f, + 0.16093020141124725f, + 0.24611230194568634f, + 0.33791524171829224f, + 0.44070982933044434f, + 0.5626170039176941f, + 0.7229568362236023f, + 1.0f}; + +template +FORCEINLINE T DequantizeOneBnb4(uint8_t x) { + if constexpr (DATA_TYPE == FP4) + return static_cast(fp4_qaunt_map[x]); + else + return static_cast(nf4_qaunt_map[x]); +} + +template +FORCEINLINE void DequantizeBlockBnb4(const uint8_t* src, T* dst, T absmax_block, int32_t block_idx, int32_t numel) { + int32_t block_len = std::min(block_size, numel - block_idx * block_size); + int32_t src_offset = block_idx * block_size / 2; + int32_t dst_offset = block_idx * block_size; + + for (int32_t idx = 0; idx < block_len; idx += 2) { + const uint8_t val = src[src_offset + idx / 2]; + + dst[dst_offset + idx] = DequantizeOneBnb4(val >> 4) * absmax_block; + if (idx + 1 < block_len) dst[dst_offset + idx + 1] = DequantizeOneBnb4(val & 0xF) * absmax_block; + } +} + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h new file mode 100644 index 0000000000000..5ddb77e5b5ee3 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h @@ -0,0 +1,143 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "blockwise_quant_block_bnb4.h" + +#include + +#include "core/common/safeint.h" +#include "core/framework/float16.h" +#include "core/platform/threadpool.h" +#include + +namespace onnxruntime { +namespace contrib { + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + QuantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define QuantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + QuantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void QuantizeBlockwiseBnb4( + uint8_t* dst, // shape: [(N * K + 1) / 2] + const T* src, // shape: [N, K] + T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + QuantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + QuantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + QuantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + QuantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + QuantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef QuantizeBlockwiseBn4DataTyped + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + int32_t numel = N * K; + int32_t total_block_count = (numel + block_size - 1) / block_size; + + concurrency::ThreadPool::TryBatchParallelFor( + thread_pool, + total_block_count, + [&](ptrdiff_t block_idx) { + DequantizeBlockBnb4( + src, + dst, + absmax[block_idx], + static_cast(block_idx), + numel); + }, + 0); +} + +#define DequantizeBlockwiseBn4DataTyped(block_size, quant_type) \ + if (quant_type == FP4) \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); \ + else \ + DequantizeBlockwiseBnb4(dst, src, absmax, N, K, thread_pool); + +template +void DequantizeBlockwiseBnb4( + T* dst, // shape: [N, K] + const uint8_t* src, // shape: [(N * K + 1) / 2)] + const T* absmax, // shape: [(N * K + block_size - 1) / block_size] + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K, + onnxruntime::concurrency::ThreadPool* thread_pool) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + if (block_size == 16) { + DequantizeBlockwiseBn4DataTyped(16, quant_type); + } else if (block_size == 32) { + DequantizeBlockwiseBn4DataTyped(32, quant_type); + } else if (block_size == 64) { + DequantizeBlockwiseBn4DataTyped(64, quant_type); + } else if (block_size == 128) { + DequantizeBlockwiseBn4DataTyped(128, quant_type); + } else if (block_size == 256) { + DequantizeBlockwiseBn4DataTyped(256, quant_type); + } else { + ORT_NOT_IMPLEMENTED("only block size 16, 32, 64, 128, 256 are supported."); + } +} + +#undef DequantizeBlockwiseBn4DataTyped + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..2f3ede49c3650 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_bnb4.cc @@ -0,0 +1,109 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "core/providers/common.h" +#include "dequantize_blockwise_bnb4.h" +#include "core/mlas/inc/mlas.h" + +namespace onnxruntime { +namespace contrib { + +class MatMulBnb4 final : public OpKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : OpKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status Compute(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +Status MatMulBnb4::Compute(OpKernelContext* ctx) const { + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const float* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const float* absmax_data = absmax->Data(); + + AllocatorPtr allocator; + auto status = ctx->GetTempSpaceAllocator(&allocator); + ORT_RETURN_IF_ERROR(status); + auto tmp_b_data_ptr = IAllocator::MakeUniquePtr(allocator, SafeInt(K_) * N_); + DequantizeBlockwiseBnb4( + tmp_b_data_ptr.get(), + b_quant_data, + absmax_data, + static_cast(block_size_), + static_cast(quant_type_), + static_cast(N_), + static_cast(K_), + thread_pool); + + constexpr bool transa = false; + constexpr bool transb = true; + TensorShape b_shape({N_, K_}); + MatMulComputeHelper helper; + ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* y = ctx->Output(0, helper.OutputShape()); + + // Bail out early if the output is going to be empty + if (y->Shape().Size() == 0) return Status::OK(); + + auto* y_data = y->MutableData(); + + const size_t max_len = helper.OutputOffsets().size(); + const size_t M = static_cast(helper.M()); + const size_t N = static_cast(helper.N()); + const size_t K = static_cast(helper.K()); + const size_t lda = helper.Lda(transa); + const size_t ldb = helper.Ldb(transb); + + // TODO: implement with native kernel + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = false; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = tmp_b_data_ptr.get() + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = 1.f; + data[i].beta = 0.0f; + } + MlasGemmBatch(CblasNoTrans, CblasTrans, M, N, K, data.data(), max_len, thread_pool); + + return Status::OK(); +} + +ONNX_OPERATOR_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + kCpuExecutionProvider, + KernelDefBuilder() + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc index e86a12d9fb873..4e103c2556a7a 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.cc @@ -20,20 +20,29 @@ namespace contrib { kCpuExecutionProvider, \ KernelDefBuilder() \ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ - SkipLayerNorm); + SkipLayerNorm); \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + SkipSimplifiedLayerNormalization, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + KernelDefBuilder() \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + SkipLayerNorm); REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(double) -template -SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) +template +SkipLayerNorm::SkipLayerNorm(const OpKernelInfo& op_kernel_info) : OpKernel(op_kernel_info) { ORT_ENFORCE(op_kernel_info.GetAttr("epsilon", &epsilon_).IsOK()); ORT_ENFORCE(epsilon_ >= 0); } -template -Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { +template +Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { const Tensor* input = p_ctx->Input(0); const Tensor* skip = p_ctx->Input(1); const Tensor* gamma = p_ctx->Input(2); @@ -102,10 +111,16 @@ Status SkipLayerNorm::Compute(OpKernelContext* p_ctx) const { } mean = mean / hidden_size; - mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + if (simplified) { + mean_square = sqrt(mean_square / hidden_size + epsilon_); + } else { + mean_square = sqrt(mean_square / hidden_size - mean * mean + epsilon_); + } for (int64_t h = 0; h < hidden_size; h++) { - if (nullptr == beta_data) { + if (simplified) { + p_output[h] = p_output[h] / mean_square * gamma_data[h]; + } else if (nullptr == beta_data) { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h]; } else { p_output[h] = (p_output[h] - mean) / mean_square * gamma_data[h] + beta_data[h]; diff --git a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h index 7723541cb6b18..69edf4609e340 100644 --- a/onnxruntime/contrib_ops/cpu/skip_layer_norm.h +++ b/onnxruntime/contrib_ops/cpu/skip_layer_norm.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { -template +template class SkipLayerNorm final : public OpKernel { public: SkipLayerNorm(const OpKernelInfo& op_kernel_info); diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc new file mode 100644 index 0000000000000..b4b5dac1fbe19 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/rotary_embedding_helper.h" +#include "contrib_ops/cuda/bert/rotary_embedding.h" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; +using namespace ::onnxruntime::common; +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::contrib::rotary_embedding_helper; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + RotaryEmbedding, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("M", DataTypeImpl::GetTensorType()), \ + RotaryEmbedding); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(MLFloat16) + +template +RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { + scale = info.GetAttrOrDefault("scale", 1.0); + interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); +} + +template +Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { + const Tensor* input = context->Input(0); + const Tensor* position_ids = context->Input(1); + const Tensor* cos_cache = context->Input(2); + const Tensor* sin_cache = context->Input(3); + + RotaryParameters parameters = {}; + ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs(input, + position_ids, + cos_cache, + sin_cache, + ¶meters)); + + Tensor* output = context->Output(0, input->Shape()); + + if (parameters.sequence_length > parameters.max_sequence_length) { + // Launch update_cos_sin_cache kernel with scale + ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported"); + } + + // Launch rotary embedding kernel + typedef typename ToCudaType::MappedType CudaT; + auto& device_prop = GetDeviceProp(); + return LaunchRotaryEmbeddingKernel( + Stream(context), + reinterpret_cast(output->template MutableData()), + reinterpret_cast(input->template Data()), + position_ids->Data(), + reinterpret_cast(cos_cache->template Data()), + reinterpret_cast(sin_cache->template Data()), + parameters.batch_size, + parameters.sequence_length, + parameters.num_heads, + parameters.head_size, + parameters.max_sequence_length, + parameters.position_ids_format, + interleaved, + device_prop.maxThreadsPerBlock); + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h new file mode 100644 index 0000000000000..6dab2ad56749e --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class RotaryEmbedding final : public CudaKernel { + public: + RotaryEmbedding(const OpKernelInfo& info); + Status ComputeInternal(OpKernelContext* context) const override; + + protected: + float scale; + bool interleaved; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu new file mode 100644 index 0000000000000..c54e72dcfce13 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -0,0 +1,141 @@ +/* +Copyright (c) Microsoft Corporation. +Licensed under the MIT License. +*/ + +/* +Kernel implementation for rotary embeddings. +*/ + +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "contrib_ops/cuda/bert/rotary_embedding_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH + const T* input, // BxSxNxH + const T* cos_cache, // Mx(H/2) + const T* sin_cache, // Mx(H/2) + const int64_t* position_ids, // (1) or BxS + const int sequence_length, + const int num_heads, + const int head_size, + const int position_ids_format, + const bool interleaved) { + // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length + // Use .x in innermost loop to access global memory efficiently + + const int b = blockIdx.z; + const int s = blockIdx.y; + const int n = blockIdx.x; + + const int i = threadIdx.x; + + const int block_offset = b * sequence_length * num_heads + s * num_heads + n; + const int data_offset = block_offset * head_size; + + const T* input_data = input + data_offset; + T* output_data = output + data_offset; + + // Cache is (M, H/2) + const int half_head_size = head_size / 2; + const int position_id = (position_ids_format == 0) ? \ + static_cast(position_ids[0]) + s \ + : static_cast(position_ids[b * sequence_length + s]); + const int cache_offset = position_id * half_head_size; + const T* cos_data = cos_cache + cache_offset; + const T* sin_data = sin_cache + cache_offset; + + int cache_idx = 0; + T sign = 0; + int j = 0; + if (interleaved) { + cache_idx = (i / 2) % half_head_size; + sign = (i % 2 == 0) ? -1 : 1; + j = (i % 2 == 0) ? i+1 : i-1; // i - sign + } else { + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; + } + output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; +} + + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + + // Note: Current implementation assumes head_size <= max_threads_per_block + // because head_size is currently large for LLaMA-2. For smaller head_size + // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` + // instead. This will require kernel changes to support. + + assert(head_size <= max_threads_per_block); + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved + ); + + return CUDA_CALL(cudaGetLastError()); +} + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + float* output, + const float* input, + const int64_t* position_ids, + const float* cos_cache, + const float* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +template Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + half* output, + const half* input, + const int64_t* position_ids, + const half* cos_cache, + const half* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h new file mode 100644 index 0000000000000..29ff48a8ad0fb --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -0,0 +1,31 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status LaunchRotaryEmbeddingKernel( + cudaStream_t stream, + T* output, + const T* input, + const int64_t* position_ids, + const T* cos_cache, + const T* sin_cache, + const int batch_size, + const int sequence_length, + const int num_heads, + const int head_size, + const int max_sequence_length, + const int position_ids_format, + const bool interleaved, + const int max_threads_per_block); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 52ff285539360..e762a80cb0e2f 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -91,6 +91,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -116,6 +118,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Inverse); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulNBits); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, MatMulBnb4); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, MatMulBnb4); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Trilu); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, UnfoldTensor); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, DynamicTimeWarping); @@ -250,6 +254,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -275,6 +281,8 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu new file mode 100644 index 0000000000000..e58723f0b31e1 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cu @@ -0,0 +1,129 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream) { + ORT_ENFORCE( + quant_type == FP4 || quant_type == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + + T host_quant_map[16]; + switch (quant_type) { + case FP4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(fp4_qaunt_map[i]); + break; + case NF4: + for (int i = 0; i < 16; i++) host_quant_map[i] = static_cast(nf4_qaunt_map[i]); + break; + } + CUDA_CALL_THROW(cudaMemcpyAsync(quant_map_buffer, host_quant_map, sizeof(T) * 16, cudaMemcpyHostToDevice, stream)); + + return Status::OK(); +} + +template Status SetBnbQuantMap(int quant_type, float* quant_map_buffer, cudaStream_t stream); + +template Status SetBnbQuantMap(int quant_type, half* quant_map_buffer, cudaStream_t stream); + +template +__global__ void kDequantizeBlockwise( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + const int block_size, + const int n) { + const int n_load = (gridDim.x * TILE_SIZE); + int valid_items_load = 0; + int valid_items_store = 0; + const int base_idx = (blockIdx.x * TILE_SIZE); + + T vals[NUM_PER_TH * 2]; + uint8_t qvals[NUM_PER_TH]; + T local_abs_max = T(0.0f); + + typedef cub::BlockLoad LoadChar; + typedef cub::BlockStore StoreT; + + __shared__ typename LoadChar::TempStorage loadchar; + __shared__ typename StoreT::TempStorage storet; + + for (unsigned int i = base_idx; i < n_load; i += gridDim.x * TILE_SIZE) { + valid_items_load = (n + 1) / 2 - i > TILE_SIZE ? TILE_SIZE : (n + 1) / 2 - i; + valid_items_store = n - i * 2 > TILE_SIZE * 2 ? TILE_SIZE * 2 : n - i * 2; + + local_abs_max = __ldg(&absmax[(i + threadIdx.x * NUM_PER_TH) / (block_size)]); + + __syncthreads(); + LoadChar(loadchar).Load(&(quant_data[i]), qvals, valid_items_load, 128); + + #pragma unroll NUM_PER_TH + for (int j = 0; j < NUM_PER_TH; j++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + vals[j * 2] = quant_map[qvals[j] >> 4] * local_abs_max; + vals[j * 2 + 1] = quant_map[qvals[j] & 0x0F] * local_abs_max; + #else + // half multiplication not supported + vals[j * 2] = static_cast(static_cast(quant_map[qvals[j] >> 4]) * static_cast(local_abs_max)); + vals[j * 2 + 1] = + static_cast(static_cast(quant_map[qvals[j] & 0x0F]) * static_cast(local_abs_max)); + #endif + } + + __syncthreads(); + StoreT(storet).Store(&(output[i * 2]), vals, valid_items_store); + } +} + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream) { + int tile_size = 1024; + kDequantizeBlockwise<<<(numel + tile_size - 1) / tile_size, 64, 0, stream>>>( + quant_map, + output, + quant_data, + absmax, + block_size / 2, + numel); + + return Status::OK(); +} + +template Status DequantizeBnb4( + const float* quant_map, + float* output, + const uint8_t* quant_data, + const float* absmax, + int block_size, + int numel, + cudaStream_t stream); + +template Status DequantizeBnb4( + const half* quant_map, + half* output, + const uint8_t* quant_data, + const half *absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh new file mode 100644 index 0000000000000..4aef3ab699f9c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +Status SetBnbQuantMap(int quant_type, T* quant_map_buffer, cudaStream_t stream); + +template +Status DequantizeBnb4( + const T* quant_map, + T* output, + const uint8_t* quant_data, + const T* absmax, + int block_size, + int numel, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc new file mode 100644 index 0000000000000..bd5b6e0a8a1ce --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cc @@ -0,0 +1,144 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/common/safeint.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/matmul_helper.h" +#include "contrib_ops/cpu/quantization/blockwise_quant_block_bnb4.h" +#include "matmul_bnb4.cuh" +#include "dequantize_blockwise_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +using namespace onnxruntime::cuda; + +template +class MatMulBnb4 final : public CudaKernel { + public: + MatMulBnb4(const OpKernelInfo& info) : CudaKernel(info) { + ORT_ENFORCE(Status::OK() == info.GetAttr("K", &K_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("N", &N_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("block_size", &block_size_)); + ORT_ENFORCE(Status::OK() == info.GetAttr("quant_type", &quant_type_)); + ORT_ENFORCE( + quant_type_ == FP4 || quant_type_ == NF4, + "Invalid quant_type, only 0 (FP4) and 1 (NF4) are supported."); + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + int64_t K_; + int64_t N_; + int64_t block_size_; + int64_t quant_type_; +}; + +template +Status MatMulBnb4::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* a = ctx->Input(0); + const Tensor* b_quant = ctx->Input(1); + const Tensor* absmax = ctx->Input(2); + + const auto* a_data = a->Data(); + const uint8_t* b_quant_data = b_quant->Data(); + const auto* absmax_data = absmax->Data(); + + typedef typename ToCudaType::MappedType CudaT; + + // TODO: find a better way to create the quant_map without using a buffer + // don't want to use malloc directly so asking from the caller + // can create a __device__ static array for float but doesn't work for half + IAllocatorUniquePtr quant_map_buffer = GetScratchBuffer(16, ctx->GetComputeStream()); + auto* quant_map_buffer_data = quant_map_buffer.get(); + ORT_RETURN_IF_ERROR(SetBnbQuantMap( + SafeInt(quant_type_), + reinterpret_cast(quant_map_buffer_data), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + constexpr bool transa = false; + constexpr bool transb = true; + MatMulComputeHelper helper; + TensorShape b_shape({N_, K_}); + ORT_RETURN_IF_ERROR( + helper.Compute(a->Shape(), b_shape, transa, transb)); + + Tensor* Y = ctx->Output(0, helper.OutputShape()); + // Bail out early if the output is going to be empty + if (Y->Shape().Size() == 0) return Status::OK(); + + bool is_4bit_done = TryMatMulBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(Y->MutableData()), + reinterpret_cast(a_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(helper.M()), + SafeInt(helper.N()), + SafeInt(helper.K()), + SafeInt(block_size_), + static_cast(ctx->GetComputeStream()->GetHandle())); + + if (!is_4bit_done) { + IAllocatorUniquePtr b_dequant_ptr = GetScratchBuffer(N_ * K_, ctx->GetComputeStream()); + auto* b_dequant_data = b_dequant_ptr.get(); + ORT_RETURN_IF_ERROR(DequantizeBnb4( + reinterpret_cast(quant_map_buffer_data), + reinterpret_cast(b_dequant_data), + b_quant_data, + reinterpret_cast(absmax_data), + SafeInt(block_size_), + SafeInt(N_ * K_), + static_cast(ctx->GetComputeStream()->GetHandle()))); + + const CudaT alpha = ToCudaType::FromFloat(1.f); + const CudaT zero = ToCudaType::FromFloat(0.f); + + CUBLAS_RETURN_IF_ERROR(cublasGemmHelper( + GetCublasHandle(ctx), + CUBLAS_OP_T, + CUBLAS_OP_N, + SafeInt(helper.N()), + SafeInt(helper.M()), + SafeInt(helper.K()), + &alpha, + reinterpret_cast(b_dequant_data), + SafeInt(K_), + reinterpret_cast(a_data), + helper.Lda(transa), + &zero, + reinterpret_cast(Y->MutableData()), + helper.Ldc(), + GetDeviceProp())); + } + + return Status::OK(); +} + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + float, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +ONNX_OPERATOR_TYPED_KERNEL_EX( + MatMulBnb4, + kMSDomain, + 1, + MLFloat16, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + MatMulBnb4); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu new file mode 100644 index 0000000000000..1d9aa75ff3701 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cu @@ -0,0 +1,192 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include + +#include +#include +#include +#include "matmul_bnb4.cuh" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define num_values_4bit 32 +template +__global__ void kgemm_4bit_inference_naive( + int M, + int N, + int K, + const T* __restrict__ A, + const uint8_t* B, + const T* absmax, + const T* datatype, + T* out, + int lda, + int ldb, + int ldc, + int block_size) { + // per threadblock: + // load step-by-step in chunks of [32,warps]: 1x32 * [32,warps] -> [1,warps] + // 4 warps -> 4 loads per iter + // 1x32 * 32x4 -> 1x4 outputs per thread block + typedef cub::WarpReduce WarpReduce; + __shared__ typename WarpReduce::TempStorage temp_storage[THREADS / 32]; + + const int warp_idx = threadIdx.x / 32; + const int warp_lane = threadIdx.x % 32; + const int row_B = (THREADS / 32) * blockIdx.x + warp_idx; + const int num_values_8bit = num_values_4bit / 2; + float local_C = 0.0f; + + uint8_t local_B_4bit[num_values_8bit]; + T local_B[num_values_4bit / 4]; + T local_A[num_values_4bit / 4]; + __shared__ T quant_map[16]; + T local_absmax = T(0.0f); + + for (int i = threadIdx.x; i < 16; i++) quant_map[i] = T(datatype[i]); + __syncthreads(); + + // A: [1, K] + // B: [N, K] + for (int inner_idx = warp_lane * num_values_4bit; inner_idx < K; inner_idx += 32 * num_values_4bit) { + int inner_idx_halved = inner_idx / 2; + int offset_B = ldb * row_B; + int absidx = ((2 * offset_B) + inner_idx) / block_size; + local_absmax = __ldg(&(absmax[absidx])); + + if (row_B < N) { + if ((inner_idx_halved + num_values_8bit) < (K / 2)) { + // this is the most important for performance considerations + reinterpret_cast(local_B_4bit)[0] = + reinterpret_cast(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) + if ((inner_idx_halved) + j < (K / 2)) + local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; + else + local_B_4bit[j] = 0b01110111; + } + } else { + #pragma unroll + for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; + } + + for (int i = 0; i < 4; i++) { + #pragma unroll + for (int k = 0; k < num_values_8bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; + local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; + #else + // half multiplication not supported + local_B[k * 2] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4]) * + static_cast(local_absmax)); + local_B[k * 2 + 1] = + static_cast(static_cast(quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F]) * + static_cast(local_absmax)); + #endif + } + + if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { + // this is also relatively important for performance + if (BITS == 16) { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 4) + i]; + } else { + reinterpret_cast(local_A)[0] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; + reinterpret_cast(local_A)[1] = + reinterpret_cast(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; + } + } else { + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + if (inner_idx + (i * num_values_4bit / 4) + k < K) + local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; + else + local_A[k] = T(0.0f); + } + } + + // accumulate in float; small performance hit for Ampere, but lower error for outputs + #pragma unroll + for (int k = 0; k < num_values_4bit / 4; k++) { + #if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 530 + local_C += static_cast(local_A[k] * local_B[k]); + #else + // half multiplication not supported + local_C += static_cast(local_A[k]) * static_cast(local_B[k]); + #endif + } + } + } + + local_C = WarpReduce(temp_storage[warp_idx]).Sum(local_C); + + if (row_B < N && warp_lane == 0) out[row_B] = T(local_C); +} + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream) { + if (k % block_size != 0 || m > 1) { + return false; + } + // supported block_sizes are [4096, 2048, 1024, 512, 256, 128, 64, 32] + if (block_size % 32 != 0 || block_size > 4096) { + return false; + } + + int lda = k; + int ldb = (k + 1) / 2; + int ldc = n; + int num_blocks = (n + 3) / 4; + + constexpr int bits = std::is_same_v ? 16 : 32; + kgemm_4bit_inference_naive<<>>( + m, n, k, a_data, b_data_quant, absmax, quant_map, output, lda, ldb, ldc, block_size); + + return true; +} + +template bool TryMatMulBnb4( + const float* quant_map, + float* output, + const float* a_data, + const uint8_t* b_data_quant, + const float* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +template bool TryMatMulBnb4( + const half* quant_map, + half* output, + const half* a_data, + const uint8_t* b_data_quant, + const half* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh new file mode 100644 index 0000000000000..743234282fbf3 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/matmul_bnb4.cuh @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +bool TryMatMulBnb4( + const T* quant_map, + T* output, + const T* a_data, + const uint8_t* b_data_quant, + const T* absmax, + int m, + int n, + int k, + int block_size, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index 6a82b3fcc734d..655d5014f3d60 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -22,6 +22,14 @@ #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #endif // ARM #endif // Linux @@ -138,6 +146,9 @@ void CPUIDInfo::ArmLinuxInit() { is_hybrid_ = cpuinfo_get_uarchs_count() > 1; has_arm_neon_dot_ = cpuinfo_has_arm_neon_dot(); has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); + has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); + has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); is_armv8_narrow_ld_.resize(core_cnt, false); @@ -162,6 +173,10 @@ void CPUIDInfo::ArmLinuxInit() { pytorch_cpuinfo_init_ = false; has_arm_neon_dot_ = ((getauxval(AT_HWCAP) & HWCAP_ASIMDDP) != 0); has_fp16_ |= has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + #endif } @@ -256,6 +271,9 @@ void CPUIDInfo::ArmWindowsInit() { has_arm_neon_dot_ = (IsProcessorFeaturePresent(PF_ARM_V82_DP_INSTRUCTIONS_AVAILABLE) != 0); has_fp16_ |= has_arm_neon_dot_; + /* TODO: implement them when hw+sw is available for testing these features */ + has_arm_neon_i8mm_ = false; + has_arm_sve_i8mm_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index 386db347c669d..a15c75104b83a 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -28,6 +28,8 @@ class CPUIDInfo { // ARM bool HasArmNeonDot() const { return has_arm_neon_dot_; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } uint32_t GetCurrentCoreIdx() const; @@ -121,6 +123,8 @@ class CPUIDInfo { bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/framework/graph_partitioner.cc b/onnxruntime/core/framework/graph_partitioner.cc index 1b492a3561396..b028596fe4e6d 100644 --- a/onnxruntime/core/framework/graph_partitioner.cc +++ b/onnxruntime/core/framework/graph_partitioner.cc @@ -13,7 +13,9 @@ #include "core/framework/kernel_registry_manager.h" #include "core/framework/kernel_registry.h" #include "core/graph/function.h" +#include "core/graph/function_utils.h" #include "core/graph/graph_viewer.h" +#include "core/graph/model.h" // uncomment this line to count non-CUDA ops in ONNX domain // #define COUNT_NON_CUDA_OPS @@ -129,6 +131,21 @@ struct GetCapabilityForEPParams { std::reference_wrapper debug_graph_fn; #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) }; + +auto get_capabilities = [](const IExecutionProvider& ep, + const GraphViewer& graph_viewer, + const IExecutionProvider::IKernelLookup& kernel_lookup) { + auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); + + // In theory an EP could return an empty capability. Remove those. + capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), + [](const std::unique_ptr& capability) { + return !capability || !capability->sub_graph; + }), + capabilities.end()); + + return capabilities; +}; } // namespace static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { @@ -143,21 +160,6 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) - auto get_capabilities = [](const IExecutionProvider& ep, - const GraphViewer& graph_viewer, - const IExecutionProvider::IKernelLookup& kernel_lookup) { - auto capabilities = ep.GetCapability(graph_viewer, kernel_lookup); - - // In theory an EP could return an empty capability. Remove those. - capabilities.erase(std::remove_if(capabilities.begin(), capabilities.end(), - [](const std::unique_ptr& capability) { - return !capability || !capability->sub_graph; - }), - capabilities.end()); - - return capabilities; - }; - const auto& kernel_registry_mgr = params.kernel_registry_mgr.get(); const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); const KernelLookup kernel_lookup{ep_type, @@ -239,6 +241,26 @@ static Status GetCapabilityForEP(const GetCapabilityForEPParams& params) { } #if !defined(ORT_MINIMAL_BUILD) + +// This function queries the capabilities for a given EP, but it does not assign the nodes. +// It also does not perform layout transformation. This will be done during normal partitioning. +static Status GetCapabilityForEPForAotInlining(const GraphViewer& graph_viewer, + const KernelRegistryManager& kernel_registry_mgr, + const IExecutionProvider& current_ep, + std::vector>& capabilities) { + const auto& ep_type = current_ep.Type(); + + const auto kernel_registries_for_ep = kernel_registry_mgr.GetKernelRegistriesByProviderType(ep_type); + const KernelLookup kernel_lookup{ep_type, + kernel_registries_for_ep, + kernel_registry_mgr.GetKernelTypeStrResolver()}; + + // TODO: Provide EP with a capability to look inside the functions. + capabilities = get_capabilities(current_ep, graph_viewer, kernel_lookup); + + return Status::OK(); +} + /** * Check if a node can be placed on a specific provider. * Do nothing if the node is already assigned @@ -518,7 +540,7 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { // successfully inlined, we re-run the partitioner on the modified graph. // NOTE: Inlining the function will change the nodes in the Graph instance, so we can't do that while iterating // using graph.Nodes(). - std::vector nodes_to_inline; + InlinedVector nodes_to_inline; for (auto& node : graph.Nodes()) { if (node.GetExecutionProviderType().empty() && node.CanBeInlined()) { nodes_to_inline.push_back(&node); @@ -533,6 +555,85 @@ static Status InlineNodes(Graph& graph, bool& modified_graph) { return Status::OK(); } +static Status InlineFunctionsAOTImpl(const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_mgr, + Graph& graph, + InlinedHashSet& not_inlined, + size_t& inlined_count) { + // handle testing edge case where optimizers or constant lifting results in graph with no nodes. + // doing it here saves all providers checking for this in GetCapability + if (graph.NumberOfNodes() == 0) { + return Status::OK(); + } + + for (auto& node : graph.Nodes()) { + for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) { + Graph* subgraph = entry.second; + // we pass through the FuncManager from the top level graph + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_mgr, + *subgraph, + not_inlined, + inlined_count)); + } + } + + // Gather the candidates + InlinedVector inline_candidates; + for (auto& node : graph.Nodes()) { + if (node.CanBeInlined()) { + inline_candidates.push_back(node.Index()); + } + } + + if (inline_candidates.empty()) { + return Status::OK(); + } + + // Find out all the nodes that are already taken + const GraphViewer graph_viewer(graph); + + InlinedHashSet claimed_by_ep; + for (const auto& ep : execution_providers) { + std::vector> capabilities; + ORT_RETURN_IF_ERROR(GetCapabilityForEPForAotInlining(graph_viewer, kernel_registry_mgr, *ep, capabilities)); + for (auto& capability : capabilities) { + const auto& nodes = capability->sub_graph->nodes; + if (nodes.size() == 1) { + // Single node capability. + ORT_IGNORE_RETURN_VALUE(claimed_by_ep.insert(nodes[0])); + } else { + // Make sure none is claimed by other EPs mirroring the logic in PartitionOnnxFormatModelImpl. + if (std::all_of(nodes.cbegin(), nodes.cend(), [&claimed_by_ep](NodeIndex node_index) { + return claimed_by_ep.count(node_index) == 0; + })) { + claimed_by_ep.insert(nodes.cbegin(), nodes.cend()); + } + } + } + } + + // TODO: Insert version check. We need to collect all the versions + // that imported by the model. If the version is not supported by + // the model, we can not inline it. + + for (auto node_index : inline_candidates) { + auto* node = graph.GetNode(node_index); + if (node != nullptr) { + if (claimed_by_ep.count(node_index) == 0) { + ORT_RETURN_IF_ERROR(graph.InlineFunction(*node)); + ++inlined_count; + } else { + // OpType is the same as function name. + auto function_id = function_utils::GetFunctionIdentifier(node->Domain(), node->OpType()); + ORT_IGNORE_RETURN_VALUE(not_inlined.insert(std::move(function_id))); + } + } + } + + return Status::OK(); +} + static Status PartitionOnnxFormatModel(const PartitionParams& partition_params, GraphPartitioner::Mode mode, const ExecutionProviders& execution_providers, KernelRegistryManager& kernel_registry_manager) { @@ -693,6 +794,35 @@ static Status PartitionOrtFormatModel(const PartitionParams& partition_params, return Status::OK(); } +#ifndef ORT_MINIMAL_BUILD + +Status GraphPartitioner::InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager) const { + auto& graph = model.MainGraph(); + InlinedHashSet not_inlined; + do { + size_t inlined_count = 0; + ORT_RETURN_IF_ERROR(InlineFunctionsAOTImpl(execution_providers, + kernel_registry_manager, + graph, + not_inlined, + inlined_count)); + + if (inlined_count == 0) { + break; + } + + ORT_RETURN_IF_ERROR(graph.Resolve()); + } while (true); + + model.RemoveLocalFunctionsProtos(not_inlined); + + return Status::OK(); +} + +#endif + Status GraphPartitioner::Partition(Graph& graph, FuncManager& func_mgr, const layout_transformation::TransformLayoutFunction& transform_layout_function, Mode mode, diff --git a/onnxruntime/core/framework/graph_partitioner.h b/onnxruntime/core/framework/graph_partitioner.h index 36a27e906c651..c1fa46de9145d 100644 --- a/onnxruntime/core/framework/graph_partitioner.h +++ b/onnxruntime/core/framework/graph_partitioner.h @@ -12,6 +12,7 @@ namespace onnxruntime { class ExecutionProviders; class KernelRegistryManager; +class Model; class GraphPartitioner { public: @@ -33,6 +34,26 @@ class GraphPartitioner { Mode mode = Mode::kNormal, const layout_transformation::DebugGraphFn& debug_graph_fn = {}) const; +#ifndef ORT_MINIMAL_BUILD + /// + // Ahead of Time Function inlining. The main purpose of the function is to inline as many + // functions as possible and delete locally defined functions to reduce the size of the model. + // This would make other optimizations to be more effective. + // + // This function performs GetCapability on the graph and its subgraphs bottom up + // and inlines any functions that are not claimed by any of the execution providers. + // This function does not attempt to run layout transformation, and it does not assign EPs. + // The latter will be done by graph partitioning after Level1 optimizations are done. + /// + /// model instance + /// execution providers considered + /// registry manager + /// + Status InlineFunctionsAOT(Model& model, + const ExecutionProviders& execution_providers, + const KernelRegistryManager& kernel_registry_manager) const; +#endif + private: ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(GraphPartitioner); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 3a75b29ffe3c7..76c3f8716ff09 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -946,7 +946,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( OpSchema::Optional) .Input(4, "key_padding_mask", - "Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)", + "Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), " + "or (batch_size, sequence_length, total_sequence_length)", "M", OpSchema::Optional) .Input(5, @@ -1129,6 +1130,49 @@ ONNX_MS_OPERATOR_SET_SCHEMA( DecoderAttentionTypeAndShapeInference(ctx); })); +constexpr const char* RotaryEmbedding_ver1_doc = R"DOC( +RotaryEmbedding is the implementation of rotary positional embeddings (RoPE). The positions are represented as rotation matrices +that are multiplied to query and key before the inner product of query and key is taken. +)DOC"; +ONNX_MS_OPERATOR_SET_SCHEMA( + RotaryEmbedding, 1, + OpSchema() + .SetDoc(RotaryEmbedding_ver1_doc) + .Attr("scale", + "Custom scale will be used if specified. Default value is 1.0", + AttributeProto::FLOAT, + OPTIONAL_VALUE) + .Attr("interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) + .Input(0, + "input", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .Input(1, + "position_ids", + "1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)", + "M") + .Input(2, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Input(3, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T") + .Output(0, + "output", + "3D tensor with shape (batch_size, sequence_length, hidden_size)", + "T") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 0, 0); + propagateShapeFromInputToOutput(ctx, 0, 0); + })); + constexpr const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, @@ -1500,4 +1544,4 @@ ONNX_MS_OPERATOR_SET_SCHEMA( })); } // namespace contrib -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 5e5eee568fa21..681a728f823da 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -3239,6 +3239,41 @@ Input zero_points is stored as uint8_t. If bits <= 4, two zero points are stored MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); }); + static const char* MatMulBnb4_ver1_doc = R"DOC( +MatMulBnb4 is a MatMul with weight quantized with 4 bits using either FP4 or NF4 data type (https://arxiv.org/pdf/2305.14314.pdf). It does Matrix Multiplication like MatMul (https://github.com/onnx/onnx/blob/main/docs/Operators.md#matmul) with differences: + 1. Input B is a 2D constant Matrix. Its input feature count and output feature count are specified by attribute 'K' and 'N'. + 2. Input B is quantized with 4 bits with quantization data type specified by attribute 'quant_type'. It is transposed, flattened and quantized blockwisely with block size specified by attribute 'block_size'. + And block_size is not an arbitrary number and must be a power of 2 and not smaller than 16, like 16, 32, 64, 128,.. + 3. Input B's quantization constants or scales are specified by input 'absmax'. + +Input B is stored as uint8_t with shape: [(N * K + 1) / 2]. +Input absmax is stored in same type as original type of B(float32, float16) with shape like: [(N * K + block_size - 1) / block_size]. + +)DOC"; + + ONNX_CONTRIB_OPERATOR_SCHEMA(MatMulBnb4) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetDoc(MatMulBnb4_ver1_doc) + .Attr("K", "size of each input feature", AttributeProto::INT) + .Attr("N", "size of each output feature", AttributeProto::INT) + .Attr("block_size", "number of groupsize used for weight quantization. It needs to be a power of 2 and not smaller than 16.", AttributeProto::INT) + .Attr("quant_type", "quantization data type. 0 for FP4, 1 for NF4.", AttributeProto::INT) + .Input(0, "A", "The input tensor, not quantized", "T1") + .Input(1, "B", "1-dimensional quantized data for weight", "T2") + .Input(2, "absmax", "quantization constants", "T1") + .Output(0, "Y", "tensor. The output tensor has the same rank as the input. ", "T1") + .TypeConstraint("T1", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float/half_float tensors.") + .TypeConstraint("T2", {"tensor(uint8)"}, "Constrain quantized weight types to uint8.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + propagateElemTypeFromInputToOutput(ctx, 0, 0); + // Shape inference + int64_t in_features = getAttribute(ctx, "K", -1); + int64_t out_features = getAttribute(ctx, "N", -1); + MatmulWithQuantWeightShapeInference(ctx, in_features, out_features); + }); + #ifdef ENABLE_ATEN ONNX_CONTRIB_OPERATOR_SCHEMA(ATen) .SetDomain(kPytorchAtenDomain) diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index afa5d101bbd8d..afaa380d6ac79 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -95,6 +95,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GatedRelativePositionBia class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RotaryEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); @@ -200,6 +201,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/graph/function_utils.cc b/onnxruntime/core/graph/function_utils.cc index 7477f48088a15..7b0a834a7ffc0 100644 --- a/onnxruntime/core/graph/function_utils.cc +++ b/onnxruntime/core/graph/function_utils.cc @@ -373,7 +373,8 @@ class Inliner { // Replace given name with a unique version of the name, and cache the // renaming-binding in current scope. void make_unique(std::string& name) { - auto new_name = prefix_ + name; + auto new_name{prefix_}; + new_name.append("_").append(name); auto& current_scope = rename_scopes_.back(); current_scope[name] = new_name; name = std::move(new_name); @@ -410,7 +411,7 @@ class Inliner { std::string rename_as = actuals.Get(i); if constexpr (isOutput) { if (rename_as.empty()) - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) @@ -420,7 +421,7 @@ class Inliner { std::string& formal = *formals.Mutable(i); std::string rename_as; if constexpr (isOutput) { - rename_as.assign(prefix_).append(formal); + rename_as.assign(prefix_).append("_").append(formal); } current_scope[formal] = rename_as; if (!rename_as.empty()) diff --git a/onnxruntime/core/graph/graph.cc b/onnxruntime/core/graph/graph.cc index 383c1d689d3c3..cab9467501f55 100644 --- a/onnxruntime/core/graph/graph.cc +++ b/onnxruntime/core/graph/graph.cc @@ -860,18 +860,18 @@ Status Node::LoadEdgesFromOrtFormat(const onnxruntime::fbs::NodeEdge& fbs_node_e } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) -void Node::Init(const std::string& name, - const std::string& op_type, - const std::string& description, - const std::vector& input_args, - const std::vector& output_args, +void Node::Init(std::string_view name, + std::string_view op_type, + std::string_view description, + gsl::span input_args, + gsl::span output_args, const NodeAttributes* attributes, - const std::string& domain) { + std::string_view domain) { name_ = name; op_type_ = op_type; description_ = description; - definitions_.input_defs = input_args; - definitions_.output_defs = output_args; + definitions_.input_defs.assign(input_args.begin(), input_args.end()); + definitions_.output_defs.assign(output_args.begin(), output_args.end()); domain_ = domain; can_be_saved_ = true; priority_ = 0; @@ -1145,7 +1145,8 @@ Graph::Graph(const Model& owning_model, IOnnxRuntimeOpSchemaCollectionPtr schema_registry, const logging::Logger& logger, bool strict_shape_type_inference) - : Graph(owning_model, graph_proto, domain_to_version, ir_version, schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} + : Graph(owning_model, graph_proto, domain_to_version, ir_version, + schema_registry, nullptr, nullptr, logger, strict_shape_type_inference) {} Graph::Graph(const Model& owning_model, GraphProto* graph_proto, const std::unordered_map& domain_to_version, Version ir_version, @@ -3261,8 +3262,8 @@ Node& Graph::AddNode(const std::string& name, gsl::span output_args, const NodeAttributes* attributes, const std::string& domain) { - std::vector inputs; - std::vector outputs; + InlinedVector inputs; + InlinedVector outputs; inputs.resize(input_args.size()); outputs.resize(output_args.size()); int i = 0; @@ -4019,69 +4020,100 @@ Node& Graph::FuseSubGraph(const IndexedSubGraph& sub_graph, return fused_node; } +Status Graph::AddConstantProtoAsInitializer(const ONNX_NAMESPACE::NodeProto& node_proto, + std::optional new_name) { + const gsl::not_null tensor{graph_proto_->add_initializer()}; + ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(node_proto, ModelPath(), *tensor, node_proto.output(0))); + + if (new_name.has_value()) { + tensor->set_name(std::string(new_name.value())); + } + + auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); + ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), + " conflicts with graph initializer. Check that the node names have been made unique."); + if (GetNodeArg(tensor->name()) == nullptr) { + TypeProto t{TypeProtoFromTensorProto(*tensor)}; + ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); + } + +#if !defined(DISABLE_SPARSE_TENSORS) + if (node_proto.attribute(0).type() == AttributeProto_AttributeType_SPARSE_TENSOR) { + ORT_IGNORE_RETURN_VALUE(sparse_tensor_names_.emplace(tensor->name())); + } +#endif + + return Status::OK(); +} + +Status Graph::InlineFunctionProto(const ONNX_NAMESPACE::FunctionProto& func_to_inline) { + auto to_node_arg = [this](const std::string& name) { + return &this->GetOrCreateNodeArg(name, nullptr); + }; + + // Process constant nodes first and create NodeArg for these as they become initializers + // It is important for the initializers to have NodeArg created, first they are needed + // if the initializer is unused and removed, second if the node depends on the initializer, + // we can have Type attached to it. + InlinedVector non_constant_nodes; + non_constant_nodes.reserve(func_to_inline.node_size()); + for (const auto& inlined_node : func_to_inline.node()) { + if (inlined_node.op_type() == kConstant) { + // Copy constant nodes _value to name_to_initial_tensor_ + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(inlined_node, std::nullopt)); + } else { + non_constant_nodes.push_back(&inlined_node); + } + } + + for (const auto* inlined_node : non_constant_nodes) { + InlinedVector inputs; + InlinedVector outputs; + + for (const auto& tensor_name : inlined_node->input()) + inputs.push_back(to_node_arg(tensor_name)); + + for (const auto& tensor_name : inlined_node->output()) + outputs.push_back(to_node_arg(tensor_name)); + + onnxruntime::NodeAttributes new_attr_map; + new_attr_map.reserve(inlined_node->attribute_size()); + for (const auto& node_attr : inlined_node->attribute()) { + new_attr_map.insert_or_assign(node_attr.name(), node_attr); + } + ORT_IGNORE_RETURN_VALUE(AddNode(inlined_node->name(), inlined_node->op_type(), + inlined_node->doc_string(), inputs, outputs, + &new_attr_map, inlined_node->domain())); + } + + return Status::OK(); +} + Status Graph::InlineFunction(Node& callnode) { - const auto& model_path = ModelPath(); - auto output_edges = callnode.GetRelationships().output_edges; + // Remove output edges. Requirement for RemoveNode() below. + auto output_edges = callnode.GetRelationships().output_edges; // copy so RemoveEdge doesn't invalidate iterator for (const auto& output_edge : output_edges) { RemoveEdge(callnode.Index(), output_edge.GetNode().Index(), output_edge.GetSrcArgIndex(), output_edge.GetDstArgIndex()); } // create a uniq_identifier to append to every node name and intermediate input\outputs // to make sure there are no unintended duplicates - std::stringstream ss; - ss << "_inline_" << callnode.OpType(); - auto uniq_identifier = GenerateNodeName(ss.str()); + std::string base_uniq_identifier{"_inlfunc_"}; + base_uniq_identifier.append(callnode.OpType()); + const auto uniq_identifier = GenerateNodeName(base_uniq_identifier); + // Replace a (function-call) node by an inlined graph. if (!callnode.GetFunctionBody()) { // This is the normal use-case: inlining a FunctionProto (representing // a model-local function or a schema-defined function). - FunctionProto inlined_fp; + ONNX_NAMESPACE::FunctionProto inlined_fp; ORT_ENFORCE(callnode.TryGetFunctionProto(inlined_fp), "Node has no function body and cannot be inlined."); - function_utils::Specialize(inlined_fp, callnode, uniq_identifier); - auto to_node_arg = [this](const std::string& name) { - return &this->GetOrCreateNodeArg(name, nullptr); - }; - - // Process constant nodes first and create NodeArg for these as they become initializers - // It is important for the initializers to have NodeArg created, first they are needed - // if the initializer is unused and removed, second if the node depends on the initializer, - // we can have Type attached to it. - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() == kConstant) { - // Copy constant nodes _value to name_to_initial_tensor_ - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(inlined_node, model_path, *tensor, inlined_node.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined function: ", - inlined_fp.name(), " conflicts with graph initializer. Check Specializing code."); - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } - } - - for (const auto& inlined_node : inlined_fp.node()) { - if (inlined_node.op_type() != kConstant) { - InlinedVector inputs; - InlinedVector outputs; - - for (const auto& tensor_name : inlined_node.input()) - inputs.push_back(to_node_arg(tensor_name)); - - for (const auto& tensor_name : inlined_node.output()) - outputs.push_back(to_node_arg(tensor_name)); - - onnxruntime::NodeAttributes new_attr_map; - new_attr_map.reserve(inlined_node.attribute_size()); - for (const auto& node_attr : inlined_node.attribute()) { - onnx::AttributeProto attr_copy = node_attr; - new_attr_map[node_attr.name()] = std::move(attr_copy); - } - AddNode(inlined_node.name(), inlined_node.op_type(), - inlined_node.doc_string(), inputs, outputs, &new_attr_map, inlined_node.domain()); - } - } + // Make all the names unique and resolve nested graphs inputs to the outer scope. + function_utils::Specialize(inlined_fp, callnode, uniq_identifier); + // In this case, global Resolve() will take care of everything. + ORT_RETURN_IF_ERROR(InlineFunctionProto(inlined_fp)); } else { // Uncommon scenario. Inlining a node representing a fused sub-graph. // TODO: Unclear that this feature is needed. Can this be removed? @@ -4115,15 +4147,7 @@ Status Graph::InlineFunction(Node& callnode) { // Copy constant nodes _value to name_to_initial_tensor_ ONNX_NAMESPACE::NodeProto subgraph_node_proto{}; subgraph_node.ToProto(subgraph_node_proto); - const gsl::not_null tensor{graph_proto_->add_initializer()}; - ORT_RETURN_IF_ERROR(utils::ConstantNodeProtoToTensorProto(subgraph_node_proto, model_path, *tensor, subgraph_node_proto.output(0))); - auto insert_result = name_to_initial_tensor_.emplace(tensor->name(), tensor); - ORT_ENFORCE(insert_result.second, "Constant node name: ", tensor->name(), " in inlined subgraph: ", - subgraph.Name(), " conflicts with graph initializer. Check Specializing code."); - if (GetNodeArg(tensor->name()) == nullptr) { - TypeProto t{TypeProtoFromTensorProto(*tensor)}; - ORT_IGNORE_RETURN_VALUE(GetOrCreateNodeArg(tensor->name(), &t)); - } + ORT_RETURN_IF_ERROR(AddConstantProtoAsInitializer(subgraph_node_proto, std::nullopt)); } } diff --git a/onnxruntime/core/graph/model.cc b/onnxruntime/core/graph/model.cc index 05747a7e5124d..076332a65c8f2 100644 --- a/onnxruntime/core/graph/model.cc +++ b/onnxruntime/core/graph/model.cc @@ -41,6 +41,35 @@ namespace onnxruntime { #if !defined(ORT_MINIMAL_BUILD) +void Model::RemoveLocalFunctionsProtos(const InlinedHashSet& retained) { + auto* local_functions = model_proto_.mutable_functions(); + if (retained.empty()) { + model_local_function_templates_maps_.clear(); + model_local_functions_.clear(); + local_functions->erase(local_functions->begin(), local_functions->end()); + } else { + const auto retained_end = retained.cend(); + for (auto it = model_local_functions_.begin(); + it != model_local_functions_.end();) { + if (retained.find(it->first) == retained_end) { + model_local_function_templates_maps_.erase(it->first); + it = model_local_functions_.erase(it); + } else { + ++it; + } + } + + for (auto it = local_functions->begin(); it != local_functions->end();) { + const auto function_id = function_utils::GetFunctionIdentifier(it->domain(), it->name()); + if (retained.find(function_id) == retained_end) { + it = local_functions->erase(it); + } else { + ++it; + } + } + } +} + static constexpr int DEFAULT_PROTOBUF_BLOCK_SIZE = 4 * 1024 * 1024; Model::Model(const std::string& graph_name, @@ -95,10 +124,10 @@ Model::Model(const std::string& graph_name, for (auto& func : model_local_functions) { auto func_ptr = model_proto_.add_functions(); func_ptr->CopyFrom(func); - model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), func_ptr); + model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func_ptr->domain(), func_ptr->name()), + func_ptr); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -111,8 +140,8 @@ Model::Model(const std::string& graph_name, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), + std::move(func_template_ptr)); } // need to call private ctor so can't use make_shared @@ -220,7 +249,6 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, model_local_functions_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), &func); } - model_local_function_templates_.reserve(model_proto_.functions().size()); model_local_function_templates_maps_.reserve(model_proto_.functions().size()); for (auto& func : model_proto_.functions()) { auto func_schema_ptr = function_utils::CreateSchema(func.domain(), @@ -233,9 +261,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, auto func_template_ptr = std::make_unique(); func_template_ptr->op_schema_ = std::move(func_schema_ptr); func_template_ptr->onnx_func_proto_ = &func; - model_local_function_templates_.push_back(std::move(func_template_ptr)); - model_local_function_templates_maps_[function_utils::GetFunctionIdentifier(func.domain(), func.name())] = - model_local_function_templates_.back().get(); + model_local_function_templates_maps_.insert_or_assign(function_utils::GetFunctionIdentifier(func.domain(), func.name()), std::move(func_template_ptr)); } // create instance. need to call private ctor so can't use make_unique @@ -244,7 +270,7 @@ Model::Model(ModelProto&& model_proto, const PathString& model_path, logger, options.strict_shape_type_inference)); } -const InlinedHashMap& Model::GetModelLocalFunctionTemplates() const { +const NodeHashMap>& Model::GetModelLocalFunctionTemplates() const { return model_local_function_templates_maps_; } @@ -332,7 +358,7 @@ const Graph& Model::MainGraph() const noexcept { } #if !defined(ORT_MINIMAL_BUILD) -ModelProto Model::ToProto() { +ModelProto Model::ToProto() const { // We want to return back the original proto // To that end invoke const overload of ToGraphProto() // that returns by value and, therefore, allows us to filter @@ -346,7 +372,7 @@ ModelProto Model::ToProto() { ModelProto Model::ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold) { + size_t initializer_size_threshold) const { ModelProto result(model_proto_); const auto& graph = *graph_; *(result.mutable_graph()) = graph.ToGraphProtoWithExternalInitializers(external_file_name, diff --git a/onnxruntime/core/graph/model.h b/onnxruntime/core/graph/model.h index 6bdb68dd734f0..4ce6660b794bc 100644 --- a/onnxruntime/core/graph/model.h +++ b/onnxruntime/core/graph/model.h @@ -139,7 +139,7 @@ class Model { // Returns empty string if not specified. const std::string GraphDocString() const; - const InlinedHashMap& GetModelLocalFunctionTemplates() const; + const NodeHashMap>& GetModelLocalFunctionTemplates() const; #else // Get model's IR version. @@ -182,14 +182,14 @@ class Model { #if !defined(ORT_MINIMAL_BUILD) // Get model's serialization proto data. - ONNX_NAMESPACE::ModelProto ToProto(); + ONNX_NAMESPACE::ModelProto ToProto() const; // Get model's serialization proto data. // Save initializer larger than the given threshold (in bytes) into an external binary file // with the given name. This function is useful to avoid hitting the size limit of protobuf files. ONNX_NAMESPACE::ModelProto ToGraphProtoWithExternalInitializers(const std::string& external_file_name, const PathString& file_path, - size_t initializer_size_threshold); + size_t initializer_size_threshold) const; #ifdef _WIN32 static common::Status Save(Model& model, const std::wstring& file_path); @@ -291,6 +291,13 @@ class Model { common::Status SaveToOrtFormat(flatbuffers::FlatBufferBuilder& builder, flatbuffers::Offset& model) const; + /// + /// Frees local function definitions in the model, excluding those in the `retained` set. + /// Called from GraphPartitioner::InlineFunctionsAOT. + /// + /// contains function IDs that should not be removed. + void RemoveLocalFunctionsProtos(const InlinedHashSet& retained); + #endif // !defined(ORT_MINIMAL_BUILD) static common::Status LoadFromOrtFormat(const onnxruntime::fbs::Model& fbs_model, @@ -312,14 +319,12 @@ class Model { // this map will be used for the local functions' schema's type/shape inference. // This container is used by ONNX code and must be an std::unordered_map. std::unordered_map model_local_functions_; - // this is the container that host the generated schemas for model local functions. - // the generated schemare will be used for graph resolving and type/shape inference. - // those schemas' type/shape inference will reference to the model_local_functions_ as context, - // so need to keep them with same lifetime. - InlinedVector> model_local_function_templates_; // this is the map from function id to the local function template. // this map will be used by graph to instantiate the function body. - InlinedHashMap model_local_function_templates_maps_; + // Defined as a node based map so the memory is released when not all of the functions + // are inlined and removed. + NodeHashMap> model_local_function_templates_maps_; + #else // properties that would normally come from ModelProto std::string producer_version_; diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S new file mode 100644 index 0000000000000..e18846c89030e --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmS8S8KernelSmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmS8S8KernelSmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the smmla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + smmla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + smmla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + smmla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + smmla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmS8S8KernelSmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmS8S8KernelSmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmS8S8KernelSmmlaFunction Zero + QgemmS8S8KernelSmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S new file mode 100644 index 0000000000000..baf6e21e6ff06 --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/QgemmU8X8KernelUmmla.S @@ -0,0 +1,922 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + QgemmU8X8KernelUmmla.s + +Abstract: + + This module implements the kernels for the Int8 precision matrix/matrix + multiply operation (QGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the ummla kernel. d8-d15, x19-x30 need save +// + .equ .LMlasQgemmKernel_backup_x19_x20, 0 + .equ .LMlasQgemmKernel_backup_x21_x22, 16 + .equ .LMlasQgemmKernel_backup_x23_x24, 32 + .equ .LMlasQgemmKernel_backup_x25_x26, 48 + .equ .LMlasQgemmKernel_backup_x27_x28, 64 + .equ .LMlasQgemmKernel_backup_d8_d9, 80 + .equ .LMlasQgemmKernel_backup_d10_d11, 96 + .equ .LMlasQgemmKernel_backup_d12_d13, 112 + .equ .LMlasQgemmKernel_backup_d14_d15, 128 + .equ .LMlasQgemmKernel_SavedRegisters, 144 + .equ .LMlasQgemmKernel_SavedRegisters_Neg, -144 + + +// +// Init Row Accumulators +// +// Generates the code to initialize the accumulators for a single row of the output +// block. +// +// +// Accumulators are initialized to ZeroPointB * RowSum + ColumnSum +// x7 for RowSumsBuffer pointer +// x10 for ColumnSumBuffer pointer +// x11 for ZeroPointB buffer pointer +// +// v12~v13 for RowSums values +// v14~v15 for ColumnSums values +// v0~v3 for ZeroPointB values +// + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, RowSumReg + + mul v7.4s, v\RowSumReg\().4s, v8.4s + mov v\Vec1Reg\().16b, v7.16b + add v\Vec1Reg\().4s, v\Vec1Reg\().4s, v0.4s +.if \Columns\() > 2 + mul v7.4s, v\RowSumReg\().4s, v9.4s + mov v\Vec2Reg\().16b, v7.16b + add v\Vec2Reg\().4s, v\Vec2Reg\().4s, v1.4s +.endif +.if \Columns\() > 4 + mul v7.4s, v\RowSumReg\().4s, v10.4s + mov v\Vec3Reg\().16b, v7.16b + add v\Vec3Reg\().4s, v\Vec3Reg\().4s, v2.4s +.endif +.if \Columns\() > 6 + mul v7.4s, v\RowSumReg\().4s, v11.4s + mov v\Vec4Reg\().16b, v7.16b + add v\Vec4Reg\().4s, v\Vec4Reg\().4s, v3.4s +.endif + + .endm + + +// +// InitBlockAccumulators +// +// Generates the code to initialize the accumulators for 8x8 output +// block. +// + .macro InitBlockAccumulators Mode, Columns, Rows + + ld1 {v14.4s},[x10],#16 // load ColumnSumBuffer[0] +.if \Columns\() > 4 + ld1 {v15.4s},[x10],#16 // load ColumnSumBuffer[4] +.endif + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast column + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + dup v4.4s,v15.s[0] // broadcast column + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + // v8~v11 will anyway get set in MatrixA loading, so they are free to use now + movi v8.4s, #1 + movi v9.4s, #1 + movi v10.4s, #1 + movi v11.4s, #1 + + cbz x11,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB + + ld1 {v4.4s},[x11],#16 // load ZeroPointB[0] + ld1 {v5.4s},[x11],#16 // load ZeroPointB[4] + + dup v6.4s, v4.s[0] + dup v7.4s, v4.s[1] + zip1 v8.4s, v6.4s, v7.4s + + dup v6.4s, v4.s[2] + dup v7.4s, v4.s[3] + zip1 v9.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[0] + dup v7.4s, v5.s[1] + zip1 v10.4s, v6.4s, v7.4s + + dup v6.4s, v5.s[2] + dup v7.4s, v5.s[3] + zip1 v11.4s, v6.4s, v7.4s + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipScaleByZeroPointB: + dup v4.4s, v12.s[0] //boardcast RowSums + dup v5.4s, v12.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),16,17,18,19,6 +.if \Rows\() > 2 + dup v4.4s, v12.s[2] //boardcast RowSums + dup v5.4s, v12.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),20,21,22,23,6 +.endif +.if \Rows\() > 4 + dup v4.4s,v13.s[0] // broadcast row sums + dup v5.4s,v13.s[1] + + uzp1 v6.2d, v4.2d, v5.2d + + InitRowAccumulators \Columns\(),24,25,26,27,6 +.endif +.if \Rows\() > 6 + dup v4.4s,v13.s[2] // broadcast row sums + dup v5.4s,v13.s[3] + + uzp1 v6.2d, v4.2d, v5.2d + InitRowAccumulators \Columns\(),28,29,30,31,6 +.endif + + .endm + + +// LoadPackedMatrixABy16Elements +// +// Generates the code to load 16 elements from matrix A. +// + .macro LoadPackedMatrixABy16Elements Rows +.if \Rows\() == 1 + ldr q8,[x0],#8 +.else + ldr q8,[x0],#16 + +.if \Rows\() > 2 + ldr q9,[x0],#16 +.endif + +.if \Rows\() > 4 + ldr q10,[x0],#16 +.endif + +.if \Rows\() > 6 + ldr q11,[x0],#16 +.endif +.endif + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + ummla v\Vec1Reg\().4s, \MatrixAReg\().16b, v4.16b +.if \Columns\() > 2 + ummla v\Vec2Reg\().4s, \MatrixAReg\().16b, v5.16b +.endif +.if \Columns\() > 4 + ummla v\Vec3Reg\().4s, \MatrixAReg\().16b, v6.16b +.endif +.if \Columns\() > 6 + ummla v\Vec4Reg\().4s, \MatrixAReg\().16b, v7.16b +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(), \Columns\(),\Rows\() + + sub x9,x3,#1 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#1 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#1 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadPackedMatrixABy16Elements \Rows\() + ld1 {v4.16b - v7.16b}, [x1], #64 + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + add v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + add v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + add v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + add v8.4s,v8.4s,v4.4s + add v9.4s,v9.4s,v5.4s + add v10.4s,v10.4s,v6.4s + add v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),8,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),4,\Rows\() + mov x0,x8 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. The matrix data has been packed + using MlasGemmQuantCopyPackA. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasGemmQuantCopyPackB. + + C (x2) - Supplies the address of matrix C. + + PackedCountK (x3) - Supplies the number of packed columns from matrix A and + the number of packed rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + ldc (x6) - Supplies the first dimension of matrix C. + + RowSumBuffer (x7) - Supplies the sum of each row from matrix A. These values + have been pre-scaled by the zero point offset of matrix B if the offset + is per-tensor (ZeroPointB is nullptr). Otherwise, these values must be + scaled by the per-column zero point offsets of matrix B. These values are + accumulated into every row of matrix C. + + ColumnSumBuffer - Supplies the sum of each column from matrix B multiplied + by the zero point offset of matrix A. These values are accumulated into + every column of matrix C. + + ZeroPointB - Optionally supplies the per-column zero point offsets of matrix + B, else nullptr if the matrix B is using per-tensor quantization. + +Return Value: + + Returns the number of rows handled. + +--*/ + + .macro QgemmU8X8KernelUmmlaFunction Mode + + FUNCTION_ENTRY MlasGemmU8X8KernelUmmla\Mode\() + + ldr x10,[sp, #0] + ldr x11,[sp,#8] + + stp x19, x20, [sp, #.LMlasQgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + + add x13,x2,x6,lsl #2 // compute matrix C plus 1 row + add x14,x13,x6,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x6,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x6,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x6,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x6,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x6,lsl #2 // compute matrix C plus 7 rows + + mov x8,x0 // save matrix A + +// +// Process 8 rows of the matrices. +// + ld1 {v12.4s},[x7],#16 // load row sum 1 ~ 4 + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ld1 {v13.4s},[x7],#16 // load row sum 5 ~ 8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasQgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasQgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasQgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasQgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasQgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasQgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasQgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasQgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasQgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + QgemmU8X8KernelUmmlaFunction Zero + QgemmU8X8KernelUmmlaFunction Add + + .end diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 0d5e425018f37..b5596727c93a4 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -184,11 +184,17 @@ class MLASCPUIDInfo bool IsCurrentCoreArmv8NarrowLd() const { return false; } + bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } + + bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + private: MLASCPUIDInfo(); bool has_arm_neon_dot_{false}; bool has_fp16_{false}; + bool has_arm_neon_i8mm_{false}; + bool has_arm_sve_i8mm_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -856,6 +862,8 @@ extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmX8S8DispatchNeon; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUdot; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSdot; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla; +extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchWasmSimd; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemmQuantDispatchDefault; extern const MLAS_GEMM_QUANT_DISPATCH MlasGemm8X8DispatchPOWER10; diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 0a4d9e05c4cd2..0ec40d0f53ac5 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -52,6 +52,14 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP_ASIMDDP (1 << 20) #endif +#ifndef HWCAP2_I8MM +#define HWCAP2_I8MM (1 << 13) +#endif + +#ifndef HWCAP2_SVEI8MM +#define HWCAP2_SVEI8MM (1 << 9) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -59,6 +67,9 @@ MLASCPUIDInfo::MLASCPUIDInfo() // raw hack! Need CPUIDInfo implementation for more precise detection has_fp16_ = has_arm_neon_dot_; + + has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); + has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); } #endif @@ -482,6 +493,17 @@ Return Value: this->ConvSymS8S8Dispatch = &MlasConvSymS8DispatchDot; } +#if defined(__linux__) + // + // Check if the processor supports ASIMD I8MM instructions. + // + if (MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_I8MM()) { + this->GemmU8U8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmU8S8Dispatch = &MlasGemmU8X8DispatchUmmla; + this->GemmS8S8Dispatch = &MlasGemmS8S8DispatchSmmla; + } +#endif + #endif // MLAS_TARGET_ARM64 #if defined(MLAS_TARGET_POWER) this->GemmFloatKernel = MlasSgemmKernel; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp new file mode 100644 index 0000000000000..c41f43ca22d18 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_smmla.cpp @@ -0,0 +1,964 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_smmla.cpp + +Abstract: + + This module implements smmla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON SMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmS8S8KernelSmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_S8S8_KERNEL_SMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef int8_t OffsetAType; + typedef int8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA( + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* D_uint8_t, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + int8_t* D = reinterpret_cast(D_uint8_t); + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + int8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a0 + lda * 2; + const int8_t* a3 = a0 + lda * 3; + const int8_t* a4 = a0 + lda * 4; + const int8_t* a5 = a0 + lda * 5; + const int8_t* a6 = a0 + lda * 6; + const int8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + int32x4_t RowSums0 = vmovq_n_s32(0); + int32x4_t RowSums1 = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + int64x2_t v4 = vld1q_s64(reinterpret_cast(a4)); + a4 += 16; + int64x2_t v5 = vld1q_s64(reinterpret_cast(a5)); + a5 += 16; + int64x2_t v6 = vld1q_s64(reinterpret_cast(a6)); + a6 += 16; + int64x2_t v7 = vld1q_s64(reinterpret_cast(a7)); + a7 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + int64x2_t z4 = vzip1q_s64(v4, v5); + int64x2_t z5 = vzip2q_s64(v4, v5); + int64x2_t z6 = vzip1q_s64(v6, v7); + int64x2_t z7 = vzip2q_s64(v6, v7); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z4)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z6)); + vst1q_s8(&D[64], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[80], vreinterpretq_s8_s64(z3)); + vst1q_s8(&D[96], vreinterpretq_s8_s64(z5)); + vst1q_s8(&D[112], vreinterpretq_s8_s64(z7)); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z4))); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z5))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z6))); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z7))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + int64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + int64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + int64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + int64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + vst1q_s8(&d[32], vmovq_n_s8(0)); + vst1q_s8(&d[48], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + int64x2_t z45 = vcombine_s64(v4, v5); + int64x2_t z67 = vcombine_s64(v6, v7); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_s32(RowSums0, vcombine_s32(RowSums0L, RowSums0H)); + + int32x4_t RowSums1L_pada = vmovq_n_s32(0); + RowSums1L_pada = vpadalq_s16(RowSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z45))); + + int32x4_t RowSums1L_ext = vextq_s32(RowSums1L_pada, RowSums1L_pada, 1); + int32x4_t RowSums1L_add = vaddq_s32(RowSums1L_pada, RowSums1L_ext); + int32x2_t RowSums1L = {vdups_laneq_s32(RowSums1L_add, 0), + vdups_laneq_s32(RowSums1L_add, 2)}; + + int32x4_t RowSums1H_pada = vmovq_n_s32(0); + RowSums1H_pada = vpadalq_s16(RowSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z67))); + + int32x4_t RowSums1H_ext = vextq_s32(RowSums1H_pada, RowSums1H_pada, 1); + int32x4_t RowSums1H_add = vaddq_s32(RowSums1H_pada, RowSums1H_ext); + int32x2_t RowSums1H = {vdups_laneq_s32(RowSums1H_add, 0), + vdups_laneq_s32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_s32(RowSums1, vcombine_s32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, RowSums0); + vst1q_s32(&RowSumBuffer[4], RowSums1); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + const int8_t* a2 = a1 + lda; + const int8_t* a3 = a2 + lda; + + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + int64x2_t v2 = vld1q_s64(reinterpret_cast(a2)); + a2 += 16; + int64x2_t v3 = vld1q_s64(reinterpret_cast(a3)); + a3 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + int64x2_t z2 = vzip1q_s64(v2, v3); + int64x2_t z3 = vzip2q_s64(v2, v3); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z2)); + vst1q_s8(&D[32], vreinterpretq_s8_s64(z1)); + vst1q_s8(&D[48], vreinterpretq_s8_s64(z3)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z2))); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z3))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + int64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + int64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + int32x4_t RowSumsH_pada = vmovq_n_s32(0); + RowSumsH_pada = vpadalq_s16(RowSumsH_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSumsH_ext = vextq_s32(RowSumsH_pada, RowSumsH_pada, 1); + int32x4_t RowSumsH_add = vaddq_s32(RowSumsH_pada, RowSumsH_ext); + int32x2_t RowSumsH = {vdups_laneq_s32(RowSumsH_add, 0), + vdups_laneq_s32(RowSumsH_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + int8_t* d = D; + + vst1q_s8(d, vmovq_n_s8(0)); + vst1q_s8(&d[16], vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int64x2_t z23 = vcombine_s64(v2, v3); + + int32x4_t RowSums0L_pada = vmovq_n_s32(0); + RowSums0L_pada = vpadalq_s16(RowSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSums0L_ext = vextq_s32(RowSums0L_pada, RowSums0L_pada, 1); + int32x4_t RowSums0L_add = vaddq_s32(RowSums0L_pada, RowSums0L_ext); + int32x2_t RowSums0L = {vdups_laneq_s32(RowSums0L_add, 0), + vdups_laneq_s32(RowSums0L_add, 2)}; + + int32x4_t RowSums0H_pada = vmovq_n_s32(0); + RowSums0H_pada = vpadalq_s16(RowSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s64(z23))); + + int32x4_t RowSums0H_ext = vextq_s32(RowSums0H_pada, RowSums0H_pada, 1); + int32x4_t RowSums0H_add = vaddq_s32(RowSums0H_pada, RowSums0H_ext); + int32x2_t RowSums0H = {vdups_laneq_s32(RowSums0H_add, 0), + vdups_laneq_s32(RowSums0H_add, 2)}; + + RowSums = vaddq_s32(RowSums, vcombine_s32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, RowSums); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const int8_t* a0 = reinterpret_cast(A); + const int8_t* a1 = a0 + lda; + + size_t k = CountK; + int32x2_t RowSums = vmov_n_s32(0); + + while (k >= 16) { + int64x2_t v0 = vld1q_s64(reinterpret_cast(a0)); + a0 += 16; + int64x2_t v1 = vld1q_s64(reinterpret_cast(a1)); + a1 += 16; + + int64x2_t z0 = vzip1q_s64(v0, v1); + int64x2_t z1 = vzip2q_s64(v0, v1); + + vst1q_s8(&D[0], vreinterpretq_s8_s64(z0)); + vst1q_s8(&D[16], vreinterpretq_s8_s64(z1)); + + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z0))); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z1))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + int64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + int64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + int8_t* d = PaddedMatrixAData; + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + int64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + int64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + int64x2_t z01 = vcombine_s64(v0, v1); + int32x4_t RowSumsL_pada = vmovq_n_s32(0); + RowSumsL_pada = vpadalq_s16(RowSumsL_pada, vpaddlq_s8(vreinterpretq_s8_s64(z01))); + + int32x4_t RowSumsL_ext = vextq_s32(RowSumsL_pada, RowSumsL_pada, 1); + int32x4_t RowSumsL_add = vaddq_s32(RowSumsL_pada, RowSumsL_ext); + int32x2_t RowSumsL = {vdups_laneq_s32(RowSumsL_add, 0), + vdups_laneq_s32(RowSumsL_add, 2)}; + + RowSums = vadd_s32(RowSums, RowSumsL); + + int8x16_t PackedVector = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, RowSums); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const int8_t* a = reinterpret_cast(A); + size_t k = CountK; + int32x4_t RowSums = vmovq_n_s32(0); + + while (k >= 16) { + int8x16_t v = vld1q_s8(a); + a += 16; + + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_s8(PaddedMatrixAData, vmovq_n_s8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + int8x16_t v = vld1q_s8(PaddedMatrixAData); + vst1q_s8(D, v); + + RowSums = vpadalq_s16(RowSums, vpaddlq_s8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_s32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmS8S8CopyPackBProcessSmmla(int8_t* D, int8x8_t BytesRow[8], int32x4_t ColumnSums[2]) +{ + int8x16_t v02 = vcombine_s8(BytesRow[0], BytesRow[2]); + int8x16_t v13 = vcombine_s8(BytesRow[1], BytesRow[3]); + + int8x16_t v46 = vcombine_s8(BytesRow[4], BytesRow[6]); + int8x16_t v57 = vcombine_s8(BytesRow[5], BytesRow[7]); + + int8x16x2_t zw1 = vzipq_s8(v02, v13); + int16x8x2_t zd1 = vzipq_s16(vreinterpretq_s16_s8(zw1.val[0]), vreinterpretq_s16_s8(zw1.val[1])); + + int8x16x2_t zw2 = vzipq_s8(v46, v57); + int16x8x2_t zd2 = vzipq_s16(vreinterpretq_s16_s8(zw2.val[0]), vreinterpretq_s16_s8(zw2.val[1])); + + int32x4x2_t zd3 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[0]), vreinterpretq_s32_s16(zd2.val[0])); + int32x4x2_t zd4 = + vzipq_s32(vreinterpretq_s32_s16(zd1.val[1]), vreinterpretq_s32_s16(zd2.val[1])); + + vst1q_s8(&D[0], vreinterpretq_s8_s32(zd3.val[0])); + vst1q_s8(&D[16], vreinterpretq_s8_s32(zd3.val[1])); + vst1q_s8(&D[32], vreinterpretq_s8_s32(zd4.val[0])); + vst1q_s8(&D[48], vreinterpretq_s8_s32(zd4.val[1])); + + int32x4_t ColSums0L_pada = vmovq_n_s32(0); + ColSums0L_pada = vpadalq_s16(ColSums0L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[0]))); + int32x4_t ColSums0L_ext = vextq_s32(ColSums0L_pada, ColSums0L_pada, 1); + int32x4_t ColSums0L_add = vaddq_s32(ColSums0L_pada, ColSums0L_ext); + int32x2_t ColSums0L = {vdups_laneq_s32(ColSums0L_add, 0), vdups_laneq_s32(ColSums0L_add, 2)}; + + int32x4_t ColSums0H_pada = vmovq_n_s32(0); + ColSums0H_pada = vpadalq_s16(ColSums0H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd3.val[1]))); + int32x4_t ColSums0H_ext = vextq_s32(ColSums0H_pada, ColSums0H_pada, 1); + int32x4_t ColSums0H_add = vaddq_s32(ColSums0H_pada, ColSums0H_ext); + int32x2_t ColSums0H = {vdups_laneq_s32(ColSums0H_add, 0), vdups_laneq_s32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_s32(ColumnSums[0], vcombine_s32(ColSums0L, ColSums0H)); + + int32x4_t ColSums1L_pada = vmovq_n_s32(0); + ColSums1L_pada = vpadalq_s16(ColSums1L_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[0]))); + int32x4_t ColSums1L_ext = vextq_s32(ColSums1L_pada, ColSums1L_pada, 1); + int32x4_t ColSums1L_add = vaddq_s32(ColSums1L_pada, ColSums1L_ext); + int32x2_t ColSums1L = {vdups_laneq_s32(ColSums1L_add, 0), vdups_laneq_s32(ColSums1L_add, 2)}; + + int32x4_t ColSums1H_pada = vmovq_n_s32(0); + ColSums1H_pada = vpadalq_s16(ColSums1H_pada, vpaddlq_s8(vreinterpretq_s8_s32(zd4.val[1]))); + int32x4_t ColSums1H_ext = vextq_s32(ColSums1H_pada, ColSums1H_pada, 1); + int32x4_t ColSums1H_add = vaddq_s32(ColSums1H_pada, ColSums1H_ext); + int32x2_t ColSums1H = {vdups_laneq_s32(ColSums1H_add, 0), vdups_laneq_s32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_s32(ColumnSums[1], vcombine_s32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* Dst, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(BIsSigned); + int8_t* D = reinterpret_cast(Dst); + const int8x16_t ZeroVector = vmovq_n_s8(0); + int8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int32x4_t ColumnSums[2]; + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + while (k >= 8) { + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = vld1_s8(&b[ldb * 1]); + BytesRow[2] = vld1_s8(&b[ldb * 2]); + BytesRow[3] = vld1_s8(&b[ldb * 3]); + BytesRow[4] = vld1_s8(&b[ldb * 4]); + BytesRow[5] = vld1_s8(&b[ldb * 5]); + BytesRow[6] = vld1_s8(&b[ldb * 6]); + BytesRow[7] = vld1_s8(&b[ldb * 7]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_s8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_s8(&b[ldb * 1]) : vget_low_s8(ZeroVector); + BytesRow[2] = (k >= 3) ? vld1_s8(&b[ldb * 2]) : vget_low_s8(ZeroVector); + BytesRow[3] = (k >= 4) ? vld1_s8(&b[ldb * 3]) : vget_low_s8(ZeroVector); + BytesRow[4] = (k >= 5) ? vld1_s8(&b[ldb * 4]) : vget_low_s8(ZeroVector); + BytesRow[5] = (k >= 6) ? vld1_s8(&b[ldb * 5]) : vget_low_s8(ZeroVector); + BytesRow[6] = (k >= 7) ? vld1_s8(&b[ldb * 6]) : vget_low_s8(ZeroVector); + BytesRow[7] = vget_low_s8(ZeroVector); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const int8_t* b = reinterpret_cast(B); + size_t k = CountK; + int8_t PaddedMatrixBData[64]; + int32x4_t ColumnSums[2]; + + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + ColumnSums[0] = vmovq_n_s32(0); + ColumnSums[1] = vmovq_n_s32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const int8_t* bcopy0 = &b[ldb * 0]; + const int8_t* bcopy1 = &b[ldb * 1]; + const int8_t* bcopy2 = &b[ldb * 2]; + const int8_t* bcopy3 = &b[ldb * 3]; + const int8_t* bcopy4 = &b[ldb * 4]; + const int8_t* bcopy5 = &b[ldb * 5]; + const int8_t* bcopy6 = &b[ldb * 6]; + const int8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_s8(&PaddedMatrixBData[0], ZeroVector); + vst1q_s8(&PaddedMatrixBData[16], ZeroVector); + vst1q_s8(&PaddedMatrixBData[32], ZeroVector); + vst1q_s8(&PaddedMatrixBData[48], ZeroVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + int8_t* padded = PaddedMatrixBData; + int8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_s8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_s8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_s8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_s8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_s8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_s8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_s8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_s8(&PaddedMatrixBData[56]); + + MlasGemmS8S8CopyPackBProcessSmmla(D, BytesRow, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], ColumnSums[0]); + vst1q_s32(&ColumnSumBuffer[4], ColumnSums[1]); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedAType* A, + const MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmS8S8KernelSmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmS8S8KernelSmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmS8S8DispatchSmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedK, + MLAS_GEMM_S8S8_KERNEL_SMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp new file mode 100644 index 0000000000000..3936154432ac7 --- /dev/null +++ b/onnxruntime/core/mlas/lib/qgemm_kernel_ummla.cpp @@ -0,0 +1,967 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + qgemm_kernel_ummla.cpp + +Abstract: + + This module implements ummla QGEMM kernel. + +--*/ + +#include "mlasi.h" +#include "qgemm.h" + +// +// Define the prototypes of the NEON UMMLA routines written in assembly. +// + +extern "C" { + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaZero(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); + +size_t MLASCALL +MlasGemmU8X8KernelUmmlaAdd(const uint8_t* A, + const uint8_t* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumVector, + const int32_t* ColumnSumVector, + const int32_t* ZeroPointB); +} + +struct MLAS_GEMM_U8X8_KERNEL_UMMLA { + typedef uint8_t PackedAType; + typedef uint8_t PackedBType; + typedef uint8_t OffsetAType; + typedef uint8_t OffsetBType; + + static constexpr size_t PackedK = 8; + static constexpr MLAS_GEMM_QUANT_STRIDES Strides{24, 128, 256}; + static constexpr MLAS_GEMM_QUANT_STRIDES PackedStrides{24, 128, 384}; +}; + +constexpr size_t MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::Strides; +constexpr MLAS_GEMM_QUANT_STRIDES MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides; + +template <> +MLAS_FORCEINLINE int32_t +MlasGemmQuantFixupZeroPointB(int32_t ZeroPointB, bool BIsSigned) +{ + if (BIsSigned) { + ZeroPointB = MLAS_GEMM_U8X8_KERNEL_UMMLA::OffsetBType(ZeroPointB ^ 0x80); + } + + return ZeroPointB; +} + +template <> +void +MlasGemmQuantCopyPackA(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* D, + const uint8_t* A, + size_t lda, + size_t CountM, + size_t CountK, + int32_t* RowSumBuffer, + bool AIsSigned) +{ + MLAS_UNREFERENCED_PARAMETER(AIsSigned); + uint8_t PaddedMatrixAData[64]; + + // + // Process 8 rows of matrix A. + // + // MMLA kernels load 8x8 block of A with four vector registers. So A is packed + // a series of 64 byte vectors where eight rows are interleaved with the + // following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // [ E0 E1 E2 E3 E4 E5 E6 E7 ] + // [ F0 F1 F2 F3 F4 F5 F6 F7 ] + // [ G0 G1 G2 G3 G4 G5 G6 G7 ] + // [ H0 H1 H2 H3 H4 H5 H6 H7 ] + // + // ... + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + while (CountM >= 8) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a0 + lda * 2; + const uint8_t* a3 = a0 + lda * 3; + const uint8_t* a4 = a0 + lda * 4; + const uint8_t* a5 = a0 + lda * 5; + const uint8_t* a6 = a0 + lda * 6; + const uint8_t* a7 = a0 + lda * 7; + + size_t k = CountK; + uint32x4_t RowSums0 = vmovq_n_u32(0); + uint32x4_t RowSums1 = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + uint64x2_t v4 = vld1q_u64(reinterpret_cast(a4)); + a4 += 16; + uint64x2_t v5 = vld1q_u64(reinterpret_cast(a5)); + a5 += 16; + uint64x2_t v6 = vld1q_u64(reinterpret_cast(a6)); + a6 += 16; + uint64x2_t v7 = vld1q_u64(reinterpret_cast(a7)); + a7 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + uint64x2_t z4 = vzip1q_u64(v4, v5); + uint64x2_t z5 = vzip2q_u64(v4, v5); + uint64x2_t z6 = vzip1q_u64(v6, v7); + uint64x2_t z7 = vzip2q_u64(v6, v7); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z4)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z6)); + vst1q_u8(&D[64], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[80], vreinterpretq_u8_u64(z3)); + vst1q_u8(&D[96], vreinterpretq_u8_u64(z5)); + vst1q_u8(&D[112], vreinterpretq_u8_u64(z7)); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z4))); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z5))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z6))); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z7))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 128; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + uint64x1_t v4 = *reinterpret_cast(a4); + a4 += 8; + uint64x1_t v5 = *reinterpret_cast(a5); + a5 += 8; + uint64x1_t v6 = *reinterpret_cast(a6); + a6 += 8; + uint64x1_t v7 = *reinterpret_cast(a7); + a7 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + *reinterpret_cast(&D[32]) = v4; + *reinterpret_cast(&D[40]) = v5; + *reinterpret_cast(&D[48]) = v6; + *reinterpret_cast(&D[56]) = v7; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + k -= 8; + } + + if (k > 0) { + // + // zero pad the remaining columns to 8 + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + vst1q_u8(&d[32], vmovq_n_u8(0)); + vst1q_u8(&d[48], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d[32] = *a4++; + d[40] = *a5++; + d[48] = *a6++; + d[56] = *a7++; + d += 1; + k -= 1; + } + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v4 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v5 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v6 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v7 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + uint64x2_t z45 = vcombine_u64(v4, v5); + uint64x2_t z67 = vcombine_u64(v6, v7); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums0 = vaddq_u32(RowSums0, vcombine_u32(RowSums0L, RowSums0H)); + + uint32x4_t RowSums1L_pada = vmovq_n_u32(0); + RowSums1L_pada = vpadalq_u16(RowSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z45))); + + uint32x4_t RowSums1L_ext = vextq_u32(RowSums1L_pada, RowSums1L_pada, 1); + uint32x4_t RowSums1L_add = vaddq_u32(RowSums1L_pada, RowSums1L_ext); + uint32x2_t RowSums1L = {vdups_laneq_u32(RowSums1L_add, 0), + vdups_laneq_u32(RowSums1L_add, 2)}; + + uint32x4_t RowSums1H_pada = vmovq_n_u32(0); + RowSums1H_pada = vpadalq_u16(RowSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z67))); + + uint32x4_t RowSums1H_ext = vextq_u32(RowSums1H_pada, RowSums1H_pada, 1); + uint32x4_t RowSums1H_add = vaddq_u32(RowSums1H_pada, RowSums1H_ext); + uint32x2_t RowSums1H = {vdups_laneq_u32(RowSums1H_add, 0), + vdups_laneq_u32(RowSums1H_add, 2)}; + + RowSums1 = vaddq_u32(RowSums1, vcombine_u32(RowSums1L, RowSums1H)); + + D += 64; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums0)); + vst1q_s32(&RowSumBuffer[4], vreinterpretq_s32_u32(RowSums1)); + + RowSumBuffer += 8; + + A = A + lda * 8; + CountM -= 8; + } + + // + // Process four rows of matrix A. + // + // The buffer is packed as a series of 32 byte vectors where four rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // [ C0 C1 C2 C3 C4 C5 C6 C7 ] + // [ D0 D1 D2 D3 D4 D5 D6 D7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 4) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + const uint8_t* a2 = a1 + lda; + const uint8_t* a3 = a2 + lda; + + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + uint64x2_t v2 = vld1q_u64(reinterpret_cast(a2)); + a2 += 16; + uint64x2_t v3 = vld1q_u64(reinterpret_cast(a3)); + a3 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + uint64x2_t z2 = vzip1q_u64(v2, v3); + uint64x2_t z3 = vzip2q_u64(v2, v3); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z2)); + vst1q_u8(&D[32], vreinterpretq_u8_u64(z1)); + vst1q_u8(&D[48], vreinterpretq_u8_u64(z3)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z2))); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z3))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 64; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + uint64x1_t v2 = *reinterpret_cast(a2); + a2 += 8; + uint64x1_t v3 = *reinterpret_cast(a3); + a3 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + *reinterpret_cast(&D[16]) = v2; + *reinterpret_cast(&D[24]) = v3; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + uint32x4_t RowSumsH_pada = vmovq_n_u32(0); + RowSumsH_pada = vpadalq_u16(RowSumsH_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSumsH_ext = vextq_u32(RowSumsH_pada, RowSumsH_pada, 1); + uint32x4_t RowSumsH_add = vaddq_u32(RowSumsH_pada, RowSumsH_ext); + uint32x2_t RowSumsH = {vdups_laneq_u32(RowSumsH_add, 0), + vdups_laneq_u32(RowSumsH_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSumsL, RowSumsH)); + + D += 32; + k -= 8; + } + + if (k > 0) { + // + // Copy the remaining bytes with zero padding. + // + uint8_t* d = D; + + vst1q_u8(d, vmovq_n_u8(0)); + vst1q_u8(&d[16], vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + d[16] = *a2++; + d[24] = *a3++; + d += 1; + k -= 1; + } + + d = D; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v2 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v3 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint64x2_t z23 = vcombine_u64(v2, v3); + + uint32x4_t RowSums0L_pada = vmovq_n_u32(0); + RowSums0L_pada = vpadalq_u16(RowSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSums0L_ext = vextq_u32(RowSums0L_pada, RowSums0L_pada, 1); + uint32x4_t RowSums0L_add = vaddq_u32(RowSums0L_pada, RowSums0L_ext); + uint32x2_t RowSums0L = {vdups_laneq_u32(RowSums0L_add, 0), + vdups_laneq_u32(RowSums0L_add, 2)}; + + uint32x4_t RowSums0H_pada = vmovq_n_u32(0); + RowSums0H_pada = vpadalq_u16(RowSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u64(z23))); + + uint32x4_t RowSums0H_ext = vextq_u32(RowSums0H_pada, RowSums0H_pada, 1); + uint32x4_t RowSums0H_add = vaddq_u32(RowSums0H_pada, RowSums0H_ext); + uint32x2_t RowSums0H = {vdups_laneq_u32(RowSums0H_add, 0), + vdups_laneq_u32(RowSums0H_add, 2)}; + + RowSums = vaddq_u32(RowSums, vcombine_u32(RowSums0L, RowSums0H)); + + D += 32; + } + + vst1q_s32(RowSumBuffer, vreinterpretq_s32_u32(RowSums)); + RowSumBuffer += 4; + + A = A + lda * 4; + CountM -= 4; + } + + // + // Process two rows of matrix A. + // + // The buffer is packed as a series of 16 byte vectors where two rows are + // interleaved with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // [ B0 B1 B2 B3 B4 B5 B6 B7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of eight, then the vector is padded + // with zeroes. + // + + if (CountM >= 2) { + const uint8_t* a0 = A; + const uint8_t* a1 = a0 + lda; + + size_t k = CountK; + uint32x2_t RowSums = vmov_n_u32(0); + + while (k >= 16) { + uint64x2_t v0 = vld1q_u64(reinterpret_cast(a0)); + a0 += 16; + uint64x2_t v1 = vld1q_u64(reinterpret_cast(a1)); + a1 += 16; + + uint64x2_t z0 = vzip1q_u64(v0, v1); + uint64x2_t z1 = vzip2q_u64(v0, v1); + + vst1q_u8(&D[0], vreinterpretq_u8_u64(z0)); + vst1q_u8(&D[16], vreinterpretq_u8_u64(z1)); + + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z0))); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z1))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 32; + k -= 16; + } + + while (k >= 8) { + uint64x1_t v0 = *reinterpret_cast(a0); + a0 += 8; + uint64x1_t v1 = *reinterpret_cast(a1); + a1 += 8; + + *reinterpret_cast(&D[0]) = v0; + *reinterpret_cast(&D[8]) = v1; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + D += 16; + k -= 8; + } + + if (k > 0) { + // + // Zero pad the remaining elements to make 8 columns. + // + + uint8_t* d = PaddedMatrixAData; + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + while (k > 0) { + d[0] = *a0++; + d[8] = *a1++; + + d += 1; + k -= 1; + } + + d = PaddedMatrixAData; + uint64x1_t v0 = *reinterpret_cast(d); + d = d + 8; + uint64x1_t v1 = *reinterpret_cast(d); + d = d + 8; + + uint64x2_t z01 = vcombine_u64(v0, v1); + uint32x4_t RowSumsL_pada = vmovq_n_u32(0); + RowSumsL_pada = vpadalq_u16(RowSumsL_pada, vpaddlq_u8(vreinterpretq_u8_u64(z01))); + + uint32x4_t RowSumsL_ext = vextq_u32(RowSumsL_pada, RowSumsL_pada, 1); + uint32x4_t RowSumsL_add = vaddq_u32(RowSumsL_pada, RowSumsL_ext); + uint32x2_t RowSumsL = {vdups_laneq_u32(RowSumsL_add, 0), + vdups_laneq_u32(RowSumsL_add, 2)}; + + RowSums = vadd_u32(RowSums, RowSumsL); + + uint8x16_t PackedVector = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, PackedVector); + + D += 16; + } + + vst1_s32(RowSumBuffer, vreinterpret_s32_u32(RowSums)); + RowSumBuffer += 2; + + A = A + lda * 2; + CountM -= 2; + } + + // + // Process one row of matrix A. + // + // The buffer is packed as a series of 8 byte with the following pattern: + // + // [ A0 A1 A2 A3 A4 A5 A6 A7 ] + // + // This pattern is repeated (CountK / 8) times. + // + // If CountK is not aligned to a multiple of 8, then the vector is padded + // with zeroes. + // + + if (CountM > 0) { + // No need to pad the rows to 2, the .S takes care of zero pdding + const uint8_t* a = A; + size_t k = CountK; + uint32x4_t RowSums = vmovq_n_u32(0); + + while (k >= 16) { + uint8x16_t v = vld1q_u8(a); + a += 16; + + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + + D += 16; + k -= 16; + } + + if (k > 0) { + // + // Copy the remaining bytes to the zero padded stack buffer. + // + + vst1q_u8(PaddedMatrixAData, vmovq_n_u8(0)); + + for (size_t kk = 0; kk < k; kk++) { + PaddedMatrixAData[kk] = a[kk]; + } + + uint8x16_t v = vld1q_u8(PaddedMatrixAData); + vst1q_u8(D, v); + + RowSums = vpadalq_u16(RowSums, vpaddlq_u8(v)); + } + + *RowSumBuffer = int32_t(vaddvq_u32(RowSums)); + } +} + +MLAS_FORCEINLINE +void +MlasGemmU8X8CopyPackBProcessUmmla(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + uint8x8_t BytesRow[8], + uint8x16_t BitFlipVector, + uint32x4_t ColumnSums[2]) +{ + uint8x16_t v02 = veorq_u8(vcombine_u8(BytesRow[0], BytesRow[2]), BitFlipVector); + uint8x16_t v13 = veorq_u8(vcombine_u8(BytesRow[1], BytesRow[3]), BitFlipVector); + + uint8x16_t v46 = veorq_u8(vcombine_u8(BytesRow[4], BytesRow[6]), BitFlipVector); + uint8x16_t v57 = veorq_u8(vcombine_u8(BytesRow[5], BytesRow[7]), BitFlipVector); + + uint8x16x2_t zw1 = vzipq_u8(v02, v13); + uint16x8x2_t zd1 = + vzipq_u16(vreinterpretq_u16_u8(zw1.val[0]), vreinterpretq_u16_u8(zw1.val[1])); + + uint8x16x2_t zw2 = vzipq_u8(v46, v57); + uint16x8x2_t zd2 = + vzipq_u16(vreinterpretq_u16_u8(zw2.val[0]), vreinterpretq_u16_u8(zw2.val[1])); + + uint32x4x2_t zd3 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[0]), vreinterpretq_u32_u16(zd2.val[0])); + uint32x4x2_t zd4 = + vzipq_u32(vreinterpretq_u32_u16(zd1.val[1]), vreinterpretq_u32_u16(zd2.val[1])); + + vst1q_u8(&D[0], vreinterpretq_u8_u32(zd3.val[0])); + vst1q_u8(&D[16], vreinterpretq_u8_u32(zd3.val[1])); + vst1q_u8(&D[32], vreinterpretq_u8_u32(zd4.val[0])); + vst1q_u8(&D[48], vreinterpretq_u8_u32(zd4.val[1])); + + uint32x4_t ColSums0L_pada = vmovq_n_u32(0); + ColSums0L_pada = vpadalq_u16(ColSums0L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[0]))); + uint32x4_t ColSums0L_ext = vextq_u32(ColSums0L_pada, ColSums0L_pada, 1); + uint32x4_t ColSums0L_add = vaddq_u32(ColSums0L_pada, ColSums0L_ext); + uint32x2_t ColSums0L = {vdups_laneq_u32(ColSums0L_add, 0), vdups_laneq_u32(ColSums0L_add, 2)}; + + uint32x4_t ColSums0H_pada = vmovq_n_u32(0); + ColSums0H_pada = vpadalq_u16(ColSums0H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd3.val[1]))); + uint32x4_t ColSums0H_ext = vextq_u32(ColSums0H_pada, ColSums0H_pada, 1); + uint32x4_t ColSums0H_add = vaddq_u32(ColSums0H_pada, ColSums0H_ext); + uint32x2_t ColSums0H = {vdups_laneq_u32(ColSums0H_add, 0), vdups_laneq_u32(ColSums0H_add, 2)}; + + ColumnSums[0] = vaddq_u32(ColumnSums[0], vcombine_u32(ColSums0L, ColSums0H)); + + uint32x4_t ColSums1L_pada = vmovq_n_u32(0); + ColSums1L_pada = vpadalq_u16(ColSums1L_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[0]))); + uint32x4_t ColSums1L_ext = vextq_u32(ColSums1L_pada, ColSums1L_pada, 1); + uint32x4_t ColSums1L_add = vaddq_u32(ColSums1L_pada, ColSums1L_ext); + uint32x2_t ColSums1L = {vdups_laneq_u32(ColSums1L_add, 0), vdups_laneq_u32(ColSums1L_add, 2)}; + + uint32x4_t ColSums1H_pada = vmovq_n_u32(0); + ColSums1H_pada = vpadalq_u16(ColSums1H_pada, vpaddlq_u8(vreinterpretq_u8_u32(zd4.val[1]))); + uint32x4_t ColSums1H_ext = vextq_u32(ColSums1H_pada, ColSums1H_pada, 1); + uint32x4_t ColSums1H_add = vaddq_u32(ColSums1H_pada, ColSums1H_ext); + uint32x2_t ColSums1H = {vdups_laneq_u32(ColSums1H_add, 0), vdups_laneq_u32(ColSums1H_add, 2)}; + + ColumnSums[1] = vaddq_u32(ColumnSums[1], vcombine_u32(ColSums1L, ColSums1H)); +} + +template <> +void +MlasGemmQuantCopyPackB(MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* D, + const uint8_t* B, + size_t ldb, + size_t CountN, + size_t CountK, + int32_t* ColumnSumBuffer, + bool BIsSigned) +{ + const uint8x16_t BitFlipVector = vdupq_n_u8(BIsSigned ? 0x80 : 0); + uint8x8_t BytesRow[8]; + + // + // Copy data from matrix B into the destination buffer 8x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const uint8_t* b = B; + size_t k = CountK; + uint32x4_t ColumnSums[2]; + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + while (k >= 8) { + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = vld1_u8(&b[ldb * 1]); + BytesRow[2] = vld1_u8(&b[ldb * 2]); + BytesRow[3] = vld1_u8(&b[ldb * 3]); + BytesRow[4] = vld1_u8(&b[ldb * 4]); + BytesRow[5] = vld1_u8(&b[ldb * 5]); + BytesRow[6] = vld1_u8(&b[ldb * 6]); + BytesRow[7] = vld1_u8(&b[ldb * 7]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + b += ldb * 8; + k -= 8; + } + + if (k > 0) { + // Pad k to 8 + + BytesRow[0] = vld1_u8(&b[ldb * 0]); + BytesRow[1] = (k >= 2) ? vld1_u8(&b[ldb * 1]) : vget_low_u8(BitFlipVector); + BytesRow[2] = (k >= 3) ? vld1_u8(&b[ldb * 2]) : vget_low_u8(BitFlipVector); + BytesRow[3] = (k >= 4) ? vld1_u8(&b[ldb * 3]) : vget_low_u8(BitFlipVector); + BytesRow[4] = (k >= 5) ? vld1_u8(&b[ldb * 4]) : vget_low_u8(BitFlipVector); + BytesRow[5] = (k >= 6) ? vld1_u8(&b[ldb * 5]) : vget_low_u8(BitFlipVector); + BytesRow[6] = (k >= 7) ? vld1_u8(&b[ldb * 6]) : vget_low_u8(BitFlipVector); + BytesRow[7] = vget_low_u8(BitFlipVector); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + // Zero pad the output buffer to a multiple of PackedK if the above + // processed an odd number of four row bundles. + // + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + + ColumnSumBuffer += 8; + + B += 8; + CountN -= 8; + } + + // + // Process the remaining columns of matrix B. + // + + if (CountN > 0) { + const uint8_t* b = B; + size_t k = CountK; + uint8_t PaddedMatrixBData[64]; + uint32x4_t ColumnSums[2]; + + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + ColumnSums[0] = vmovq_n_u32(0); + ColumnSums[1] = vmovq_n_u32(0); + + // + // Interleave rows of matrix B using an intermediate zero padded stack + // buffer and write to the packed buffer. + // + + while (k > 0) { + const uint8_t* bcopy0 = &b[ldb * 0]; + const uint8_t* bcopy1 = &b[ldb * 1]; + const uint8_t* bcopy2 = &b[ldb * 2]; + const uint8_t* bcopy3 = &b[ldb * 3]; + const uint8_t* bcopy4 = &b[ldb * 4]; + const uint8_t* bcopy5 = &b[ldb * 5]; + const uint8_t* bcopy6 = &b[ldb * 6]; + const uint8_t* bcopy7 = &b[ldb * 7]; + + if (k >= 8) { + b += ldb * 8; + k -= 8; + + } else { + vst1q_u8(&PaddedMatrixBData[0], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[16], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[32], BitFlipVector); + vst1q_u8(&PaddedMatrixBData[48], BitFlipVector); + + bcopy1 = (k >= 2) ? bcopy1 : &PaddedMatrixBData[56]; + bcopy2 = (k >= 3) ? bcopy2 : &PaddedMatrixBData[56]; + bcopy3 = (k >= 4) ? bcopy3 : &PaddedMatrixBData[56]; + bcopy4 = (k >= 5) ? bcopy4 : &PaddedMatrixBData[56]; + bcopy5 = (k >= 6) ? bcopy5 : &PaddedMatrixBData[56]; + bcopy6 = (k >= 7) ? bcopy6 : &PaddedMatrixBData[56]; + bcopy7 = &PaddedMatrixBData[56]; + + k = 0; + } + + uint8_t* padded = PaddedMatrixBData; + uint8_t* padded_end = padded + CountN; + do { + padded[0] = *bcopy0++; + padded[8] = *bcopy1++; + padded[16] = *bcopy2++; + padded[24] = *bcopy3++; + padded[32] = *bcopy4++; + padded[40] = *bcopy5++; + padded[48] = *bcopy6++; + padded[56] = *bcopy7++; + + } while (++padded < padded_end); + + BytesRow[0] = vld1_u8(&PaddedMatrixBData[0]); + BytesRow[1] = vld1_u8(&PaddedMatrixBData[8]); + BytesRow[2] = vld1_u8(&PaddedMatrixBData[16]); + BytesRow[3] = vld1_u8(&PaddedMatrixBData[24]); + BytesRow[4] = vld1_u8(&PaddedMatrixBData[32]); + BytesRow[5] = vld1_u8(&PaddedMatrixBData[40]); + BytesRow[6] = vld1_u8(&PaddedMatrixBData[48]); + BytesRow[7] = vld1_u8(&PaddedMatrixBData[56]); + + MlasGemmU8X8CopyPackBProcessUmmla(D, BytesRow, BitFlipVector, ColumnSums); + + D += 64; + } + + vst1q_s32(&ColumnSumBuffer[0], vreinterpretq_s32_u32(ColumnSums[0])); + vst1q_s32(&ColumnSumBuffer[4], vreinterpretq_s32_u32(ColumnSums[1])); + } +} + +template <> +MLAS_FORCEINLINE size_t +MlasGemmQuantKernel(const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedAType* A, + const MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedBType* B, + int32_t* C, + size_t PackedCountK, + size_t CountM, + size_t CountN, + size_t ldc, + const int32_t* RowSumBuffer, + const int32_t* ColumnSumBuffer, + const int32_t* ZeroPointB, + bool ZeroMode) +{ + size_t RowsHandled; + + if (ZeroMode) { + RowsHandled = MlasGemmU8X8KernelUmmlaZero(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } else { + RowsHandled = MlasGemmU8X8KernelUmmlaAdd(A, B, C, PackedCountK, CountM, CountN, ldc, + RowSumBuffer, ColumnSumBuffer, ZeroPointB); + } + + return RowsHandled; +} + +const MLAS_GEMM_QUANT_DISPATCH MlasGemmU8X8DispatchUmmla = { + MlasGemmQuantOperation, + MlasGemmQuantPackedOperation, + MlasGemmQuantCopyPackB, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedK, + MLAS_GEMM_U8X8_KERNEL_UMMLA::PackedStrides.K, + 8}; diff --git a/onnxruntime/core/optimizer/graph_transformer_utils.cc b/onnxruntime/core/optimizer/graph_transformer_utils.cc index c4416068e2457..5a441b1d1701e 100644 --- a/onnxruntime/core/optimizer/graph_transformer_utils.cc +++ b/onnxruntime/core/optimizer/graph_transformer_utils.cc @@ -50,6 +50,7 @@ #include "core/optimizer/matmul_integer_to_float.h" #include "core/optimizer/matmul_scale_fusion.h" #include "core/optimizer/matmul_transpose_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/nchwc_transformer.h" #include "core/optimizer/noop_elimination.h" #include "core/optimizer/not_where_fusion.h" @@ -127,6 +128,7 @@ InlinedVector> GenerateRewriteRules( rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); + rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); rules.push_back(std::make_unique()); break; diff --git a/onnxruntime/core/optimizer/initializer.cc b/onnxruntime/core/optimizer/initializer.cc index c8da15f65a6d7..9e807ddc7be59 100644 --- a/onnxruntime/core/optimizer/initializer.cc +++ b/onnxruntime/core/optimizer/initializer.cc @@ -291,7 +291,11 @@ Initializer& Initializer::sqrt() { namespace { template struct ScaleByAxis { - void operator()(Tensor& data, const Tensor& scalers, const size_t block_size, const size_t num_blocks) const { + void operator()(Tensor& data, + const Tensor& scalers, + const size_t block_size, + const size_t num_blocks, + const bool column_major) const { ToNumeric to_numeric; const auto scaler_size = scalers.Shape().Size(); T* dst = data.MutableData(); @@ -303,24 +307,32 @@ struct ScaleByAxis { } } else { for (size_t block_offset = 0, i = 0; i < num_blocks; i++) { - const auto numeric_scaler = to_numeric(scalers_data[i]); - for (size_t j = 0; j < block_size; ++j, ++block_offset) { - dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + if (column_major) { + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + const auto numeric_scaler = to_numeric(scalers_data[j]); + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } + } else { + const auto numeric_scaler = to_numeric(scalers_data[i]); + for (size_t j = 0; j < block_size; ++j, ++block_offset) { + dst[block_offset] = T(to_numeric(dst[block_offset]) * numeric_scaler); + } } } } } }; - } // namespace -void Initializer::scale_by_axis(const Initializer& scalers, int axis) { +void Initializer::scale_by_axis(const Initializer& scalers, int axis, bool column_major) { ORT_ENFORCE(axis >= 0, "Axis must be non-negative"); const size_t block_size = narrow(data_.Shape().SizeFromDimension(gsl::narrow_cast(axis))); const size_t num_blocks = size() / block_size; - ORT_ENFORCE(scalers.size() == 1 || scalers.size() == num_blocks, "Invalid other(scalers) size"); + ORT_ENFORCE(scalers.size() == 1 || + (column_major ? scalers.size() == block_size : scalers.size() == num_blocks), + "Invalid other(scalers) size"); utils::MLTypeCallDispatcher t_disp(data_.GetElementType()); - t_disp.Invoke(data_, scalers.data_, block_size, num_blocks); + t_disp.Invoke(data_, scalers.data_, block_size, num_blocks, column_major); } #endif // ORT_EXTENDED_MINIMAL_BUILD } // namespace onnxruntime diff --git a/onnxruntime/core/optimizer/initializer.h b/onnxruntime/core/optimizer/initializer.h index dfe054ba1aced..78e3fd6a3d24e 100644 --- a/onnxruntime/core/optimizer/initializer.h +++ b/onnxruntime/core/optimizer/initializer.h @@ -86,7 +86,7 @@ class Initializer final { Initializer& sqrt(); - void scale_by_axis(const Initializer& other, int axis); + void scale_by_axis(const Initializer& other, int axis, bool column_major = false); #endif // ORT_EXTENDED_MINIMAL_BUILD private: std::string name_; diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.cc b/onnxruntime/core/optimizer/matmul_bn_fusion.cc new file mode 100644 index 0000000000000..e944522c9c338 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.cc @@ -0,0 +1,230 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/optimizer/matmul_bn_fusion.h" +#include "core/graph/graph_utils.h" +#include "core/optimizer/initializer.h" +#include "core/optimizer/utils.h" + +namespace onnxruntime { + +namespace { +const std::vector>> ignorable_nodes{ + {"Reshape", {1, 5, 13, 14, 19}}, + {"Transpose", {1, 13}}}; +const std::pair> dest = {"BatchNormalization", {1, 6, 7, 9, 14, 15}}; +} // namespace + +bool NodeIsIgnorable(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + const Node* curr_node = graph.GetNode(curr_node_index); + + // curr_node has different execution provider then it's parent or + // has output edge != 1 (this condition will handle the case when ignorable node + // is graph output i.e. a graph like this "MatMul->Transpose") + if (curr_node->GetExecutionProviderType() != root_node.GetExecutionProviderType() || + curr_node->GetOutputEdgesCount() != 1) { + return false; + } + + // curr_node can be any of the ignorable_nodes. + for (size_t index = 0; index < ignorable_nodes.size(); index++) { + if (graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, ignorable_nodes[index].first, ignorable_nodes[index].second)) { + return true; + } + } + + return false; +} + +std::optional MatchPath(const Graph& graph, const Node& root_node, NodeIndex curr_node_index) { + while (NodeIsIgnorable(graph, root_node, curr_node_index)) { + curr_node_index = graph.GetNode(curr_node_index)->OutputNodesBegin()->Index(); + } + + // curr_node is neither ignorable nor dest + const Node* curr_node = graph.GetNode(curr_node_index); + if (curr_node->OpType() != dest.first) { + return std::nullopt; + } + + if (curr_node->GetExecutionProviderType() == root_node.GetExecutionProviderType() && + graph_utils::IsSupportedOptypeVersionAndDomain(*curr_node, dest.first, dest.second)) { + return curr_node_index; + } + + // either curr_node has different execution provider or + // has invalid opset. + return std::nullopt; +} + +/* + * Given a MatMul node, it will verify the following pattern. + * MatMul GEMM + * | | + * Reshape ^ ---> Reshape ^ + * | | + * Transpose ^ Transpose ^ + * | + * BatchNormalization + * Note: ^ means there can be 0 or any occurrences of that node. + * Few example fusable pattern: + * - MatMul -> Reshape -> Transpose -> BatchNormalization ---> GEMM -> Reshape -> Transpose + * - MatMul -> Reshape -> BatchNormalization ---> GEMM -> Reshape + * - MatMul -> Transpose -> BatchNormalization ---> GEMM -> Transpose + * - MatMul -> Reshape -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Reshape + * - MatMul -> Reshape -> Transpose -> Reshape -> BatchNormalization ---> GEMM -> Reshape -> Transpose -> Reshape + * - MatMul -> BatchNormalization ---> GEMM + * Other Conditions: + * - B tensor of MatMul should be constant. + * - scale, B, mean, var tensors of BatchNormalization should be constant. + * - Every node in the path, except the BatchNormalization, should have only 1 output edge. + */ +bool MatmulBNFusion::SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger&) const { + if (!graph_utils::IsSupportedOptypeVersionAndDomain(node, "MatMul", {1, 9, 13}) || + node.GetOutputEdgesCount() != 1) { + return false; + } + + if (graph.NodeProducesGraphOutput(node)) { + return false; + } + + // because is not producing graph output, it means it will have a child node + NodeIndex child_node_index = node.OutputNodesBegin()->Index(); + std::optional batch_norm_index = MatchPath(graph, node, child_node_index); + if (!batch_norm_index.has_value()) { + return false; + } + + const Node* batch_norm_node = graph.GetNode(*batch_norm_index); + + // Check that the appropriate inputs to the Matmul and BN nodes are constants. + if (!graph_utils::NodeArgIsConstant(graph, *node.InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[1]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[2]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[3]) || + !graph_utils::NodeArgIsConstant(graph, *batch_norm_node->InputDefs()[4])) { + return false; + } + + // First output from BN is required. Others are optional. If any optional outputs exist we can't fuse. + const auto& output_defs = batch_norm_node->OutputDefs(); + if (output_defs.size() > 1) { + for (size_t i = 1, end = output_defs.size(); i < end; ++i) { + if (output_defs[i] != nullptr && output_defs[i]->Exists()) { + return false; + } + } + } + + return true; +} + +/* + * BatchNormalization: [https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_batch_normalization_operator_desc] + * Scale * ((Input - Mean) / sqrt(Variance + Epsilon)) + Bias // ignore the FusedActivation in the above definition, that's very specific to DML + * Expanding out the terms: + * Output = (Scale / sqrt(Variance + Epsilon)) * Input + (Scale / sqrt(Variance + Epsilon)) * -Mean + Bias + * Here, + * [Scale/sqrt(Variance + Epsilon)] is constant, and let's call it `alpha` + * [(Scale / sqrt(Variance + Epsilon)) * -Mean + Bias] is also constant, and let's call it `beta` + * Output = alpha * Input + beta, Input = B tensor of MatMul. + * + */ +Status MatmulBNFusion::Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger&) const { + NodeIndex child_node_index = matmul_node.OutputNodesBegin()->Index(); + NodeIndex batch_norm_node_index = MatchPath(graph, matmul_node, child_node_index).value(); + + Node& batch_norm_node = *graph.GetNode(batch_norm_node_index); // need mutable node, that's why extracting node from graph + + // only perform fusion if epsilon is present and is of float_32 type + auto epsilon_attribute = batch_norm_node.GetAttributes().find("epsilon"); + if (epsilon_attribute == batch_norm_node.GetAttributes().end() || + epsilon_attribute->second.type() != ONNX_NAMESPACE::AttributeProto_AttributeType_FLOAT) { + return Status::OK(); + } + const float epsilon = epsilon_attribute->second.f(); + + const onnx::TensorProto* scale_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[1]->Name()); + ORT_ENFORCE(scale_tensor); + const onnx::TensorProto* bias_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[2]->Name()); + ORT_ENFORCE(bias_tensor); + const onnx::TensorProto* mean_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[3]->Name()); + ORT_ENFORCE(mean_tensor); + const onnx::TensorProto* var_tensor = graph_utils::GetConstantInitializer(graph, batch_norm_node.InputDefs()[4]->Name()); + ORT_ENFORCE(var_tensor); + const onnx::TensorProto* matmul_b_tensor = graph_utils::GetConstantInitializer(graph, matmul_node.InputDefs()[1]->Name()); + ORT_ENFORCE(matmul_b_tensor); + + if (!optimizer_utils::IsFloatingPointDataType(*matmul_b_tensor) || + !optimizer_utils::IsFloatingPointDataType(*scale_tensor) || + !optimizer_utils::IsFloatingPointDataType(*bias_tensor) || + !optimizer_utils::IsFloatingPointDataType(*mean_tensor) || + !optimizer_utils::IsFloatingPointDataType(*var_tensor) || + scale_tensor->dims_size() != 1 || + bias_tensor->dims_size() != 1 || + mean_tensor->dims_size() != 1 || + var_tensor->dims_size() != 1 || + scale_tensor->dims(0) != matmul_b_tensor->dims(1) || + bias_tensor->dims(0) != matmul_b_tensor->dims(1) || + mean_tensor->dims(0) != matmul_b_tensor->dims(1) || + var_tensor->dims(0) != matmul_b_tensor->dims(1)) { + return Status::OK(); + } + + /* + * temp = scale / sqrt(var + epsilon) + * output = (temp * Input) - ((temp * mean) + bias) + */ + Initializer scale(*scale_tensor, graph.ModelPath()); + Initializer bias(*bias_tensor, graph.ModelPath()); + Initializer mean(*mean_tensor, graph.ModelPath()); + Initializer var(*var_tensor, graph.ModelPath()); + Initializer matmul_b(*matmul_b_tensor, graph.ModelPath()); + + var.add(epsilon); + var.sqrt(); + scale.div(var); // this is the temp + matmul_b.scale_by_axis(scale, 1, true); + + mean.mul(scale); + bias.sub(mean); + + // create B tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_b_tensor(*matmul_b_tensor); + matmul_b.ToProto(new_gemm_b_tensor); + const std::string new_gemm_b_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmB_" + matmul_b_tensor->name()); + new_gemm_b_tensor.set_name(new_gemm_b_name); + NodeArg& new_gemm_b_node_arg = graph_utils::AddInitializer(graph, new_gemm_b_tensor); + + // create bias tensorProto for new Gemm node from initializer. + ONNX_NAMESPACE::TensorProto new_gemm_bias_tensor(*bias_tensor); + bias.ToProto(new_gemm_bias_tensor); + const std::string new_gemm_bias_name = graph.GenerateNodeArgName("MatMulBnFusion_GemmBias"); + new_gemm_bias_tensor.set_name(new_gemm_bias_name); + NodeArg& new_gemm_bias_node_arg = graph_utils::AddInitializer(graph, new_gemm_bias_tensor); + + Node& gemm_node = graph.AddNode( + graph.GenerateNodeArgName("MatMulBnFusion_Gemm"), + "Gemm", + "Generated from Matmul BatchNormalization fusion", + {matmul_node.MutableInputDefs()[0], &new_gemm_b_node_arg, &new_gemm_bias_node_arg}, + matmul_node.MutableOutputDefs(), + nullptr, + kOnnxDomain); + + // Remove MatMul node. + Node* node = graph.GetNode(matmul_node.Index()); + graph_utils::RemoveNodeOutputEdges(graph, *node); + graph.RemoveNode(matmul_node.Index()); + + // Delete optional empty output defs. + // Delete BatchNormalization node and update the input of the child of BatchNormalization + batch_norm_node.MutableOutputDefs().resize(1); + NodeIndex batch_norm_parent_index = graph.GetNode(child_node_index)->OpType() == "BatchNormalization" ? gemm_node.Index() : batch_norm_node.InputNodesBegin()->Index(); + graph_utils::FinalizeNodeFusion(graph, *graph.GetNode(batch_norm_parent_index), batch_norm_node); + + rule_effect = RewriteRuleEffect::kRemovedCurrentNode; + return Status::OK(); +} +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/matmul_bn_fusion.h b/onnxruntime/core/optimizer/matmul_bn_fusion.h new file mode 100644 index 0000000000000..7a43483cf37d4 --- /dev/null +++ b/onnxruntime/core/optimizer/matmul_bn_fusion.h @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/optimizer/rewrite_rule.h" + +namespace onnxruntime { +/* + * This fusion submerges a BatchNormalization operator to it's super + * precedding MatMul operator, if and only if MatmulBNFusion::SatisfyCondition() + * is true. + */ +class MatmulBNFusion : public RewriteRule { + public: + MatmulBNFusion() : RewriteRule("MatMul_BatchNormalization_Fusion") {} + + std::vector TargetOpTypes() const noexcept override { + return {"MatMul"}; + } + + private: + bool SatisfyCondition(const Graph& graph, const Node& node, const logging::Logger& logger) const override; + + Status Apply(Graph& graph, Node& matmul_node, RewriteRuleEffect& rule_effect, const logging::Logger& logger) const override; +}; +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/optimizer/transformer_memcpy.cc b/onnxruntime/core/optimizer/transformer_memcpy.cc index 07f391f2ae430..ed3e35706b688 100644 --- a/onnxruntime/core/optimizer/transformer_memcpy.cc +++ b/onnxruntime/core/optimizer/transformer_memcpy.cc @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "transformer_memcpy.h" +#include "core/common/logging/logging.h" #include "core/framework/kernel_registry_manager.h" #include "core/framework/execution_providers.h" #include "core/framework/utils.h" @@ -16,12 +17,12 @@ class TransformerMemcpyImpl { TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider) : graph_(graph), provider_(provider) {} - bool ModifyGraph(const KernelRegistryManager& schema_registries); + bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter); private: void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed); void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries); - void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input); + void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger); bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed); private: @@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st // very simple GraphTransformer that uses TransformerMemcpyImpl for each graph // and mainly provides the subgraph recursion functionality -common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const { +common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, + const logging::Logger& logger) const { for (auto& provider : provider_types_) { if (!utils::ProviderIsCpuBased(provider)) { TransformerMemcpyImpl copy_impl(graph, provider); - auto current_modified = copy_impl.ModifyGraph(registry_manager_); + + int copy_node_counter = 0; + auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter); + if (copy_node_counter > 0 && provider == kCudaExecutionProvider) { + LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name() + << " for " << provider + << ". It might have negative impact on performance (including unable to run CUDA graph). " + << "Set session_options.log_severity_level=1 to see the detail logs before this message."; + } + modified = modified || current_modified; break; } @@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different */ -bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) { +bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries, + const logging::Logger& logger, + int& copy_node_counter) { bool modified = false; InitializedTensorSet initializers_consumed; // find defs that require copy @@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // For inputs we need to create a copy node only when the input is connected to both provider // and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job. if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) { - AddCopyNode(const_cast(arg), true); + AddCopyNode(const_cast(arg), true, logger); + copy_node_counter++; modified = true; } for (auto arg : non_provider_output_defs_) if (provider_input_defs_.count(arg)) { - AddCopyNode(arg, true); + AddCopyNode(arg, true, logger); + copy_node_counter++; modified = true; } for (auto arg : provider_output_defs_) if (non_provider_input_defs_.count(arg)) { - AddCopyNode(arg, false); + AddCopyNode(arg, false, logger); + copy_node_counter++; modified = true; } @@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi // (the name will be the same as the parent node's implicit input) const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg); - AddCopyNode(const_cast(node_arg_in_current_graph_level), true); + AddCopyNode(const_cast(node_arg_in_current_graph_level), true, logger); + copy_node_counter++; modified = true; } } @@ -297,7 +314,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co } } -void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) { +void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) { // create unique name for new def std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_); @@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input std::string new_node_name = graph_.GenerateNodeName("Memcpy"); const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost"; + LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name() + << " for " << provider_; + auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory", std::vector{src_arg}, std::vector{dst_arg}); diff --git a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc index 3d03abf5b7ebc..2ca3b1cdf817e 100644 --- a/onnxruntime/core/providers/cpu/cpu_execution_provider.cc +++ b/onnxruntime/core/providers/cpu/cpu_execution_provider.cc @@ -365,7 +365,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, Slice); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 11, Dropout); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 10, NonMaxSuppression); -class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, IsInf); +class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 19, IsInf); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, float, RoiAlign); class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, 15, double, RoiAlign); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 10, ReverseSequence); @@ -682,9 +682,9 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, Ga class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterND); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 15, ScatterElements); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 13, Identity); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, double, IsNaN); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, MLFloat16, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, float, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, double, IsNaN); +class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, 19, MLFloat16, IsNaN); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, bool, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, float, NonZero); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 13, int32_t, NonZero); @@ -960,6 +960,18 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 19, Sh // Opset 20 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, ConstantOfShape); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, AffineGrid); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, float, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, double, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, MLFloat16, IsNaN); +#if !defined(DISABLE_FLOAT8_TYPES) +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FN, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E4M3FNUZ, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2, IsNaN); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, Float8E5M2FNUZ, IsNaN); +#endif +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kOnnxDomain, 20, IsInf); // !!PLEASE READ BELOW!! Following that, add new entries above this comment @@ -1492,7 +1504,7 @@ Status RegisterOnnxOperatorKernels(KernelRegistry& kernel_registry) { Dropout)>, BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#if !defined(DISABLE_FLOAT8_TYPES) + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, +#endif + BuildKernelCreateInfo, }; for (auto& function_table_entry : function_table) { diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc new file mode 100644 index 0000000000000..15900ba553983 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cpu/tensor/affine_grid.h" + +#include "core/common/common.h" +#include "core/providers/op_kernel_type_control.h" +#include "core/util/math_cpuonly.h" +#include +#include "Eigen/src/Core/Map.h" +#include +#include "core/common/eigen_common_wrapper.h" + +namespace onnxruntime { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_CPU_OPERATOR_TYPED_KERNEL( \ + AffineGrid, \ + 20, \ + T, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + AffineGrid); + +REGISTER_KERNEL_TYPED(float) +REGISTER_KERNEL_TYPED(double) + +template +void generate_base_grid_2d(int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + + base_grid.resize(static_cast(H * W), 2); + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(j * static_cast(W) + i) << row_vec(i), col_vec(j); + } + } +} + +template +void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, Eigen::Matrix& base_grid) { + Eigen::VectorXf row_vec = Eigen::VectorXf::LinSpaced(static_cast(W), -1, 1); + if (!align_corners) { + row_vec = row_vec * (W - 1) / W; + } + Eigen::VectorXf col_vec = Eigen::VectorXf::LinSpaced(static_cast(H), -1, 1); + if (!align_corners) { + col_vec = col_vec * (H - 1) / H; + } + Eigen::VectorXf slice_vec = Eigen::VectorXf::LinSpaced(static_cast(D), -1, 1); + if (!align_corners) { + slice_vec = slice_vec * (D - 1) / D; + } + + base_grid.resize(static_cast(D * H * W), 3); + for (Eigen::Index k = 0; k < D; k++) { + for (Eigen::Index j = 0; j < H; j++) { + for (Eigen::Index i = 0; i < W; i++) { + base_grid.row(k * static_cast(H * W) + j * static_cast(W) + i) << row_vec(i), col_vec(j), slice_vec(k); + } + } + } +} + +template +void affine_grid_generator_2d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 2 * 3; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; + const Eigen::Array theta_T(theta_data[2], theta_data[5]); + + auto grid_batch_offset = batch_num * H * W * 2; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +void affine_grid_generator_3d(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 3 * 4; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{ + {theta_data[0], theta_data[1], theta_data[2]}, + {theta_data[4], theta_data[5], theta_data[6]}, + {theta_data[8], theta_data[9], theta_data[10]}}; + const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); + + auto grid_batch_offset = batch_num * D * H * W * 3; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} + +template +Status AffineGrid::Compute(OpKernelContext* context) const { + const Tensor* theta = context->Input(0); + const TensorShape& theta_shape = theta->Shape(); + if (theta_shape.NumDimensions() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3"); + } + + const Tensor* size = context->Input(1); + const TensorShape& size_shape = size->Shape(); + const int64_t* size_data = size->Data(); + + if (size_shape.GetDims()[0] == 4 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], H = size_data[2], W = size_data[3]; + + TensorShape grid_shape{N, H, W, 2}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_2d(H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else if (size_shape.GetDims()[0] == 5 /*&& get_check_2d_grid_sample_consistency(theta_shape, size_shape, N, C, H, W)*/) { + int64_t N = size_data[0], D = size_data[2], H = size_data[3], W = size_data[4]; + + TensorShape grid_shape{N, D, H, W, 3}; + auto grid = context->Output(0, grid_shape); + + Eigen::Matrix base_grid; + generate_base_grid_3d(D, H, W, align_corners_, base_grid); + Eigen::Matrix base_grid_transposed = base_grid.transpose(); + + std::function fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid); + }; + + concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Invalidate size - length of size should be 4 or 5."); + } + return Status::OK(); +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.h b/onnxruntime/core/providers/cpu/tensor/affine_grid.h new file mode 100644 index 0000000000000..5ffe660e986f2 --- /dev/null +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.h @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" + +namespace onnxruntime { + +template +class AffineGrid final : public OpKernel { + public: + AffineGrid(const OpKernelInfo& info) : OpKernel(info) { + int64_t align_corners = info.GetAttrOrDefault("align_corners", 0); + align_corners_ = (align_corners != 0); + } + + Status Compute(OpKernelContext* context) const override; + + private: + bool align_corners_; +}; + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/tensor/isinf.cc b/onnxruntime/core/providers/cpu/tensor/isinf.cc index bc99caa8036cf..1b449f46927a2 100644 --- a/onnxruntime/core/providers/cpu/tensor/isinf.cc +++ b/onnxruntime/core/providers/cpu/tensor/isinf.cc @@ -14,15 +14,38 @@ namespace onnxruntime { // https://github.com/onnx/onnx/blob/main/docs/Operators.md#IsInf namespace op_kernel_type_control { -ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPES_ALL_OPSETS( - kCpuExecutionProvider, kOnnxDomain, IsInf, Input, 0, - float, double); +using IsInfTypesOpset10 = TypeList; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, kOnnxDomain, IsInf, 10, Input, 0, + IsInfTypesOpset10); + +using IsInfTypesOpset20 = + TypeList< + float, + double +#if !defined(DISABLE_FLOAT8_TYPES) + , + Float8E4M3FN, Float8E4M3FNUZ, Float8E5M2, Float8E5M2FNUZ +#endif + >; + +ORT_SPECIFY_OP_KERNEL_ARG_DEFAULT_TYPE_LIST( + kCpuExecutionProvider, + kOnnxDomain, + IsInf, + 20, + Input, + 0, + IsInfTypesOpset20); } // namespace op_kernel_type_control class IsInf final : public OpKernel { public: - using EnabledDataTypes = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST_ALL_OPSETS(kCpuExecutionProvider, kOnnxDomain, - IsInf, Input, 0); + using EnabledDataTypes10 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 10, Input, 0); + using EnabledDataTypes20 = ORT_OP_KERNEL_ARG_ENABLED_TYPE_LIST(kCpuExecutionProvider, kOnnxDomain, + IsInf, 20, Input, 0); explicit IsInf(const OpKernelInfo& info); Status Compute(OpKernelContext* context) const override; @@ -30,14 +53,25 @@ class IsInf final : public OpKernel { private: int64_t detect_positive_{1}; int64_t detect_negative_{1}; + int opset_; }; -ONNX_CPU_OPERATOR_KERNEL( +ONNX_CPU_OPERATOR_VERSIONED_KERNEL( IsInf, 10, + 19, KernelDefBuilder() .TypeConstraint("T1", - BuildKernelDefConstraintsFromTypeList()) + BuildKernelDefConstraintsFromTypeList()) + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), + IsInf); + +ONNX_CPU_OPERATOR_KERNEL( + IsInf, + 20, + KernelDefBuilder() + .TypeConstraint("T1", + BuildKernelDefConstraintsFromTypeList()) .TypeConstraint("T2", DataTypeImpl::GetTensorType()), IsInf); @@ -46,6 +80,7 @@ IsInf::IsInf(const OpKernelInfo& info) : OpKernel(info) { ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_positive"); status = info.GetAttr("detect_negative", &detect_negative_); ORT_ENFORCE(status.IsOK(), "Failed to obtain detect_negative"); + opset_ = info.node().SinceVersion(); } namespace isinf_internal { @@ -78,6 +113,49 @@ struct ComputeDispatchTarget { } } }; + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor& X, Tensor& Y, bool detect_positive, bool detect_negative) const { + auto& dims = X.Shape(); + auto input = ConstEigenVectorMap(static_cast(static_cast(X.Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.00 + if (detect_positive && detect_negative) { + output.array() = input.array() == 0b01111100 || input.array() == 0b11111100; + } else if (detect_positive) { + output.array() = input.array() == 0b01111100; + } else if (detect_negative) { + output.array() = input.array() == 0b11111100; + } else { + output.array() = false; + } + } +}; + +template <> +struct ComputeDispatchTarget { + void operator()(const Tensor&, Tensor& Y, bool, bool) const { + EigenMap(Y).array() = false; + } +}; +#endif } // namespace isinf_internal Status IsInf::Compute(OpKernelContext* context) const { @@ -88,8 +166,13 @@ Status IsInf::Compute(OpKernelContext* context) const { using namespace isinf_internal; - utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; - dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + if (opset_ < 20) { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } else { + utils::MLTypeCallDispatcherFromTypeList dispatcher{X.GetElementType()}; + dispatcher.Invoke(X, Y, detect_positive_ != 0, detect_negative_ != 0); + } return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/tensor/isnan.cc b/onnxruntime/core/providers/cpu/tensor/isnan.cc index 33d0f8eb6c1ae..34495e382278a 100644 --- a/onnxruntime/core/providers/cpu/tensor/isnan.cc +++ b/onnxruntime/core/providers/cpu/tensor/isnan.cc @@ -20,10 +20,20 @@ namespace onnxruntime { .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ IsNaN); +#define ADD_TYPED_ISNAN_OP_13(data_type) \ + ONNX_CPU_OPERATOR_VERSIONED_TYPED_KERNEL( \ + IsNaN, \ + 13, 19, \ + data_type, \ + KernelDefBuilder() \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + IsNaN); + #define ADD_TYPED_ISNAN_OP(data_type) \ ONNX_CPU_OPERATOR_TYPED_KERNEL( \ IsNaN, \ - 13, \ + 20, \ data_type, \ KernelDefBuilder() \ .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ @@ -33,10 +43,20 @@ namespace onnxruntime { ADD_TYPED_ISNAN_OP_9(float); ADD_TYPED_ISNAN_OP_9(double); ADD_TYPED_ISNAN_OP_9(MLFloat16); +ADD_TYPED_ISNAN_OP_13(float); +ADD_TYPED_ISNAN_OP_13(double); +ADD_TYPED_ISNAN_OP_13(MLFloat16); ADD_TYPED_ISNAN_OP(float); ADD_TYPED_ISNAN_OP(double); ADD_TYPED_ISNAN_OP(MLFloat16); +#if !defined(DISABLE_FLOAT8_TYPES) +ADD_TYPED_ISNAN_OP(Float8E4M3FN); +ADD_TYPED_ISNAN_OP(Float8E4M3FNUZ); +ADD_TYPED_ISNAN_OP(Float8E5M2); +ADD_TYPED_ISNAN_OP(Float8E5M2FNUZ); +#endif + template Status IsNaN::Compute(OpKernelContext* context) const { const auto* X_ptr = context->Input(0); @@ -70,4 +90,63 @@ Status IsNaN::Compute(OpKernelContext* context) const { return Status::OK(); } + +#if !defined(DISABLE_FLOAT8_TYPES) +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.1111.111 + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return (c & 0x7f) == 0x7f; }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = + ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto& dims = X->Shape(); + auto& Y = *context->Output(0, dims); + + auto input = ConstEigenVectorMap(static_cast(static_cast(X->Data())), onnxruntime::narrow(dims.Size())); + auto output = EigenMap(Y); + + // S.11111.{01, 10, 11} + std::transform(input.begin(), input.end(), output.begin(), [](uint8_t c) { return ((c & 0x7c) == 0x7c) && ((c & 0x03) != 0x00); }); + return Status::OK(); +} + +template <> +Status IsNaN::Compute(OpKernelContext* context) const { + const auto* X = context->Input(0); + auto X_data = X->Data(); + auto& dims = X->Shape(); + auto shape_size = dims.Size(); + auto& Y = *context->Output(0, dims); + + // 1.0000.000 + EigenMap(Y) = ConstEigenVectorMap(static_cast(static_cast(X_data)), onnxruntime::narrow(shape_size)).array() == 0x80; + + return Status::OK(); +} +#endif } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc index e855a515f445a..5f1dbd30f6a3e 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.cc +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.cc @@ -7,6 +7,25 @@ namespace onnxruntime { +DeferredCpuAllocator::DeferredCpuAllocator(CudaStream& cuda_stream) : cuda_stream_(cuda_stream) { + OrtAllocator::version = ORT_API_VERSION; + OrtAllocator::Alloc = + [](OrtAllocator* this_, size_t size) { + auto self = reinterpret_cast(this_); + return self->cuda_stream_.GetCpuAllocator()->Alloc(size); + }; + OrtAllocator::Free = + [](OrtAllocator* this_, void* p) { + auto self = reinterpret_cast(this_); + self->cuda_stream_.EnqueDeferredCPUBuffer(p); + }; + OrtAllocator::Info = + [](const OrtAllocator* this_) { + auto self = reinterpret_cast(this_); + return &self->cuda_stream_.GetCpuAllocator()->Info(); + }; +} + struct CudaNotification : public synchronize::Notification { CudaNotification(Stream& s) : Notification(s) { CUDA_CALL_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); @@ -46,7 +65,8 @@ CudaStream::CudaStream(cudaStream_t stream, cublasHandle_t external_cublas_handle) : Stream(stream, device), own_stream_(own_flag), cpu_allocator_(cpu_allocator), - release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream) { + release_cpu_buffer_on_cuda_stream_(release_cpu_buffer_on_cuda_stream), + deferred_cpu_allocator_(*this) { if (own_flag) { CUBLAS_CALL_THROW(cublasCreate(&cublas_handle_)); CUBLAS_CALL_THROW(cublasSetStream(cublas_handle_, stream)); @@ -162,6 +182,9 @@ void* CudaStream::GetResource(int version, int id) const { case CudaResource::cublas_handle_t: return reinterpret_cast(cublas_handle_); break; + case CudaResource::deferred_cpu_allocator_t: + return const_cast(&deferred_cpu_allocator_); + break; default: break; } diff --git a/onnxruntime/core/providers/cuda/cuda_stream_handle.h b/onnxruntime/core/providers/cuda/cuda_stream_handle.h index 9c62b029b7a36..917702fae08f1 100644 --- a/onnxruntime/core/providers/cuda/cuda_stream_handle.h +++ b/onnxruntime/core/providers/cuda/cuda_stream_handle.h @@ -9,6 +9,13 @@ namespace onnxruntime { +struct CudaStream; + +struct DeferredCpuAllocator : public OrtAllocator { + DeferredCpuAllocator(CudaStream&); + CudaStream& cuda_stream_; +}; + struct CudaStream : Stream { CudaStream(cudaStream_t stream, const OrtDevice& device, @@ -36,10 +43,13 @@ struct CudaStream : Stream { void* GetResource(int version, int id) const override; + onnxruntime::IAllocator* GetCpuAllocator() const { return cpu_allocator_.get(); } + private: std::vector deferred_cpu_buffers_; AllocatorPtr cpu_allocator_; bool release_cpu_buffer_on_cuda_stream_{true}; + DeferredCpuAllocator deferred_cpu_allocator_; }; void RegisterCudaStreamHandles(IStreamCommandHandleRegistry& stream_handle_registry, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h index 52018500b134c..cdb0338157561 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h @@ -3,6 +3,9 @@ #pragma once interface IMLOperatorRegistry; +interface IDMLDevice; +interface ID3D12CommandQueue; +interface ID3D12Resource; #include "core/common/status.h" #include "core/framework/data_transfer.h" @@ -28,7 +31,8 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr); void FlushContext(onnxruntime::IExecutionProvider* provider); diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h index 04381b6ce355c..074f13b309181 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h @@ -7,11 +7,14 @@ #include #include #include +#include #include "core/framework/op_kernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" struct AbstractOperatorDesc; interface IMLOperatorTensor; +interface IDMLOperator; struct DML_INPUT_GRAPH_EDGE_DESC; struct DML_OUTPUT_GRAPH_EDGE_DESC; struct DML_INTERMEDIATE_GRAPH_EDGE_DESC; @@ -92,6 +95,8 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo )>; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp index ede3e7f2c2257..eb068087de4ad 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/AbiCustomRegistry.cpp @@ -491,6 +491,8 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( const onnxruntime::Node& node, MLOperatorTensorGetter& constantInputGetter, const void* executionHandle, + const EdgeShapes* inputShapesOverrides, + /*out*/ EdgeShapes* outputShapes, /*out*/ DmlGraphNodeCreateInfo* graphNodeCreateInfo ) { @@ -498,15 +500,15 @@ HRESULT STDMETHODCALLTYPE AbiCustomRegistry::RegisterOperatorKernel( onnxruntime::OpNodeProtoHelper protoHelper(&nodeContext); // Use the same list of required constant inputs for the shape inferrer and the kernel. - EdgeShapes outputShapes; - InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, nullptr, outputShapes); + InferAndVerifyOutputSizes(node, &defaultAttributesCapture, shapeInferrerCapture.Get(), constantCpuInputCapture, constantInputGetter, inputShapesOverrides, *outputShapes); // Create the kernel while allowing input shape and output shape queries according to options ComPtr kernelInfoWrapper = wil::MakeOrThrow( &protoHelper, executionHandle, true, - &outputShapes, + inputShapesOverrides, + outputShapes, &defaultAttributesCapture, graphNodeCreateInfo, constantCpuInputCapture, diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h new file mode 100644 index 0000000000000..5ff70493252bd --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace Windows::AI::MachineLearning::Adapter +{ + // edges and unused edges have an empty array of dimensions. + class EdgeShapes + { + public: + EdgeShapes() = default; + + EdgeShapes(size_t count) : m_shapes(count) {} + + const std::vector& GetShape(size_t edgeIndex) const + { + return m_shapes[edgeIndex]; + } + + std::vector& GetMutableShape(size_t edgeIndex) + { + return m_shapes[edgeIndex]; + } + + size_t EdgeCount() const { return m_shapes.size(); } + + void Reset(size_t edge_count) + { + m_shapes.clear(); + m_shapes.resize(edge_count); + } + + bool operator!=(const EdgeShapes& other) const noexcept + { + return (m_shapes != other.m_shapes); + } + + private: + std::vector> m_shapes; + }; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp index 51b93efb3a646..cd74e7fa92940 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.cpp @@ -1,7 +1,7 @@ #pragma once #include "DmlGraphFusionHelper.h" - +#include "DmlRuntimeFusedGraphKernel.h" namespace Dml { @@ -501,5 +501,171 @@ namespace DmlGraphFusionHelper graph.FinalizeFuseSubGraph(indexedSubGraph, fusedNode); } + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable) + { + struct NodeInfo + { + std::string name; + std::string opType; + std::string description; + std::string domain; + onnxruntime::NodeAttributes attributes; + std::vector inputDefPointers; + std::vector outputDefPointers; + }; + + auto partitionNodePropsMap = DmlGraphFusionHelper::CreatePartitionNodePropsMap( + graph, + *indexedSubGraph, + std::move(graphNodePropertyMap)); + + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph->GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph->GetMetaDef()->outputs; + + std::vector nodesInfo; + nodesInfo.reserve(indexedSubGraph->nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + std::vector nodeAttributes; + nodeAttributes.reserve(indexedSubGraph->nodes.size()); + + std::vector> intermediateNodeArgs; + + for (size_t sortedNodeIndex : indexedSubGraph->nodes) + { + auto node = graph.GetNode(sortedNodeIndex); + + nodeAttributes.push_back(node->GetAttributes()); + + NodeInfo nodeInfo{}; + nodeInfo.name = node->Name(); + nodeInfo.opType = node->OpType(); + nodeInfo.description = node->Description(); + nodeInfo.domain = node->Domain(); + nodeInfo.attributes = node->GetAttributes(); + nodeInfo.inputDefPointers.reserve(node->InputDefs().size()); + nodeInfo.outputDefPointers.reserve(node->OutputDefs().size()); + + for (const onnxruntime::NodeArg* inputDef : node->InputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(inputDef->Name(), inputDef->TypeAsProto())); + nodeInfo.inputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + for (const onnxruntime::NodeArg* outputDef : node->OutputDefs()) + { + intermediateNodeArgs.emplace_back(std::make_shared(outputDef->Name(), outputDef->TypeAsProto())); + nodeInfo.outputDefPointers.push_back(intermediateNodeArgs.back().get()); + } + + nodesInfo.push_back(std::move(nodeInfo)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + + // We need to keep the initializers alive since they will be freed once the nodes are removed from the graph + std::vector ownedInitializers; + ownedInitializers.reserve(isInitializerTransferable.size()); + + for (auto& kvp : isInitializerTransferable) + { + ONNX_NAMESPACE::TensorProto tensorProto; + tensorProto.set_data_type(kvp.second.first->data_type()); + tensorProto.set_raw_data(kvp.second.first->raw_data()); + tensorProto.set_name(kvp.second.first->name()); + + for (int i = 0; i < kvp.second.first->dims_size(); ++i) + { + tensorProto.add_dims(kvp.second.first->dims(i)); + } + ownedInitializers.push_back(std::move(tensorProto)); + kvp.second.first = &ownedInitializers.back(); + } + + // lamda captures for the kernel registration + auto fused_kernel_func = [ + indexedSubGraph, + &modelPath, + nodesInfo = std::move(nodesInfo), + intermediateNodeArgs = std::move(intermediateNodeArgs), + subgraphInputs = std::move(subgraphInputs), + subgraphOutputs = std::move(subgraphOutputs), + partitionNodePropsMap = std::move(partitionNodePropsMap), + ownedInitializers = std::move(ownedInitializers)] (onnxruntime::FuncManager& func_mgr, const onnxruntime::OpKernelInfo& info, std::unique_ptr& out) mutable ->onnxruntime::Status + { + std::vector> subgraphNodes; + subgraphNodes.reserve(nodesInfo.size()); + + for (const NodeInfo& nodeInfo : nodesInfo) + { + subgraphNodes.emplace_back(std::make_shared( + nodeInfo.name, + nodeInfo.opType, + nodeInfo.description, + nodeInfo.inputDefPointers, + nodeInfo.outputDefPointers, + &nodeInfo.attributes, + nodeInfo.domain)); + } + + out.reset(CreateRuntimeFusedGraphKernel( + info, + indexedSubGraph, + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers))); + return Status::OK(); + }; + + // build the kernel definition on the fly, and register it to the fused_kernel_regisitry. + onnxruntime::KernelDefBuilder builder; + builder.SetName(indexedSubGraph->GetMetaDef()->name) + .SetDomain(indexedSubGraph->GetMetaDef()->domain) + .SinceVersion(indexedSubGraph->GetMetaDef()->since_version) + .Provider(onnxruntime::kDmlExecutionProvider); + + // Force the CPU inputs to be allocated on the CPU + for (int i = 0; i < subGraphInputArgNames.size(); ++i) + { + if (dynamicCpuInputMap.find(subGraphInputArgNames[i]) != dynamicCpuInputMap.end()) + { + builder.InputMemoryType(OrtMemTypeCPUInput, i); + } + } + + ORT_THROW_IF_ERROR(registryForPartitionKernels->Register(builder, fused_kernel_func)); + + auto& fusedNode = graph.BeginFuseSubGraph(*indexedSubGraph, indexedSubGraph->GetMetaDef()->name); + fusedNode.SetExecutionProviderType(onnxruntime::kDmlExecutionProvider); + + graph.FinalizeFuseSubGraph(*indexedSubGraph, fusedNode); + } } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h index 030cffc2a8794..f8f6162aaa1e0 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h @@ -80,5 +80,14 @@ namespace DmlGraphFusionHelper std::vector&& isInputsUploadedByDmlEP, const GraphDescBuilder::GraphDesc& graphDesc, Microsoft::WRL::ComPtr compiledExecutionPlanOperator); + + void RegisterDynamicKernel( + onnxruntime::Graph& graph, + onnxruntime::KernelRegistry* registryForPartitionKernels, + const ExecutionProviderImpl* providerImpl, + std::unordered_map graphNodePropertyMap, + const std::unordered_set& dynamicCpuInputMap, + std::shared_ptr indexedSubGraph, + std::unordered_map>&& isInitializerTransferable); } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp index 4813707cdf50c..679738b639ec9 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.cpp @@ -15,6 +15,18 @@ namespace Dml { + namespace + { + struct CompiledPartitionInfo + { + Microsoft::WRL::ComPtr compiledOperator; + onnxruntime::IndexedSubGraph indexedSubGraph; + std::vector isInputsUploadedByDmlEP; + GraphDescBuilder::GraphDesc graphDesc; + std::unordered_map> isInitializerTransferable; + }; + } + DmlGraphFusionTransformer::DmlGraphFusionTransformer( const std::string& name, const onnxruntime::IExecutionProvider* provider @@ -24,15 +36,6 @@ namespace Dml { } - struct CompiledPartitionInfo - { - Microsoft::WRL::ComPtr compiledOperator; - onnxruntime::IndexedSubGraph indexedSubGraph; - std::vector isInputsUploadedByDmlEP; - GraphDescBuilder::GraphDesc graphDesc; - std::unordered_map> isInitializerTransferable; - }; - onnxruntime::common::Status DmlGraphFusionTransformer::ApplyImpl( onnxruntime::Graph& graph, bool& modified, @@ -87,6 +90,7 @@ namespace Dml { // Initializers needed by any graph partition std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; std::unordered_map graphNodePropertyMap; onnxruntime::GraphViewer graphViewer(graph); std::vector> partitions = BuildPartitions( @@ -96,8 +100,10 @@ namespace Dml m_providerImpl->GetSupportedDeviceDataTypeMask(), graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, additionalSplittingNodes, - implicitInputDefs); + implicitInputDefs, + false); // Reset the splitting nodes for the current iteration additionalSplittingNodes.clear(); @@ -190,17 +196,48 @@ namespace Dml std::move(graphNodePropertyMap)); // Convert partitionONNXGraph into DML EP GraphDesc + auto modelPath = graph.ModelPath(); + + const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; + const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; + + std::vector subgraphNodes; + subgraphNodes.reserve(indexedSubGraph.nodes.size()); + + std::vector subgraphInputs; + subgraphInputs.reserve(subGraphInputArgNames.size()); + + std::vector subgraphOutputs; + subgraphOutputs.reserve(subGraphOutputArgNames.size()); + + for (size_t sortedNodeIndex : indexedSubGraph.nodes) + { + subgraphNodes.push_back(graph.GetNode(sortedNodeIndex)); + } + + for (const std::string& graphInputName : subGraphInputArgNames) + { + subgraphInputs.push_back(graph.GetNodeArg(graphInputName)); + } + + for (const std::string& graphOutputName : subGraphOutputArgNames) + { + subgraphOutputs.push_back(graph.GetNodeArg(graphOutputName)); + } + ComPtr device; ORT_THROW_IF_FAILED(m_providerImpl->GetDmlDevice(device.GetAddressOf())); GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( isInputsUploadedByDmlEP.data(), isInputsUploadedByDmlEP.size(), isInitializerTransferable, - graph, - indexedSubGraph, partitionNodePropsMap, device.Get(), - m_providerImpl); + m_providerImpl, + modelPath, + subgraphNodes, + subgraphInputs, + subgraphOutputs); // Compile the operator auto compiledPartition = DmlGraphFusionHelper::TryCreateCompiledOperator( diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp new file mode 100644 index 0000000000000..1db22ac92e527 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.cpp @@ -0,0 +1,369 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "precomp.h" + +#include "core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionHelper.h" + +using namespace Windows::AI::MachineLearning::Adapter; + +namespace Dml +{ + class DmlRuntimeFusedGraphKernel : public onnxruntime::OpKernel + { + public: + DmlRuntimeFusedGraphKernel() = delete; + + DmlRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& kernelInfo, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + : OpKernel(kernelInfo), + m_indexedSubGraph(std::move(indexedSubGraph)), + m_modelPath(modelPath), + m_subgraphNodes(std::move(subgraphNodes)), + m_subgraphInputs(std::move(subgraphInputs)), + m_subgraphOutputs(std::move(subgraphOutputs)), + m_intermediateNodeArgs(std::move(intermediateNodeArgs)), + m_partitionNodePropsMap(std::move(partitionNodePropsMap)), + m_ownedInitializers(std::move(ownedInitializers)) + { + for (const auto& initializer : m_ownedInitializers) + { + m_isInitializerTransferable[initializer.name()] = std::make_pair(&initializer, false); + } + + // Get the execution provider interfaces + auto executionHandle = kernelInfo.GetExecutionProvider()->GetExecutionHandle(); + if (executionHandle) + { + // We assume the execution object inherits IUnknown as its first base + ComPtr providerExecutionObject = const_cast(static_cast(executionHandle)); + + // Get the WinML-specific execution provider interface from the execution object. + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_provider)); + ORT_THROW_IF_FAILED(providerExecutionObject.As(&m_winmlProvider)); + } + + m_subgraphNodePointers.reserve(m_subgraphNodes.size()); + + for (auto& subgraphNode : m_subgraphNodes) + { + m_subgraphNodePointers.push_back(subgraphNode.get()); + } + } + + void TranslateAndCompileGraph( + const onnxruntime::OpKernelInfo& kernelInfo, + std::vector>& initializeResourceRefs, + std::vector initInputBindings) const + { + // Allocate a persistent resource and initialize the operator + UINT64 persistentResourceSize = m_compiledExecutionPlanOperator->GetBindingProperties().PersistentResourceSize; + if (persistentResourceSize > 0) + { + ORT_THROW_IF_FAILED(m_provider->AllocatePooledResource( + static_cast(persistentResourceSize), + AllocatorRoundingMode::Disabled, + m_persistentResource.ReleaseAndGetAddressOf(), + m_persistentResourceAllocatorUnk.ReleaseAndGetAddressOf())); + + m_persistentResourceBinding = DML_BUFFER_BINDING { m_persistentResource.Get(), 0, persistentResourceSize }; + } + + ORT_THROW_IF_FAILED(m_provider->InitializeOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + gsl::make_span(initInputBindings))); + + std::for_each( + initializeResourceRefs.begin(), + initializeResourceRefs.end(), + [&](ComPtr& resource){ m_winmlProvider->QueueReference(WRAP_GRAPHICS_UNKNOWN(resource).Get()); } + ); + } + + onnxruntime::Status Compute(onnxruntime::OpKernelContext* kernelContext) const override + { + ORT_THROW_HR_IF(E_UNEXPECTED, m_subgraphInputs.size() != kernelContext->InputCount()); + + bool recompileNeeded = m_compiledExecutionPlanOperator == nullptr; + + for (int inputIndex = 0; inputIndex < kernelContext->InputCount(); ++inputIndex) + { + const auto& input = kernelContext->RequiredInput(inputIndex); + const std::string& inputName = m_subgraphInputs[inputIndex]->Name(); + auto shapeIter = m_inferredInputShapes.find(inputName); + + if (shapeIter == m_inferredInputShapes.end()) + { + m_inferredInputShapes[inputName] = input.Shape(); + recompileNeeded = true; + } + else if (shapeIter->second != input.Shape()) + { + shapeIter->second = input.Shape(); + recompileNeeded = true; + } + + // If we have CPU inputs that are not initializers (i.e. they were computed at runtime), add them to the initializer list + if (input.Location().device.Type() == OrtDevice::CPU) + { + auto inputProto = onnxruntime::utils::TensorToTensorProto(input, inputName); + + // We can only avoid recompiling the graph when all CPU inputs are identical + auto initializerIter = m_isInitializerTransferable.find(inputName); + + if (initializerIter != m_isInitializerTransferable.end()) + { + if (initializerIter->second.first->raw_data().length() == inputProto.raw_data().length()) + { + for (int i = 0; i < inputProto.raw_data().length(); ++i) + { + if (initializerIter->second.first->raw_data()[i] != inputProto.raw_data()[i]) + { + recompileNeeded = true; + break; + } + } + } + else + { + recompileNeeded = true; + } + } + else + { + recompileNeeded = true; + } + + m_ownedCpuInputs.push_back(std::make_unique(std::move(inputProto))); + m_isInitializerTransferable[inputName] = std::make_pair(m_ownedCpuInputs.back().get(), false); + } + } + + if (recompileNeeded) + { + // Go through all the node args and replace their shapes with the real ones + for (auto& nodeArg : m_intermediateNodeArgs) + { + auto iter = m_inferredInputShapes.find(nodeArg->Name()); + if (iter != m_inferredInputShapes.end()) + { + auto tensorShape = *nodeArg->Shape(); + ORT_THROW_HR_IF(E_UNEXPECTED, tensorShape.dim_size() != iter->second.NumDimensions()); + + for (int i = 0; i < tensorShape.dim_size(); ++i) + { + tensorShape.mutable_dim(i)->set_dim_value(iter->second.GetDims()[i]); + } + + nodeArg->SetShape(tensorShape); + } + } + + // Populate input bindings for operator initialization + const uint32_t fusedNodeInputCount = gsl::narrow_cast(m_indexedSubGraph->GetMetaDef()->inputs.size()); + std::vector> initializeResourceRefs; // For lifetime control + std::vector initInputBindings(fusedNodeInputCount); + std::vector isInputsUploadedByDmlEP(fusedNodeInputCount); + auto providerImpl = static_cast(Info().GetExecutionProvider())->GetImpl(); + + // Convert partitionONNXGraph into DML EP GraphDesc + ComPtr device; + ORT_THROW_IF_FAILED(providerImpl->GetDmlDevice(device.GetAddressOf())); + GraphDescBuilder::GraphDesc graphDesc = GraphDescBuilder::BuildGraphDesc( + isInputsUploadedByDmlEP.data(), + isInputsUploadedByDmlEP.size(), + m_isInitializerTransferable, + m_partitionNodePropsMap, + device.Get(), + providerImpl, + m_modelPath, + m_subgraphNodePointers, + m_subgraphInputs, + m_subgraphOutputs); + + m_outputShapes = graphDesc.outputShapes; + + // Walk through each graph edge and mark used inputs + m_inputsUsed.resize(fusedNodeInputCount, false); + for (const DML_INPUT_GRAPH_EDGE_DESC& edge : graphDesc.inputEdges) + { + m_inputsUsed[edge.GraphInputIndex] = true; + } + + // Compile the operator + m_compiledExecutionPlanOperator = DmlGraphFusionHelper::TryCreateCompiledOperator( + graphDesc, + *m_indexedSubGraph, + providerImpl); + + // Queue references to objects which must be kept alive until resulting GPU work completes + m_winmlProvider->QueueReference(m_compiledExecutionPlanOperator.Get()); + + TranslateAndCompileGraph( + Info(), + initializeResourceRefs, + initInputBindings); + } + + // Wrap tensors as required by Dml::IExecutionProvider::ExecuteOperator + OpKernelContextWrapper contextWrapper( + kernelContext, + Info().GetExecutionProvider(), + true, + nullptr); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + // Get input resources for execution, excluding those which were specified as owned by DML and provided + // at initialization instead. + std::vector> inputTensors(kernelContext->InputCount()); + std::vector inputPtrs(kernelContext->InputCount()); + + for (int i = 0; i < kernelContext->InputCount(); ++i) + { + if (!m_inputsUsed[i]) + { + continue; + } + + ORT_THROW_IF_FAILED(contextWrapper.GetInputTensor(i, inputTensors[i].GetAddressOf())); + inputPtrs[i] = m_provider->DecodeResource(MLOperatorTensor(inputTensors[i].Get()).GetDataInterface().Get()); + } + + auto outputTensors = contextWrapper.GetOutputTensors(m_outputShapes); + ExecuteOperator( + m_compiledExecutionPlanOperator.Get(), + m_persistentResourceBinding ? &*m_persistentResourceBinding : nullptr, + inputPtrs, + outputTensors); + + ORT_THROW_IF_FAILED(m_provider->AddUAVBarrier()); + + return onnxruntime::Status::OK(); + } + + void ExecuteOperator( + IDMLCompiledOperator* op, + _In_opt_ const DML_BUFFER_BINDING* persistentResourceBinding, + gsl::span inputTensors, + gsl::span outputTensors) const + { + auto FillBindingsFromTensors = [this](auto& bufferBindings, auto& bindingDescs, gsl::span& tensors) + { + for (IMLOperatorTensor* tensor : tensors) + { + if (tensor) + { + assert(tensor->IsDataInterface()); + ID3D12Resource* resource = m_provider->DecodeResource(MLOperatorTensor(tensor).GetDataInterface().Get()); + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + auto FillBindingsFromBuffers = [](auto& bufferBindings, auto& bindingDescs, gsl::span& resources) + { + for (ID3D12Resource* resource : resources) + { + if (resource) + { + D3D12_RESOURCE_DESC resourceDesc = resource->GetDesc(); + bufferBindings.push_back({ resource, 0, resourceDesc.Width }); + bindingDescs.push_back({ DML_BINDING_TYPE_BUFFER, &bufferBindings.back() }); + } + else + { + bufferBindings.push_back({ nullptr, 0, 0 }); + bindingDescs.push_back({ DML_BINDING_TYPE_NONE, nullptr }); + } + } + }; + + std::vector inputBufferBindings; + inputBufferBindings.reserve(inputTensors.size()); + std::vector inputBindings; + inputBindings.reserve(inputTensors.size()); + FillBindingsFromBuffers(inputBufferBindings, inputBindings, inputTensors); + + std::vector outputBufferBindings; + outputBufferBindings.reserve(outputTensors.size()); + std::vector outputBindings; + outputBindings.reserve(outputTensors.size()); + FillBindingsFromTensors(outputBufferBindings, outputBindings, outputTensors); + + ORT_THROW_IF_FAILED(m_provider->ExecuteOperator( + op, + persistentResourceBinding, + inputBindings, + outputBindings)); + } + + private: + ComPtr m_winmlProvider; + ComPtr m_provider; + + mutable std::optional m_persistentResourceBinding; + std::shared_ptr m_indexedSubGraph; + const onnxruntime::Path& m_modelPath; + + std::vector> m_subgraphNodes; + std::vector m_subgraphInputs; + std::vector m_subgraphOutputs; + mutable std::vector> m_intermediateNodeArgs; + std::unordered_map m_partitionNodePropsMap; + std::vector m_ownedInitializers; + mutable std::unordered_map> m_isInitializerTransferable; + std::vector m_subgraphNodePointers; + + // Bindings from previous executions of a re-used command list + mutable std::vector> m_ownedCpuInputs; + mutable ComPtr m_compiledExecutionPlanOperator; + mutable std::vector m_inputsUsed; + mutable ComPtr m_persistentResource; + mutable ComPtr m_persistentResourceAllocatorUnk; // Controls when the persistent resource is returned to the allocator + mutable Windows::AI::MachineLearning::Adapter::EdgeShapes m_outputShapes; + mutable std::unordered_map m_inferredInputShapes; + }; + + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers) + { + return new DmlRuntimeFusedGraphKernel( + info, + std::move(indexedSubGraph), + modelPath, + std::move(subgraphNodes), + std::move(subgraphInputs), + std::move(subgraphOutputs), + std::move(intermediateNodeArgs), + std::move(partitionNodePropsMap), + std::move(ownedInitializers) + ); + } +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h new file mode 100644 index 0000000000000..d679c5aa5667c --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeFusedGraphKernel.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/op_kernel.h" +#include "GraphDescBuilder.h" +#include "DmlRuntimeGraphFusionTransformer.h" + +namespace Dml +{ + onnxruntime::OpKernel* CreateRuntimeFusedGraphKernel( + const onnxruntime::OpKernelInfo& info, + std::shared_ptr indexedSubGraph, + const onnxruntime::Path& modelPath, + std::vector>&& subgraphNodes, + std::vector&& subgraphInputs, + std::vector&& subgraphOutputs, + std::vector>&& intermediateNodeArgs, + std::unordered_map&& partitionNodePropsMap, + std::vector&& ownedInitializers + ); +} // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp new file mode 100644 index 0000000000000..6318b0d5e2865 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.cpp @@ -0,0 +1,161 @@ +#pragma once + +#include "precomp.h" +#include "GraphDescBuilder.h" +#include "ExecutionProvider.h" +#include "DmlRuntimeGraphFusionTransformer.h" +#include "GraphPartitioner.h" +#include "core/framework/kernel_type_str_resolver.h" +#include "core/framework/kernel_lookup.h" +#include "core/optimizer/constant_sharing.h" +#include "DmlRuntimeFusedGraphKernel.h" +#include "MLOperatorAuthorImpl.h" +#include "DmlGraphFusionHelper.h" + +namespace Dml +{ + namespace + { + struct CompiledPartitionInfo + { + std::shared_ptr indexedSubGraph; + std::unordered_map> isInitializerTransferable; + }; + } + + DmlRuntimeGraphFusionTransformer::DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ) + :onnxruntime::GraphTransformer(name), + m_providerImpl(static_cast(provider)->GetImpl()) + { + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImpl( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const + { + return ApplyImplHelper(graph, modified, graphLevel, logger, {}); + } + + onnxruntime::common::Status DmlRuntimeGraphFusionTransformer::ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const + { + onnxruntime::ProviderType providerType = onnxruntime::kDmlExecutionProvider; + const gsl::not_null registry = m_providerImpl->GetKernelRegistry().get(); + const auto kernelTypeStrResolver = onnxruntime::OpSchemaKernelTypeStrResolver{}; + const auto kernelLookup = onnxruntime::KernelLookup( + providerType, + gsl::make_span(®istry, 1), + kernelTypeStrResolver); + + onnxruntime::GraphViewer graphViewer(graph); + const auto& nodeTopologyList = graphViewer.GetNodesInTopologicalOrder(); + + for (auto nodeIndex : nodeTopologyList) + { + auto* node = graph.GetNode(nodeIndex); + if (!node) + { + continue; // node was removed + } + + std::unordered_map subgraphImplicitInputDefs; + for (const onnxruntime::NodeArg* inputDef : node->ImplicitInputDefs()) + { + subgraphImplicitInputDefs[inputDef->Name()] = inputDef; + } + + for (auto& entry : node->GetAttributeNameToMutableSubgraphMap()) + { + auto& subgraph = *entry.second; + ORT_RETURN_IF_ERROR(ApplyImplHelper(subgraph, modified, graphLevel + 1, logger, subgraphImplicitInputDefs)); + } + } + + // Initializers needed by any graph partition + std::vector additionalSplittingNodes; + std::unordered_map graphNodePropertyMap; + std::unordered_set requiredInitializerMap; + std::unordered_set dynamicCpuInputMap; + std::vector> partitions = BuildPartitions( + graphViewer, + *m_providerImpl->GetInternalRegistrationInfoMap(), + kernelLookup, + m_providerImpl->GetSupportedDeviceDataTypeMask(), + graphNodePropertyMap, + requiredInitializerMap, + dynamicCpuInputMap, + additionalSplittingNodes, + implicitInputDefs, + true); + + // Reset the splitting nodes for the current iteration + additionalSplittingNodes.clear(); + + // Reset the compiled operators for the current iteration + std::vector> compiledPartitionInfos(partitions.size()); + + // Create a map between each initialized tensor and the partition(s) it is part of. + auto initializerPartitionMap = DmlGraphFusionHelper::GetInitializerToPartitionMap(graphViewer, partitions); + + for (uint32_t partitionIndex = 0; partitionIndex < partitions.size(); ++partitionIndex) + { + auto& partition = partitions[partitionIndex]; + + if (partition->GetRootMergedPartition() != partition.get() || + !partition->IsDmlPartition()) + { + continue; + } + + if (partition->IsDmlGraphPartition()) + { + std::unordered_map> isInitializerTransferable; + + std::string partitionKernelPrefix = std::to_string(m_providerImpl->GetPartitionKernelPrefixVal()) + "_"; + m_providerImpl->IncreasePartitionKernelPrefixVal(); + + // populate isInitializerTransferable + for (const auto& input : partition->GetInputs()) + { + const onnx::TensorProto* tensor = nullptr; + if (graph.GetInitializedTensor(input, tensor) && requiredInitializerMap.find(input) != requiredInitializerMap.end()) + { + isInitializerTransferable[input] = {tensor, false}; + } + } + + compiledPartitionInfos[partitionIndex] = std::make_shared(); + compiledPartitionInfos[partitionIndex]->indexedSubGraph = std::make_shared( + DmlGraphFusionHelper::CreateIndexedSubGraph(partition.get(), partitionIndex, partitionKernelPrefix)); + compiledPartitionInfos[partitionIndex]->isInitializerTransferable = std::move(isInitializerTransferable); + } + } + + for (auto&& compiledPartitionInfo : compiledPartitionInfos) + { + // Null compiled operators were not DML partitions + if (compiledPartitionInfo) + { + DmlGraphFusionHelper::RegisterDynamicKernel( + graph, + m_providerImpl->GetKernelRegistry().get(), + m_providerImpl, + graphNodePropertyMap, + dynamicCpuInputMap, + std::move(compiledPartitionInfo->indexedSubGraph), + std::move(compiledPartitionInfo->isInitializerTransferable)); + } + } + + return onnxruntime::common::Status::OK(); + } +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h new file mode 100644 index 0000000000000..cfa743e1f2b85 --- /dev/null +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include "core/optimizer/graph_transformer.h" +#include "core/framework/execution_providers.h" + +namespace Dml +{ +class ExecutionProviderImpl; + +class DmlRuntimeGraphFusionTransformer : public onnxruntime::GraphTransformer +{ +public: + DmlRuntimeGraphFusionTransformer( + const std::string& name, + const onnxruntime::IExecutionProvider* provider + ); + +public: + static inline const char* const DML_GRAPH_FUSION_NODE_NAME_PREFIX = "DmlRuntimeFusedNode_"; + static inline const char* const DML_GRAPH_FUSION_NODE_DOMAIN = "DmlRuntimeFusedNodeDomain"; + +private: + onnxruntime::common::Status ApplyImpl(onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger) const final; + + onnxruntime::common::Status ApplyImplHelper( + onnxruntime::Graph& graph, + bool& modified, + int graphLevel, + const onnxruntime::logging::Logger& logger, + const std::unordered_map& implicitInputDefs) const; + +private: + const ExecutionProviderImpl* m_providerImpl = nullptr; +}; +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp index 5f6bd178aaa15..8644b8d56a426 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.cpp @@ -67,7 +67,8 @@ namespace Dml ExecutionProvider::ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) : + bool enableMetacommands, + bool enableDynamicGraphFusion) : IExecutionProvider(onnxruntime::kDmlExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0)) { D3D12_COMMAND_LIST_TYPE queueType = commandQueue->GetDesc().Type; @@ -80,7 +81,7 @@ namespace Dml ComPtr device; GRAPHICS_THROW_IF_FAILED(commandQueue->GetDevice(IID_GRAPHICS_PPV_ARGS(device.GetAddressOf()))); - m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands); + m_impl = wil::MakeOrThrow(dmlDevice, device.Get(), commandQueue, enableMetacommands, enableDynamicGraphFusion); } std::vector> @@ -147,12 +148,12 @@ namespace Dml // Task 24384515: Update ORT AIInfra release agent pool to install 19H1 SDK on VM bootstrap #define D3D_FEATURE_LEVEL_1_0_CORE_PRIVATE ((D3D_FEATURE_LEVEL)0x1000) - ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands) + ExecutionProviderImpl::ExecutionProviderImpl(IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, bool enableMetacommands, bool enableDynamicGraphFusion) : m_d3d12Device(d3d12Device), m_dmlDevice(dmlDevice), - m_areMetacommandsEnabled(enableMetacommands) + m_areMetacommandsEnabled(enableMetacommands), + m_dynamicGraphFusionEnabled(enableDynamicGraphFusion) { - D3D12_FEATURE_DATA_FEATURE_LEVELS featureLevels = {}; D3D_FEATURE_LEVEL featureLevelsList[] = { @@ -1093,6 +1094,11 @@ namespace Dml return m_areMetacommandsEnabled; } + bool ExecutionProviderImpl::DynamicGraphFusionEnabled() const noexcept + { + return m_dynamicGraphFusionEnabled; + } + std::shared_ptr ExecutionProviderImpl::GetInternalRegistrationInfoMap() const { @@ -1129,9 +1135,10 @@ namespace Dml std::unique_ptr CreateExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands) + bool enableMetacommands, + bool enableDynamicGraphFusion) { - return std::make_unique(dmlDevice, commandQueue, enableMetacommands); + return std::make_unique(dmlDevice, commandQueue, enableMetacommands, enableDynamicGraphFusion); } ID3D12Resource* GetD3D12ResourceFromAllocation(onnxruntime::IAllocator* allocator, void* ptr) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h index 31b893a2f25d7..3aaa11cdee479 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h @@ -5,6 +5,7 @@ #include "GraphTransformer.h" #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" +#include "core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h" #include #include @@ -34,7 +35,8 @@ namespace Dml IDMLDevice* dmlDevice, ID3D12Device* d3d12Device, ID3D12CommandQueue* queue, - bool enableMetacommands = true); + bool enableMetacommands, + bool enableDynamicGraphFusion); void ReleaseCompletedReferences(); @@ -150,6 +152,7 @@ namespace Dml STDMETHOD_(bool, IsMcdmDevice)() const noexcept final; STDMETHOD_(bool, MetacommandsEnabled)() const noexcept final; + bool DynamicGraphFusionEnabled() const noexcept; std::shared_ptr GetGpuAllocator(); std::shared_ptr GetCpuInputAllocator(); @@ -184,6 +187,7 @@ namespace Dml ComPtr m_dmlDevice; bool m_isMcdmDevice = false; bool m_areMetacommandsEnabled = true; + bool m_dynamicGraphFusionEnabled = false; bool m_native16BitShaderOpsSupported = false; std::shared_ptr m_context; std::unique_ptr m_uploadHeap; @@ -236,7 +240,8 @@ namespace Dml explicit ExecutionProvider( IDMLDevice* dmlDevice, ID3D12CommandQueue* commandQueue, - bool enableMetacommands = true + bool enableMetacommands, + bool enableDynamicGraphFusion ); std::unique_ptr GetDataTransfer() const final override @@ -299,9 +304,9 @@ namespace Dml return m_impl.Get(); } - void MetacommandsEnabled() + bool DynamicGraphFusionEnabled() const { - m_impl->MetacommandsEnabled(); + return m_impl->DynamicGraphFusionEnabled(); } virtual std::vector CreatePreferredAllocators() override diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp index 636f46428ce99..3fc8f415e5a58 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.cpp @@ -147,14 +147,14 @@ namespace Dml::GraphDescBuilder const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle) + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs) { - const gsl::span subGraphInputArgNames = indexedSubGraph.GetMetaDef()->inputs; - const gsl::span subGraphOutputArgNames = indexedSubGraph.GetMetaDef()->outputs; struct NodeAndIndex { uint32_t nodeIndex; // The index of the node itself @@ -164,12 +164,14 @@ namespace Dml::GraphDescBuilder // Map from Lotus node argument names to the new node and index where it will be produced std::unordered_map nameToNodeAndIndexMap; + std::unordered_map nodeOutputShapes; + // Map from Lotus node argument names to input indices of the fused kernel node. std::unordered_map nameToDmlFusedNodeInputIndex; - for (size_t inputIndex = 0; inputIndex < subGraphInputArgNames.size(); ++inputIndex) + for (size_t inputIndex = 0; inputIndex < subgraphInputs.size(); ++inputIndex) { - const onnxruntime::NodeArg* graphInput = graph.GetNodeArg(subGraphInputArgNames[inputIndex]); + const onnxruntime::NodeArg* graphInput = subgraphInputs[inputIndex]; if (!graphInput) { @@ -196,13 +198,11 @@ namespace Dml::GraphDescBuilder const uint32_t minNodeCountToReuseCommandList = 5; bool reuseCommandList = false; - if (indexedSubGraph.nodes.size() >= minNodeCountToReuseCommandList) + if (subgraphNodes.size() >= minNodeCountToReuseCommandList) { reuseCommandList = true; } - auto modelPath = graph.ModelPath(); - auto constantCpuGraphInputGetter = [&isInitializerTransferable, &modelPath](const std::string& argName) { ComPtr tensorWrapper; @@ -219,9 +219,11 @@ namespace Dml::GraphDescBuilder // Iterate through each node and create a corresponding node in the new graph // We can iterate the nodes in any order because the edge connectivity will take care of the topological order - for (size_t sortedNodeIndex : indexedSubGraph.nodes) + std::unordered_map> inferredOutputShapes; + + for (const onnxruntime::Node* subgraphNode : subgraphNodes) { - const onnxruntime::Node& node = *graph.GetNode(sortedNodeIndex); + const onnxruntime::Node& node = *subgraphNode; const GraphNodeProperties& graphNodeProps = graphNodePropertyMap.find(GetUniqueNodeName(node))->second; const auto& requiredConstantCpuInputs = graphNodeProps.internalRegInfo->requiredConstantCpuInputs; @@ -244,14 +246,45 @@ namespace Dml::GraphDescBuilder return tensor; }; + EdgeShapes inputShapesOverrides(node.InputDefs().size()); + + // Override the input shapes with shapes that were previously inferred + for (int inputIndex = 0; inputIndex < node.InputDefs().size(); ++inputIndex) + { + auto inputDef = node.InputDefs()[inputIndex]; + + auto outputShapesIter = inferredOutputShapes.find(inputDef->Name()); + if (outputShapesIter != inferredOutputShapes.end()) + { + inputShapesOverrides.GetMutableShape(inputIndex) = outputShapesIter->second; + } + else if (inputDef->HasTensorOrScalarShape()) + { + for (int i = 0; i < inputDef->Shape()->dim_size(); ++i) + { + ORT_THROW_HR_IF(E_INVALIDARG, !inputDef->Shape()->dim(i).has_dim_value()); + inputShapesOverrides.GetMutableShape(inputIndex).push_back(gsl::narrow_cast(inputDef->Shape()->dim(i).dim_value())); + } + } + } + + EdgeShapes outputShapes; DmlGraphNodeCreateInfo graphNodeCreateInfo; graphNodeProps.internalRegInfo->graphNodeFactoryRegistration->factory( node, constantCpuNodeInputGetter, executionHandle, + &inputShapesOverrides, + /*out*/ &outputShapes, /*out*/ &graphNodeCreateInfo ); + ORT_THROW_HR_IF(E_UNEXPECTED, outputShapes.EdgeCount() != node.OutputDefs().size()); + for (int i = 0; i < node.OutputDefs().size(); ++i) + { + inferredOutputShapes[node.OutputDefs()[i]->Name()] = outputShapes.GetShape(i); + } + // Create a map between operatorGraphNodeIndex to mainGraphNodeIndex. std::unordered_map operatorGraphNodeIndexToMainGraphNodeIndexMap; uint32_t graphNodeCount = gsl::narrow_cast(graphNodes.size()); @@ -347,6 +380,8 @@ namespace Dml::GraphDescBuilder operatorGraphNodeIndexToMainGraphNodeIndexMap[operatorGraphOutputEdge.FromNodeIndex], operatorGraphOutputEdge.FromNodeOutputIndex }; + + nodeOutputShapes[arg->Name()] = outputShapes; } } @@ -367,10 +402,12 @@ namespace Dml::GraphDescBuilder } } + EdgeShapes graphOutputShapes(subgraphOutputs.size()); + // Add graph output nodes, which might be in a different order from the encapsulating node - for (size_t outputIndex = 0; outputIndex < subGraphOutputArgNames.size(); ++outputIndex) + for (size_t outputIndex = 0; outputIndex < subgraphOutputs.size(); ++outputIndex) { - const onnxruntime::NodeArg* graphOutput = graph.GetNodeArg(subGraphOutputArgNames[outputIndex]); + const onnxruntime::NodeArg* graphOutput = subgraphOutputs[outputIndex]; ORT_THROW_HR_IF_NULL_MSG(E_POINTER, graphOutput, "FusedNode's nodeArgList does not contain one of the nodeArg"); const auto& outputNodeAndIndex = nameToNodeAndIndexMap.at(graphOutput->Name()); @@ -380,6 +417,7 @@ namespace Dml::GraphDescBuilder edge.FromNodeOutputIndex = outputNodeAndIndex.targetIndex; edge.GraphOutputIndex = gsl::narrow_cast(outputIndex); graphOutputEdges.push_back(edge); + graphOutputShapes.GetMutableShape(outputIndex) = nodeOutputShapes[graphOutput->Name()].GetShape(outputNodeAndIndex.targetIndex); } RemoveUnconnectedNodes(graphNodes, graphInputEdges, graphIntermediateEdges, graphOutputEdges); @@ -390,6 +428,7 @@ namespace Dml::GraphDescBuilder graphDesc.outputEdges = std::move(graphOutputEdges); graphDesc.intermediateEdges = std::move(graphIntermediateEdges); graphDesc.reuseCommandList = reuseCommandList; + graphDesc.outputShapes = std::move(graphOutputShapes); return graphDesc; } } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h index 5c04962e55557..0039678c00e59 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphDescBuilder.h @@ -9,10 +9,10 @@ namespace Dml { struct GraphNodeProperties { - std::shared_ptr + std::shared_ptr internalRegInfo; - // These are currently passed from the partitioning step since the only DML operators current + // These are currently passed from the partitioning step since the only DML operators current // supporting graph nodes don't customize the order of edges or shapes, other than coercing // dimension count. This will change as the supported set of operators as graph nodes increases. Windows::AI::MachineLearning::Adapter::EdgeShapes inputShapes; @@ -38,16 +38,19 @@ namespace Dml std::vector outputEdges; std::vector intermediateEdges; bool reuseCommandList; + Windows::AI::MachineLearning::Adapter::EdgeShapes outputShapes; }; GraphDesc BuildGraphDesc( const uint8_t* isConstGpuGraphInput, const size_t isConstGpuGraphInputCount, const std::unordered_map>& isInitializerTransferable, - const onnxruntime::Graph& graph, - const onnxruntime::IndexedSubGraph& indexedSubGraph, const std::unordered_map& graphNodePropertyMap, IDMLDevice* device, - const void* executionHandle); + const void* executionHandle, + const onnxruntime::Path& modelPath, + gsl::span subgraphNodes, + gsl::span subgraphInputs, + gsl::span subgraphOutputs); } -} \ No newline at end of file +} diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp index 18943878ccedc..f7a4743801d81 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.cpp @@ -151,6 +151,8 @@ namespace Dml _In_opt_ const std::unordered_map* nodeNameToPartitionMap, _Inout_ std::unordered_map& dmlNodePropertyMap, _Inout_ std::unordered_set& requiredInitializerMap, + _Inout_ std::unordered_set& dynamicCpuInputMap, + bool allowDmlGraphDynamicShapes, _Out_ bool* isDmlGraphNode ) { @@ -172,36 +174,68 @@ namespace Dml if (internalRegInfo && internalRegInfo->graphNodeFactoryRegistration) { - bool requiredCpuInputsConstant = true; - for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + if (allowDmlGraphDynamicShapes) { - if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) { - continue; - } + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } - const onnx::TensorProto* tensor = nullptr; - const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); - if (!graph.GetInitializedTensor(inputName, tensor)) - { - requiredCpuInputsConstant = false; - break; + if (graph.GetInitializedTensor(inputName, tensor)) + { + requiredInitializerMap.insert(inputName); + } + else + { + dynamicCpuInputMap.insert(inputName); + } } - requiredInitializerMap.insert(inputName); + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size()) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } - - std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; - if (requiredCpuInputsConstant && - TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && - TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && - !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && - (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + else { - *isDmlGraphNode = true; - graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + bool requiredCpuInputsConstant = true; + for (uint32_t inputIndex : internalRegInfo->requiredConstantCpuInputs) + { + if (inputIndex >= node.InputDefs().size() || !node.InputDefs()[inputIndex]->Exists()) + { + continue; + } + + const onnx::TensorProto* tensor = nullptr; + const std::string& inputName = node.InputDefs()[inputIndex]->Name(); + + if (!graph.GetInitializedTensor(inputName, tensor)) + { + requiredCpuInputsConstant = false; + break; + } + + requiredInitializerMap.insert(inputName); + } + + std::optional requiredInputCount = internalRegInfo->graphNodeFactoryRegistration->requiredInputCount; + if (requiredCpuInputsConstant && + TryGetStaticInputShapes( node, graphNodeProperty.first->second.inputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.inputShapes, internalRegInfo->requiredConstantCpuInputs) && + TryGetStaticOutputShapes(node, graphNodeProperty.first->second.outputShapes) && + !ContainsEmptyDimensions(graphNodeProperty.first->second.outputShapes, internalRegInfo->requiredConstantCpuInputs) && + (requiredInputCount == std::nullopt || *requiredInputCount == node.InputDefs().size())) + { + *isDmlGraphNode = true; + graphNodeProperty.first->second.internalRegInfo = internalRegInfo; + } } } } @@ -379,8 +413,10 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs) + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes) { // Nodes are uniquely identified by the name of their first output argument std::vector> partitions; @@ -443,6 +479,8 @@ namespace Dml &nodeNameToPartitionMap, graphNodePropertyMap, requiredInitializerMap, + dynamicCpuInputMap, + allowDmlGraphDynamicShapes, /*out*/ &isDmlGraphNode ); } diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h index 37d577f647fb5..3bddb5ae16086 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/GraphPartitioner.h @@ -50,6 +50,8 @@ namespace Dml uint32_t supportedDeviceDataTypeMask, // Each bit corresponds to each DML_TENSOR_DATA_TYPE. std::unordered_map& graphNodePropertyMap, std::unordered_set& requiredInitializerMap, + std::unordered_set& dynamicCpuInputMap, gsl::span additionalSplittingNodes, - const std::unordered_map& implicitInputs); + const std::unordered_map& implicitInputs, + bool allowDmlGraphDynamicShapes); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h index d7a0a607cdec9..a8a6d6745e908 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/IExecutionProvider.h @@ -2,8 +2,15 @@ // Licensed under the MIT License. #pragma once + +#include + #include "core/providers/dml/DmlExecutionProvider/inc/DmlExecutionProvider.h" +interface IDMLCompiledOperator; +struct DML_BUFFER_BINDING; +struct DML_BINDING_DESC; + namespace Dml { struct Binding diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp index 6cd10e14e08d2..4deec620fe5fb 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.cpp @@ -1356,13 +1356,14 @@ namespace Windows::AI::MachineLearning::Adapter const onnxruntime::OpNodeProtoHelper* protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, gsl::span requiredConstantCpuInputs, MLOperatorTensorGetter& constantInputGetter ) - : OpNodeInfoWrapper(protoHelper, nullptr, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), + : OpNodeInfoWrapper(protoHelper, inputShapesOverrides, defaultAttributes, requiredConstantCpuInputs, constantInputGetter, nullptr), m_inferredOutputShapes(inferredOutputShapes), m_internalOperator(isInternalOperator), m_graphNodeCreateInfo(graphNodeCreateInfo) diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h index a7f8bebb2de78..913997ff4ad49 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/MLOperatorAuthorImpl.h @@ -4,6 +4,7 @@ #pragma once #include "core/providers/dml/DmlExecutionProvider/inc/IWinmlExecutionProvider.h" #include "core/providers/dml/OperatorAuthorHelper/MLOperatorAuthorHelper.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlEdgeShapes.h" #include "core/framework/op_kernel.h" #include "core/framework/customregistry.h" #include "core/framework/tensorprotoutils.h" @@ -93,42 +94,6 @@ struct AttributeValue using AttributeMap = std::map; -// Encapsulation of shapes across different edges of an operator. Non-tensor -// edges and unused edges have an empty array of dimensions. -class EdgeShapes -{ -public: - EdgeShapes() = default; - - EdgeShapes(size_t count) : m_shapes(count) {} - - const std::vector& GetShape(size_t edgeIndex) const - { - return m_shapes[edgeIndex]; - } - - std::vector& GetMutableShape(size_t edgeIndex) - { - return m_shapes[edgeIndex]; - } - - size_t EdgeCount() const { return m_shapes.size(); } - - void Reset(size_t edge_count) - { - m_shapes.clear(); - m_shapes.resize(edge_count); - } - - bool operator!=(const EdgeShapes& other) const noexcept - { - return (m_shapes != other.m_shapes); - } - - private: - std::vector> m_shapes; -}; - // Base class for ABI objects which may be "Closed", at which point calls will predictably // fail or return a dummy value. This is used for transient ABI context objects which // are passed to methods on kernel or inferencers, and which wrap Lotus objects whose lifetimes @@ -434,6 +399,7 @@ class DmlGraphOpKernelInfoWrapper : public OpNodeInfoWrapper< const onnxruntime::OpNodeProtoHelper * protoHelper, const void* executionHandle, bool isInternalOperator, + const EdgeShapes* inputShapesOverrides, const EdgeShapes* inferredOutputShapes, const AttributeMap* defaultAttributes, DmlGraphNodeCreateInfo* graphNodeCreateInfo, diff --git a/onnxruntime/core/providers/dml/dml_provider_factory.cc b/onnxruntime/core/providers/dml/dml_provider_factory.cc index cd8bc8fe909dc..d587424fe01f8 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory.cc +++ b/onnxruntime/core/providers/dml/dml_provider_factory.cc @@ -30,8 +30,12 @@ namespace onnxruntime { struct DMLProviderFactory : IExecutionProviderFactory { DMLProviderFactory(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) : dml_device_(dml_device), - cmd_queue_(cmd_queue) {} + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) : dml_device_(dml_device), + cmd_queue_(cmd_queue), + metacommands_enabled_(!disable_metacommands), + dynamic_graph_fusion_enabled_(enable_dynamic_graph_fusion) {} ~DMLProviderFactory() override {} std::unique_ptr CreateProvider() override; @@ -42,10 +46,11 @@ struct DMLProviderFactory : IExecutionProviderFactory { ComPtr dml_device_{}; ComPtr cmd_queue_{}; bool metacommands_enabled_ = true; + bool dynamic_graph_fusion_enabled_ = false; }; std::unique_ptr DMLProviderFactory::CreateProvider() { - auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_); + auto provider = Dml::CreateExecutionProvider(dml_device_.Get(), cmd_queue_.Get(), metacommands_enabled_, dynamic_graph_fusion_enabled_); return provider; } @@ -54,7 +59,9 @@ void DMLProviderFactory::SetMetacommandsEnabled(bool metacommands_enabled) { } std::shared_ptr CreateExecutionProviderFactory_DML(IDMLDevice* dml_device, - ID3D12CommandQueue* cmd_queue) { + ID3D12CommandQueue* cmd_queue, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { #ifndef _GAMING_XBOX // Validate that the D3D12 devices match between DML and the command queue. This specifically asks for IUnknown in // order to be able to compare the pointers for COM object identity. @@ -73,7 +80,7 @@ std::shared_ptr CreateExecutionProviderFactory_DML(ID const Env& env = Env::Default(); auto luid = d3d12_device->GetAdapterLuid(); env.GetTelemetryProvider().LogExecutionProviderEvent(&luid); - return std::make_shared(dml_device, cmd_queue); + return std::make_shared(dml_device, cmd_queue, disable_metacommands, enable_dynamic_graph_fusion); } void DmlConfigureProviderFactoryMetacommandsEnabled(IExecutionProviderFactory* factory, bool metacommandsEnabled) { @@ -234,12 +241,10 @@ static void SortHeterogenousDXCoreAdapterList( std::sort(adapter_infos.begin(), adapter_infos.end(), policy); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id) { - return Create(device_id, /*skip_software_device_check*/ false); -} - std::shared_ptr DMLProviderFactoryCreator::CreateFromOptions( - OrtDmlDeviceOptions* device_options) { + OrtDmlDeviceOptions* device_options, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { auto default_device_options = OrtDmlDeviceOptions { Default, Gpu }; if (device_options == nullptr) { device_options = &default_device_options; @@ -286,7 +291,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom adapters.begin(), [](auto& a){ return a.Adapter; }); - return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters)); + return onnxruntime::DMLProviderFactoryCreator::CreateFromAdapterList(std::move(adapters), disable_metacommands, enable_dynamic_graph_fusion); } static std::optional ParsePerformancePreference(const ProviderOptions& provider_options) { @@ -354,12 +359,32 @@ static std::optional ParseDeviceId(const ProviderOptions& provider_options) return {}; } +static bool ParseBoolean(const ProviderOptions& provider_options, const std::string& key) { + auto preference_it = provider_options.find(key); + if (preference_it != provider_options.end() && !preference_it->second.empty()) { + if (preference_it->second == "True" || preference_it->second == "true") { + return true; + } else if (preference_it->second == "False" || preference_it->second == "false") { + return false; + } else { + ORT_THROW("[ERROR] [DirectML] The value for the key '" + key + "' should be 'True' or 'False'. Default value is 'False'.\n"); + } + } + + return false; +} + std::shared_ptr DMLProviderFactoryCreator::CreateFromProviderOptions( - const ProviderOptions& provider_options) { + const ProviderOptions& provider_options) { + + bool disable_metacommands = ParseBoolean(provider_options, "disable_metacommands"); + bool enable_dynamic_graph_fusion = ParseBoolean(provider_options, "enable_dynamic_graph_fusion"); + bool skip_software_device_check = false; auto device_id = ParseDeviceId(provider_options); + if (device_id.has_value()) { - return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value()); + return onnxruntime::DMLProviderFactoryCreator::Create(device_id.value(), skip_software_device_check, disable_metacommands, enable_dynamic_graph_fusion); } auto preference = ParsePerformancePreference(provider_options); @@ -367,7 +392,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom // If no preference/filters are specified then create with default preference/filters. if (!preference.has_value() && !filter.has_value()) { - return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(nullptr, disable_metacommands, enable_dynamic_graph_fusion); } if (!preference.has_value()) { @@ -381,7 +406,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom OrtDmlDeviceOptions device_options; device_options.Preference = preference.value(); device_options.Filter = filter.value(); - return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options); + return onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(&device_options, disable_metacommands, enable_dynamic_graph_fusion); } Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateD3D12Device( @@ -441,7 +466,10 @@ Microsoft::WRL::ComPtr DMLProviderFactoryCreator::CreateDMLDevice(ID return dml_device; } -std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3D12Device* d3d12_device) { +std::shared_ptr CreateDMLDeviceAndProviderFactory( + ID3D12Device* d3d12_device, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { D3D12_COMMAND_QUEUE_DESC cmd_queue_desc = {}; cmd_queue_desc.Type = D3D12_COMMAND_LIST_TYPE_DIRECT; cmd_queue_desc.Flags = D3D12_COMMAND_QUEUE_FLAG_DISABLE_GPU_TIMEOUT; @@ -450,16 +478,22 @@ std::shared_ptr CreateDMLDeviceAndProviderFactory(ID3 ORT_THROW_IF_FAILED(d3d12_device->CreateCommandQueue(&cmd_queue_desc, IID_GRAPHICS_PPV_ARGS(cmd_queue.ReleaseAndGetAddressOf()))); auto dml_device = onnxruntime::DMLProviderFactoryCreator::CreateDMLDevice(d3d12_device); - return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get()); + return CreateExecutionProviderFactory_DML(dml_device.Get(), cmd_queue.Get(), disable_metacommands, enable_dynamic_graph_fusion); } -std::shared_ptr DMLProviderFactoryCreator::Create(int device_id, bool skip_software_device_check) { +std::shared_ptr DMLProviderFactoryCreator::Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { ComPtr d3d12_device = CreateD3D12Device(device_id, skip_software_device_check); - return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } std::shared_ptr DMLProviderFactoryCreator::CreateFromAdapterList( - std::vector>&& dxcore_devices) { + std::vector>&& dxcore_devices, + bool disable_metacommands, + bool enable_dynamic_graph_fusion) { // Choose the first device from the list since it's the highest priority auto dxcore_device = dxcore_devices[0]; @@ -467,7 +501,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom ComPtr d3d12_device; ORT_THROW_IF_FAILED(D3D12CreateDevice(dxcore_device.Get(), D3D_FEATURE_LEVEL_11_0, IID_GRAPHICS_PPV_ARGS(d3d12_device.ReleaseAndGetAddressOf()))); - return CreateDMLDeviceAndProviderFactory(d3d12_device.Get()); + return CreateDMLDeviceAndProviderFactory(d3d12_device.Get(), disable_metacommands, enable_dynamic_graph_fusion); } } // namespace onnxruntime @@ -477,7 +511,7 @@ std::shared_ptr DMLProviderFactoryCreator::CreateFrom // The OrtSessionOptionsAppendExecutionProvider_DML export on the OrtDmlApi should be used instead. ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML, _In_ OrtSessionOptions* options, int device_id) { API_IMPL_BEGIN - options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id)); + options->provider_factories.push_back(onnxruntime::DMLProviderFactoryCreator::Create(device_id, false, false, false)); API_IMPL_END return nullptr; } @@ -489,7 +523,9 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProviderEx_DML, _In_ OrtSess _In_ IDMLDevice* dml_device, _In_ ID3D12CommandQueue* cmd_queue) { API_IMPL_BEGIN options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_DML(dml_device, - cmd_queue)); + cmd_queue, + false, + false)); API_IMPL_END return nullptr; } @@ -517,7 +553,7 @@ ORT_API_STATUS_IMPL(FreeGPUAllocation, _In_ void* ptr) { ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_options) { API_IMPL_BEGIN #ifdef USE_DML - auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options); + auto factory = onnxruntime::DMLProviderFactoryCreator::CreateFromOptions(device_options, false, false); // return the create function for a dxcore device options->provider_factories.push_back(factory); #endif // USE_DML diff --git a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h index 4e13330a4cd71..0fab9fe902526 100644 --- a/onnxruntime/core/providers/dml/dml_provider_factory_creator.h +++ b/onnxruntime/core/providers/dml/dml_provider_factory_creator.h @@ -17,15 +17,24 @@ namespace onnxruntime { struct DMLProviderFactoryCreator { - static std::shared_ptr Create(int device_id); - static std::shared_ptr Create(int device_id, bool skip_software_device_check); + static std::shared_ptr Create( + int device_id, + bool skip_software_device_check, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static std::shared_ptr CreateFromProviderOptions( const ProviderOptions& provider_options_map); - static std::shared_ptr CreateFromOptions(OrtDmlDeviceOptions* device_options); + + static std::shared_ptr CreateFromOptions( + OrtDmlDeviceOptions* device_options, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static std::shared_ptr CreateFromAdapterList( - std::vector>&& dxcore_devices); + std::vector>&& dxcore_devices, + bool disable_metacommands, + bool enable_dynamic_graph_fusion); static Microsoft::WRL::ComPtr CreateD3D12Device(int device_id, bool skip_software_device_check); static Microsoft::WRL::ComPtr CreateDMLDevice(ID3D12Device* d3d12_device); diff --git a/onnxruntime/core/providers/js/operators/cast.cc b/onnxruntime/core/providers/js/operators/cast.cc index f05e1eac4329c..9b6ac6d7e253b 100644 --- a/onnxruntime/core/providers/js/operators/cast.cc +++ b/onnxruntime/core/providers/js/operators/cast.cc @@ -14,8 +14,7 @@ const std::vector& CastOpTypeConstraints() { // https://gpuweb.github.io/gpuweb/wgsl/#plain-types-section // static std::vector types{ - // TODO(fs-eire): support f16 when it's ready - // DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/core/providers/js/operators/resize.cc b/onnxruntime/core/providers/js/operators/resize.cc index 7619c33a477aa..5b2e385777a37 100644 --- a/onnxruntime/core/providers/js/operators/resize.cc +++ b/onnxruntime/core/providers/js/operators/resize.cc @@ -5,15 +5,15 @@ namespace onnxruntime { namespace js { -#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \ - ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ - Resize, \ - domain, \ - 10, 10, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ +#define REGISTER_RESIZE_VERSIONED_10_10_KERNEL(domain) \ + ONNX_OPERATOR_VERSIONED_KERNEL_EX( \ + Resize, \ + domain, \ + 10, 10, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .TypeConstraint("T", JsepSupportedFloatTypes()), \ Resize); #define REGISTER_RESIZE_VERSIONED_KERNEL(domain, sinceVersion, endVerion) \ @@ -26,22 +26,22 @@ namespace js { .InputMemoryType(OrtMemTypeCPUInput, 1) \ .InputMemoryType(OrtMemTypeCPUInput, 2) \ .InputMemoryType(OrtMemTypeCPUInput, 3) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ + .TypeConstraint("T1", JsepSupportedFloatTypes()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()), \ Resize); -#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \ - ONNX_OPERATOR_KERNEL_EX( \ - Resize, \ - domain, \ - sinceVersion, \ - kJsExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .InputMemoryType(OrtMemTypeCPUInput, 1) \ - .InputMemoryType(OrtMemTypeCPUInput, 2) \ - .InputMemoryType(OrtMemTypeCPUInput, 3) \ - .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ - .TypeConstraint("T2", DataTypeImpl::GetTensorType()), \ +#define REGISTER_RESIZE_KERNEL(domain, sinceVersion) \ + ONNX_OPERATOR_KERNEL_EX( \ + Resize, \ + domain, \ + sinceVersion, \ + kJsExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 1) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) \ + .InputMemoryType(OrtMemTypeCPUInput, 3) \ + .TypeConstraint("T1", JsepSupportedFloatTypes()) \ + .TypeConstraint("T2", JsepSupportedFloatTypes()), \ Resize); #define REGISTER_RESIZE_KERNEL_DOMAIN(domain) \ diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index 74d237a62f73d..ef1f0bf9f8d0e 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1210,6 +1210,12 @@ Status TensorrtExecutionProvider::OnRunEnd(bool sync_stream) { } void TensorrtExecutionProvider::GetCustomOpDomainList(std::vector& custom_op_domain_list) const { + if (info_.custom_op_domain_list.empty()) { + common::Status status = CreateTensorRTCustomOpDomainList(info_); + if (!status.IsOK()) { + LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; + } + } custom_op_domain_list = info_.custom_op_domain_list; } @@ -1869,6 +1875,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph, } else if (number_of_trt_nodes == number_of_ort_nodes) { LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider"; } else { + sync_stream_after_enqueue_ = true; LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs; } @@ -2387,7 +2394,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vectorallocate_func, context->release_func, context->allocator_handle, context->node_name, &parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name], &networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name], - input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, + input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_, dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_, runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_, dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_, @@ -2415,6 +2422,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector& input_indexes = (trt_state->input_info)[0]; const std::unordered_map& output_indexes = (trt_state->output_info)[0]; const std::unordered_map& output_types = (trt_state->output_info)[1]; + bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue; auto fused_node_name = trt_state->fused_node_name; auto& shape_ranges = trt_state->input_shape_ranges; auto trt_builder = trt_state->builder->get(); @@ -3022,6 +3030,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector> input_info; std::vector> output_info; std::unordered_map>>> input_shape_ranges; + bool sync_stream_after_enqueue = false; OrtMutex* tensorrt_mu_ptr = nullptr; bool fp16_enable = false; bool int8_enable = false; @@ -196,7 +197,7 @@ class TensorrtExecutionProvider : public IExecutionProvider { Status ReplayGraph() override; private: - TensorrtExecutionProviderInfo info_; + mutable TensorrtExecutionProviderInfo info_; bool external_stream_ = false; cudaStream_t stream_ = nullptr; int max_partition_iterations_ = 1000; @@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider { cudnnHandle_t external_cudnn_handle_ = nullptr; cublasHandle_t external_cublas_handle_ = nullptr; + // Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3() + mutable bool sync_stream_after_enqueue_ = false; + CUDAGraph cuda_graph_; bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc index b5dbe1ac459b1..d7e13df000272 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.cc @@ -75,11 +75,6 @@ struct Tensorrt_Provider : Provider { info.device_id = device_id; info.has_trt_options = false; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } @@ -121,11 +116,6 @@ struct Tensorrt_Provider : Provider { info.profile_opt_shapes = options.trt_profile_opt_shapes == nullptr ? "" : options.trt_profile_opt_shapes; info.cuda_graph_enable = options.trt_cuda_graph_enable != 0; - common::Status status = CreateTensorRTCustomOpDomainList(info); - if (!status.IsOK()) { - LOGS_DEFAULT(WARNING) << "[TensorRT EP] Failed to get TRT plugins from TRT plugin registration."; - } - return std::make_shared(info); } diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index b4d47652942b7..1163be27b1685 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -52,8 +52,10 @@ #include "core/providers/cpu/cpu_execution_provider.h" #ifdef USE_DML // TODO: This is necessary for the workaround in TransformGraph #include "core/providers/dml/DmlExecutionProvider/src/DmlGraphFusionTransformer.h" +#include "core/providers/dml/DmlExecutionProvider/src/DmlRuntimeGraphFusionTransformer.h" #include "core/providers/dml/DmlExecutionProvider/src/GraphTransformer.h" #include "core/providers/dml/dml_session_options_config_keys.h" +#include "core/providers/dml/DmlExecutionProvider/src/ExecutionProvider.h" #endif #include "core/session/environment.h" #include "core/session/user_logging_sink.h" @@ -613,9 +615,35 @@ common::Status InferenceSession::RegisterExecutionProvider(const std::shared_ptr } #if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS) - // Create Custom Op if EP requests it + // Register Custom Op if EP requests it std::vector custom_op_domains; - p_exec_provider->GetCustomOpDomainList(custom_op_domains); + std::vector candidate_custom_op_domains; + p_exec_provider->GetCustomOpDomainList(candidate_custom_op_domains); + + auto registry_kernels = kernel_registry_manager_.GetKernelRegistriesByProviderType(p_exec_provider->Type()); + + // Register the custom op domain only if it has not been registered before + if (registry_kernels.empty()) { + custom_op_domains = candidate_custom_op_domains; + } else { + for (auto candidate_custom_op_domain : candidate_custom_op_domains) { + for (auto registry_kernel : registry_kernels) { + const auto& kernel_map = registry_kernel->GetKernelCreateMap(); + bool need_register = true; + // If the kernel registry is the ep's custom op registry, we only need to check the first kernel, + // because all kernels in one kernel registry should have the same domain name. + for (auto iter = kernel_map.begin(); iter != kernel_map.end(); iter++) { + if (iter->second.kernel_def->Domain() == candidate_custom_op_domain->domain_) { + need_register = false; + break; + } + } + if (need_register) { + custom_op_domains.push_back(candidate_custom_op_domain); + } + } + } + } if (!custom_op_domains.empty()) { if (AddCustomOpDomains(custom_op_domains) != Status::OK()) { @@ -984,14 +1012,25 @@ common::Status InferenceSession::Load() { common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool saving_model_in_ort_format) { // The transformer order: - // 1. ensure potential QDQ node units have unique DQ nodes (required transformer). + // 1. Ensure we inline as many functions as possible. We refer to it as Ahead Of Time (AOT) function inlining. + // 2. ensure potential QDQ node units have unique DQ nodes (required transformer). // - This is a required transformer as the ORT code has a hard requirement there are no overlapping QDQ node units. // - We run it here in case optimizers are disabled. - // 2. run level 1 optimizations. these only use ONNX operators. - // 3. partition nodes based on EP capabilities. EPs may fuse nodes during this process. - // 4. run level 2+ optimizations. level 2 and 3 optimizations use contrib ops. - // 5. insert cast nodes (required transformer). - // 6. insert copy nodes (required transformer). + // 3. run level 1 optimizations. these only use ONNX operators. + // 4. partition nodes based on EP capabilities. EPs may fuse nodes during this process. + // 5. run level 2+ optimizations. level 2 and 3 optimizations use contrib ops. + // 6. insert cast nodes (required transformer). + // 7. insert copy nodes (required transformer). + + // Run Ahead Of time function inlining + GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); + if (const bool disable_aot_function_inlining = + session_options_.config_options.GetConfigOrDefault( + kOrtSessionOptionsDisableAheadOfTimeFunctionInlining, "0") == "1"; + !disable_aot_function_inlining) { + ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.InlineFunctionsAOT(*model_, + execution_providers_, kernel_registry_manager_)); + } auto apply_transformer_once = [](const GraphTransformer& transformer, const logging::Logger& logger, Graph& graph) { @@ -1075,7 +1114,6 @@ common::Status InferenceSession::TransformGraph(onnxruntime::Graph& graph, bool } // Do partitioning based on execution providers' capabilities. - GraphPartitioner partitioner(kernel_registry_manager_, execution_providers_); ORT_RETURN_IF_ERROR_SESSIONID_(partitioner.Partition(graph, session_state_->GetMutableFuncMgr(), transform_layout_fn, mode, debug_graph_fn)); @@ -1562,7 +1600,9 @@ common::Status InferenceSession::Initialize() { record_runtime_optimization_produced_op_schema)); #ifdef USE_DML - if (execution_providers_.Get(kDmlExecutionProvider)) { + const IExecutionProvider* dmlExecutionProvider = execution_providers_.Get(kDmlExecutionProvider); + + if (dmlExecutionProvider) { // DML graph fusion is an important runtime optimization that cannot be done ahead of time; it must be disabled // when running in "offline mode" and saving an optimized model to disk. To support users that want to optimize // models offline, and then disable graph optimizations when running "online", this transformer ignores the ORT @@ -1572,11 +1612,20 @@ common::Status InferenceSession::Initialize() { if (dml_graph_fusion_enabled) { std::unique_ptr dmlGraphFusionTransformer = std::make_unique("DmlGraphFusionTransformer", - execution_providers_.Get(kDmlExecutionProvider)); + dmlExecutionProvider); if (dmlGraphFusionTransformer == nullptr) { return Status(common::ONNXRUNTIME, common::FAIL, "DmlGraphFusionTransformer is nullptr"); } ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + + if (static_cast(dmlExecutionProvider)->DynamicGraphFusionEnabled()) { + std::unique_ptr dmlRuntimeGraphFusionTransformer = std::make_unique("DmlRuntimeGraphFusionTransformer", + dmlExecutionProvider); + if (dmlRuntimeGraphFusionTransformer == nullptr) { + return Status(common::ONNXRUNTIME, common::FAIL, "DmlRuntimeGraphFusionTransformer is nullptr"); + } + ORT_RETURN_IF_ERROR_SESSIONID_(graph_transformer_mgr_.Register(std::move(dmlRuntimeGraphFusionTransformer), onnxruntime::TransformerLevel::Level3)); + } } // This transformer applies DML-specific fusions that go beyond what ORT offers by default diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index d950223f2d108..d307f79c372ed 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -1625,6 +1625,28 @@ ProviderOptions GetProviderInfo_Cuda(const OrtCUDAProviderOptionsV2* provider_op } // namespace onnxruntime +void AddTensorRTCustomOpDomainToSessionOption(OrtSessionOptions* options, std::string extra_plugin_lib_paths) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + + std::vector custom_op_domains; + onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); + provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); + for (auto ptr : custom_op_domains) { + if (!is_already_in_domains(ptr->domain_, options->custom_op_domains_)) { + options->custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } + } +} + ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Dnnl, _In_ OrtSessionOptions* options, int use_arena) { API_IMPL_BEGIN auto factory = onnxruntime::DnnlProviderFactoryCreator::Create(use_arena); @@ -1646,13 +1668,8 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = onnxruntime::Env::Default().GetEnvironmentVar("trt_extra_plugin_lib_paths"); - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END @@ -1679,12 +1696,7 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In options->provider_factories.push_back(factory); - std::vector custom_op_domains; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, ""); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, ""); return nullptr; API_IMPL_END @@ -1788,13 +1800,8 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT_V2, options->provider_factories.push_back(factory); - std::vector custom_op_domains; std::string extra_plugin_lib_paths = (tensorrt_options == nullptr || tensorrt_options->trt_extra_plugin_lib_paths == nullptr) ? "" : tensorrt_options->trt_extra_plugin_lib_paths; - onnxruntime::ProviderInfo_TensorRT& provider_info = onnxruntime::GetProviderInfo_TensorRT(); - provider_info.GetTensorRTCustomOpDomainList(custom_op_domains, extra_plugin_lib_paths); - for (auto ptr : custom_op_domains) { - options->custom_op_domains_.push_back(ptr); - } + AddTensorRTCustomOpDomainToSessionOption(options, extra_plugin_lib_paths); return nullptr; API_IMPL_END diff --git a/onnxruntime/python/onnxruntime_pybind_quant.cc b/onnxruntime/python/onnxruntime_pybind_quant.cc index 52ea677d5141d..04dfa9b51e112 100644 --- a/onnxruntime/python/onnxruntime_pybind_quant.cc +++ b/onnxruntime/python/onnxruntime_pybind_quant.cc @@ -6,6 +6,7 @@ #include #include "contrib_ops/cpu/quantization/dequantize_blockwise.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" #include "core/util/thread_utils.h" namespace pybind11 { @@ -64,9 +65,39 @@ void QuantizeMatMul4BitsBlockwise( tp.get()); } +template +void QuantizeMatMulBnb4Blockwise( + py::array_t dst, + py::array_t src, + py::array_t absmax, + int32_t block_size, + int32_t quant_type, + int32_t N, + int32_t K) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + py::buffer_info dst_buf = dst.request(); + py::buffer_info src_buf = src.request(); + py::buffer_info absmax_buf = absmax.request(); + + contrib::QuantizeBlockwiseBnb4( + static_cast(dst_buf.ptr), + static_cast(src_buf.ptr), + static_cast(absmax_buf.ptr), + block_size, + quant_type, + N, + K, + tp.get()); +} + void CreateQuantPybindModule(py::module& m) { m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); m.def("quantize_matmul_4bits", &QuantizeMatMul4BitsBlockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); + m.def("quantize_matmul_bnb4", &QuantizeMatMulBnb4Blockwise); } } // namespace python diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index a8c217b0ff1f6..3a977772873f3 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -59,7 +59,7 @@ void addGlobalSchemaFunctions(pybind11::module& m) { onnxruntime::ArmNNProviderFactoryCreator::Create(0), #endif #ifdef USE_DML - onnxruntime::DMLProviderFactoryCreator::Create(0, /*skip_software_device_check*/ true), + onnxruntime::DMLProviderFactoryCreator::Create(0, false, false, false), #endif #ifdef USE_NNAPI onnxruntime::NnapiProviderFactoryCreator::Create(0, std::optional()), diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 35e03bf9eacd5..a72f563601512 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -433,6 +433,15 @@ const ROCMExecutionProviderInfo GetRocmExecutionProviderInfo(ProviderInfo_ROCM* #ifdef USE_TENSORRT void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOptions& options) { if (auto* tensorrt_provider_info = TryGetProviderInfo_TensorRT()) { + auto is_already_in_domains = [&](std::string& domain_name, std::vector& domains) { + for (auto ptr : domains) { + if (domain_name == ptr->domain_) { + return true; + } + } + return false; + }; + std::string trt_extra_plugin_lib_paths = ""; const auto it = options.find("trt_extra_plugin_lib_paths"); if (it != options.end()) { @@ -441,7 +450,11 @@ void RegisterTensorRTPluginsAsCustomOps(PySessionOptions& so, const ProviderOpti std::vector domain_list; tensorrt_provider_info->GetTensorRTCustomOpDomainList(domain_list, trt_extra_plugin_lib_paths); for (auto ptr : domain_list) { - so.custom_op_domains_.push_back(ptr); + if (!is_already_in_domains(ptr->domain_, so.custom_op_domains_)) { + so.custom_op_domains_.push_back(ptr); + } else { + LOGS_DEFAULT(WARNING) << "The custom op domain name " << ptr->domain_ << " is already in session option."; + } } } else { ORT_THROW("Please install TensorRT libraries as mentioned in the GPU requirements page, make sure they're in the PATH or LD_LIBRARY_PATH, and that your GPU is supported."); diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu new file mode 100644 index 0000000000000..3504ce1bebe8c --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/dequant_blockwise_bnb4.cu @@ -0,0 +1,89 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/device_array.h" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct DequantizeBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const uint8_t* quant_; + const T* absmax_; + T* quant_map_buffer_; + int n_; + int k_; +}; + +template +class DequantizeBnb4 : public IKernelExplorer { + public: + DequantizeBnb4( + int quant_type, + DeviceArray& output, + DeviceArray& quant, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.quant_type_ = quant_type; + params_.output_ = static_cast(output.ptr()); + params_.quant_ = static_cast(quant.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + ORT_THROW_IF_ERROR(contrib::cuda::DequantizeBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.quant_, + params_.absmax_, + 64, + params_.n_ * params_.k_, + params_.StreamHandle())); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = DequantizeBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(DequantizeBnb4, half); + REGISTER_OP(DequantizeBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu new file mode 100644 index 0000000000000..e4cd83565357a --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/cuda/matmul_bnb4.cu @@ -0,0 +1,96 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// This file serve as a simple example for adding a tunable op to onnxruntime. + +#include +#include +#include + +#include + +#include "core/providers/cuda/tunable/cuda_tunable.h" +#include "python/tools/kernel_explorer/kernel_explorer_interface.h" +#include "python/tools/kernel_explorer/kernels/vector_add_kernel.cuh" +#include "contrib_ops/cuda/quantization/dequantize_blockwise_bnb4.cuh" +#include "contrib_ops/cuda/quantization/matmul_bnb4.cuh" + +namespace py = pybind11; + +namespace onnxruntime { + +// Extend the OpParams so that all specializations have the same parameter passing interface +template +struct MatrixFloatBnb4Params : cuda::tunable::OpParams { + std::string Signature() const override { return std::to_string(n_); } + + int quant_type_; + T* output_; + const T* a_; + const uint8_t* b_; + const T* absmax_; + T* quant_map_buffer_; + int m_; + int n_; + int k_; +}; + +template +class MatrixFloatBnb4 : public IKernelExplorer { + public: + MatrixFloatBnb4(DeviceArray& output, + DeviceArray& a, + DeviceArray& b, + DeviceArray& absmax, + DeviceArray& quant_map_buffer, + int quant_type, int m, int n, int k) { + params_.tuning_ctx = TuningContext(); + params_.stream = Stream(); + params_.output_ = static_cast(output.ptr()); + params_.a_ = static_cast(a.ptr()); + params_.b_ = static_cast(b.ptr()); + params_.absmax_ = static_cast(absmax.ptr()); + params_.quant_map_buffer_ = static_cast(quant_map_buffer.ptr()); + params_.quant_type_ = quant_type; + params_.m_ = m; + params_.n_ = n; + params_.k_ = k; + } + + void Run() override { + ORT_THROW_IF_ERROR(contrib::cuda::SetBnbQuantMap( + params_.quant_type_, + params_.quant_map_buffer_, + params_.StreamHandle())); + contrib::cuda::TryMatMulBnb4( + params_.quant_map_buffer_, + params_.output_, + params_.a_, + params_.b_, + params_.absmax_, + params_.m_, + params_.n_, + params_.k_, + 64, + params_.StreamHandle()); + } + + private: + // A VectorAddOp is a callable that can process const VectorAddParams* + using ParamsT = MatrixFloatBnb4Params; + ParamsT params_{}; +}; + +#define REGISTER_OP(name, type) \ + py::class_>(m, #name "_" #type) \ + .def(py::init()) \ + .def("SetRepeats", &name::SetRepeats) \ + .def("Profile", &name::Profile) \ + .def("Run", &name::Run); + +KE_REGISTER(m) { + REGISTER_OP(MatrixFloatBnb4, half); + REGISTER_OP(MatrixFloatBnb4, float); +} + +} // namespace onnxruntime diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py new file mode 100644 index 0000000000000..140151aadcc0f --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/dequantize_blockwise_bnb4.py @@ -0,0 +1,92 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "DequantizeBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "DequantizeBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class DequantizeBnb4Metric(ke.BandwidthMetric): + quant_type: str + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} n={self.n} k={self.k} {self.name}" + ) + + +def profile_dequantize_int4_func(qt, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(n, k).astype(dtype) + quant = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + quant_d = ke.DeviceArray(quant) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + my_op = f(quant_enums[qt], output_d, quant_d, absmax_d, quant_map_buffer_d, n, k) + duration_ms = my_op.Profile() + total_bytes = numel / 2 + (numel + numel / block_size) * dtype_to_bytes(dtype) + + ke.report(DequantizeBnb4Metric(func, dtype, duration_ms, total_bytes, qt, n, k)) + + +def profile_with_args(qt, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_dequantize_int4_func(qt, n, k, dtype, func) + + +def profile(): + for qt in quant_types: + for dt in dtypes: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, n, k, dt, True) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py new file mode 100644 index 0000000000000..4a9489050fd61 --- /dev/null +++ b/onnxruntime/python/tools/kernel_explorer/kernels/matmul_bnb4.py @@ -0,0 +1,136 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +import sys +from dataclasses import dataclass + +import kernel_explorer as ke +import numpy as np +from utils import dtype_to_bytes + + +def dtype_to_funcs(dtype): + type_map = { + "float16": list(filter(lambda x: "MatrixFloatBnb4_half" in x, dir(ke))), + "float32": list(filter(lambda x: "MatrixFloatBnb4_float" in x, dir(ke))), + } + return type_map[dtype] + + +def dtype_to_funcs_cublas(dtype): + type_map = { + "float16": list(filter(lambda x: "GemmBenchmark_half" in x, dir(ke))), + "float32": list(filter(lambda x: "GemmBenchmark_float" in x, dir(ke))), + } + return type_map[dtype] + + +quant_enums = {"FP4": 0, "NF4": 1} + + +dtypes = ["float16", "float32"] +quant_types = ["FP4", "NF4"] + + +@dataclass +class MatrixMulMetric(ke.BandwidthMetric): + m: int + n: int + k: int + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +@dataclass +class MatrixFpBnb4Metric(MatrixMulMetric): + quant_type: str + + def report(self): + return ( + f"{self.duration:6.2f} us {self.gbps:5.2f} GB/s" + f" {self.quant_type} {self.dtype} m={self.m} n={self.n} k={self.k} {self.name}" + ) + + +def profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func): + np.random.seed(0) + block_size = 64 + numel = n * k + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.randint(low=0, high=255, size=(numel + 1) // 2).astype("uint8") + absmax = np.random.rand((numel + block_size - 1) // block_size).astype(dtype) + quant_map_buffer = np.zeros(16).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + absmax_d = ke.DeviceArray(absmax) + quant_map_buffer_d = ke.DeviceArray(quant_map_buffer) + f = getattr(ke, func) + + my_op = f(output_d, a_d, b_d, absmax_d, quant_map_buffer_d, quant_enums[qt], m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixFpBnb4Metric(func, dtype, duration_ms, total_bytes, m, n, k, qt)) + + +def profile_gemm_func(m, n, k, dtype, func): + np.random.seed(0) + output = np.random.rand(m, n).astype(dtype) + a = np.random.rand(m, k).astype(dtype) + b = np.random.rand(k, n).astype(dtype) + + output_d = ke.DeviceArray(output) + a_d = ke.DeviceArray(a) + b_d = ke.DeviceArray(b) + f = getattr(ke, func) + my_op = f(output_d, a_d, b_d, m, n, k) + duration_ms = my_op.Profile() + total_bytes = (m * k + n * k + m * n) * (dtype_to_bytes(dtype)) + + ke.report(MatrixMulMetric(func, dtype, duration_ms, total_bytes, m, n, k)) + + +def profile_with_args(qt, m, n, k, dtype, sort): + with ke.benchmark(sort): + for func in dtype_to_funcs(dtype): + profile_matmul_fp_bnb4_func(qt, m, n, k, dtype, func) + + for func in dtype_to_funcs_cublas(dtype): + profile_gemm_func(m, n, k, dtype, func) + + +def profile(): + dims_m = [1] + for qt in quant_types: + for dt in dtypes: + for m in dims_m: + for n, k in ((4096, 4096), (4096, 12288), (12288, 4096)): + profile_with_args(qt, m, n, k, dt, False) + print() + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + group = parser.add_argument_group("profile with args") + group.add_argument("m", type=int) + group.add_argument("n", type=int) + group.add_argument("k", type=int) + group.add_argument("quant_type", choices=quant_types) + group.add_argument("dtype", choices=dtypes) + group.add_argument("--sort", action="store_true") + + if len(sys.argv) == 1: + profile() + else: + args = parser.parse_args() + profile_with_args(args.quant_type, args.m, args.n, args.k, args.dtype, args.sort) diff --git a/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py new file mode 100644 index 0000000000000..951746a089305 --- /dev/null +++ b/onnxruntime/python/tools/quantization/matmul_bnb4_quantizer.py @@ -0,0 +1,240 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import argparse +import logging +import os +from typing import List, Tuple + +import numpy as np +import numpy.typing as npt +import onnx +from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto + +from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + +from .onnx_model import ONNXModel +from .quant_utils import attribute_to_kwarg + +logger = logging.getLogger(__name__) + + +class MatMulBnb4Quantizer: + """Perform 4b quantization of constant MatMul weights using FP4 or NF4 data type""" + + ################## + # quantization types, must be consistent with native code type + # Bnb_DataType_t defined in blockwise_quant_block_bnb4.h + + # 4b floating point with bias of 3 + FP4 = 0 + + # 4b NormalFloat + NF4 = 1 + + def __init__(self, model: ModelProto, quant_type: int, block_size: int, nodes_to_exclude=None): + nodes_to_exclude = nodes_to_exclude or [] + assert quant_type in [MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4] + self.model = ONNXModel(model) + self.quant_type = quant_type + self.block_size = block_size + self.nodes_to_exclude = set(nodes_to_exclude) + + @staticmethod + def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]: + for gid in range(len(graph_path) - 1, -1, -1): + graph = graph_path[gid] + for tensor in graph.initializer: + if tensor.name == name: + return tensor, graph + return None, None + + def bnb4_block_quant(self, fpweight: npt.ArrayLike) -> np.ndarray: + """4b quantize fp32/fp16 weight""" + + if len(fpweight.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + # need to copy since the transposed weight still has the original memory layout + # Linear4bit quantizes its weight data which is the transposed weight + fpweight_t = fpweight.transpose().copy() + + rows, cols = fpweight.shape + numel = rows * cols + block_size = self.block_size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=fpweight.dtype) + # block wise quantization, fpweight_t is flattened and divided into blocks + quantize_matmul_bnb4(packed, fpweight_t, absmax, block_size, self.quant_type, cols, rows) + + return (packed, absmax) + + def _bnb4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto: + """If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node""" + + if node.op_type != "MatMul": + return node # only care about MatMul for now + + logger.debug(f"start to quantize {node.name} ...") + if node.name in self.nodes_to_exclude: + logger.debug(f"exclude to quantize {node.name} as specified by nodes_to_exclude...") + return node + + inputB = node.input[1] # noqa: N806 + B, Bs_graph = MatMulBnb4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806 + if B is None: + logger.debug("MatMul doesn't have const weight. Skip to quantize") + return node # only care about constant weight + + B_array = onnx.numpy_helper.to_array(B) # noqa: N806 + if len(B_array.shape) != 2: + logger.debug("MatMul weight is not 2D. Skip to quantize") + return node # can only process 2-D matrix + + packed, absmax = self.bnb4_block_quant(B_array) + B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806 + B_quant.name = B.name + "_Bnb4" + for input in Bs_graph.input: + if input.name == inputB: + Bs_graph.input.remove(input) + break + + absmax_tensor = onnx.numpy_helper.from_array(absmax) + absmax_tensor.name = B.name + "_absmax" + + Bs_graph.initializer.extend([B_quant, absmax_tensor]) + + kwargs = {} + rows, cols = B_array.shape + kwargs["K"] = rows + kwargs["N"] = cols + kwargs["block_size"] = self.block_size + kwargs["quant_type"] = self.quant_type + + matmul_bnb4_node = onnx.helper.make_node( + "MatMulBnb4", + inputs=[node.input[0], B_quant.name, absmax_tensor.name], + outputs=[node.output[0]], + name=node.name + "_Bnb4" if node.name else "", + domain="com.microsoft", + **kwargs, + ) + + logger.debug(f"complete quantization of {node.name} ...") + + return matmul_bnb4_node + + def _process_subgraph(self, graph_stack: List[GraphProto]): + new_nodes = [] + graph = graph_stack[-1] + + for node in graph.node: + graph_attrs = [ + attr + for attr in node.attribute + if attr.type == onnx.AttributeProto.GRAPH or attr.type == onnx.AttributeProto.GRAPHS + ] + if len(graph_attrs): + kwargs = {} + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + # recursive call to take care of sub-graph + graph_stack.append(attr.g) + kv = {attr.name: self._process_subgraph(graph_stack)} + elif attr.type == onnx.AttributeProto.GRAPHS: + value = [] + for subgraph in attr.graphs: + # recursive call to take care of sub-graph + graph_stack.append(subgraph) + value.extend([self._process_subgraph(graph_stack)]) + kv = {attr.name: value} + else: + kv = attribute_to_kwarg(attr) + kwargs.update(kv) + node = onnx.helper.make_node( # noqa: PLW2901 + node.op_type, node.input, node.output, name=node.name, **kwargs + ) + + new_nodes.append(self._bnb4_matmul_node_weight(node, graph_stack)) + + graph.ClearField("node") + graph.node.extend(new_nodes) + graph_stack.pop() + return graph + + def process(self): + # use a stack to keep track of sub-graphs + graph_stack = [self.model.graph()] + opset_import = self.model.opset_import() + + has_ms_domain = False + for opset in opset_import: + if opset.domain == "com.microsoft": + has_ms_domain = True + if not has_ms_domain: + opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)]) + + self._process_subgraph(graph_stack) + self.model.clean_initializers() + + +def parse_args(): + parser = argparse.ArgumentParser( + description="""Blockwise FP4/NF4 quantization for MatMul 2D weight matrices. + +A weight matrix is partitioned into blocks, where each block is a contiguous +subset inside the flattened transposed weight matrix. Each block is quantized +into a set of 4b integers with an absolute value scaling factor. +""" + ) + + parser.add_argument("--input_model", required=True, help="Path to the input model file") + parser.add_argument("--output_model", required=True, help="Path to the output model file") + parser.add_argument( + "--quant_type", + required=False, + default=1, + options=[MatMulBnb4Quantizer.FP4, MatMulBnb4Quantizer.NF4], + help="Quantization data type. 0: FP4, 1: NF4", + ) + parser.add_argument( + "--block_size", + required=False, + default=64, + description="Block size for blockwise quantization. Note: bnb.nn.Linear4bit only uses block_size=64", + ) + parser.add_argument("-v", "--verbose", required=False, action="store_true") + parser.set_defaults(verbose=False) + parser.add_argument( + "--nodes_to_exclude", + nargs="+", + type=str, + required=False, + default=[], + help="Specify the nodes to be excluded from quantization with node names", + ) + + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + if args.verbose: + logger.setLevel(logging.DEBUG) + + input_model_path = args.input_model + output_model_path = args.output_model + + if os.path.exists(output_model_path): + logger.error(f"file {output_model_path} already exists") + raise Exception(f"file {output_model_path} already exists") + + model = onnx.load(input_model_path) + quant = MatMulBnb4Quantizer(model, args.quant_type, args.block_size, nodes_to_exclude=args.nodes_to_exclude) + quant.process() + quant.model.save_model_to_file(output_model_path, True) diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 6d954bd540718..272727a9f5375 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -206,9 +206,11 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "PackedAttention": self._infer_PackedAttention, "PackedMultiHeadAttention": self._infer_PackedMultiHeadAttention, "PythonOp": self._infer_PythonOp, + "QuickGelu": self._infer_FastGelu, "RelativePositionBias": self._infer_RelativePositionBias, "RemovePadding": self._infer_RemovePadding, "RestorePadding": self._infer_RestorePadding, + "RotaryEmbedding": self._infer_RotaryEmbedding, "SimplifiedLayerNormalization": self._infer_LayerNormalization, "SkipLayerNormalization": self._infer_SkipLayerNormalization, "SkipSimplifiedLayerNormalization": self._infer_SkipLayerNormalization, @@ -230,7 +232,6 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "upsample_nearest1d": self._infer_aten_upsample, "upsample_nearest2d": self._infer_aten_upsample, "upsample_nearest3d": self._infer_aten_upsample, - "upsample_bilinear2d": self._infer_aten_upsample, } self.run_ = True self.suggested_merge_ = {} @@ -463,6 +464,8 @@ def _onnx_infer_single_node(self, node): "BiasSplitGelu", "BiasAdd", "NhwcConv", + "QuickGelu", + "RotaryEmbedding", ] if not skip_infer: @@ -2308,6 +2311,9 @@ def _infer_FastGelu(self, node): # noqa: N802 def _infer_Gelu(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_QuickGelu(self, node): # noqa: N802 + self._propagate_shape_and_type(node) + def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) @@ -2379,6 +2385,19 @@ def _infer_BiasSplitGelu(self, node): # noqa: N802 def _infer_BiasAdd(self, node): # noqa: N802 self._propagate_shape_and_type(node) + def _infer_RotaryEmbedding(self, node): # noqa: N802 + if len(node.output) == 1: + self._propagate_shape_and_type(node) + elif len(node.output) == 2: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=0, output_index=1) # true output + elif len(node.output) == 3: + # Extraneous constant nodes outputted by RotaryEmbedding function made with `export_modules_as_functions` + self._propagate_shape_and_type(node, input_index=1, output_index=0) + self._propagate_shape_and_type(node, input_index=1, output_index=1) + self._propagate_shape_and_type(node, input_index=0, output_index=2) # true output + def _infer_PythonOp(self, node): # noqa: N802 output_tensor_types = get_attribute(node, "output_tensor_types") assert output_tensor_types @@ -2584,12 +2603,19 @@ def get_prereq(node): self._check_merged_dims(in_dims, allow_broadcast=True) for i_o in range(len(node.output)): - # Special case: We do not care about the training related - # outputs of SkipLayerNormalization + # Special cases: + # 1) We do not care about the training related outputs of SkipLayerNormalization + # 2) We do not care about the extraneous constant outputs in RotaryEmbedding because + # the RotaryEmbedding op created during export can be replaced by the RotaryEmbedding + # contrib op if ( node.op_type == "SkipLayerNormalization" or node.op_type == "SkipSimplifiedLayerNormalization" ) and i_o in [1, 2]: continue + if node.op_type == "RotaryEmbedding" and len(node.output) > 1: + # Skip symbolic shape inference for RotaryEmbedding functions that have extraneous outputs + # generated by `export_modules_as_functions` + continue vi = self.known_vi_[node.output[i_o]] out_type = vi.type @@ -2751,13 +2777,13 @@ def get_prereq(node): if i in self.known_vi_: logger.debug(self.known_vi_[i]) else: - logger.debug(f"not in knwon_vi_ for {i}") + logger.debug(f"not in known_vi_ for {i}") logger.debug("node outputs:") for o in node.output: if o in self.known_vi_: logger.debug(self.known_vi_[o]) else: - logger.debug(f"not in knwon_vi_ for {o}") + logger.debug(f"not in known_vi_ for {o}") if self.auto_merge_ and not out_type_undefined: logger.debug("Merging: " + str(self.suggested_merge_)) return False diff --git a/onnxruntime/python/tools/transformers/benchmark_helper.py b/onnxruntime/python/tools/transformers/benchmark_helper.py index 4f898245d01bd..b6f7a44450c62 100644 --- a/onnxruntime/python/tools/transformers/benchmark_helper.py +++ b/onnxruntime/python/tools/transformers/benchmark_helper.py @@ -33,6 +33,7 @@ class Precision(Enum): FLOAT32 = "fp32" FLOAT16 = "fp16" INT8 = "int8" + INT4 = "int4" def __str__(self): return self.value @@ -610,7 +611,7 @@ def measure_memory(is_gpu, func, monitor_type="cuda", start_memory=None): return memory_before_test with ThreadPoolExecutor() as executor: - monitor = MemoryMonitor() + monitor = memory_monitor_type() mem_thread = executor.submit(monitor.measure_cpu_usage) try: fn_thread = executor.submit(func) diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index c1c709d6d759b..4228c892d03ae 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1272,6 +1272,38 @@ def find_past_seq_len_usage(subg: GraphProto): return tensor_names_to_rename, nodes_to_remove +def replace_mha_with_gqa(model: OnnxModel, past_seq_len_input: str, kv_num_heads: int = 0): + past_seq_len = past_seq_len_input + if past_seq_len not in model.get_graphs_input_names(): + # Replace model input for past sequence length + new_input = onnx.helper.make_tensor_value_info(past_seq_len, onnx.TensorProto.INT64, shape=[1]) + model.model.graph.input.append(new_input) + + # Replace MultiHeadAttention with GroupQueryAttention + for node in model.model.graph.node: + if node.op_type == "MultiHeadAttention": + gqa_node = onnx.helper.make_node( + "GroupQueryAttention", + inputs=[ + node.input[0], # query + node.input[1], # key + node.input[2], # value + node.input[6], # past_key + node.input[7], # past_value + past_seq_len, # past_sequence_length + ], + outputs=node.output, + name=node.name.replace("MultiHeadAttention", "GroupQueryAttention"), + domain="com.microsoft", + num_heads=node.attribute[0].i, + kv_num_heads=node.attribute[0].i if kv_num_heads == 0 else kv_num_heads, + is_past_bsnh=0, + ) + model.model.graph.node.remove(node) + model.model.graph.node.extend([gqa_node]) + return model + + def update_decoder_subgraph_output_cross_attention(subg: GraphProto): input_self_past_0 = 1 # w/wo attention mask, w/wo hidden_state diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index 1dbdf39613cdd..c1b241aa1a5ec 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -111,7 +111,7 @@ def __init__( model: OnnxModel, hidden_size: int, num_heads: int, - attention_mask: AttentionMask, + attention_mask: Optional[AttentionMask] = None, use_multi_head_attention: bool = False, disable_multi_head_attention_bias: bool = False, search_op_types: List[str] = ["SkipLayerNormalization", "LayerNormalization"], # noqa: B006 @@ -120,7 +120,7 @@ def __init__( super().__init__(model, attention_op_name, search_op_types) self.hidden_size = hidden_size self.num_heads = num_heads - self.attention_mask = attention_mask + self.attention_mask = attention_mask if attention_mask else AttentionMask(model) self.use_multi_head_attention = use_multi_head_attention self.disable_multi_head_attention_bias = disable_multi_head_attention_bias self.mask_filter_value = None @@ -219,6 +219,31 @@ def get_add_qk_str(self, add_qk: NodeProto): return add_qk.input[1] + def reshape_add_qk(self, add_qk: str): + # Convert 4D mask from (B,1,S,T) to (B,N,S,T) + # B = batch size, N = num heads, S = source sequence length, T = target sequence length + mask_output_name = add_qk + "_mask" + + # Check if concat node for (B,1,S,T) --> (B,N,S,T) already exists + concat_node = list(filter(lambda node: node.output[0] == mask_output_name, self.nodes_to_add)) + if len(concat_node) == 1: + return mask_output_name + + assert len(concat_node) == 0 + concat_node_name = self.model.create_node_name("Concat") + concat_add_qk_fp32 = helper.make_node( + "Concat", + inputs=[add_qk for _ in range(self.num_heads)], + outputs=[mask_output_name], + name=concat_node_name, + axis=1, + ) + # Add new node to graph + self.nodes_to_add.append(concat_add_qk_fp32) + self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + + return mask_output_name + def concat_kv(self, past_k: str, past_v: str) -> str: """Concatenate past_k and past_v inputs to create past_kv input. @@ -875,21 +900,8 @@ def create_attention_node( past_kv = self.concat_kv(past_k, past_v) attention_inputs.append(past_kv) - if add_qk_str: - # Convert 4d mask from (B,1,M,M) to (B,N,M,M) - # B = batch size, M = max sequence length, N = num heads - concat_node_name = self.model.create_node_name("Concat") - mask_output_name = add_qk_str + "_mask" - concat_add_qk_fp32 = helper.make_node( - "Concat", - inputs=[add_qk_str for _ in range(num_heads)], - outputs=[mask_output_name], - name=concat_node_name, - axis=1, - ) - # Add new nodes to graph - self.nodes_to_add.append(concat_add_qk_fp32) - self.node_name_to_graph_name[concat_node_name] = self.this_graph_name + if add_qk_str is not None: + mask_output_name = self.reshape_add_qk(add_qk_str) # Add attention mask to attention node if not past_exists: diff --git a/onnxruntime/python/tools/transformers/fusion_base.py b/onnxruntime/python/tools/transformers/fusion_base.py index 117468be412fa..c5d7bc16d64f7 100644 --- a/onnxruntime/python/tools/transformers/fusion_base.py +++ b/onnxruntime/python/tools/transformers/fusion_base.py @@ -113,3 +113,20 @@ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: self.model.add_initializer(tensor, self.this_graph_name) return tensor + + def add_nodes_to_remove(self, nodes: List[NodeProto]): + # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths). + # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B + # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are + # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first. + # Since path A's shared nodes are removed, path B's shared nodes are not removed because they + # were previously removed for path A. This causes an error to print in remove_node that a node + # has failed to be removed. + # + # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`. + # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could + # be scenarios where the nodes need to be removed in a specific order and converting to a set would + # lose this order. + for node in nodes: + if node not in self.nodes_to_remove: + self.nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/fusion_options.py b/onnxruntime/python/tools/transformers/fusion_options.py index 69b5cd26f4525..8c80fcad0ab49 100644 --- a/onnxruntime/python/tools/transformers/fusion_options.py +++ b/onnxruntime/python/tools/transformers/fusion_options.py @@ -26,6 +26,7 @@ def __init__(self, model_type): self.enable_gelu = True self.enable_layer_norm = True self.enable_attention = True + self.enable_rotary_embeddings = True # Use MultiHeadAttention instead of Attention operator. The difference: # (1) Attention has merged weights for Q/K/V projection, which might be faster in some cases since 3 MatMul is @@ -81,6 +82,8 @@ def parse(args): options.enable_gelu = False if args.disable_layer_norm: options.enable_layer_norm = False + if args.disable_rotary_embeddings: + options.enable_rotary_embeddings = False if args.disable_attention: options.enable_attention = False if args.use_multi_head_attention: @@ -294,3 +297,10 @@ def add_arguments(parser: ArgumentParser): help="Use channels_first (NCHW) instead of channels_last (NHWC) for GroupNorm. Only works for model_type=unet or vae", ) parser.set_defaults(use_group_norm_channels_first=False) + + parser.add_argument( + "--disable_rotary_embeddings", + required=False, + action="store_true", + help="Do not fuse rotary embeddings into RotaryEmbedding op", + ) diff --git a/onnxruntime/python/tools/transformers/fusion_rotary_attention.py b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py new file mode 100644 index 0000000000000..3c5029ac5752f --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_rotary_attention.py @@ -0,0 +1,1044 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import logging +from typing import Optional, Union + +from fusion_attention import FusionAttention +from fusion_base import Fusion +from onnx import FunctionProto, NodeProto, TensorProto, helper, numpy_helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionRotaryAttention(FusionAttention): + """ + Fuse Attention subgraph with rotary positional embeddings into one MultiHeadAttention node. + """ + + def __init__( + self, + model: OnnxModel, + hidden_size: int, + num_heads: int, + ): + super().__init__( + model, + hidden_size, + num_heads, + use_multi_head_attention=True, + search_op_types=["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Add"], + ) + + def create_mha_node( + self, + input: str, + output: str, + q_rotary: NodeProto, + k_rotary: NodeProto, + v_matmul: NodeProto, + attn_mask: str = "", + add_qk: str = "", + past_k: str = "", + past_v: str = "", + present_k: str = "", + present_v: str = "", + scale: Optional[float] = None, + ) -> Union[NodeProto, None]: + assert self.num_heads > 0 + + if self.hidden_size > 0 and (self.hidden_size % self.num_heads) != 0: + logger.debug( + f"fuse_rotary_attention: input hidden size {self.hidden_size} is not a multiple of num of heads {self.num_heads}" + ) + return None + + mha_node_name = self.model.create_node_name("MultiHeadAttention") + mha_inputs = [ + q_rotary.output[0], + k_rotary.output[0], + v_matmul.output[0], + "", # bias + attn_mask, # key_padding_mask + add_qk, # relative_position_bias + past_k, + past_v, + ] + + mha_outputs = [output] + if present_k and present_v: + mha_outputs.extend([present_k, present_v]) + + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=mha_outputs, + name=mha_node_name, + ) + + mha_node.domain = "com.microsoft" + mha_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)]) + if scale is not None: + mha_node.attribute.extend([helper.make_attribute("scale", scale)]) + if self.mask_filter_value is not None: + mha_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))]) + + self.increase_counter("MultiHeadAttention") + return mha_node + + def check_runtime_shape_paths_for_function( + self, + reshape_qkv_2, # Reshape after Transpose + reshape_qkv_1, # Reshape before Transpose + reshape_q_2, # Reshape after RotaryEmbedding + reshape_k_2, # Reshape after RotaryEmbedding + reshape_v_2, # Reshape after Transpose + reshape_v_1, # Reshape before Transpose + add_qk, # Add before Softmax + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_2_path = self.model.match_parent_path(reshape_qkv_2, ["Concat"], [1]) + concat_qkv_1_path = self.model.match_parent_path(reshape_qkv_1, ["Concat"], [1]) + if concat_qkv_2_path is None or concat_qkv_1_path is None: + return False + concat_qkv_2, concat_qkv_1 = concat_qkv_2_path[0], concat_qkv_1_path[0] + + reshape_qkv_2_path_1 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_2_path_2 = self.model.match_parent_path(concat_qkv_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + reshape_qkv_1_path_1 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_1_path_2 = self.model.match_parent_path(concat_qkv_1, ["Unsqueeze", "Gather", "Shape"], [2, 0, 0]) + if ( + reshape_qkv_2_path_1 is None + or reshape_qkv_2_path_2 is None + or reshape_qkv_1_path_1 is None + or reshape_qkv_1_path_2 is None + ): + return False + + _, gather_1, shape_1 = reshape_qkv_2_path_1 + _, gather_2, shape_2 = reshape_qkv_2_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_qkv_1_path_1 and reshape_qkv_1_path_2 + if reshape_qkv_1_path_1[1].name != gather_1.name or reshape_qkv_1_path_2[1].name != gather_2.name: + return False + + # Check #2: check paths for v nodes + concat_v_2_path = self.model.match_parent_path(reshape_v_2, ["Concat"], [1]) + concat_v_1_path = self.model.match_parent_path(reshape_v_1, ["Concat"], [1]) + if concat_v_2_path is None or concat_v_1_path is None: + return False + concat_v_2, concat_v_1 = concat_v_2_path[0], concat_v_1_path[0] + + reshape_v_2_path_1 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_v_2_path_2 = self.model.match_parent_path( + concat_v_2, ["Unsqueeze", "Add", "Gather", "Shape"], [1, 0, 0, 0] + ) + reshape_v_1_path_1 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_1_path_2 = self.model.match_parent_path(concat_v_1, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if ( + reshape_v_2_path_1 is None + or reshape_v_2_path_2 is None + or reshape_v_1_path_1 is None + or reshape_v_1_path_2 is None + ): + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_v_2_path_2 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_v_1_path_1 and reshape_v_1_path_2 + if ( + reshape_v_2_path_1[2].name != gather_1.name + or reshape_v_2_path_2[2].name != gather_2.name + or reshape_v_1_path_1[1].name != gather_1.name + or reshape_v_1_path_2[1].name != gather_2.name + ): + return False + + # Check #3: check paths for k nodes + concat_k_2_path = self.model.match_parent_path(reshape_k_2, ["Concat"], [1]) + if concat_k_2_path is None: + return False + concat_k_2 = concat_k_2_path[0] + + reshape_k_2_path_1 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_k_2_path_2 = self.model.match_parent_path( + concat_k_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 0, 0] + ) + if reshape_k_2_path_1 is None or reshape_k_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_1 + # Check Gather --> Add --> Unsqueeze --> Concat --> Reshape connection for reshape_k_2_path_2 + if reshape_k_2_path_1[2].name != gather_1.name or reshape_k_2_path_2[2].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_2_path = self.model.match_parent_path(reshape_q_2, ["Concat"], [1]) + if concat_q_2_path is None: + return False + concat_q_2 = concat_q_2_path[0] + + reshape_q_2_path_1 = self.model.match_parent_path( + concat_q_2, ["Unsqueeze", "Mul", "Gather", "Shape"], [0, 0, 0, 0] + ) + reshape_q_2_path_2 = self.model.match_parent_path(concat_q_2, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_2_path_1 is None or reshape_q_2_path_2 is None: + return False + + # Check Gather --> Mul --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_1 + # Check Gather --> Unsqueeze --> Concat --> Reshape connection for reshape_q_2_path_2 + if reshape_q_2_path_1[2].name != gather_1.name or reshape_q_2_path_2[1].name != gather_2.name: + return False + + # Check #5: check Mul nodes are the same for q, k, v + mul_q = reshape_q_2_path_1[1] + mul_k = reshape_k_2_path_1[1] + mul_v = reshape_v_2_path_1[1] + gather_1_out = gather_1.output[0] + if mul_q.input[0] != gather_1_out or mul_k.input[0] != gather_1_out or mul_v.input[0] != gather_1_out: + return False + + # Check #6: check paths for attention mask nodes + attn_mask_path_1 = self.model.match_parent_path(add_qk, ["Concat", "Slice", "Slice"], [1, 0, 0]) + attn_mask_path_2 = self.model.match_parent_path(add_qk, ["Cast", "Concat", "Slice", "Slice"], [1, 0, 0, 0]) + if attn_mask_path_1 is not None: + _, slice_qk_2, slice_qk_1 = attn_mask_path_1 + elif attn_mask_path_2 is not None: + _, _, slice_qk_2, slice_qk_1 = attn_mask_path_2 + else: + return False + # Check first input to Slice #1 is 3D attention mask of shape (B,S,T) + if slice_qk_1.input[0] not in {"attn_mask", "attention_mask"}: + return False + + slice_qk_2_path = self.model.match_parent_path( + slice_qk_2, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_1 = self.model.match_parent_path( + slice_qk_1, ["Unsqueeze", "Add", "Gather", "Shape"], [2, 0, 1, 0] + ) + slice_qk_1_path_2 = self.model.match_parent_path(slice_qk_1, ["Unsqueeze"], [1]) + if slice_qk_2_path is None or slice_qk_1_path_1 is None or slice_qk_1_path_2 is None: + return False + + # Check Gather --> Add --> Unsqueeze #3 --> Slice #2 connection for slice_qk_2_path + # Check Gather --> Add --> Unsqueeze #2 --> Slice #1 connection for slice_qk_1_path_1 + if slice_qk_2_path[1].name != slice_qk_1_path_1[1].name or slice_qk_2_path[2].name != slice_qk_1_path_1[2].name: + return False + + # Check Unsqueeze #1 --> Slice #1 connection for slice_qk_1_path_2 + # Check if first input to Add and Unsqueeze #1 is position ids + if slice_qk_1_path_1[1].input[0] != slice_qk_1_path_2[0].input[0]: + return False + + return True + + def check_runtime_shape_paths_for_nodes( + self, + reshape_qkv, # Final reshape before o_proj MatMul + reshape_q, # Reshape before q_proj MatMul + reshape_k, # Reshape before k_proj MatMul + reshape_v, # Reshape before v_proj MatMul + root_input, # Root input to attention subgraph + ): + # Check #1: check paths for qkv nodes + concat_qkv_path = self.model.match_parent_path(reshape_qkv, ["Concat"], [1]) + if concat_qkv_path is None: + return False + concat_qkv = concat_qkv_path[0] + + reshape_qkv_path_1 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_qkv_path_2 = self.model.match_parent_path(concat_qkv, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_qkv_path_1 is None or reshape_qkv_path_2 is None: + return False + + _, gather_1, shape_1 = reshape_qkv_path_1 + _, gather_2, shape_2 = reshape_qkv_path_2 + + # Check root_input --> Shape --> Gather connection + if shape_1.input[0] != root_input or shape_2.input[0] != root_input: + return False + + # Check #2: check paths for v nodes + concat_v_path = self.model.match_parent_path(reshape_v, ["Concat"], [1]) + if concat_v_path is None: + return False + concat_v = concat_v_path[0] + + reshape_v_path_1 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_v_path_2 = self.model.match_parent_path(concat_v, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_v_path_1 is None or reshape_v_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_v_path_1[1].name != gather_1.name or reshape_v_path_2[1].name != gather_2.name: + return False + + # Check #3: check paths for k nodes + concat_k_path = self.model.match_parent_path(reshape_k, ["Concat"], [1]) + if concat_k_path is None: + return False + concat_k = concat_k_path[0] + + reshape_k_path_1 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_k_path_2 = self.model.match_parent_path(concat_k, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_k_path_1 is None or reshape_k_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_k_path_1[1].name != gather_1.name or reshape_k_path_2[1].name != gather_2.name: + return False + + # Check #4: check paths for q nodes + concat_q_path = self.model.match_parent_path(reshape_q, ["Concat"], [1]) + if concat_q_path is None: + return False + concat_q = concat_q_path[0] + + reshape_q_path_1 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [0, 0, 0]) + reshape_q_path_2 = self.model.match_parent_path(concat_q, ["Unsqueeze", "Gather", "Shape"], [1, 0, 0]) + if reshape_q_path_1 is None or reshape_q_path_2 is None: + return False + + # Check Gather --> Unsqueeze --> Concat --> Reshape connection + if reshape_q_path_1[1].name != gather_1.name or reshape_q_path_2[1].name != gather_2.name: + return False + + return True + + def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add": + return + + # qkv_nodes_1 is for LLaMA-2 Microsoft + # qkv_nodes_2 is for LLaMA-2 Hugging Face + qkv_nodes = None + qkv_nodes_1 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + qkv_nodes_2 = self.model.match_parent_path( + normalize_node, + ["MatMul", "Reshape", "Transpose", "MatMul"], + [1, 0, 0, 0], + ) + if qkv_nodes_1 is not None: + _, reshape_qkv_2, _, reshape_qkv_1, matmul_qkv = qkv_nodes_1 + qkv_nodes = qkv_nodes_1 + elif qkv_nodes_2 is not None: + _, reshape_qkv, _, matmul_qkv = qkv_nodes_2 + qkv_nodes = qkv_nodes_2 + else: + logger.debug("fuse_rotary_attention: failed to match qkv nodes") + return + + # v_nodes_1 is for LLaMA-2 Microsoft + # v_nodes_3 is for LLaMA-2 Hugging Face + past_v, present_v, past_seq_len = "", "", "" + v_nodes = None + v_nodes_1 = self.model.match_parent_path( + matmul_qkv, + ["Reshape", "Transpose", "Concat", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + v_nodes_2 = self.model.match_parent_path( + matmul_qkv, + ["Concat", "Transpose", "Reshape", "MatMul"], + [1, 1, 0, 0], + ) + v_nodes_3 = self.model.match_parent_path( + matmul_qkv, + ["Transpose", "Reshape", "MatMul"], + [1, 0, 0], + ) + if v_nodes_1 is not None: + reshape_v_2, _, concat_v, _, reshape_v_1, matmul_v = v_nodes_1 + v_nodes = v_nodes_1 + + concat_v_path = self.model.match_parent_path( + concat_v, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_v_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in v path") + return + + past_v = concat_v_path[0].input[0] + past_seq_len = concat_v_path[-1].input[0] + present_v = concat_v.output[0] + elif v_nodes_2 is not None: + concat_v, transpose_v, reshape_v, matmul_v = v_nodes_2 + v_nodes = v_nodes_2 + past_v = concat_v.input[0] + present_v = concat_v.output[0] + elif v_nodes_3 is not None: + transpose_v, reshape_v, matmul_v = v_nodes_3 + v_nodes = v_nodes_3 + present_v = transpose_v.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match v path") + return + + qk_nodes = self.model.match_parent_path( + matmul_qkv, + ["Softmax", "Add", "Div", "MatMul"], + [0, 0, 0, 0], + ) + add_qk, matmul_qk = None, None + if qk_nodes is not None: + _, add_qk, _, matmul_qk = qk_nodes + else: + logger.debug("fuse_rotary_attention: failed to match qk nodes") + return + + # attn_mask_nodes_1, attn_mask_nodes_2 are for LLaMA-2 Microsoft's 3D attention mask + # attn_mask_nodes_3, attn_mask_nodes_4 are for LLaMA-2 Hugging Face's 2D attention mask + attn_mask, add_qk_str = "", "" + attn_mask_nodes_1 = self.model.match_parent_path( + add_qk, + ["Concat", "Slice", "Slice"], + [1, 0, 0], + ) + attn_mask_nodes_2 = self.model.match_parent_path( + add_qk, + ["Cast", "Concat", "Slice", "Slice"], + [1, 0, 0, 0], + ) + attn_mask_nodes_3 = self.model.match_parent_path( + add_qk, + ["Add", "Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 0, 2, 1, 0, 0, 0], + ) + attn_mask_nodes_4 = self.model.match_parent_path( + add_qk, + ["Where", "Sub", "Cast", "Expand", "Unsqueeze", "Unsqueeze"], + [1, 2, 1, 0, 0, 0], + ) + if attn_mask_nodes_1 is not None: + _, slice_mask_1, slice_mask_2 = attn_mask_nodes_1 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_2 is not None: + _, _, slice_mask_1, slice_mask_2 = attn_mask_nodes_2 + attn_mask = slice_mask_1.output[0] + elif attn_mask_nodes_3 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_3[0].output[0]) + elif attn_mask_nodes_4 is not None: + # Reshape from (B,1,S,T) to (B,N,S,T) + add_qk_str = self.reshape_add_qk(attn_mask_nodes_4[0].output[0]) + else: + logger.debug("fuse_rotary_attention: failed to match attention mask nodes") + return + + # k_nodes_1 is for LLaMA-2 Microsoft + # k_nodes_2 is for LLaMA-2 Hugging Face + past_k, present_k = "", "" + k_nodes = None + k_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "Concat", "Transpose", "RotaryEmbedding", "MatMul"], + [1, 0, 0, 1, 0, 0], + ) + k_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 0, 0, 0], + ) + k_nodes_3 = self.model.match_parent_path( + matmul_qk, + ["Transpose", "Concat", "RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [1, 0, 1, 0, 0, 0], + ) + if k_nodes_1 is not None: + reshape_k_2, _, concat_k, _, rotary_k, matmul_k = k_nodes_1 + k_nodes = k_nodes_1 + + concat_k_path = self.model.match_parent_path( + concat_k, + ["Slice", "Unsqueeze"], + [0, 2], + ) + if concat_k_path is None: + logger.debug("fuse_rotary_attention: failed to match past/present concat in k path") + return + + past_k = concat_k_path[0].input[0] + shared_past_seq_len = concat_k_path[-1].input[0] + present_k = concat_k.output[0] + + assert past_seq_len == shared_past_seq_len + elif k_nodes_2 is not None: + _, rotary_k, _, reshape_k, matmul_k = k_nodes_2 + k_nodes = k_nodes_2 + present_k = rotary_k.output[0] + elif k_nodes_3 is not None: + _, concat_k, rotary_k, _, reshape_k, matmul_k = k_nodes_3 + k_nodes = k_nodes_3 + past_k = concat_k.input[0] + present_k = concat_k.output[0] + else: + logger.debug("fuse_rotary_attention: failed to match k nodes") + return + + # q_nodes_1 is for LLaMA-2 Microsoft + # q_nodes_2 is for LLaMA-2 Hugging Face + q_nodes = None + q_nodes_1 = self.model.match_parent_path( + matmul_qk, + ["Reshape", "Transpose", "RotaryEmbedding", "MatMul"], + [0, 0, 0, 0], + ) + q_nodes_2 = self.model.match_parent_path( + matmul_qk, + ["RotaryEmbedding", "Transpose", "Reshape", "MatMul"], + [0, 0, 0, 0], + ) + if q_nodes_1 is not None: + reshape_q_2, _, rotary_q, matmul_q = q_nodes_1 + q_nodes = q_nodes_1 + elif q_nodes_2 is not None: + rotary_q, _, reshape_q, matmul_q = q_nodes_2 + q_nodes = q_nodes_2 + else: + logger.debug("fuse_rotary_attention: failed to match q nodes") + return + + if matmul_q.input[0] != matmul_k.input[0] and matmul_k.input[0] != matmul_v.input[0]: + logger.debug("fuse_rotary_attention: failed to find the same root_input for q, k, v paths") + return + + root_output = "" + if qkv_nodes == qkv_nodes_1: + if not self.check_runtime_shape_paths_for_function( + reshape_qkv_2, + reshape_qkv_1, + reshape_q_2, + reshape_k_2, + reshape_v_2, + reshape_v_1, + add_qk, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv_2.output[0] + + elif qkv_nodes == qkv_nodes_2: + if not self.check_runtime_shape_paths_for_nodes( + reshape_qkv, + reshape_q, + reshape_k, + reshape_v, + matmul_q.input[0], + ): + logger.debug("fuse_rotary_attention: failed to verify runtime shape paths") + return + root_output = reshape_qkv.output[0] + + # Rename inputs of rotary_q/k so it connects with output of matmul_q/k + # Before: MatMul --> Reshape --> Transpose --> RotaryEmbedding + # After: MatMul --> RotaryEmbedding + rotary_q.input[0] = matmul_q.output[0] + rotary_k.input[0] = matmul_k.output[0] + + # Rename current output of rotary_k (present_key) so it doesn't match output of MHA (present_key) + rotary_k.output[0] = rotary_k.name + "_output_0" + + new_node = self.create_mha_node( + matmul_q.input[0], + root_output, + rotary_q, + rotary_k, + matmul_v, + attn_mask, + add_qk_str, + past_k, + past_v, + present_k, + present_v, + ) + if new_node is None: + logger.debug("fuse_rotary_attention: failed to create multi-head attention with rotary embeddings") + return + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend(qkv_nodes[1:]) + self.nodes_to_remove.extend(v_nodes[:-1]) + self.nodes_to_remove.extend(qk_nodes) + + if k_nodes == k_nodes_1: + self.nodes_to_remove.extend(k_nodes[:-2]) + elif k_nodes == k_nodes_2: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[2]) + self.nodes_to_remove.append(k_nodes[3]) + elif k_nodes == k_nodes_3: + self.nodes_to_remove.append(k_nodes[0]) + self.nodes_to_remove.append(k_nodes[1]) + self.nodes_to_remove.append(k_nodes[3]) + self.nodes_to_remove.append(k_nodes[4]) + + if q_nodes == q_nodes_1: + self.nodes_to_remove.extend(q_nodes[:-2]) + elif q_nodes == q_nodes_2: + self.nodes_to_remove.append(q_nodes[1]) + self.nodes_to_remove.append(q_nodes[2]) + + self.prune_graph = True + + +class FusionRotaryEmbeddings(Fusion): + def __init__(self, model: OnnxModel): + self.base_name = "RotaryEmbedding" + super().__init__(model, self.base_name, [self.base_name, self.base_name + ".1", "Add"]) + + # The RotaryEmbedding function can have multiple extraneous constant outputs even though the function is supposed to produce only one output. + # This is a byproduct of a potential CSE bug when using `export_modules_as_functions` in the TorchScript exporter. + # To work around this issue, we set the extraneous constant values from the RotaryEmbedding function as initializers in the locations where they are actually used. + def reassign_extra_outputs(self, rot_emb_node: NodeProto, function: FunctionProto): + # Find extra outputs and Constant nodes attached to those outputs + extra_constants, extra_outputs = [], [] + for fn_node in function.node: + if fn_node.op_type == "Constant" and fn_node.input == [] and fn_node.output[0] in function.output: + extra_constants.append(fn_node) + output_index = list(function.output).index(fn_node.output[0]) + extra_outputs.append(rot_emb_node.output[output_index]) + + # Set extra Constant node outputs as initializers + extra_initializers = [] + for extra_constant in extra_constants: + constant_tensorproto = extra_constant.attribute[0].t + constant_tensorproto.name = self.model.create_node_name("Constant") + self.model.add_initializer(constant_tensorproto) + extra_initializers.append(constant_tensorproto.name) + + # Update references of Constant node outputs to initializer references + for extra_output, extra_initializer in zip(extra_outputs, extra_initializers): + nodes_to_update = list(filter(lambda entry: extra_output in entry.input, self.model.model.graph.node)) + for node_to_update in nodes_to_update: + OnnxModel.replace_node_input(node_to_update, extra_output, extra_initializer) + + return extra_outputs + + def create_rotary_embeddings_from_function(self, node: NodeProto): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + matmul_path = self.model.match_parent_path( + node, + ["Reshape", "MatMul"], + [0, 0], + ) + if matmul_path is not None: + reshape_node, matmul_node = matmul_path + else: + logger.debug("fuse_rotary_embeddings: failed to match MatMul") + return + + rotary_emb_inputs = [ + matmul_node.output[0], # x is of shape (B,S,D) instead of (B,S,N,H) + node.input[1], # position_ids + ] + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == node.input[2], self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == node.input[3], self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_inputs.extend([cos_cache_name, sin_cache_name]) + + rotary_emb_outputs = node.output + if len(rotary_emb_outputs) > 1: + # Re-assign extraneous constant outputs in RotaryEmbedding functions as initializers + func = list(filter(lambda fn: fn.name == node.op_type, self.model.model.functions)) + assert len(func) == 1 + extra_outputs = self.reassign_extra_outputs(node, func[0]) + rotary_emb_outputs = list(filter(lambda output_name: output_name not in extra_outputs, rotary_emb_outputs)) + assert len(rotary_emb_outputs) == 1 + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=rotary_emb_inputs, + outputs=rotary_emb_outputs, + name=rotary_emb_node_name, + interleaved=1, + ) + rotary_emb_node.domain = "com.microsoft" + + self.nodes_to_remove.append(reshape_node) + + return rotary_emb_node + + def create_rotary_embeddings_from_nodes( + self, + root_input: str, + position_ids: str, + cos_slice: str, + sin_slice: str, + output: str, + ): + rotary_emb_node_name = self.model.create_node_name(self.base_name) + + # Convert cos_cache and sin_cache from node attributes to model initializers + cos_cache_node = list(filter(lambda constant: constant.output[0] == cos_slice, self.model.model.graph.node)) + sin_cache_node = list(filter(lambda constant: constant.output[0] == sin_slice, self.model.model.graph.node)) + cos_cache_name, sin_cache_name = "cos_cache", "sin_cache" + + if ( + len(cos_cache_node) == 1 + and len(sin_cache_node) == 1 + and self.model.get_initializer(cos_cache_name) is None + and self.model.get_initializer(sin_cache_name) is None + ): + cos_cache = numpy_helper.to_array(cos_cache_node[0].attribute[0].t).squeeze() + sin_cache = numpy_helper.to_array(sin_cache_node[0].attribute[0].t).squeeze() + + # Reshape cos/sin cache from (M, H) to (M, H/2) + head_size = cos_cache.shape[1] + cos_cache = cos_cache[:, : (head_size // 2)] + sin_cache = sin_cache[:, : (head_size // 2)] + + cos_cache_tensor = helper.make_tensor( + name=cos_cache_name, + data_type=TensorProto.FLOAT, + dims=list(cos_cache.shape), + vals=cos_cache.flatten().tolist(), + ) + self.model.add_initializer(cos_cache_tensor, self.this_graph_name) + sin_cache_tensor = helper.make_tensor( + name=sin_cache_name, + data_type=TensorProto.FLOAT, + dims=list(sin_cache.shape), + vals=sin_cache.flatten().tolist(), + ) + self.model.add_initializer(sin_cache_tensor, self.this_graph_name) + + self.nodes_to_remove.extend([cos_cache_node[0], sin_cache_node[0]]) + + rotary_emb_node = helper.make_node( + self.base_name, + inputs=[root_input, position_ids, cos_cache_name, sin_cache_name], + outputs=[output], + name=rotary_emb_node_name, + interleaved=0, + ) + rotary_emb_node.domain = "com.microsoft" + return rotary_emb_node + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + # Node is either RotaryEmbedding function or Add + if self.base_name not in node.op_type and node.op_type != "Add": + return + + # Check if node is "RotaryEmbedding nn.Module" exported as a function + # (e.g. export_modules_as_functions={RotaryEmbedding} in torch.onnx.export) + rotary_emb_node = None + if node.op_type != "Add": + # Verify that function has the correct inputs + if len(node.input) not in {4, 5} or node.input[1] not in { + "pos", + "pos_id", + "position_id", + "pos_ids", + "position_ids", + }: + logger.debug("fuse_rotary_embeddings: failed to verify inputs for RotaryEmbedding function") + return + + rotary_emb_node = self.create_rotary_embeddings_from_function(node) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove RotaryEmbedding function + self.nodes_to_remove.append(node) + + # Remove RotaryEmbedding function's shape inference stored in value_info + # The new shape will be calculated during symbolic shape inference + old_shape_infer = list( + filter(lambda node: node.name == rotary_emb_node.output[0], self.model.model.graph.value_info) + ) + assert len(old_shape_infer) == 1 + self.model.model.graph.value_info.remove(old_shape_infer[0]) + + else: + # Rotary embeddings are defined using the below functions: + # + # def rotate_half(x): + # """Rotates half the hidden dims of the input.""" + # x1 = x[..., : x.shape[-1] // 2] + # x2 = x[..., x.shape[-1] // 2 :] + # return torch.cat((-x2, x1), dim=-1) + # + # def apply_rope(x, cos, sin, position_ids): + # cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + # sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + # cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + # x_embed = (x * cos) + (rotate_half(x) * sin) + # return x_embed + + # Check paths for rotate_half(x) + rotate_half_x2_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Transpose"], + [1, 0, 0, 0, 0], + ) + rotate_half_x2_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Neg", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 0, 0, 1, 0, 0, 0, 0], + ) + if rotate_half_x2_path_1 is None or rotate_half_x2_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x2 in rotate_half") + return + + rotate_half_x1_path_1 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Transpose"], + [1, 0, 1, 0], + ) + rotate_half_x1_path_2 = self.model.match_parent_path( + node, + ["Mul", "Concat", "Slice", "Unsqueeze", "Div", "Gather", "Shape", "Transpose"], + [1, 0, 1, 2, 0, 0, 0, 0], + ) + if rotate_half_x1_path_1 is None or rotate_half_x1_path_2 is None: + logger.debug("fuse_rotary_embeddings: failed to match x1 in rotate_half") + return + + if ( + rotate_half_x1_path_1[-1].name != rotate_half_x1_path_2[-1].name + or rotate_half_x2_path_1[-1].name != rotate_half_x2_path_2[-1].name + or rotate_half_x1_path_1[-1].name != rotate_half_x2_path_1[-1].name + or rotate_half_x1_path_2[-1].name != rotate_half_x2_path_2[-1].name + ): + logger.debug("fuse_rotary_embeddings: failed to match common input in rotate_half") + return + + # Check path for x + x_path = self.model.match_parent_path( + node, + ["Mul", "Transpose"], + [0, 0], + ) + if x_path is None: + logger.debug("fuse_rotary_embeddings: failed to match x in rotate_half") + return + + # Check path for sin + sin_path, sin_cache, position_ids = None, "", "" + sin_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 0, 0, 2, 0, 0], + ) + sin_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 0, 0, 2, 0], + ) + sin_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [1, 1, 0, 0, 2, 0, 0], + ) + sin_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [1, 1, 0, 0, 2, 0], + ) + if sin_path_1 is not None: + sin_path = sin_path_1 + sin_cache = sin_path[-4].input[0] + elif sin_path_2 is not None: + sin_path = sin_path_2 + sin_cache = sin_path[-3].input[0] + elif sin_path_3 is not None: + sin_path = sin_path_3 + sin_cache = sin_path[-4].input[0] + position_ids = sin_path[2].input[1] + elif sin_path_4 is not None: + sin_path = sin_path_4 + sin_cache = sin_path[-3].input[0] + position_ids = sin_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for cos + cos_path, cos_cache = None, "" + cos_path_1 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 0, 0, 2, 0, 0], + ) + cos_path_2 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Squeeze", "Squeeze", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 0, 0, 2, 0], + ) + cos_path_3 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Gather", "Shape"], + [0, 1, 0, 0, 2, 0, 0], + ) + cos_path_4 = self.model.match_parent_path( + node, + ["Mul", "Unsqueeze", "Gather", "Slice", "Unsqueeze", "Add"], + [0, 1, 0, 0, 2, 0], + ) + if cos_path_1 is not None: + cos_path = cos_path_1 + cos_cache = cos_path[-4].input[0] + elif cos_path_2 is not None: + cos_path = cos_path_2 + cos_cache = cos_path[-3].input[0] + elif cos_path_3 is not None: + cos_path = cos_path_3 + cos_cache = cos_path[-4].input[0] + position_ids = cos_path[2].input[1] + elif cos_path_4 is not None: + cos_path = cos_path_4 + cos_cache = cos_path[-3].input[0] + position_ids = cos_path[2].input[1] + else: + logger.debug("fuse_rotary_embeddings: failed to match sin path in apply_rope") + return + + # Check path for position ids + if position_ids == "": + position_ids_from_sin_path = self.model.match_parent_path( + sin_path[2], + ["Reshape"], + [1], + ) + position_ids_from_cos_path = self.model.match_parent_path( + cos_path[2], + ["Reshape"], + [1], + ) + if ( + position_ids_from_sin_path is None + or position_ids_from_cos_path is None + or position_ids_from_sin_path[0].name != position_ids_from_cos_path[0].name + ): + logger.debug("fuse_rotary_embeddings: failed to match position ids path in apply_rope") + return + position_ids = position_ids_from_cos_path[0].input[0] + else: + position_ids_from_sin_path = [] + position_ids_from_cos_path = [] + + past_seq_len_path, curr_seq_len_path = None, None + if (sin_path == sin_path_1 and cos_path == cos_path_1) or ( + sin_path == sin_path_3 and cos_path == cos_path_3 + ): + if sin_path[-2].name != cos_path[-2].name or sin_path[-1].name != cos_path[-1].name: + logger.debug( + "fuse_rotary_embeddings: failed to match common Gather node and Shape node in sin cache and cos cache" + ) + return + elif (sin_path == sin_path_2 and cos_path == cos_path_2) or ( + sin_path == sin_path_4 and cos_path == cos_path_4 + ): + if sin_path[-1].name != cos_path[-1].name: + logger.debug("fuse_rotary_embeddings: failed to match common Add node in sin cache and cos cache") + return + # Match past sequence length path: past_key --> Shape --> Gather --> Add + past_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape"], + [1, 0], + ) + # Match current sequence length path: transpose_k --> Shape --> Gather --> Add + curr_seq_len_path = self.model.match_parent_path( + sin_path[-1], + ["Gather", "Shape", "Transpose"], + [0, 0, 0], + ) + if ( + past_seq_len_path is None + or curr_seq_len_path is None + or self.model.find_graph_input(past_seq_len_path[-1].input[0]) is None + or curr_seq_len_path[-1].op_type != "Transpose" + ): + logger.debug("fuse_rotary_embeddings: failed to match past_seq_len and curr_seq_len paths") + return + else: + logger.debug("fuse_rotary_embeddings: failed to match common cache paths") + + rotary_emb_node = self.create_rotary_embeddings_from_nodes( + rotate_half_x1_path_1[-1].output[0], + position_ids, + cos_cache, + sin_cache, + node.output[0], + ) + if rotary_emb_node is None: + logger.debug("fuse_rotary_embeddings: failed to create RotaryEmbedding node") + return + + # Remove rotary embedding nodes + self.add_nodes_to_remove([node]) + self.add_nodes_to_remove(rotate_half_x1_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x1_path_2[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_1[:-1]) + self.add_nodes_to_remove(rotate_half_x2_path_2[:-1]) + self.add_nodes_to_remove(x_path[:-1]) + self.add_nodes_to_remove(sin_path) + self.add_nodes_to_remove(cos_path) + self.add_nodes_to_remove(position_ids_from_sin_path[:-1]) + self.add_nodes_to_remove(position_ids_from_cos_path[:-1]) + + if past_seq_len_path is not None and len(self.model.get_children(past_seq_len_path[0])) == 1: + # In merged HF model, output of Gather in past_seq_len_path is used twice + # for past_key_values.0.key and once for other past_key_values + self.add_nodes_to_remove(past_seq_len_path) + if curr_seq_len_path is not None: + self.add_nodes_to_remove(curr_seq_len_path[:-1]) + + self.increase_counter(self.base_name) + self.node_name_to_graph_name[rotary_emb_node.name] = self.this_graph_name + self.nodes_to_add.append(rotary_emb_node) + self.prune_graph = True diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index 11d6b7a8d3cf4..bc32d78eda66c 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -48,22 +48,22 @@ def fuse( input_name_to_nodes: Dict[str, List[NodeProto]], output_name_to_node: Dict[str, NodeProto], ): - """ - Smplify subgraph like - - (2d_input) - / \ - Shape shape - / \ - Gather(indices=0) Gather(indices=1) - | | - Unsqueeze(axes=0) Unsqueeze(axes=0) - \\ / - Concat - | - - into (2d_input) --> Shape --> - """ + # + # Simplify subgraph like + # + # (2d_input) + # / \ + # Shape shape + # / \ + # Gather(indices=0) Gather(indices=1) + # | | + # Unsqueeze(axes=0) Unsqueeze(axes=0) + # \ / + # Concat + # | + # + # into (2d_input) --> Shape --> + # opset_version = self.model.get_opset_version() inputs = len(concat_node.input) diff --git a/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py new file mode 100644 index 0000000000000..6f35fa5617a39 --- /dev/null +++ b/onnxruntime/python/tools/transformers/fusion_simplified_layernorm.py @@ -0,0 +1,141 @@ +import logging +from typing import Dict + +from fusion_base import Fusion +from fusion_skiplayernorm import FusionSkipLayerNormalization +from onnx import helper +from onnx_model import OnnxModel + +logger = logging.getLogger(__name__) + + +class FusionSimplifiedLayerNormalization(Fusion): + def __init__(self, model: OnnxModel): + super().__init__(model, "SimplifiedLayerNormalization", "Mul") + + def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): + if node.op_type != "Mul": + return + + sim_ln_nodes = None + # SimplifiedLayerNorm calculation (notation from https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary): + # DD = Pow(D, 2) + # Var = ReduceMean(DD) + # VarEps = Add(Var, epsilon) + # StdDev = Sqrt(VarEps) + # InvStdDev = Div(1, StdDev) + # Normalized = Mul(D, InvStdDev) + # NormalizedScaled = Mul(Normalized, Scale) + + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_1 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [1, 1, 1, 0, 0, 0, 0], + ) + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Gather --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_2 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], + [1, 1, 1, 0, 0, 0, 0], + ) + + # For LLaMA from Microsoft custom export: + # sim_ln_nodes_3 uses a different start parent index than sim_ln_nodes_1 + # + # SimplifiedLayerNorm + # +-------------------------------------------------------+ + # | | + # Add --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Mul + # | + # node + sim_ln_nodes_3 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], + [0, 1, 1, 0, 0, 0, 0], + ) + + # sim_ln_nodes_4 starts with a graph input instead of an Add node like sim_ln_nodes_3 + # + # SimplifiedLayerNorm + # +-----------------------------------------------+ + # | | + # graph_input --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul + # | + # node + sim_ln_nodes_4 = self.model.match_parent_path( + node, + ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow"], + [0, 1, 1, 0, 0, 0], + ) + + add_node, pow_node = None, None + if sim_ln_nodes_1 is not None: + sim_ln_nodes = sim_ln_nodes_1 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_2 is not None: + sim_ln_nodes = sim_ln_nodes_2 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_3 is not None: + sim_ln_nodes = sim_ln_nodes_3 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-2] + elif sim_ln_nodes_4 is not None: + sim_ln_nodes = sim_ln_nodes_4 + add_node = sim_ln_nodes[3] + pow_node = sim_ln_nodes[-1] + # Verify that parent input to Pow node is graph_input + if pow_node.input[0] not in self.model.get_graphs_input_names(): + return + else: + return + + layernorm_weight_index = 1 if sim_ln_nodes in (sim_ln_nodes_3, sim_ln_nodes_4) else 0 + starts_with_graph_input = sim_ln_nodes == sim_ln_nodes_4 + + if self.model.find_constant_input(pow_node, 2.0) != 1: + return + + root_input = pow_node.input[0] + if root_input != sim_ln_nodes[0].input[0]: + return + + i, add_weight = self.model.get_constant_input(add_node) + if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: + logger.warning(f"epsilon value is not expected: {add_weight}") + return + + self.nodes_to_remove.extend(sim_ln_nodes[:-1] if not starts_with_graph_input else sim_ln_nodes) + self.nodes_to_remove.append(node) + + normalize_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[root_input, node.input[layernorm_weight_index]], + outputs=[node.output[0]], + name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), + ) + normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) + normalize_node.attribute.extend([helper.make_attribute("axis", -1)]) + normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) + self.nodes_to_add.append(normalize_node) + self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name + + +class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): + def __init__(self, model: OnnxModel): + super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") + + def fuse(self, node, input_name_to_nodes, output_name_to_node): + super().fuse(node, input_name_to_nodes, output_name_to_node) diff --git a/onnxruntime/python/tools/transformers/large_model_exporter.py b/onnxruntime/python/tools/transformers/large_model_exporter.py new file mode 100644 index 0000000000000..3b344d6dc9342 --- /dev/null +++ b/onnxruntime/python/tools/transformers/large_model_exporter.py @@ -0,0 +1,385 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +""" +Export LLM to onnx +""" +import argparse +import inspect +import math +import os +import tempfile +from pathlib import Path +from typing import Optional + +import onnx +import torch +import transformers +from torch import nn + + +def disable_huggingface_init(): + """do not init model twice as it slow initialization""" + + torch.nn.init.kaiming_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.normal_ = lambda x, *args, **kwargs: x + torch.nn.init.constant_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_uniform_ = lambda x, *args, **kwargs: x + torch.nn.init.xavier_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.kaiming_normal_ = lambda x, *args, **kwargs: x + torch.nn.init.orthogonal_ = lambda x, *args, **kwargs: x + + +def get_model_parameter_size(model: nn.Module): + """to calculate how much memory this model needs""" + param_size = 0 + param_sum = 0 + for param in model.parameters(): + param_size += param.nelement() * param.element_size() + param_sum += param.nelement() + buffer_size = 0 + buffer_sum = 0 + for buffer in model.buffers(): + buffer_size += buffer.nelement() * buffer.element_size() + buffer_sum += buffer.nelement() + all_size = (param_size + buffer_size) / 1024 / 1024 + return all_size + + +def initialize_model_and_sample_inputs(hf_model: str, cache_dir: Optional[str], tokenizer=None): + """ + get the pretrained torch model from hugginface, + and sample model-inputs + """ + + disable_huggingface_init() + + model = transformers.AutoModelForCausalLM.from_pretrained( # type: ignore + hf_model, torch_dtype=torch.float16, cache_dir=cache_dir, trust_remote_code=True + ) + if tokenizer is None: + tokenizer = hf_model + tokenizer = transformers.AutoTokenizer.from_pretrained(tokenizer) # type: ignore + + sample_inputs = tuple(tokenizer("Hello, my dog is cute", return_tensors="pt").values()) + return model, sample_inputs + + +def auto_pipeline_parallel(model: nn.Module, gpulist: list, sample_inputs: tuple): + """Make the model executable across multiple GPUs.""" + + def input_gpu_device_hook(mod, inputs, kwargs): + modifyed_inputs = [] + first_dev = None + for layer_input in inputs: + if type(layer_input) is not torch.Tensor: + modifyed_inputs.append(layer_input) + elif hasattr(mod, "weight"): + modifyed_inputs.append(layer_input.to(mod.weight.device)) + elif hasattr(mod, "parameters"): + device = next(mod.parameters(), layer_input).device + modifyed_inputs.append(layer_input.to(device)) + elif hasattr(next(mod.children(), None), "weight"): + modifyed_inputs.append(layer_input.to(next(mod.children()).weight.device)) + elif first_dev is not None and layer_input.device != first_dev: + modifyed_inputs.append(layer_input.to(first_dev)) + else: + modifyed_inputs.append(layer_input) + if first_dev is None: + first_dev = modifyed_inputs[0].device + for key, value in kwargs.items(): + if type(value) is torch.Tensor: + kwargs[key] = value.to(first_dev) + + return (tuple(modifyed_inputs), kwargs) + + def move_layer_to_device_rurc(mod, dev): + mod.to(dev) + for layer in mod.named_children(): + move_layer_to_device_rurc(layer[1], dev) + + model = model.half() + all_hooks = [] + all_hooks.append(model.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + pre_fix = next(iter(model.named_children()))[0] + for top_name, top_module in model.named_children(): + for name, module in top_module.named_children(): + all_hooks.append(module.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + if type(module) in [torch.nn.ModuleList]: + num_layers_on_each_gpu = math.floor(len(module) / len(gpulist)) + for idx, attn_layer in enumerate(module): + all_hooks.append(attn_layer.register_forward_pre_hook(input_gpu_device_hook, with_kwargs=True)) + + to_dev = gpulist[min(idx // num_layers_on_each_gpu, len(gpulist))] + attn_layer.to(to_dev) + move_layer_to_device_rurc(attn_layer, to_dev) + print(f"move {pre_fix}.{name}.{idx} to {to_dev}") + else: + module.to(gpulist[0]) + print(f"move {pre_fix}.{name} to {gpulist[0]}") + if len(list(top_module.named_children())) == 0: + top_module.to(gpulist[0]) + print(f"move {top_name} to {gpulist[0]}") + + with torch.no_grad(): + model(sample_inputs[0], attention_mask=sample_inputs[1]) + return model + + +def retrieve_onnx_inputs(model: nn.Module, sample_inputs: tuple, with_past: bool): + """ + auto retrieve onnx inputs from torch model as we can't enumlate all possibilities + for all models + """ + user_inputs = [] + + def hook_for_inputs(_, inputs, kwargs): + user_inputs.append((inputs, kwargs)) + return user_inputs[0] + + hook_handle = model.register_forward_pre_hook(hook_for_inputs, with_kwargs=True) + + forward_params = inspect.signature(model.forward).parameters + input_keys = list(forward_params.keys()) + default_values = [forward_params.get(key).default for key in input_keys] + out = model(sample_inputs[0], attention_mask=sample_inputs[1]) + hook_handle.remove() + user_inputs = user_inputs[0] + onnx_inputs = default_values + for idx, _val in enumerate(user_inputs[0]): + onnx_inputs[idx] = user_inputs[0][idx] + for key, value in user_inputs[1].items(): + idx = input_keys.index(key) + onnx_inputs[idx] = value + for idx, (key, value) in enumerate(zip(input_keys, onnx_inputs)): + if type(value) is torch.Tensor: + value.to(model.device) + # Didn't touch past_key_value now, please change it if you want + if "use_cache" in key: + onnx_inputs[idx] = with_past + + return input_keys, onnx_inputs, out.past_key_values + + +def move_to_approprate_device(model: nn.Module, sample_inputs_tp: tuple) -> nn.Module: + """ + According to the model size, we will upload it to + CPU if has no GPU or enough GPU memory, + Single GPU if has only one GPU in local or model size is enough to fit one GPU + Multiple GPU if there is more than one gpu in local and model is too large + """ + total_mem_per_cpu = torch.cuda.get_device_properties(0).total_memory / 1024 / 1024 + + print(f"Model_Size = {get_model_parameter_size(model)/1024} GB") + print(f"total_mem_per_cpu = {total_mem_per_cpu/1024} GB") + if get_model_parameter_size(model) > total_mem_per_cpu * 0.45: + device_collection = [torch.device(i) for i in range(torch.cuda.device_count())] + if len(device_collection) > 1: + print( + f"{len(device_collection)} GPUs are used to export onnx, \ + Please set CUDA_VISIBLE_DEVICES to use specific GPU group" + ) + model = auto_pipeline_parallel(model, device_collection, sample_inputs_tp) + else: + print("!!!! convert model to float and export onnx using CPU") + model = model.cpu().float() + else: + print("Export model on a single GPU") + model = model.cuda().half() + return model + + +def adapt_inputs_to_device(sample_inputs: tuple, device: torch.device) -> tuple: + """move inputs to device""" + sample_inputs_ = [] + for sample_int in sample_inputs: + if isinstance(sample_int, torch.Tensor): + sample_inputs_.append(sample_int.to(device)) + else: + sample_inputs_.append(sample_int) + return tuple(sample_inputs_) + + +def fetch_onnx_inputs_outputs_name( + model: nn.Module, + onnx_inputs: list, + torch_input_names: tuple, + past_key_values: tuple, + with_past: bool, + input_with_past: bool, +): + """fetch onnx inputs and outputs name""" + num_of_past_key = 0 + kv_cache_axis = {0: "batch_size"} + # try get num_of_past_key and shape of past_key_value + if past_key_values is not None: + num_of_past_key = len(past_key_values) + seq_index = (torch.tensor(past_key_values[0][0].shape) == onnx_inputs[0].shape[-1]).nonzero().view(-1) + assert seq_index.numel() == 1 + kv_cache_axis = {0: "batch_size", seq_index.item(): "seq_len"} + + if not num_of_past_key: + num_of_past_key = model.config.num_hidden_layers + + onnx_inp_names = ("input_ids", "attention_mask") + onnx_out_names = ("logits",) + onnx_dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "attention_mask": {0: "batch_size", 1: "seq_len"}, + } + if input_with_past: + for i in range(num_of_past_key): + onnx_inp_names += (f"present_key.{i}",) + onnx_inp_names += (f"present_values.{i}",) + + onnx_dynamic_axes[onnx_inp_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_inp_names[-2]] = kv_cache_axis + + if with_past or input_with_past: + for i in range(num_of_past_key): + onnx_out_names += (f"past_key.{i}",) + onnx_out_names += (f"past_values.{i}",) + onnx_dynamic_axes[onnx_out_names[-1]] = kv_cache_axis + onnx_dynamic_axes[onnx_out_names[-2]] = kv_cache_axis + + for idx, name in enumerate(torch_input_names): + if input_with_past: + if name == "past_key_values": + onnx_inputs[idx] = past_key_values + elif name == "attention_mask": + attn_mask = onnx_inputs[idx] + onnx_inputs[idx] = torch.cat( + (attn_mask, torch.ones((attn_mask.shape[0], 1), device=attn_mask.device)), dim=1 + ) + elif name == "input_ids": + input_ids = onnx_inputs[idx] + onnx_inputs[idx] = input_ids[:, -1:] + + return onnx_inp_names, onnx_out_names, onnx_dynamic_axes + + +def do_export_internal(model: nn.Module, onnx_io_tuple: tuple, onnx_inputs: tuple, onnx_path: Path, opset: int): + """do export with torch.onnx.export""" + onnx_model_name = onnx_path.name + onnx_inp_names, onnx_out_names, onnx_dynamic_axes = onnx_io_tuple + # two step to export onnx + # 1. export onnx with lots of pieces of weights + # 2. save all weights to external data + with tempfile.TemporaryDirectory() as tmpdirname: + tmp_onnx = os.path.join(tmpdirname, "tmp.onnx") + + torch.onnx.export( + model=model, + args=tuple(onnx_inputs), + f=tmp_onnx, + verbose=False, + opset_version=opset, + input_names=onnx_inp_names, + output_names=onnx_out_names, + dynamic_axes=onnx_dynamic_axes, + ) + + onnx_path.unlink(missing_ok=True) + (onnx_path.parent / f"{onnx_model_name}_ext.data").unlink(missing_ok=True) + + onnx_model = onnx.load(str(tmp_onnx)) + onnx.save_model( + onnx_model, + str(onnx_path), + save_as_external_data=(len(os.listdir(tmpdirname)) > 1), + all_tensors_to_one_file=True, + location=f"{onnx_model_name}_ext.data", + size_threshold=1024, + convert_attribute=False, + ) + + +@torch.no_grad() +def export_onnx(hf_model: str, cache_dir: Optional[str], onnx_path_str: str, with_past: bool, opset: int): + """ + do export + model: torch model + onnx_path: where the onnx model saved to + sample_inputs_tp: inputs for torch model + """ + model, sample_inputs_tp = initialize_model_and_sample_inputs(hf_model, cache_dir) + + model = move_to_approprate_device(model, sample_inputs_tp) + + sample_inputs = adapt_inputs_to_device(sample_inputs_tp, next(model.parameters()).device) + + # input_keys would be usesful if the model has some special inputs + input_keys, onnx_inputs, past_key_value = retrieve_onnx_inputs(model, sample_inputs, with_past) + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, False) + + onnx_model_name = "model.onnx" + onnx_path: Path = Path(onnx_path_str).absolute() + if onnx_path.suffix != ".onnx": + onnx_path = onnx_path / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + if not with_past: + return + + onnx_io_tuple = fetch_onnx_inputs_outputs_name(model, onnx_inputs, input_keys, past_key_value, with_past, True) + + onnx_model_name = "model_with_past.onnx" + onnx_path = onnx_path.parent / onnx_model_name + + do_export_internal(model, onnx_io_tuple, onnx_inputs, onnx_path, opset) + + +def parse_arguments(): + """arguments parsing.""" + parser = argparse.ArgumentParser() + + parser.add_argument( + "-m", + "--model", + required=True, + type=str, + default=["meta-llama/Llama-2-70b-hf"], + help="Pre-trained models in huggingface model hub", + ) + parser.add_argument( + "-s", + "--saved_path", + required=False, + type=str, + default="./onnx_models/", + help="where the onnx model will be saved", + ) + parser.add_argument( + "--cache_dir", + required=False, + type=str, + default=None, + help=("cache directy of huggingface, by setting this to avoid useless downloading if you have one"), + ) + parser.add_argument( + "--with_past", + action="store_true", + default=False, + help=("The tool will export onnx without past-key-value by default"), + ) + parser.add_argument( + "--opset", + required=False, + type=int, + default=17, + help=( + "the opset to save onnx model, \ + try to increase it if this opset doens't have new features you want" + ), + ) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_arguments() + + export_onnx(args.model, args.cache_dir, args.saved_path, args.with_past, args.opset) diff --git a/onnxruntime/python/tools/transformers/models/llama/README.md b/onnxruntime/python/tools/transformers/models/llama/README.md index b4461a2eadb8c..6057b46667fe6 100644 --- a/onnxruntime/python/tools/transformers/models/llama/README.md +++ b/onnxruntime/python/tools/transformers/models/llama/README.md @@ -17,12 +17,31 @@ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama To make this option compatible with [Hugging Face's Optimum](https://github.com/huggingface/optimum), you will need to create `config.json` and `generation_config.json` for your model and store them in the same directory as your ONNX models. For example, you can find those JSON files for LLaMA-2 7B on Hugging Face [here](https://huggingface.co/meta-llama/Llama-2-7b-hf). +As indicated in `requirements.txt`, you will also need to install Optimum from source. Once installed, you will need to modify `ORTModelForCausalLM.forward` in `optimum/optimum/onnxruntime/modeling_decoder.py` as follows: + +``` +# Before +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] + # Flatten the past_key_values (no need to flatten for models using multi-query attn) + + +# After +if self.use_cache: + if past_key_values is not None: + input_ids = input_ids[:, -1:] if past_key_values[0][0].shape[2] != 0 else input_ids + # Flatten the past_key_values (no need to flatten for models using multi-query attn) +``` + ### Option 2: from [Microsoft's custom export](https://github.com/microsoft/Llama-2-Onnx) Please follow the [README instructions](https://github.com/microsoft/Llama-2-Onnx#before-you-start) in the custom export of LLaMA-2. ### Option 3: from [Hugging Face Optimum](https://github.com/huggingface/optimum) +Note that this will produce two ONNX models whereas the above two options produce one ONNX model. + First, log into the Hugging Face CLI in your terminal: ``` @@ -56,38 +75,81 @@ $ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./ $ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --input ./Llama-2-7b-hf --output ./llama2-7b ``` -Export for FP16 +Export for FP32 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-gpu --precision fp32 --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cuda +``` + +Export for FP32 CPU ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32-cpu --precision fp32 --execution_provider cpu # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp32 --precision fp32 --execution_provider cpu ``` -Export for INT8 +Export for FP16 CUDA ``` # From source: -$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda # From wheel: -$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-fp16 --precision fp16 --execution_provider cuda +``` + +Export for INT8 CPU (SmoothQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method smooth_quant --execution_provider cpu ``` Note: [Intel's Neural Compressor](https://github.com/intel/neural-compressor) takes time to run the SmoothQuant quantization algorithm on LLMs. On an [Azure Standard_NC24s_v3 VM](https://learn.microsoft.com/en-us/azure/virtual-machines/ncv3-series), it takes about ~30-45 min for each of the exported ONNX models. +Export for INT8 CPU (DynamicQuant) +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int8 --precision int8 --quantization_method quantize_dynamic --execution_provider cpu +``` + +Export for INT4 CUDA +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-gpu --precision int4 --quantization_method blockwise --execution_provider cuda + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cuda +``` + +Export for INT4 CPU +``` +# From source: +$ python3 -m models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4-cpu --precision int4 --quantization_method blockwise --execution_provider cpu + +# From wheel: +$ python3 -m onnxruntime.transformers.models.llama.convert_to_onnx -m meta-llama/Llama-2-7b-hf --output llama2-7b-int4 --precision int4 --quantization_method blockwise --execution_provider cpu +``` + ## Benchmark LLaMA-2 Here are some examples of how you can benchmark LLaMA-2. -Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`. - ### Variants -1. PyTorch (without `torch.compile`), FP32 +1. PyTorch without `torch.compile`, FP32 ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-pt \ + --benchmark-type hf-pt-eager \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -96,10 +158,10 @@ python3 -m models.llama.benchmark \ --auth ``` -2. PyTorch 2.0 (with `torch.compile`), FP16 +2. PyTorch with `torch.compile`, FP16 ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-pt2 \ + --benchmark-type hf-pt-compile \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -112,7 +174,7 @@ python3 -m models.llama.benchmark \ ``` python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ - --hf-ort-model-path ./Llama-2-7b-hf-onnx/ \ + --hf-ort-dir-path ./Llama-2-7b-hf-onnx/ \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -125,7 +187,7 @@ python3 -m models.llama.benchmark \ ``` python3 -m models.llama.benchmark \ --benchmark-type hf-ort \ - --hf-ort-model-path ./llama2-7b-fp16/ \ + --hf-ort-dir-path ./llama2-7b-fp16/ \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -134,24 +196,35 @@ python3 -m models.llama.benchmark \ --auth ``` -5. Optimum + ONNX Runtime, INT8, export via convert_to_onnx +5. ONNX Runtime, FP32, Microsoft custom export ``` python3 -m models.llama.benchmark \ - --benchmark-type hf-ort \ - --hf-ort-model-path ./llama2-7b-int8/ \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ - --precision int8 \ + --precision fp32 \ --batch-sizes "1 2" \ --sequence-lengths "8 16" \ - --device cpu \ - --auth + --device cpu ``` -6. ONNX Runtime, FP32, Microsoft custom export +6. ONNX Runtime, FP16, Microsoft custom export ``` python3 -m models.llama.benchmark \ - --benchmark-type ort \ - --ort-model-path llama-2-onnx/7B_float32/ONNX/LlamaV2_7B_float32.onnx \ + --benchmark-type ort-msft \ + --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --model-name meta-llama/Llama-2-7b-hf \ + --precision fp16 \ + --batch-sizes "1 2" \ + --sequence-lengths "8 16" \ + --device cuda +``` + +7. ONNX Runtime, FP32, convert_to_onnx +``` +python3 -m models.llama.benchmark \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp32.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp32 \ --batch-sizes "1 2" \ @@ -159,11 +232,11 @@ python3 -m models.llama.benchmark \ --device cpu ``` -7. ONNX Runtime, FP16, Microsoft custom export +8. ONNX Runtime, FP16, convert_to_onnx ``` python3 -m models.llama.benchmark \ - --benchmark-type ort \ - --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --benchmark-type ort-convert-to-onnx \ + --ort-model-path ./llama2-7b/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ @@ -174,11 +247,14 @@ python3 -m models.llama.benchmark \ You can profile a variant by adding the `--profile` flag and providing one batch size and sequence length combination. ### Benchmark All -You can use `benchmark_all.py` to benchmark across various platforms and automatically store the results in a CSV file. Here is an example. +You can use `benchmark_all.py` to benchmark across various options and automatically store the results in a CSV file. Here is an example. ``` python3 -m models.llama.benchmark_all \ - --hf-ort-model-path ./llama2-7b-fp16/ \ - --ort-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./llama2-7b-fp16/ \ + --ort-convert-to-onnx-model-path ./llama2-7b-fp16/Llama-2-7b-hf_decoder_merged_model_fp16.onnx \ + --ort-msft-model-path ./llama-2-onnx/7B_float16/ONNX/LlamaV2_7B_float16.onnx \ --model-name meta-llama/Llama-2-7b-hf \ --precision fp16 \ --batch-sizes "1 2" \ diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark.py b/onnxruntime/python/tools/transformers/models/llama/benchmark.py index d19ed5cc28fed..976de2abc7c57 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark.py @@ -8,10 +8,17 @@ import time import numpy as np +import onnx import psutil import torch from benchmark_helper import setup_logger -from llama_inputs import get_msft_sample_inputs, get_sample_inputs, get_sample_with_past_kv_inputs +from llama_inputs import ( + convert_inputs_for_ort, + get_merged_sample_with_past_kv_inputs, + get_msft_sample_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) from optimum.onnxruntime import ORTModelForCausalLM from torch.profiler import ProfilerActivity, profile, record_function from tqdm import trange @@ -23,8 +30,29 @@ logger = logging.getLogger(__name__) -def get_inputs(args: argparse.Namespace): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: +# For determining whether the ONNX model can do both prompt generation and token generation or only one of the two +def get_ort_model_inputs_len(args, model): + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + return 0 + if args.benchmark_type == "hf-ort": + try: + # New Optimum export (https://github.com/huggingface/optimum/blob/888332364c2e0091da1fc974737c7e277af168bf/optimum/onnxruntime/modeling_ort.py#L268) + return len(model.inputs_names) + except Exception: + # Old Optimum export (https://github.com/huggingface/optimum/blob/c5ad7f971cb0a494eac03dc0909f146725f999c5/optimum/onnxruntime/base.py#L54) + return len(model.decoder.input_names) + return len(model.get_inputs()) + + +def get_inputs(args: argparse.Namespace, ort_model_inputs_len: int): + init_inputs, iter_inputs = None, None + + # For past_present_share_buffer: + # Set max_seq_len to 2048 for Hugging Face model since that is the default value + # Set max_seq_len to 2048 for Microsoft model since that is the max value currently supported + max_seq_len = 2048 + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: init_inputs = get_sample_inputs( args.config, args.target_device, @@ -41,14 +69,95 @@ def get_inputs(args: argparse.Namespace): return_dict=True, ) - elif args.benchmark_type == "ort": + elif args.benchmark_type == "hf-ort": + if ort_model_inputs_len == 3: # [input_ids, attention_mask, position_ids] + # Using split models in Optimum (e.g. created by Optimum export) + init_inputs = get_sample_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + return_dict=True, + ) + iter_inputs = get_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + else: + # Using merged model in Optimum (e.g. created by convert_to_onnx export) + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + use_fp16=args.use_fp16, + return_dict=True, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + + elif args.benchmark_type == "ort-convert-to-onnx": + # Microsoft export from convert_to_onnx + init_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=args.sequence_length, + past_seq_len=0, + use_fp16=args.use_fp16, + return_dict=True, + ) + iter_inputs = get_merged_sample_with_past_kv_inputs( + args.config, + args.target_device, + args.batch_size, + seq_len=1, + past_seq_len=args.sequence_length, + use_fp16=args.use_fp16, + return_dict=True, + ) + init_inputs = convert_inputs_for_ort( + init_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=0, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + iter_inputs = convert_inputs_for_ort( + iter_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + + elif args.benchmark_type == "ort-msft": # Microsoft export from https://github.com/microsoft/Llama-2-Onnx + split_kv = ort_model_inputs_len > 5 # original inputs: [x, attn_mask, k_cache, v_cache, pos] + init_inputs = get_msft_sample_inputs( args.config, args.batch_size, past_seq_len=0, seq_len=args.sequence_length, use_fp16=args.use_fp16, + split_kv=split_kv, ) iter_inputs = get_msft_sample_inputs( args.config, @@ -56,6 +165,25 @@ def get_inputs(args: argparse.Namespace): past_seq_len=args.sequence_length, seq_len=1, use_fp16=args.use_fp16, + split_kv=split_kv, + ) + init_inputs = convert_inputs_for_ort( + init_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=0, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, + ) + iter_inputs = convert_inputs_for_ort( + iter_inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.past_present_share_buffer, + past_seq_len=args.sequence_length, + max_seq_len=max_seq_len, + device=args.device, + device_id=args.device_id, ) else: @@ -69,12 +197,14 @@ def get_model(args: argparse.Namespace): start_time, end_time = None, None # There are multiple sources that the model could come from: - # 1) Benchmark LLaMA from unofficial source on Hugging Face - # 2) Benchmark LLaMA from official source on Hugging Face, which requires an authentication token - # 3) Benchmark LLaMA from local download of model - - if args.benchmark_type in {"hf-pt", "hf-pt2"}: - source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name + # 1) Benchmark LLaMA-2 from unofficial source on Hugging Face + # 2) Benchmark LLaMA-2 from official source on Hugging Face, which requires an authentication token + # 3) Benchmark LLaMA-2 from local download of model + # 4) Benchmark LLaMA-2 from Microsoft (already optimized, available at https://github.com/microsoft/Llama-2-Onnx) + # 5) Benchmark LLaMA-2 from convert_to_onnx + + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: + source = args.hf_pt_dir_path if args.hf_pt_dir_path else args.model_name start_time = time.time() model = LlamaForCausalLM.from_pretrained( source, @@ -84,10 +214,10 @@ def get_model(args: argparse.Namespace): ).to(args.target_device) end_time = time.time() - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": model = torch.compile(model) - elif args.benchmark_type in {"hf-ort", "ort"}: + elif args.benchmark_type in {"hf-ort", "ort-msft", "ort-convert-to-onnx"}: sess_options = ort.SessionOptions() sess_options.enable_profiling = args.profile if args.verbose: @@ -104,32 +234,33 @@ def get_model(args: argparse.Namespace): decoder_file_name = None decoder_with_past_file_name = None - for filename in os.listdir(args.hf_ort_model_path): + for filename in os.listdir(args.hf_ort_dir_path): if ".onnx" not in filename or ".onnx_data" in filename or ".onnx.data" in filename: continue - if "decoder_model.onnx" in filename or f"decoder_model_{args.precision}.onnx" in filename: + if "decoder_model" in filename or filename == "model.onnx": + decoder_file_name = filename + if "decoder_with_past_model" in filename: + decoder_with_past_file_name = filename + if "decoder_merged_model" in filename: decoder_file_name = filename - if ( - "decoder_with_past_model.onnx" in filename - or f"decoder_with_past_model_{args.precision}.onnx" in filename - ): decoder_with_past_file_name = filename start_time = time.time() model = ORTModelForCausalLM.from_pretrained( - args.hf_ort_model_path, + args.hf_ort_dir_path, decoder_file_name=decoder_file_name, decoder_with_past_file_name=decoder_with_past_file_name, use_auth_token=args.auth, use_io_binding=(args.device != "cpu"), + use_merged=(True if decoder_file_name == "model.onnx" else None), provider=provider, provider_options=provider_options, session_options=sess_options, ) end_time = time.time() - if args.benchmark_type == "ort": - # Microsoft export from https://github.com/microsoft/Llama-2-Onnx + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: + # Ex: Microsoft export from https://github.com/microsoft/Llama-2-Onnx logger.info(f"Loading model from {args.ort_model_path}") start_time = time.time() model = ort.InferenceSession( @@ -140,7 +271,6 @@ def get_model(args: argparse.Namespace): end_time = time.time() logger.info(f"Loaded model in {end_time - start_time} s") - return model @@ -148,7 +278,7 @@ def time_fn(args, fn, inputs): # Warm up warmup_range = ( range(args.warmup_runs) - if args.benchmark_type == "ort" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.warmup_runs, file=sys.stdout, desc="Warm up") ) @@ -166,7 +296,7 @@ def time_fn(args, fn, inputs): bench_range = ( range(args.num_runs) - if args.benchmark_type == "ort" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"} else trange(args.num_runs, file=sys.stdout, desc="Benchmark") ) for _ in bench_range: @@ -177,7 +307,7 @@ def time_fn(args, fn, inputs): end_time = time.time() # Newline print after trange in order to print metrics on new lines without progress bar on same line - if args.benchmark_type != "ort": + if args.benchmark_type not in {"ort-msft", "ort-convert-to-onnx"}: logger.info("") latency = (end_time - start_time) / args.num_runs @@ -186,7 +316,7 @@ def time_fn(args, fn, inputs): logger.info(f"Batch Size: {args.batch_size}") logger.info(f"Sequence Length: {args.sequence_length}") logger.info(f"Latency: {latency} s") - logger.info(f"Throughput: {throughput} qps") + logger.info(f"Throughput: {throughput} tps") return @@ -196,7 +326,7 @@ def profile_fn(args, fn, inputs, inputs_type): prefix = f"b{args.batch_size}_s{args.sequence_length}_{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" filename = None - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: # Profile PyTorch kernels with profile( # noqa: SIM117 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True @@ -267,7 +397,7 @@ def get_logits(inputs): generate_fn = get_logits - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": # Run forward pass once with each set of inputs to process through Dynamo generate_fn(init_inputs) generate_fn(iter_inputs) @@ -280,7 +410,7 @@ def get_logits(inputs): logger.warning(f"Renaming {old_logname} to {new_logname}") os.rename(old_logname, os.path.join(args.log_folder, new_logname)) - new_logname = profile_fn(args, generate_fn, iter_inputs, "per-token") + new_logname = profile_fn(args, generate_fn, iter_inputs, "token") if args.benchmark_type == "hf-ort": # Turn profiling off to stop appending to log old_logname = model.decoder_with_past.session.end_profiling() @@ -319,10 +449,24 @@ def prepare_ort_inputs(inputs): # Add IO bindings for non-CPU execution providers if args.device != "cpu": io_binding = model.io_binding() + for k, v in inputs.items(): - io_binding.bind_cpu_input(k, v) + if args.past_present_share_buffer: + # Bind all OrtValue inputs to device + io_binding.bind_ortvalue_input(k, v) + else: + io_binding.bind_cpu_input(k, v) + for output in model.get_outputs(): - io_binding.bind_output(output.name) + name = output.name + if args.past_present_share_buffer and ("out" in name or "present" in name): + # Bind present KV cache outputs to OrtValue with buffer sharing + io_binding.bind_ortvalue_output( + name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] + ) + else: + io_binding.bind_output(name, device_type=args.device, device_id=args.device_id) + return io_binding return inputs @@ -350,7 +494,7 @@ def without_io_binding(inputs): # Re-initialize model for new log file instead of appending to old log file model = get_model(args) ort_iter_inputs = prepare_ort_inputs(iter_inputs) - new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "per-token") + new_logname = profile_fn(args, generate_fn, ort_iter_inputs, "token") # Turn profiling off to stop appending to log old_logname = model.end_profiling() @@ -371,9 +515,9 @@ def without_io_binding(inputs): def run_inference(args, init_inputs, iter_inputs, model): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: run_hf_inference(args, init_inputs, iter_inputs, model) - elif args.benchmark_type == "ort": + elif args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: run_ort_inference(args, init_inputs, iter_inputs, model) else: raise Exception(f"Cannot recognize {args.benchmark_type}") @@ -382,7 +526,11 @@ def run_inference(args, init_inputs, iter_inputs, model): def get_args(): parser = argparse.ArgumentParser() parser.add_argument( - "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"] + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort-msft", "ort-convert-to-onnx"], ) parser.add_argument( "-m", @@ -402,20 +550,20 @@ def get_args(): required=True, type=str, default="fp32", - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision for model. For ONNX models, the model's precision should be set before running this script.", ) parser.add_argument( - "--hf-pt-model-path", + "--hf-pt-dir-path", type=str, default="", help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", ) parser.add_argument( - "--hf-ort-model-path", + "--hf-ort-dir-path", type=str, default="", - help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)", + help="Path to directory containing all ONNX files (e.g. tokenizer, decoder_merged, decoder, decoder_with_past)", ) parser.add_argument( "--ort-model-path", @@ -475,15 +623,20 @@ def get_args(): args.execution_provider = (args.execution_provider, {"device_id": args.device_id}) args.device = "cuda" - # Check that model paths have been specified for any benchmarking with ORT + # Check that paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": - assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`" - if args.benchmark_type == "ort": + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" + if args.benchmark_type in {"ort-msft", "ort-convert-to-onnx"}: assert args.ort_model_path, "Please specify a path to `--ort-model-path`" args.batch_sizes = args.batch_sizes.split(" ") args.sequence_lengths = args.sequence_lengths.split(" ") + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.device == "cpu") else "fp16" + ) + # Check that only one (batch_size, sequence_length) combination is set for profiling if args.profile: assert ( @@ -509,14 +662,27 @@ def main(): setattr(args, "target_device", target_device) # noqa: B010 setattr(args, "use_fp16", use_fp16) # noqa: B010 - # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) + # Get model and model info model = get_model(args) + ort_model_inputs_len = get_ort_model_inputs_len(args, model) + + # Check if past_present_share_buffer can be enabled (only for FP16 models with GQA) + if args.benchmark_type in {"ort-convert-to-onnx", "ort-msft"}: + onnx_model = onnx.load_model(args.ort_model_path, load_external_data=False) + gqa_nodes = list(filter(lambda node: node.op_type == "GroupQueryAttention", onnx_model.graph.node)) + + use_buffer_share = use_fp16 and len(gqa_nodes) > 0 and args.device != "cpu" + setattr(args, "past_present_share_buffer", use_buffer_share) # noqa: B010 + else: + setattr(args, "past_present_share_buffer", False) # noqa: B010 + + # Measure prompt cost (init_inputs) and generated token cost (iter_inputs) for batch_size, sequence_length in itertools.product(args.batch_sizes, args.sequence_lengths): logger.info(f"\nBatch size = {batch_size} and sequence length = {sequence_length}...") setattr(args, "batch_size", int(batch_size)) # noqa: B010 setattr(args, "sequence_length", int(sequence_length)) # noqa: B010 - init_inputs, iter_inputs = get_inputs(args) + init_inputs, iter_inputs = get_inputs(args, ort_model_inputs_len) run_inference(args, init_inputs, iter_inputs, model) diff --git a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py index 7199c945fe6ba..951b2549368f7 100644 --- a/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/llama/benchmark_all.py @@ -43,15 +43,38 @@ def get_args(): ) parser.add_argument( - "--hf-ort-model-path", + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", type=str, + default="", help="Path to folder containing ONNX models for Optimum + ORT benchmarking", ) parser.add_argument( - "--ort-model-path", + "--ort-msft-model-path", + type=str, + default="", + help="Path to ONNX model from https://github.com/microsoft/Llama-2-Onnx", + ) + + parser.add_argument( + "--ort-convert-to-onnx-model-path", type=str, - help="Path to ONNX model for ORT benchmarking", + default="", + help="Path to ONNX model from convert_to_onnx", ) parser.add_argument( @@ -65,7 +88,7 @@ def get_args(): "--precision", type=str, required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision to run model", ) @@ -138,8 +161,6 @@ def process_log_file(device_id, log_file, base_results): step = "per-token" elif latency_pattern in line: latency_s = float(line[len(latency_pattern) : line.rfind(" ")]) - if step == "prompt": - latency_s /= sequence_length latency_ms = latency_s * 1000 elif throughput_pattern in line: throughput = float(line[len(throughput_pattern) : line.rfind(" ")]) @@ -184,7 +205,7 @@ def save_results(results, filename): "Step", "Latency (s)", "Latency (ms)", - "Throughput (qps)", + "Throughput (tps)", "Memory (GB)", ], ) @@ -194,7 +215,7 @@ def save_results(results, filename): df["Sequence Length"] = df["Sequence Length"].astype("int") df["Latency (s)"] = df["Latency (s)"].astype("float") df["Latency (ms)"] = df["Latency (ms)"].astype("float") - df["Throughput (qps)"] = df["Throughput (qps)"].astype("float") + df["Throughput (tps)"] = df["Throughput (tps)"].astype("float") df["Memory (GB)"] = df["Memory (GB)"].astype("float") df.to_csv(filename, index=False) @@ -226,75 +247,81 @@ def main(): torch.backends.cudnn.benchmark = True all_results = [] + # Benchmark PyTorch without torch.compile - benchmark_cmd = [ - "python3", - "benchmark.py", - "--benchmark-type", - "hf-pt", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--batch-sizes", - args.batch_sizes, - "--sequence-lengths", - args.sequence_lengths, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - "--auth", - ] - logger.info("Benchmark PyTorch without torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch") - all_results.extend(results) + if args.hf_pt_eager: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager") + all_results.extend(results) # Benchmark PyTorch with torch.compile - benchmark_cmd = [ - "python3", - "benchmark.py", - "--benchmark-type", - "hf-pt2", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--batch-sizes", - args.batch_sizes, - "--sequence-lengths", - args.sequence_lengths, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - "--auth", - ] - logger.info("Benchmark PyTorch with torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch-2") - all_results.extend(results) + if args.hf_pt_compile: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + "--auth", + ] + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile") + all_results.extend(results) # Benchmark Optimum + ONNX Runtime - if args.hf_ort_model_path: + if args.hf_ort_dir_path: benchmark_cmd = [ - "python3", - "benchmark.py", + "python", + "-m", + "models.llama.benchmark", "--benchmark-type", "hf-ort", - "--hf-ort-model-path", - args.hf_ort_model_path, + "--hf-ort-dir-path", + args.hf_ort_dir_path, "--model-name", args.model_name, "--precision", @@ -316,18 +343,52 @@ def main(): "--auth", ] logger.info("Benchmark Optimum + ONNX Runtime") - results = benchmark(args, benchmark_cmd, "pytorch-ort") + results = benchmark(args, benchmark_cmd, "optimum-ort") + all_results.extend(results) + + # Benchmark Microsoft model in ONNX Runtime + if args.ort_msft_model_path: + benchmark_cmd = [ + "python", + "-m", + "models.llama.benchmark", + "--benchmark-type", + "ort-msft", + "--ort-model-path", + args.ort_msft_model_path, + "--model-name", + args.model_name, + "--precision", + args.precision, + "--batch-sizes", + args.batch_sizes, + "--sequence-lengths", + args.sequence_lengths, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + logger.info("Benchmark Microsoft model in ONNX Runtime") + results = benchmark(args, benchmark_cmd, "ort-msft") all_results.extend(results) - # Benchmark ONNX Runtime - if args.ort_model_path: + # Benchmark convert_to_onnx model in ONNX Runtime + if args.ort_convert_to_onnx_model_path: benchmark_cmd = [ - "python3", - "benchmark.py", + "python", + "-m", + "models.llama.benchmark", "--benchmark-type", - "ort", + "ort-convert-to-onnx", "--ort-model-path", - args.ort_model_path, + args.ort_convert_to_onnx_model_path, "--model-name", args.model_name, "--precision", @@ -347,7 +408,7 @@ def main(): "--log-folder", args.log_folder, ] - logger.info("Benchmark ONNX Runtime") + logger.info("Benchmark convert_to_onnx model in ONNX Runtime") results = benchmark(args, benchmark_cmd, "onnxruntime") all_results.extend(results) diff --git a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py index f96347ba67aa6..61d71bc38f4e9 100644 --- a/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py @@ -8,12 +8,16 @@ import onnx import torch from benchmark_helper import Precision, prepare_environment, setup_logger -from llama_inputs import get_sample_inputs, get_sample_with_past_kv_inputs +from convert_generation import replace_mha_with_gqa +from llama_inputs import get_merged_sample_with_past_kv_inputs, get_sample_inputs, get_sample_with_past_kv_inputs from llama_parity import main as parity_check from onnx_model import OnnxModel +from optimizer import optimize_model +from packaging import version from transformers import LlamaConfig, LlamaForCausalLM from onnxruntime import quantization as ort_quantization +from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer logger = logging.getLogger("") @@ -58,6 +62,33 @@ def get_model_with_past_kv_dynamic_axes(input_names: List[str], output_names: Li return dynamic_axes +def get_merged_model_dynamic_axes(input_names: List[str], output_names: List[str]): + dynamic_axes = {} + for name in input_names + output_names: + if name in {"input_ids", "position_ids"}: + # shape is (batch_size, sequence_length) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif name == "attention_mask": + # shape is (batch_size, past_sequence_length + sequence_length) = (batch_size, total_sequence_length) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 1: "total_sequence_length"} + elif "past" in name: + # shape is (batch_size, num_heads, past_sequence_length, head_size) + dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"} + elif name == "logits": + # shape is (batch_size, sequence_length, vocab_size) + dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"} + elif "present" in name: + # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size) = (batch_size, num_heads, total_sequence_length, head_size) + # for prompt generation, past_sequence_length = 0 + # for token generation, sequence_length = 1 + dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"} + else: + raise Exception("Unknown input or output name found") + return dynamic_axes + + def save_onnx_model(onnx_model: onnx.ModelProto, output_path: str, data_path: str): onnx.save( onnx_model, @@ -152,7 +183,7 @@ def run_dynamo_export(args: argparse.Namespace, l_config: LlamaConfig, llama: Ll logger.info(f"The {args.model_name} ONNX model has been successfully created with the Dynamo exporter!") -def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): +def run_torchscript_separate_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): # Dummy values for export batch_size, sequence_length = 2, 8 device = torch.device("cpu") @@ -248,12 +279,206 @@ def run_torchscript_export(args: argparse.Namespace, l_config: LlamaConfig, llam logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") +def run_torchscript_merged_export(args: argparse.Namespace, l_config: LlamaConfig, llama: LlamaForCausalLM): + # Dummy values for export + batch_size, sequence_length, past_sequence_length = 2, 8, 0 + device = torch.device("cpu") + + # Export decoder_merged_model.onnx + decoder_merged_inputs = get_merged_sample_with_past_kv_inputs( + l_config, device, batch_size, sequence_length, past_sequence_length + ) + input_names = [ + "input_ids", + "attention_mask", + "position_ids", + *list( + chain.from_iterable( + (f"past_key_values.{i}.key", f"past_key_values.{i}.value") for i in range(l_config.num_hidden_layers) + ) + ), + ] + output_names = [ + "logits", + *list( + chain.from_iterable((f"present.{i}.key", f"present.{i}.value") for i in range(l_config.num_hidden_layers)) + ), + ] + dynamic_axes = get_merged_model_dynamic_axes(input_names, output_names) + temp_dir = tempfile.TemporaryDirectory() + temp_path = os.path.join(temp_dir.name, "temp.onnx") + torch.onnx.export( + llama, + args=decoder_merged_inputs, + f=temp_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=13, + do_constant_folding=True, + verbose=args.verbose, + ) + + # Check decoder_merged_model.onnx and save all external data to one file + onnx.checker.check_model(temp_path) + onnx.shape_inference.infer_shapes_path(temp_path) + + output_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + onnx_model = onnx.load_model(temp_path, load_external_data=True) + save_onnx_model( + onnx_model, + output_path, + f"{args.model_name}_decoder_merged_model_fp32.onnx.data", + ) + del onnx_model + temp_dir.cleanup() + + logger.info(f"The {args.model_name} ONNX model has been successfully created with the TorchScript exporter!") + + +# Optimize the model as FP32 +def optimize_export(config: LlamaConfig, input_path: str, output_path: str): + from fusion_options import FusionOptions + + optimization_options = FusionOptions("gpt2") + + model_opt = optimize_model( + input_path, + model_type="gpt2", + num_heads=config.num_attention_heads, + hidden_size=config.hidden_size, + opt_level=0, + optimization_options=optimization_options, + only_onnxruntime=False, + ) + model_opt.save_model_to_file(output_path, use_external_data_format=True) + logger.info(f"The ONNX model at {input_path} has been successfully optimized and saved at {output_path}!") + remove_existing_model(input_path) + + +def convert_to_float16(args: argparse.Namespace, config: LlamaConfig, old_paths: List[str]): + decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") + decoder_with_past_model_fp16_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" + ) + decoder_merged_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp16.onnx") + new_paths = [decoder_model_fp16_path, decoder_with_past_model_fp16_path, decoder_merged_model_fp16_path] + + logger.info("Converting to float16...") + for fp32_path, fp16_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + model = OnnxModel(onnx.load_model(fp32_path, load_external_data=True)) + model.convert_float_to_float16(keep_io_types=False) + model = use_group_query_attention(config, model) + model.save_model_to_file(fp16_path, use_external_data_format=True) + del model + logger.info(f"The ONNX model at {fp32_path} has been converted to float16 and saved at {fp16_path}!") + remove_existing_model(fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully converted to float16!") + return new_paths + + +def use_group_query_attention(config: LlamaConfig, fp16_model_opt: OnnxModel): + # Replace MultiHeadAttention with GroupQueryAttention and remove attention mask nodes + fp16_model_opt = replace_mha_with_gqa(fp16_model_opt, "past_sequence_length", config.num_key_value_heads) + fp16_model_opt.prune_graph() + fp16_model_opt.update_graph(allow_remove_graph_inputs=True) + return fp16_model_opt + + +def smooth_quant( + args: argparse.Namespace, + decoder_model_fp32_path: str, + decoder_with_past_model_fp32_path: str, + decoder_model_int8_path: str, + decoder_with_past_model_int8_path: str, +): + from neural_compressor import PostTrainingQuantConfig + from neural_compressor import quantization as intel_quantization + from neural_compressor import set_workspace + from onnx.external_data_helper import load_external_data_for_model + from quant_kv_dataloader import QuantKVDataLoader + + set_workspace(args.nc_workspace) + quantization_config = PostTrainingQuantConfig( + calibration_sampling_size=[args.calibration_sampling_size], + recipes={ + "optypes_to_exclude_output_quant": ["MatMul"], + "smooth_quant": args.smooth_quant, + "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, + }, + op_type_dict={ + "^((?!(MatMul|Gather|Conv)).)*$": { + "weight": {"dtype": ["fp32"]}, + "activation": {"dtype": ["fp32"]}, + } + }, + ) + + # Convert decoder_model.onnx to INT8 + decoder_model_int8 = intel_quantization.fit( + decoder_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args), + ) + load_external_data_for_model( + decoder_model_int8._model, + os.path.split(decoder_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_model_int8._model, + decoder_model_int8_path, + f"{args.model_name}_decoder_model_int8.onnx.data", + ) + del decoder_model_int8 + logger.info( + f"The ONNX model at {decoder_model_fp32_path} has been quantized to int8 and saved at {decoder_model_int8_path}!" + ) + remove_existing_model(decoder_model_fp32_path) + + # Convert decoder_with_past_model.onnx to INT8 + decoder_with_past_model_int8 = intel_quantization.fit( + decoder_with_past_model_fp32_path, + quantization_config, + calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path), + ) + load_external_data_for_model( + decoder_with_past_model_int8._model, + os.path.split(decoder_with_past_model_int8._model_path)[0], + ) + save_onnx_model( + decoder_with_past_model_int8._model, + decoder_with_past_model_int8_path, + f"{args.model_name}_decoder_with_past_model_int8.onnx.data", + ) + del decoder_with_past_model_int8 + logger.info( + f"The ONNX model at {decoder_with_past_model_fp32_path} has been quantized to int8 and saved at {decoder_with_past_model_int8_path}!" + ) + remove_existing_model(decoder_with_past_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") + + logger.info(f"Removing {args.nc_workspace}") + os.system(f"rm -R {args.nc_workspace}") + + +def remove_existing_model(model_path: str): + # Remove ONNX model and its external data + data_path = os.path.join(model_path + ".data") + os.remove(model_path) + os.remove(data_path) + logger.warning(f"Removed {model_path} and {data_path}") + + def remove_existing_files(output_path: str): for filename in os.listdir(output_path): filepath = os.path.join(output_path, filename) if ".onnx" in filename or ".onnx.data" in filename: os.remove(filepath) - logger.warning(f"Removing {filepath}") + logger.warning(f"Removed {filepath}") def get_args(): @@ -288,7 +513,7 @@ def get_args(): required=False, type=Precision, default=Precision.FLOAT32, - choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8], + choices=[Precision.FLOAT32, Precision.FLOAT16, Precision.INT8, Precision.INT4], help="Precision to export model in", ) @@ -301,15 +526,51 @@ def get_args(): help="Execution provider to verify parity with", ) + parser.add_argument( + "-id", + "--device-id", + required=False, + type=str, + default="0", + help="Device ID for GPUs", + ) + + parser.add_argument( + "-r", + "--reexport", + required=False, + action="store_true", + help="Re-export models and overwrite existing models in output folder", + ) + parser.set_defaults(reexport=False) + + parser.add_argument( + "--no_merged", + required=False, + action="store_true", + help="Export models into 2 ONNX files instead of 1. Deprecated in favor of exporting into 1 ONNX file.", + ) + parser.set_defaults(no_merged=False) + parser.add_argument( "-q", "--quantization_method", default="", - choices=["smooth_quant", "quantize_dynamic"], - help="Run a specific quantization algorithm. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", + choices=["blockwise", "smooth_quant", "quantize_dynamic"], + help="Run a specific quantization algorithm (blockwise for int4, smooth_quant for int8, quantize_dynamic for int8). Blockwise is recommended. Need to install extra packages in `requirements-quant.txt` for SmoothQuant.", ) - smooth_quant_group = parser.add_argument_group("smooth_quant") + blockwise_group = parser.add_argument_group("4-bit quantization") + + blockwise_group.add_argument( + "--block_size", + required=False, + default=32, + type=int, + help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.", + ) + + smooth_quant_group = parser.add_argument_group("smooth_quant (8-bit quantization)") smooth_quant_group.add_argument( "--smooth_quant_alpha", @@ -352,7 +613,7 @@ def get_args(): help="Workspace to save intermediate files generated by Intel's Neural Compressor package.", ) - quantize_dynamic_group = parser.add_argument_group("quantize_dynamic") + quantize_dynamic_group = parser.add_argument_group("quantize_dynamic (8-bit quantization)") quantize_dynamic_group.add_argument( "--quantize_embedding_layer", @@ -399,177 +660,193 @@ def get_args(): def main(): + if version.parse(torch.__version__) < version.parse("2.2.0") and "2.2.0.dev" not in torch.__version__: + # Second predicate is for comparing nightly (ex: 2.2.0.dev20230920 vs 2.2.0) since first predicate is false + # in that scenario. It can be removed when torch v2.2.0 is released in stable. + logger.error(f"Detected PyTorch version {torch.__version__}. Please upgrade and use v2.2.0 or newer.") + return + args = get_args() setup_logger(args.verbose) prepare_environment(args.input, args.output, args.execution_provider != "cpu") - remove_existing_files(args.output) + if args.reexport: + remove_existing_files(args.output) logger.info(f"Arguments: {args}") # Load model and config use_auth_token = args.input == os.path.join(".") setattr(args, "use_auth_token", use_auth_token) # noqa: B010 - l_config = LlamaConfig.from_pretrained( - args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token - ) - llama = LlamaForCausalLM.from_pretrained( - args.model_name if use_auth_token else args.input, use_auth_token=use_auth_token, use_cache=True - ) + + location = args.model_name if use_auth_token else args.input + l_config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) + llama = LlamaForCausalLM.from_pretrained(location, use_auth_token=use_auth_token, use_cache=True) original_model_name = args.model_name setattr(args, "original_model_name", original_model_name) # noqa: B010 args.model_name = args.model_name.split("/")[-1] - # Export to ONNX - if args.use_dynamo_export: - logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") - logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") - logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") - logger.warning( - "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" - ) - logger.warning( - "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." - ) - run_dynamo_export(args, l_config, llama) - else: - run_torchscript_export(args, l_config, llama) - - # Change precision of exported models if not FP32 + # Set model paths for FP32 model decoder_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32.onnx") decoder_with_past_model_fp32_path = os.path.join( args.output, f"{args.model_name}_decoder_with_past_model_fp32.onnx" ) + decoder_merged_model_fp32_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_fp32.onnx") + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + missing_separate_exports = ( + args.no_merged + and not os.path.exists(decoder_model_fp32_path) + and not os.path.exists(decoder_with_past_model_fp32_path) + ) + missing_merged_export = not args.no_merged and not os.path.exists(decoder_merged_model_fp32_path) + + # Export to ONNX + if missing_separate_exports or missing_merged_export: + if args.use_dynamo_export and missing_separate_exports: + logger.warning("Please ensure you have installed PyTorch, ONNX, and ONNX Script as follows.") + logger.warning("Step 1 - PyTorch nightly: https://pytorch.org/get-started/locally/") + logger.warning("Step 2 - ONNX weekly: https://pypi.org/project/onnx-weekly/") + logger.warning( + "Step 3 - ONNX Script from source: https://github.com/microsoft/onnxscript#installing-onnx-script" + ) + logger.warning( + "Note: After you install ONNX weekly, omit `onnx` when running the first line for installing ONNX Script. This is because you already installed `onnx-weekly` in the previous step." + ) + run_dynamo_export(args, l_config, llama) + elif args.no_merged: + run_torchscript_separate_export(args, l_config, llama) + else: + run_torchscript_merged_export(args, l_config, llama) + + # Set model paths to store FP32 optimized model + decoder_model_fp32_opt_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp32_opt.onnx") + decoder_with_past_model_fp32_opt_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_fp32_opt.onnx" + ) + decoder_merged_model_fp32_opt_path = os.path.join( + args.output, f"{args.model_name}_decoder_merged_model_fp32_opt.onnx" + ) + new_paths = [decoder_model_fp32_opt_path, decoder_with_past_model_fp32_opt_path, decoder_merged_model_fp32_opt_path] + + # Run the optimizer script + logger.info("Optimizing models...") + for orig_path, opt_path in zip(old_paths, new_paths): + if os.path.exists(orig_path): + optimize_export(l_config, input_path=orig_path, output_path=opt_path) + + # Re-assign default FP32 model paths as their optimized versions + decoder_model_fp32_path = decoder_model_fp32_opt_path + decoder_with_past_model_fp32_path = decoder_with_past_model_fp32_opt_path + decoder_merged_model_fp32_path = decoder_merged_model_fp32_opt_path + old_paths = [decoder_model_fp32_path, decoder_with_past_model_fp32_path, decoder_merged_model_fp32_path] + + logger.info( + f"The {args.model_name} ONNX model has been successfully optimized with the ORT transformer optimizer script!" + ) + # Change precision of exported models from FP32 if args.precision == Precision.FLOAT16: - # Convert decoder_model.onnx to FP16 - decoder_model_fp16_path = os.path.join(args.output, f"{args.model_name}_decoder_model_fp16.onnx") - model = OnnxModel(onnx.load_model(decoder_model_fp32_path, load_external_data=True)) - model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"]) - model.save_model_to_file(decoder_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True) - del model - - # Convert decoder_with_past_model.onnx to FP16 - decoder_with_past_model_fp16_path = os.path.join( - args.output, f"{args.model_name}_decoder_with_past_model_fp16.onnx" - ) - model = OnnxModel(onnx.load_model(decoder_with_past_model_fp32_path, load_external_data=True)) - model.convert_float_to_float16(keep_io_types=False, op_block_list=["If"]) - model.save_model_to_file( - decoder_with_past_model_fp16_path, use_external_data_format=True, all_tensors_to_one_file=True - ) - del model + new_paths = convert_to_float16(args, l_config, old_paths) elif args.precision == Precision.INT8: decoder_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int8.onnx") decoder_with_past_model_int8_path = os.path.join( args.output, f"{args.model_name}_decoder_with_past_model_int8.onnx" ) + decoder_merged_model_int8_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int8.onnx") + new_paths = [decoder_model_int8_path, decoder_with_past_model_int8_path, decoder_merged_model_int8_path] if args.quantization_method == "smooth_quant": - from neural_compressor import PostTrainingQuantConfig - from neural_compressor import quantization as intel_quantization - from neural_compressor import set_workspace - from onnx.external_data_helper import load_external_data_for_model - from quant_kv_dataloader import QuantKVDataLoader - - set_workspace(args.nc_workspace) - quantization_config = PostTrainingQuantConfig( - calibration_sampling_size=[args.calibration_sampling_size], - recipes={ - "optypes_to_exclude_output_quant": ["MatMul"], - "smooth_quant": args.smooth_quant, - "smooth_quant_args": {"alpha": args.smooth_quant_alpha}, - }, - op_type_dict={ - "^((?!(MatMul|Gather|Conv)).)*$": { - "weight": {"dtype": ["fp32"]}, - "activation": {"dtype": ["fp32"]}, - } - }, - ) - - # Convert decoder_model.onnx to INT8 - decoder_model_int8 = intel_quantization.fit( - decoder_model_fp32_path, - quantization_config, - calib_dataloader=QuantKVDataLoader(args), - ) - load_external_data_for_model( - decoder_model_int8._model, - os.path.split(decoder_model_int8._model_path)[0], - ) - save_onnx_model( - decoder_model_int8._model, - decoder_model_int8_path, - f"{args.model_name}_decoder_model_int8.onnx.data", - ) - del decoder_model_int8 - - # Convert decoder_with_past_model.onnx to INT8 - decoder_with_past_model_int8 = intel_quantization.fit( - decoder_with_past_model_fp32_path, - quantization_config, - calib_dataloader=QuantKVDataLoader(args, onnx_model_path=decoder_model_fp32_path), - ) - load_external_data_for_model( - decoder_with_past_model_int8._model, - os.path.split(decoder_with_past_model_int8._model_path)[0], - ) - save_onnx_model( - decoder_with_past_model_int8._model, - decoder_with_past_model_int8_path, - f"{args.model_name}_decoder_with_past_model_int8.onnx.data", - ) - del decoder_with_past_model_int8 - - logger.info(f"Removing {args.nc_workspace}") - os.system(f"rm -R {args.nc_workspace}") + if not args.no_merged: + logger.error("SmoothQuant must be used on separately exported models") + else: + logger.info(f"Quantizing {decoder_model_fp32_path} and {decoder_with_past_model_fp32_path} to int8") + smooth_quant(args, old_paths[0], old_paths[1], new_paths[0], new_paths[1]) elif args.quantization_method == "quantize_dynamic": logger.warning( "The `quantize_dynamic` method is deprecated in favor of `smooth_quant` instead. Precision loss may be high with `quantize_dynamic`." ) - # Convert decoder_model.onnx to INT8 - ort_quantization.quantize_dynamic( - decoder_model_fp32_path, - decoder_model_int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) - - # Convert decoder_with_past_model.onnx to INT8 - ort_quantization.quantize_dynamic( - decoder_with_past_model_fp32_path, - decoder_with_past_model_int8_path, - op_types_to_quantize=["MatMul", "Gemm", "Gather"] - if args.quantize_embedding_layer - else ["MatMul", "Gemm"], - per_channel=args.quantize_per_channel, - reduce_range=args.quantize_reduce_range, - use_external_data_format=True, - extra_options={"MatMulConstBOnly": True}, - ) + logger.info("Quantizing to int8...") + for fp32_path, int8_path in zip(old_paths, new_paths): + if os.path.exists(fp32_path): + ort_quantization.quantize_dynamic( + fp32_path, + int8_path, + op_types_to_quantize=["MatMul", "Gemm", "Gather"] + if args.quantize_embedding_layer + else ["MatMul", "Gemm"], + per_channel=args.quantize_per_channel, + reduce_range=args.quantize_reduce_range, + use_external_data_format=True, + extra_options={"MatMulConstBOnly": True}, + ) + logger.info(f"The ONNX model at {fp32_path} has been quantized to int8 and saved at {int8_path}!") + remove_existing_model(decoder_model_fp32_path) + + logger.info(f"The {args.model_name} ONNX model has been successfully quantized to int8!") else: raise Exception(f"Could not recognize {args.quantization_method} as a quantization method") - # Verify parity on all saved ONNX models + elif args.precision == Precision.INT4: + if args.execution_provider != "cpu": + old_paths = convert_to_float16(args, l_config, old_paths) + + decoder_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_model_int4.onnx") + decoder_with_past_model_int4_path = os.path.join( + args.output, f"{args.model_name}_decoder_with_past_model_int4.onnx" + ) + decoder_merged_model_int4_path = os.path.join(args.output, f"{args.model_name}_decoder_merged_model_int4.onnx") + new_paths = [decoder_model_int4_path, decoder_with_past_model_int4_path, decoder_merged_model_int4_path] + + for fp_path, int4_path in zip(old_paths, new_paths): + if os.path.exists(fp_path): + model = onnx.load_model(fp_path, load_external_data=True) + quant = MatMul4BitsQuantizer(model, args.block_size, is_symmetric=True, nodes_to_exclude=[]) + quant.process() + quant.model.save_model_to_file(int4_path, use_external_data_format=True) + del model + del quant + logger.info(f"The ONNX model at {fp_path} has been quantized to int4 and saved at {int4_path}!") + remove_existing_model(fp_path) + del llama # Delete LLaMA model from memory since it will be loaded again during parity check logger.info("Verifying parity on all ONNX models created") + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {"int8", "fp32"} or (args.precision == Precision.INT4 and args.execution_provider == "cpu") + else "fp16" + ) + + # Verify parity on all saved ONNX models for filename in os.listdir(args.output): if ".data" in filename or ".onnx" not in filename: continue - precision = filename[filename.rfind("_") + 1 : filename.find(".onnx")] - parity_cmd = ["-m", f"{original_model_name}", "-o", f"{os.path.join(args.output, filename)}", "-fp", precision] + parity_cmd = [ + "-m", + original_model_name, + "-o", + os.path.join(args.output, filename), + "-ep", + args.execution_provider, + "-id", + args.device_id, + "-fp", + args.precision, + ] if "with_past" in filename: parity_cmd.append("--use_past_kv") - parity_check(parity_cmd) + if "merged" in filename: + parity_cmd.append("--merged") + + try: + parity_check(parity_cmd) + except Exception as e: + logger.warning(f"An error occurred while verifying parity: {e}", exc_info=True) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py index 6a28498a9ffc9..2652e9f0ca64e 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_inputs.py @@ -4,10 +4,13 @@ import torch from transformers import LlamaConfig +from onnxruntime import OrtValue + # Get position_ids from attention_mask def get_position_ids(attention_mask: torch.Tensor, use_past_kv: bool): position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) if use_past_kv: position_ids = position_ids[:, -1].unsqueeze(-1) return position_ids @@ -62,11 +65,41 @@ def get_sample_with_past_kv_inputs( return inputs +# Inputs for all passes with past_key_values +def get_merged_sample_with_past_kv_inputs( + config: LlamaConfig, + device: torch.device, + batch_size: int, + seq_len: int, + past_seq_len: int, + use_fp16: bool = False, + return_dict: bool = False, +): + input_ids = torch.randint( + low=0, high=config.vocab_size, size=(batch_size, seq_len), device=device, dtype=torch.int64 + ) + attention_mask = torch.ones(batch_size, past_seq_len + seq_len, device=device, dtype=torch.int64) + # position_ids is of shape (batch_size, seq_len) for prompt generation, (batch_size, 1) for token generation + position_ids = get_position_ids(attention_mask, use_past_kv=(past_seq_len != 0)) + past_kv = get_sample_past_kv_inputs(config, device, batch_size, past_seq_len, use_fp16) + + if not return_dict: + return (input_ids, attention_mask, position_ids, past_kv) + + inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "position_ids": position_ids, + "past_key_values": past_kv, + } + return inputs + + # Create past_key_values def get_sample_past_kv_inputs( config: LlamaConfig, device: torch.device, batch_size: int, past_seq_len: int, use_fp16: bool ): - num_heads, head_size = config.num_attention_heads, config.hidden_size // config.num_attention_heads + num_heads, head_size = config.num_key_value_heads, config.hidden_size // config.num_key_value_heads torch_dtype = torch.float16 if use_fp16 else torch.float32 past_kv = [ ( @@ -89,31 +122,83 @@ def flatten_past_kv_inputs(past_key_values: List[Tuple[torch.Tensor, torch.Tenso # Format PyTorch inputs to ONNX Runtime inputs -def convert_inputs_for_ort(pt_inputs: dict, use_fp16: bool): +def convert_inputs_for_ort( + pt_inputs: dict, + use_fp16: bool, + use_buffer_share: bool = False, + past_seq_len: int = 0, + max_seq_len: int = 2048, + device: str = "", + device_id: int = -1, +): ort_inputs = {} for k, v in pt_inputs.items(): - if k == "past_key_values": + if isinstance(v, np.ndarray): + ort_inputs[k] = v + elif k == "past_key_values": ort_inputs.update(flatten_past_kv_inputs(v, use_fp16)) + elif k == "attention_mask" and use_fp16 and use_buffer_share: + # Skip because FP16 model has GroupQueryAttention, uses buffer sharing, + # and GQA supports a causal mask by default + + # Instead, add the past sequence length input for GQA + ort_inputs["past_sequence_length"] = np.array([past_seq_len], dtype=np.int64) else: ort_inputs[k] = v.detach().cpu().numpy() + + # Enable past-present-share-buffer by using device memory directly + if use_buffer_share and device != "" and device != "cpu" and device_id > -1: + for k, v in ort_inputs.items(): + new_v = v + # Allocate new buffers with max_sequence_length for GQA + if "cache" in k or "past_key_values" in k: + # Copy v (BxSxPxH) into new_v (BxSxMxH) + batch_size, num_heads, _, head_size = v.shape + new_v = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=v.dtype) + new_v[:batch_size, :num_heads, :past_seq_len, :head_size] = v + ort_inputs[k] = OrtValue.ortvalue_from_numpy(new_v, device_type=device, device_id=device_id) + return ort_inputs # Inputs for Microsoft export from https://github.com/microsoft/Llama-2-Onnx -def get_msft_sample_inputs(config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool): +def get_msft_sample_inputs( + config: LlamaConfig, batch_size: int, past_seq_len: int, seq_len: int, use_fp16: bool, split_kv: bool +): np_dtype = np.float16 if use_fp16 else np.float32 head_size = config.hidden_size // config.num_attention_heads max_seq_len = 2048 - ort_inputs = { - "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), - "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), - "k_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "v_cache": np.random.rand( - batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size - ).astype(np_dtype), - "pos": np.array(past_seq_len, dtype=np.int64), - } + if not split_kv: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (-10000.0 * np.triu(np.ones((batch_size, max_seq_len, max_seq_len)), k=1)).astype(np_dtype), + "k_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "v_cache": np.random.rand( + batch_size, config.num_hidden_layers, past_seq_len, config.num_attention_heads, head_size + ).astype(np_dtype), + "pos": np.array(past_seq_len, dtype=np.int64), + } + else: + ort_inputs = { + "x": np.random.rand(batch_size, seq_len, config.hidden_size).astype(np_dtype), + "attn_mask": (np.triu(np.ones((batch_size, max_seq_len, max_seq_len), dtype=np.int32), k=1) - 1).astype( + np.int32 + ), + "pos": np.array(past_seq_len, dtype=np.int64), + } + for i in range(config.num_hidden_layers): + ort_inputs.update( + { + f"k_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + f"v_{i}_cache": np.random.rand( + batch_size, config.num_attention_heads, past_seq_len, head_size + ).astype(np_dtype), + } + ) + return ort_inputs diff --git a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py index dadf394440c9a..6bfcb9b4f290d 100644 --- a/onnxruntime/python/tools/transformers/models/llama/llama_parity.py +++ b/onnxruntime/python/tools/transformers/models/llama/llama_parity.py @@ -1,44 +1,143 @@ import argparse import logging import os +import time from typing import List import numpy as np import torch -from benchmark_helper import create_onnxruntime_session, setup_logger -from llama_inputs import convert_inputs_for_ort, get_sample_inputs, get_sample_with_past_kv_inputs +from benchmark_helper import setup_logger +from llama_inputs import ( + convert_inputs_for_ort, + get_merged_sample_with_past_kv_inputs, + get_sample_inputs, + get_sample_with_past_kv_inputs, +) from transformers import LlamaConfig, LlamaForCausalLM +import onnxruntime as ort + logger = logging.getLogger("") -def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): +def get_sequence_lengths(args: argparse.Namespace): + past_sequence_length, curr_sequence_length = (8, 1) if args.use_past_kv else (0, 8) + max_sequence_length = 2048 + return past_sequence_length, curr_sequence_length, max_sequence_length + + +def get_inputs(args: argparse.Namespace, config: LlamaConfig): # Dummy values for parity - batch_size, sequence_length = 2, 8 - device = torch.device("cpu") + batch_size = 2 + past_sequence_length, sequence_length, _ = get_sequence_lengths(args) - # Run inference with PyTorch - inputs = ( - get_sample_inputs(config, device, batch_size, sequence_length, return_dict=True) - if not args.use_past_kv - else get_sample_with_past_kv_inputs( - config, device, batch_size, sequence_length, use_fp16=(args.precision == "fp16"), return_dict=True + if args.merged: + inputs = get_merged_sample_with_past_kv_inputs( + config, + args.device, + batch_size, + sequence_length, + past_sequence_length, + use_fp16=args.use_fp16, + return_dict=True, ) - ) + elif args.use_past_kv: + inputs = get_sample_with_past_kv_inputs( + config, args.device, batch_size, sequence_length, use_fp16=args.use_fp16, return_dict=True + ) + else: + inputs = get_sample_inputs(config, args.device, batch_size, sequence_length, return_dict=True) + + return inputs + + +def add_io_bindings(args: argparse.Namespace, model: ort.InferenceSession, inputs: dict): + # Add IO bindings for non-CPU execution providers + io_binding = model.io_binding() + + for k, v in inputs.items(): + if args.use_fp16: + # Bind all OrtValue inputs to device + io_binding.bind_ortvalue_input(k, v) + else: + io_binding.bind_cpu_input(k, v) + + for output in model.get_outputs(): + name = output.name + if args.use_fp16 and ("out" in name or "present" in name): + # Bind present KV cache outputs to OrtValue with buffer sharing + io_binding.bind_ortvalue_output( + name, inputs[name.replace("out", "cache").replace("present", "past_key_values")] + ) + else: + io_binding.bind_output(name, device_type=args.execution_provider, device_id=int(args.device_id)) + + return io_binding + + +def verify_parity(args: argparse.Namespace, config: LlamaConfig, pt_model: LlamaForCausalLM): + inputs = get_inputs(args, config) + + # Run inference with PyTorch + if args.execution_provider != "cpu": + torch.cuda.synchronize() + start_time = time.time() pt_outputs = pt_model(**inputs).logits.detach().cpu().numpy() + if args.execution_provider != "cpu": + torch.cuda.synchronize() + end_time = time.time() + logger.info(f"PyTorch took {end_time - start_time} s") # Run inference with ORT - inputs = convert_inputs_for_ort(inputs, use_fp16=(args.precision == "fp16")) - ort_model = create_onnxruntime_session( + past_sequence_length, _, max_sequence_length = get_sequence_lengths(args) + inputs = convert_inputs_for_ort( + inputs, + use_fp16=args.use_fp16, + use_buffer_share=args.use_fp16, + past_seq_len=past_sequence_length, + max_seq_len=max_sequence_length, + device=args.execution_provider, + device_id=int(args.device_id), + ) + + ep = f"{args.execution_provider.upper()}ExecutionProvider" + if ep == "CUDAExecutionProvider": + ep = (ep, {"device_id": args.device_id}) + ort_model = ort.InferenceSession( args.onnx_model_path, - args.execution_provider != "cpu", # use_gpu - provider=args.execution_provider, - verbose=args.verbose, + sess_options=ort.SessionOptions(), + providers=[ep], ) - ort_outputs = ort_model.run(None, inputs)[0] + + # Add IO bindings for non-CPU execution providers + if args.execution_provider != "cpu": + io_binding = add_io_bindings(args, ort_model, inputs) + + torch.cuda.synchronize() + start_time = time.time() + ort_model.run_with_iobinding(io_binding) + torch.cuda.synchronize() + end_time = time.time() + + ort_outputs = io_binding.copy_outputs_to_cpu()[0] # Get logits + + else: + start_time = time.time() + ort_outputs = ort_model.run(None, inputs) + end_time = time.time() + + ort_outputs = ort_outputs[0] # Get logits + + logger.info(f"ONNX Runtime took {end_time - start_time} s") # Compare PyTorch and ONNX Runtime accuracy - tol = 1e-3 if args.precision == "fp32" else 1e-2 if args.precision == "fp16" else 1e2 + tol = ( + 2e1 + if "int4" in args.onnx_model_path or "int8" in args.onnx_model_path + else 1e-3 + if args.precision == "fp32" + else 5e-1 + ) parity = np.allclose(pt_outputs, ort_outputs, rtol=tol, atol=tol) logger.warning(f"Are PyTorch and ONNX Runtime results close? {parity}") if not parity: @@ -80,6 +179,15 @@ def get_args(argv: List[str]): help="Execution provider to verify parity with", ) + parser.add_argument( + "-id", + "--device-id", + required=False, + type=str, + default="0", + help="Device ID for GPUs", + ) + parser.add_argument( "-v", "--verbose", @@ -96,15 +204,29 @@ def get_args(argv: List[str]): ) parser.set_defaults(use_past_kv=False) + parser.add_argument( + "--merged", + action="store_true", + help="Use merged model (i.e. decoder_merged_model.onnx).", + ) + parser.set_defaults(merged=False) + parser.add_argument( "-fp", "--precision", required=True, - choices=["int8", "fp16", "fp32"], + choices=["int4", "int8", "fp16", "fp32"], help="Precision of model", ) args = parser.parse_args() if argv == [] else parser.parse_args(argv) + + # Use FP32 precision for FP32, INT8, INT4 CPU models, use FP16 precision for FP16 and INT4 GPU models + args.precision = ( + "fp32" + if args.precision in {"int8", "fp32"} or (args.precision == "int4" and args.execution_provider == "cpu") + else "fp16" + ) return args @@ -114,19 +236,34 @@ def main(argv: List[str] = []): # noqa: B006 logger.info(f"Arguments: {args}") # Load model and config + setattr(args, "use_fp16", args.precision == "fp16") # noqa: B010 + setattr(args, "device_name", "cpu" if args.execution_provider == "cpu" else f"cuda:{args.device_id}") # noqa: B010 + setattr(args, "device", torch.device(args.device_name)) # noqa: B010 use_auth_token = args.torch_model_directory == os.path.join(".") location = args.model_name if use_auth_token else args.torch_model_directory config = LlamaConfig.from_pretrained(location, use_auth_token=use_auth_token) llama = LlamaForCausalLM.from_pretrained( location, - torch_dtype=(torch.float16 if args.precision == "fp16" else torch.float32), + torch_dtype=(torch.float16 if args.use_fp16 else torch.float32), use_auth_token=use_auth_token, use_cache=True, - ) + ).to(args.device) + + if not args.merged: + verify_parity(args, config, llama) + else: + # Verify prompt generation in merged model (decoder_model.onnx) + args.use_past_kv = False + verify_parity(args, config, llama) - verify_parity(args, config, llama) + # Verify token generation in merged model (decoder_with_past_model.onnx) + args.use_past_kv = True + verify_parity(args, config, llama) if __name__ == "__main__": + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) main() diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt index e9ad937cf14e7..e06c3ada834b0 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cpu.txt @@ -1,3 +1,2 @@ -r requirements.txt -torch>=2.0.1 -onnxruntime>=1.16.0 \ No newline at end of file +onnxruntime>=1.17.0 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt index 5544abcaa1228..773680937bd21 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements-cuda.txt @@ -1,4 +1,4 @@ -r requirements.txt -# Please manually install torch>=2.0.1 with CUDA enabled for the CUDA version installed in your system. +# Please manually install torch>=2.2.0.dev20230920 with CUDA enabled for the CUDA version installed in your system. # Instructions can be found here: https://pytorch.org/get-started/locally/ -onnxruntime-gpu>=1.16.0 \ No newline at end of file +onnxruntime-gpu>=1.17.0 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/llama/requirements.txt b/onnxruntime/python/tools/transformers/models/llama/requirements.txt index f843ef4dc5568..4210f36982aef 100644 --- a/onnxruntime/python/tools/transformers/models/llama/requirements.txt +++ b/onnxruntime/python/tools/transformers/models/llama/requirements.txt @@ -1,5 +1,6 @@ -git+https://github.com/kunal-vaishnavi/optimum.git@kvaishnavi/llama-add-position-ids -transformers>=4.28.1 +git+https://github.com/huggingface/optimum.git +transformers>=4.33.2 +torch>=2.2.0.dev20230920 onnx>=1.14.0 datasets>=2.8.0 protobuf==3.20.2 \ No newline at end of file diff --git a/onnxruntime/python/tools/transformers/models/whisper/README.md b/onnxruntime/python/tools/transformers/models/whisper/README.md index e9365becd2cd1..8ff5c8a6e1de0 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/README.md +++ b/onnxruntime/python/tools/transformers/models/whisper/README.md @@ -79,24 +79,22 @@ $ python3 -m onnxruntime.transformers.models.whisper.convert_to_onnx -m openai/w Here are some examples of how you can benchmark Whisper across various end-to-end (E2E) implementations. -Note: In the below examples, `PyTorch` refers to running in PyTorch without `torch.compile` and `PyTorch 2.0` refers to running in PyTorch with `torch.compile`. - ### Variants -1. PyTorch (without `torch.compile`), FP32 +1. PyTorch without `torch.compile`, FP32 ``` python3 -m models.whisper.benchmark \ - --benchmark-type hf-pt \ + --benchmark-type hf-pt-eager \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ --precision fp32 \ --device cpu ``` -2. PyTorch 2.0 (with `torch.compile`), FP16 +2. PyTorch with `torch.compile`, FP16 ``` python3 -m models.whisper.benchmark \ - --benchmark-type hf-pt2 \ + --benchmark-type hf-pt-compile \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ --precision fp16 \ @@ -109,7 +107,7 @@ python3 -m models.whisper.benchmark \ --benchmark-type hf-ort \ --audio-path 1272-141231-0002.mp3 \ --model-name openai/whisper-large-v2 \ - --hf-ort-model-path ./whisper-large-v2-onnx/ \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ --precision fp32 \ --device cpu ``` @@ -156,7 +154,9 @@ You can use `benchmark_all.py` to benchmark across various platforms and automat ``` python3 -m models.whisper.benchmark_all \ --audio-path ./whisper-test-audios/ \ - --hf-ort-model-path ./whisper-large-v2-onnx/ \ + --hf-pt-eager \ + --hf-pt-compile \ + --hf-ort-dir-path ./whisper-large-v2-onnx/ \ --ort-model-path ./wlarge-fp32/whisper-large-v2_all.onnx \ --model-name openai/whisper-large-v2 \ --precision fp32 \ @@ -169,28 +169,28 @@ Here is a benchmark for an MP3 file with 20.7s of audio. #### FP16 -| Engine | Size | Per-Token Latency | Real-Time Factor | -| ------------- | -------- | ----------------- | ---------------- | -| PyTorch | Tiny | 4.697 ms/token | 0.004697 | -| PyTorch 2.0 | Tiny | 3.406 ms/token | 0.003406 | -| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 | -| PyTorch | Medium | 17.837 ms/token | 0.017387 | -| PyTorch 2.0 | Medium | 18.124 ms/token | 0.018124 | -| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 | -| PyTorch | Large v2 | 23.470 ms/token | 0.023470 | -| PyTorch 2.0 | Large v2 | 23.146 ms/token | 0.023146 | -| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 | +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 4.697 ms/token | 0.004697 | +| PyTorch compile | Tiny | 3.406 ms/token | 0.003406 | +| ONNX Runtime | Tiny | 0.746 ms/token | 0.000746 | +| PyTorch eager | Medium | 17.837 ms/token | 0.017387 | +| PyTorch compile | Medium | 18.124 ms/token | 0.018124 | +| ONNX Runtime | Medium | 3.894 ms/token | 0.003894 | +| PyTorch eager | Large v2 | 23.470 ms/token | 0.023470 | +| PyTorch compile | Large v2 | 23.146 ms/token | 0.023146 | +| ONNX Runtime | Large v2 | 6.262 ms/token | 0.006262 | #### FP32 -| Engine | Size | Per-Token Latency | Real-Time Factor | -| ------------- | -------- | ----------------- | ---------------- | -| PyTorch | Tiny | 6.220 ms/token | 0.006220 | -| PyTorch 2.0 | Tiny | 3.944 ms/token | 0.003944 | -| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 | -| PyTorch | Medium | 19.093 ms/token | 0.019093 | -| PyTorch 2.0 | Medium | 20.459 ms/token | 0.020459 | -| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 | -| PyTorch | Large v2 | 25.844 ms/token | 0.025844 | -| PyTorch 2.0 | Large v2 | 26.397 ms/token | 0.026397 | -| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 | +| Engine | Size | Per-Token Latency | Real-Time Factor | +| --------------- | -------- | ----------------- | ---------------- | +| PyTorch eager | Tiny | 6.220 ms/token | 0.006220 | +| PyTorch compile | Tiny | 3.944 ms/token | 0.003944 | +| ONNX Runtime | Tiny | 1.545 ms/token | 0.001545 | +| PyTorch eager | Medium | 19.093 ms/token | 0.019093 | +| PyTorch compile | Medium | 20.459 ms/token | 0.020459 | +| ONNX Runtime | Medium | 9.440 ms/token | 0.009440 | +| PyTorch eager | Large v2 | 25.844 ms/token | 0.025844 | +| PyTorch compile | Large v2 | 26.397 ms/token | 0.026397 | +| ONNX Runtime | Large v2 | 7.492 ms/token | 0.007492 | diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py index 283528bea7465..759ae6d14f184 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark.py @@ -24,7 +24,7 @@ def get_inputs(args: argparse.Namespace): - if args.benchmark_type not in {"hf-pt", "hf-pt2", "hf-ort", "ort"}: + if args.benchmark_type not in {"hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"}: raise Exception("Unable to auto-detect inputs for provided model") def load_via_ffmpeg(): @@ -102,7 +102,7 @@ def get_model(args: argparse.Namespace): # 2) Benchmark Whisper ONNX model from Optimum export (without pre/post processing) # 3) Benchmark Whisper ONNX E2E model from Olive (with pre/post processing) - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: source = args.hf_pt_model_path if args.hf_pt_model_path else args.model_name start_time = time.time() model = AutoModelForSpeechSeq2Seq.from_pretrained( @@ -112,7 +112,7 @@ def get_model(args: argparse.Namespace): ).to(args.target_device) end_time = time.time() - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": model = torch.compile(model) elif args.benchmark_type in {"hf-ort", "ort"}: @@ -136,7 +136,7 @@ def get_model(args: argparse.Namespace): start_time = time.time() model = ORTModelForSpeechSeq2Seq.from_pretrained( - args.hf_ort_model_path, + args.hf_ort_dir_path, use_io_binding=(args.device != "cpu"), provider=provider, provider_options=provider_options, @@ -214,7 +214,7 @@ def profile_fn(args, fn, inputs, inputs_type): prefix = f"{args.benchmark_type.lower()}-{args.precision}-{args.device}_{fn.__name__.replace('_', '-')}_{inputs_type}_{datetime.datetime.now():%Y-%m-%d_%H:%M:%S}" filename = None - if args.benchmark_type in {"hf-pt", "hf-pt2"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile"}: # Profile PyTorch kernels with profile( # noqa: SIM117 activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, profile_memory=True @@ -280,7 +280,7 @@ def gen_and_dec(inputs): generate_fn = gen_and_dec - if args.benchmark_type == "hf-pt2": + if args.benchmark_type == "hf-pt-compile": # Run forward pass once with each set of inputs to process through Dynamo generate_fn(inputs) @@ -345,7 +345,7 @@ def prepare_ort_inputs(inputs, warmup=False): for k, v in inputs.items(): io_binding.bind_cpu_input(k, v) for output in model.get_outputs(): - io_binding.bind_output(output.name) + io_binding.bind_output(output.name, device_type=args.device, device_id=args.device_id) return io_binding return inputs @@ -407,7 +407,7 @@ def handle_output(output): def run_inference(args, inputs, model): - if args.benchmark_type in {"hf-pt", "hf-pt2", "hf-ort"}: + if args.benchmark_type in {"hf-pt-eager", "hf-pt-compile", "hf-ort"}: run_hf_inference(args, inputs, model) elif args.benchmark_type == "ort": run_ort_inference(args, inputs, model) @@ -419,8 +419,13 @@ def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( - "-bt", "--benchmark-type", type=str, required=True, choices=["hf-pt", "hf-pt2", "hf-ort", "ort"] + "-bt", + "--benchmark-type", + type=str, + required=True, + choices=["hf-pt-eager", "hf-pt-compile", "hf-ort", "ort"], ) + parser.add_argument( "-m", "--model-name", @@ -445,7 +450,7 @@ def parse_args(): help="Path to directory containing all PyTorch files (e.g. tokenizer, PyTorch model)", ) parser.add_argument( - "--hf-ort-model-path", + "--hf-ort-dir-path", type=str, default="", help="Path to directory containing all ONNX files (e.g. tokenizer, encoder, decoder, decoder_with_past)", @@ -538,7 +543,7 @@ def parse_args(): # Check that model paths have been specified for any benchmarking with ORT if args.benchmark_type == "hf-ort": - assert args.hf_ort_model_path, "Please specify a path to `--hf-ort-model-path`" + assert args.hf_ort_dir_path, "Please specify a path to `--hf-ort-dir-path`" if args.benchmark_type == "ort": assert args.ort_model_path, "Please specify a path to `--ort-model-path`" diff --git a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py index 08d7befec3cfd..071b539ac1899 100644 --- a/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py +++ b/onnxruntime/python/tools/transformers/models/whisper/benchmark_all.py @@ -54,7 +54,21 @@ def get_args(): ) parser.add_argument( - "--hf-ort-model-path", + "--hf-pt-eager", + default=False, + action="store_true", + help="Benchmark in PyTorch without `torch.compile`", + ) + + parser.add_argument( + "--hf-pt-compile", + default=False, + action="store_true", + help="Benchmark in PyTorch with `torch.compile`", + ) + + parser.add_argument( + "--hf-ort-dir-path", type=str, help="Path to folder containing ONNX models for Optimum + ORT benchmarking", ) @@ -136,7 +150,7 @@ def process_log_file(device_id, log_file, base_results): load_audio_latency_s, load_audio_throughput_s = None, None feat_ext_latency_s, feat_ext_throughput_s = None, None - latency_s, per_token_latency_s, per_token_latency_ms = None, None, None + token_length, latency_s, per_token_latency_s, per_token_latency_ms = None, None, None, None throughput, memory = None, None # Detect metrics @@ -310,73 +324,75 @@ def main(): logger.info(f"Testing {audio_path}...") # Benchmark PyTorch without torch.compile - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "hf-pt", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + hf_decoder_input_ids_cmd - logger.info("Benchmark PyTorch without torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch", audio_file, duration) - all_results.extend(results) + if args.hf_pt_eager: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-eager", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch without torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-eager", audio_file, duration) + all_results.extend(results) # Benchmark PyTorch with torch.compile - benchmark_cmd = [ # noqa: RUF005 - "python3", - "-m", - "models.whisper.benchmark", - "--audio-path", - audio_path, - "--benchmark-type", - "hf-pt2", - "--model-name", - args.model_name, - "--precision", - args.precision, - "--device", - args.device, - "--device-id", - str(args.device_id), - "--warmup-runs", - str(args.warmup_runs), - "--num-runs", - str(args.num_runs), - "--log-folder", - args.log_folder, - ] + hf_decoder_input_ids_cmd - logger.info("Benchmark PyTorch with torch.compile") - results = benchmark(args, benchmark_cmd, "pytorch-2", audio_file, duration) - all_results.extend(results) + if args.hf_pt_compile: + benchmark_cmd = [ # noqa: RUF005 + "python", + "-m", + "models.whisper.benchmark", + "--audio-path", + audio_path, + "--benchmark-type", + "hf-pt-compile", + "--model-name", + args.model_name, + "--precision", + args.precision, + "--device", + args.device, + "--device-id", + str(args.device_id), + "--warmup-runs", + str(args.warmup_runs), + "--num-runs", + str(args.num_runs), + "--log-folder", + args.log_folder, + ] + hf_decoder_input_ids_cmd + logger.info("Benchmark PyTorch with torch.compile") + results = benchmark(args, benchmark_cmd, "pytorch-compile", audio_file, duration) + all_results.extend(results) # Benchmark Optimum + ONNX Runtime - if args.hf_ort_model_path: + if args.hf_ort_dir_path: benchmark_cmd = [ # noqa: RUF005 - "python3", + "python", "-m", "models.whisper.benchmark", "--audio-path", audio_path, "--benchmark-type", "hf-ort", - "--hf-ort-model-path", - args.hf_ort_model_path, + "--hf-ort-dir-path", + args.hf_ort_dir_path, "--model-name", args.model_name, "--precision", @@ -393,14 +409,14 @@ def main(): args.log_folder, ] + hf_decoder_input_ids_cmd logger.info("Benchmark Optimum + ONNX Runtime") - results = benchmark(args, benchmark_cmd, "pytorch-ort", audio_file, duration) + results = benchmark(args, benchmark_cmd, "optimum-ort", audio_file, duration) all_results.extend(results) # Benchmark ONNX Runtime if args.ort_model_path: benchmark_cmd = ( [ # noqa: RUF005 - "python3", + "python", "-m", "models.whisper.benchmark", "--audio-path", diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 995f8c6541b4c..7a69922e67072 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -22,7 +22,9 @@ from fusion_qordered_layernorm import FusionQOrderedLayerNormalization from fusion_qordered_matmul import FusionQOrderedMatMul from fusion_reshape import FusionReshape +from fusion_rotary_attention import FusionRotaryEmbeddings from fusion_shape import FusionShape +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization from fusion_utils import FusionUtils from onnx import GraphProto, ModelProto, TensorProto, ValueInfoProto, helper @@ -106,10 +108,36 @@ def fuse_layer_norm(self): fusion = FusionQOrderedLayerNormalization(self) fusion.apply() + def fuse_simplified_layer_norm(self): + fusion = FusionSimplifiedLayerNormalization(self) + fusion.apply() + def fuse_skip_layer_norm(self): fusion = FusionSkipLayerNormalization(self) fusion.apply() + def fuse_skip_simplified_layer_norm(self): + fusion = FusionSkipSimplifiedLayerNormalization(self) + fusion.apply() + + def fuse_rotary_embeddings(self): + fusion = FusionRotaryEmbeddings(self) + fusion.apply() + # Remove non-MS domain functions + rot_emb_nodes = list( + filter( + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + ) + ) + non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) + i = 0 + while i < len(self.model.functions): + fn = self.model.functions[i] + if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep: + self.model.functions.remove(fn) + else: + i += 1 + # Only relevant in models with Q-DQ nodes def fuse_qordered_mamtul(self): fusion = FusionQOrderedMatMul(self) @@ -367,6 +395,7 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + self.fuse_simplified_layer_norm() if (options is None) or options.enable_gelu: self.fuse_gelu() @@ -377,6 +406,10 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + self.fuse_skip_simplified_layer_norm() + + if (options is None) or options.enable_rotary_embeddings: + self.fuse_rotary_embeddings() if options is not None: self.attention_mask.set_mask_format(options.attention_mask_format) @@ -442,14 +475,17 @@ def get_fused_operator_statistics(self): "BiasGelu", "GemmFastGelu", "LayerNormalization", + "SimplifiedLayerNormalization", "SkipLayerNormalization", + "SkipSimplifiedLayerNormalization", + "RotaryEmbedding", ] q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) - logger.info(f"Optimized operators:{op_count}") + logger.info(f"Optimized operators: {op_count}") return op_count def is_fully_optimized(self): @@ -461,11 +497,20 @@ def is_fully_optimized(self): attention = op_count["Attention"] + op_count["MultiHeadAttention"] + op_count["QOrderedAttention"] gelu = op_count["Gelu"] + op_count["BiasGelu"] + op_count["FastGelu"] layer_norm = op_count["LayerNormalization"] + op_count["SkipLayerNormalization"] - is_perfect = (embed > 0) and (attention > 0) and (attention == gelu) and (layer_norm >= 2 * attention) + simple_layer_norm = op_count["SimplifiedLayerNormalization"] + op_count["SkipSimplifiedLayerNormalization"] + is_perfect = ( + (embed > 0) + and (attention > 0) + and (attention == gelu) + and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention)) + ) if layer_norm == 0: logger.debug("Layer Normalization not fused") + if simple_layer_norm == 0: + logger.debug("Simple Layer Normalization not fused") + if gelu == 0: logger.debug("Gelu/FastGelu not fused") diff --git a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py index 263857ffbc130..6545bb08cdd5e 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_gpt2.py +++ b/onnxruntime/python/tools/transformers/onnx_model_gpt2.py @@ -8,6 +8,7 @@ from fusion_gpt_attention import FusionGptAttention from fusion_gpt_attention_megatron import FusionGptAttentionMegatron from fusion_gpt_attention_no_past import FusionGptAttentionNoPast +from fusion_rotary_attention import FusionRotaryAttention from onnx_model_bert import BertOnnxModel logger = logging.getLogger(__name__) @@ -27,6 +28,9 @@ def fuse_attention(self): fusion = FusionGptAttentionMegatron(self, self.num_heads) fusion.apply() + fusion = FusionRotaryAttention(self, self.hidden_size, self.num_heads) + fusion.apply() + def postprocess(self): """ Remove extra reshape nodes. @@ -94,4 +98,4 @@ def postprocess(self): reshape_count += 2 self.prune_graph() - logger.info(f"postprocess: remove Reshape count:{reshape_count}") + logger.info(f"postprocess: remove Reshape count: {reshape_count}") diff --git a/onnxruntime/python/tools/transformers/onnx_model_t5.py b/onnxruntime/python/tools/transformers/onnx_model_t5.py index e9f98e956b760..95f40af3fd746 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_t5.py +++ b/onnxruntime/python/tools/transformers/onnx_model_t5.py @@ -3,12 +3,12 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- import logging -from typing import Dict, Optional, Union +from typing import Optional, Union import numpy as np from fusion_attention import AttentionMask, FusionAttention from fusion_base import Fusion -from fusion_skiplayernorm import FusionSkipLayerNormalization +from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization from fusion_utils import NumpyHelper from onnx import NodeProto, TensorProto, helper from onnx_model import OnnxModel @@ -56,8 +56,8 @@ def create_attention_node( Args: mask_index (str): mask input q_matmul (NodeProto): MatMul node in fully connection for Q - k_matmul (NodeProto): MatMul node in fully connection for K - v_matmul (NodeProto): MatMul node in fully connection for V + k_matmul (NodeProto): MatMul node in fully connection for K + v_matmul (NodeProto): MatMul node in fully connection for V num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning. hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning. input (str): input name @@ -687,67 +687,6 @@ def fuse(self, node, input_name_to_nodes, output_name_to_node): self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name -class FusionSimplifiedLayerNormalization(Fusion): - def __init__(self, model: OnnxModel): - super().__init__(model, "SimplifiedLayerNormalization", "Mul") - - def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict): - if node.op_type != "Mul": - return - - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Add"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - sim_ln_nodes = self.model.match_parent_path( - node, - ["Mul", "Div", "Sqrt", "Add", "ReduceMean", "Pow", "Gather"], - [1, 1, 1, 0, 0, 0, 0], - ) - if sim_ln_nodes is None: - return - - pow_node = sim_ln_nodes[-2] - if self.model.find_constant_input(pow_node, 2.0) != 1: - return - - root_input = pow_node.input[0] - - mul_node_1 = sim_ln_nodes[0] - if root_input != mul_node_1.input[0]: - return - - second_add_node = sim_ln_nodes[3] - i, add_weight = self.model.get_constant_input(second_add_node) - if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4: - logger.warning(f"epsilon value is not expeced: {add_weight}") - return - - self.nodes_to_remove.extend(sim_ln_nodes[:-1]) - - normalize_node = helper.make_node( - "SimplifiedLayerNormalization", - inputs=[root_input, node.input[0]], - outputs=[node.output[0]], - name=self.model.create_node_name("SimplifiedLayerNormalization", name_prefix="LayerNorm"), - ) - normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))]) - normalize_node.attribute.extend([helper.make_attribute("axis", int(-1))]) - normalize_node.attribute.extend([helper.make_attribute("stash_type", 1)]) - self.nodes_to_add.append(normalize_node) - self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name - - -class FusionSkipSimplifiedLayerNormalization(FusionSkipLayerNormalization): - def __init__(self, model: OnnxModel): - super().__init__(model, "SkipSimplifiedLayerNormalization", "SimplifiedLayerNormalization") - - def fuse(self, node, input_name_to_nodes, output_name_to_node): - super().fuse(node, input_name_to_nodes, output_name_to_node) - - class T5OnnxModel(BertOnnxModel): def __init__(self, model, num_heads, hidden_size): super().__init__(model, num_heads, hidden_size) diff --git a/onnxruntime/python/tools/transformers/optimizer.py b/onnxruntime/python/tools/transformers/optimizer.py index 5ded027b36f74..00b26c019d4b5 100644 --- a/onnxruntime/python/tools/transformers/optimizer.py +++ b/onnxruntime/python/tools/transformers/optimizer.py @@ -103,7 +103,7 @@ def optimize_by_onnxruntime( logger.error("There is no gpu for onnxruntime to do optimization.") return onnx_model_path - model = OnnxModel(load_model(onnx_model_path, format=None, load_external_data=False)) + model = OnnxModel(load_model(onnx_model_path, load_external_data=False)) if model.use_float16() and not use_gpu: logger.warning( "This model uses float16 in the graph, use_gpu=False might cause extra Cast nodes. " @@ -546,7 +546,7 @@ def main(): if args.input_int32: optimizer.change_graph_inputs_to_int32() - if args.model_type in ["bert", "gpt2"]: + if args.model_type in set(MODEL_TYPES.keys()): if optimizer.is_fully_optimized(): logger.info("The model has been fully optimized.") else: diff --git a/onnxruntime/python/tools/transformers/shape_infer_helper.py b/onnxruntime/python/tools/transformers/shape_infer_helper.py index f8a5464d8af78..f1fc0c952e8e4 100644 --- a/onnxruntime/python/tools/transformers/shape_infer_helper.py +++ b/onnxruntime/python/tools/transformers/shape_infer_helper.py @@ -28,12 +28,12 @@ def __init__(self, model, verbose=0, int_max=2**31 - 1, auto_merge=True, guess_o self.is_inferred_: bool = False self.dynamic_axis_mapping_: Dict[str, int] = {} - def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 128): + def infer(self, dynamic_axis_mapping: Dict[str, int], max_runs: int = 200): """Run shape inference, and try replace dynamic axis from string to integer when mapping is provided. Args: dynamic_axis_mapping (_type_): a dictionary with name of dynamic axis as key, like {"batch_size" : 4} - max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 32. + max_runs (int, optional): limit maximum number of runs to avoid infinite loop. Defaults to 200. Returns: bool: whether all shapes has been inferred or not. diff --git a/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc new file mode 100644 index 0000000000000..e739b17d5885f --- /dev/null +++ b/onnxruntime/test/contrib_ops/matmul_bnb4_test.cc @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#ifndef ORT_MINIMAL_BUILD + +#include "core/common/span_utils.h" +#include "core/framework/tensor.h" +#include "core/mlas/inc/mlas_q4.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/inference_session.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/framework/test_utils.h" +#include "test/optimizer/graph_transform_test_builder.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" +#include "core/util/qmath.h" +#include "contrib_ops/cpu/quantization/dequantize_blockwise_bnb4.h" + +#include +#include + +#include "gtest/gtest.h" +#include "gmock/gmock.h" + +namespace onnxruntime { +namespace test { + +void QuantizeDequantizeBnb4(std::vector& raw_vals, // N X K + std::vector& quant_vals, + std::vector& absmax, + int32_t quant_type, + int32_t N, + int32_t K, + int32_t block_size) { + OrtThreadPoolParams to; + auto tp = concurrency::CreateThreadPool(&onnxruntime::Env::Default(), to, + concurrency::ThreadPoolType::INTRA_OP); + + contrib::QuantizeBlockwiseBnb4( + quant_vals.data(), + raw_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); + + contrib::DequantizeBlockwiseBnb4( + raw_vals.data(), + quant_vals.data(), + absmax.data(), + block_size, + quant_type, + N, + K, + tp.get()); +} + +void RunTest(int64_t quant_type, int64_t M, int64_t N, int64_t K, int64_t block_size, bool use_float16) { + RandomValueGenerator random{1234}; + std::vector input0_vals(random.Gaussian(std::vector({M, K}), 0.0f, 0.25f)); + // quantizer expects transposed weights, N X K + std::vector input1_f_vals(random.Gaussian(std::vector({N, K}), 0.0f, 0.25f)); + + int64_t numel = N * K; + int64_t quantized_numel = (numel + 1) / 2; + int64_t total_block_count = (numel + block_size - 1) / block_size; + std::vector input1_vals(quantized_numel); + std::vector absmax(total_block_count); + + QuantizeDequantizeBnb4(input1_f_vals, + input1_vals, + absmax, + static_cast(quant_type), + static_cast(N), + static_cast(K), + static_cast(block_size)); + + std::vector expected_vals(M * N); + for (int64_t m = 0; m < M; m++) { + for (int64_t n = 0; n < N; n++) { + float sum = 0.0f; + for (int64_t k = 0; k < K; k++) { + sum += input0_vals[m * K + k] * input1_f_vals[n * K + k]; + } + expected_vals[m * N + n] = sum; + } + } + + OpTester test("MatMulBnb4", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("quant_type", quant_type); + if (use_float16) { + test.AddInput("A", {M, K}, ToFloat16(input0_vals), false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, ToFloat16(absmax), true); + + test.AddOutput("Y", {M, N}, ToFloat16(expected_vals)); + test.SetOutputAbsErr("Y", 0.02f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } else { + test.AddInput("A", {M, K}, input0_vals, false); + test.AddInput("B", {quantized_numel}, input1_vals, true); + test.AddInput("absmax", {total_block_count}, absmax, true); + + test.AddOutput("Y", {M, N}, expected_vals); + + test.Run(); + } +} + +TEST(MatMulBnb4, Float32) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, false); + } + } + } + } + } +} + +#if defined(USE_CUDA) +TEST(MatMulBnb4, Float16) { + for (auto qt : {0, 1}) { + for (auto M : {1, 2, 100}) { + for (auto N : {1, 2, 32, 288}) { + for (auto K : {16, 32, 64, 128, 256, 1024, 93, 1234}) { + for (auto block_size : {16, 32, 64, 128}) { + RunTest(qt, M, N, K, block_size, true); + } + } + } + } + } +} + +#endif +} // namespace test +} // namespace onnxruntime + +#endif // ORT_MINIMAL_BUILD diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc new file mode 100644 index 0000000000000..29d8219c162a5 --- /dev/null +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -0,0 +1,632 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "gtest/gtest.h" +#include "core/session/onnxruntime_cxx_api.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +static void RunTest( + const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size, + int num_heads, + int max_sequence_length, + int64_t interleaved, + bool use_float16, + bool disable_cpu, + bool disable_cuda) { + // input : (batch_size, sequence_length, hidden_size) + // position ids : (1) or (batch_size, sequence_length) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) + // interleaved : 0 = false, 1 = true + + int hidden_size = num_heads * head_size; + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector pos_dims; + std::vector cache_dims = {max_sequence_length, head_size / 2}; + + assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); + assert(max_sequence_length >= sequence_length); + if (position_ids.size() == 1) { + pos_dims = {1}; + } else { + pos_dims = {batch_size, sequence_length}; + } + + std::string op_type = "RotaryEmbedding"; + std::vector> execution_providers; + + int min_cuda_architecture = use_float16 ? 530 : 0; + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + if (enable_cuda && !disable_cuda) { + execution_providers.push_back(DefaultCudaExecutionProvider()); + } + if (!use_float16 && !disable_cpu) { + execution_providers.push_back(DefaultCpuExecutionProvider()); + } + if (execution_providers.size() == 0) { + // Return early if CI pipeline does not support EP (e.g. CUDA EP for CPU CI pipeline) + return; + } + + OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", interleaved); + + if (!use_float16) { + test.AddInput("input", input_dims, input_data); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, cos_cache); + test.AddInput("sin_cache", cache_dims, sin_cache); + test.AddOutput("output", input_dims, output_data); + } else { + test.AddInput("input", input_dims, ToFloat16(input_data)); + test.AddInput("position_ids", pos_dims, position_ids); + test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); + test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); + test.AddOutput("output", input_dims, ToFloat16(output_data)); + } + test.SetOutputAbsErr("output", 0.002f); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); +} + +static void RunTests(const std::vector& input_data, + const std::vector& position_ids, + const std::vector& cos_cache, + const std::vector& sin_cache, + const std::vector& output_data, + int batch_size, + int sequence_length, + int head_size = 0, + int num_heads = 0, + int max_sequence_length = 0, + int64_t interleaved = 0, + bool use_float16 = true) { + // FP32 test for CPU + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + true /* disable_cuda */); + + // FP32 test for CUDA + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + false, /* use_fp16 */ + false, /* disable_cpu */ + false /* disable_cuda */); + + // FP16 test for CUDA + if (use_float16) { + RunTest(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved, + true, /* use_fp16 */ + true, /* disable_cpu */ + false /* disable_cuda*/); + } +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 3; + int num_heads = 2; + int head_size = 4; + int max_sequence_length = 8; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.2188f, 1.1676f, -1.0574f, -0.1188f, -0.7396f, -1.2425f, -0.1752f, 0.6990f, + -0.8110f, 0.6737f, -1.1233f, -0.0919f, -0.6861f, 0.7202f, 0.1963f, 0.6142f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 0.5403f, 0.9999f, -0.4161f, 0.9998f, -0.9900f, 0.9996f, + -0.6536f, 0.9992f, 0.2837f, 0.9988f, 0.9602f, 0.9982f, 0.7539f, 0.9976f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.8415f, 0.0100f, 0.9093f, 0.0200f, 0.1411f, 0.0300f, + -0.7568f, 0.0400f, -0.9589f, 0.0500f, -0.2794f, 0.0600f, 0.6570f, 0.0699f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -0.1320f, -0.2751f, -0.2350f, 0.0937f, + -1.6411f, -0.3948f, -1.0561f, -0.1294f, 0.6460f, -1.2937f, -0.1822f, 0.6972f, + -0.2751f, -1.0178f, -1.1212f, -0.1143f, -0.3694f, -0.9235f, 0.1840f, 0.6180f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = true, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 1; // true + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.4713f, + -0.9540f, -0.9229f, 0.3027f, -0.5708f, -0.2363f, + -1.2713f, 0.1137f, 0.8112f, -1.1659f, -0.5824f, + -0.4419f, -0.7649f, 0.7011f, -0.4569f, -0.5639f, + -0.5328f, -0.6424f, 1.0979f, 0.8773f, 0.5462f, + 0.0793f, 0.2582f, 0.8576f, 0.2653f, 1.2295f, + -0.1839f, -0.4517f, -1.5052f, -0.4651f, 0.1155f, + -2.1237f, -0.7586f, -0.2110f, 1.1441f, -0.6304f, + 0.4186f, 0.2303f, -0.1519f, 1.1903f, 0.5382f, + -0.1906f, -1.0080f, 2.3112f, -0.2220f, -0.9655f, + -0.0099f, 1.5198f, 0.7652f, -0.6410f, 0.0365f, + -0.0452f, 1.0593f, 0.8929f, 1.4856f, 0.0038f, + -1.0865f, 1.4794f, -0.2417f, 0.9428f, -0.6894f, + -0.6293f, 0.2904f, 1.5747f, -0.4956f, 0.9199f, + -0.2424f, 0.1801f, 0.7503f, -1.4576f, 0.6529f, + -1.1340f, -0.6807f, -0.0252f, -0.3834f, 2.7394f, + 0.1308f, 1.1203f, -2.1196f, -0.9618f, 0.1970f, + -0.0972f, -0.2764f, 0.3332f, -0.4522f, 1.1844f, + 0.3867f, -0.6626f, -0.9405f, 1.8656f, 0.5053f, + -1.2361f, 1.2072f, 0.1789f, -1.1002f, 1.0129f, + 1.7702f, 0.1949f, -1.1653f, 1.6049f, -0.2755f, + -0.2749f, 2.1087f, 0.4272f, 0.8076f, 0.2900f, + -0.0714f, 0.8261f, -1.1016f, -1.3814f, -0.1366f, + 0.2981f, 0.6060f, -1.4132f, 0.0893f, -0.1939f, + 0.2779f, 0.3910f, -0.8906f, -0.6489f, -1.2496f, + 0.3383f, -0.0315f, -0.7461f, 1.1510f, 0.4445f, + 0.3203f, -0.9031f, 0.2727f, 0.2609f, 2.0968f, + 1.0974f, 0.7120f, -0.5164f, 0.7415f, -0.0031f, + -0.1568f, 0.1533f, 0.5487f, -0.3357f, -0.9064f, + 1.0546f, 0.0542f, 1.1870f, -0.4045f, -1.3431f, + -0.6094f, -1.1105f, -0.9631f, -0.1137f, -0.7219f, + 0.8582f, -1.3443f, -0.6684f, -1.0227f, -1.5929f, + -0.2622f, 0.2264f, 0.0713f, 0.1843f, -1.3387f, + -1.6797f, 2.3165f, 0.1009f, 0.1081f, -0.9969f, + -1.4488f, 0.6291f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, 0.5985f, -1.0968f, 1.5662f, 1.4693f, + 0.8776f, 0.3408f, 0.4345f, 1.2549f, 0.6631f, + 1.4543f, 0.3374f, 0.0445f, 1.2320f, 1.4311f, + -2.0483f, -0.7272f, 0.4114f, -1.1449f, 1.6283f, + -0.9524f, -1.6435f, 0.5422f, 0.9907f, -0.0708f, + 0.3972f, 0.7376f, -1.5947f, 1.6138f, -0.9586f, + -0.4600f, 0.3993f, -1.5884f, 1.2934f, -1.4467f, + 1.2833f, -1.2459f, -0.7760f, 0.3108f, -3.3677f, + -0.0287f, 0.6942f, -0.7601f, -0.6993f, 2.3690f, + 1.3834f, -0.5234f, 0.3435f, 1.0053f, 0.1604f, + -0.9560f, -1.2641f, 0.2406f, 0.4973f, 0.9206f, + -1.9987f, -1.1733f, -0.4197f, -0.0366f, -0.6720f, + -1.3350f, -1.5960f, -0.1097f, 0.6386f, 0.5624f, + -0.6184f, 0.0778f, 0.1867f, 0.9643f, -1.3629f, + -0.0972f, -1.7907f, -0.3037f, 0.8245f, -0.0789f, + -0.2940f, -0.2833f, -0.2165f, 0.6264f, -1.1726f, + 0.7926f, 1.3621f, 1.3586f, -0.9007f, -0.8138f, + -2.7421f, 1.3155f, 2.4507f, 0.0507f, 0.6305f, + 1.6900f, 0.5210f, -0.3309f, 2.0630f, 1.8026f, + -0.7859f, -0.6802f, -1.1003f, -0.1990f, -0.5391f, + -0.9370f, 0.0857f, -2.3330f, -2.0112f, 0.7193f, + -0.1272f, -0.9981f, -0.1818f, 0.3973f, -0.9963f, + 1.4929f, -1.0109f, 0.4304f, 1.0160f, -1.4590f, + 0.2682f, 1.5658f, 0.1762f, 0.3038f, -0.7491f, + 0.3052f, -1.1534f, -0.0478f, 0.0021f, -0.0665f, + -0.8118f, 0.1310f, 0.2171f, 0.5485f, -0.1610f, + -1.5784f, -0.8660f, 0.7289f, -0.4678f, 0.1937f, + 1.1287f, -0.5772f, -0.0259f, -0.2212f, 0.2479f, + 0.6336f, 0.6407f, -0.6543f, 0.3838f, 0.9039f, + 0.4724f, 0.7117f, 1.0165f, 1.0270f, 1.1908f, + 1.3750f, -0.0850f, 0.5517f, -1.3842f, 0.3703f, + -0.8806f, 0.9336f, 0.8362f, 0.8105f, -1.1566f, + -0.6813f, 0.0294f, -0.1122f, 0.5620f, -0.2884f, + -2.0803f, 0.4684f, 0.6009f, -1.4160f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (1) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { + int batch_size = 2; + int sequence_length = 8; + int num_heads = 4; + int head_size = 6; + int max_sequence_length = 16; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -1.0574f, + -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.5912f, 1.1312f, 0.7562f, -1.2023f, -0.5833f, + -0.4407f, 0.1766f, 1.0224f, -0.4826f, -0.5421f, + -0.5342f, -0.6413f, 1.3314f, -0.4498f, 0.5493f, + 0.0539f, 0.2601f, 0.8570f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -1.9791f, + 0.7787f, -0.7749f, -0.1398f, 1.1414f, -0.6354f, + 0.0352f, -0.4765f, -0.0409f, 1.1993f, 0.5374f, + -0.1930f, 2.5211f, -0.0452f, -0.3105f, -0.9407f, + -0.0034f, 1.5199f, -0.8480f, 0.5266f, 0.0299f, + -0.0498f, 1.0651f, 0.8860f, -1.4702f, -0.2134f, + -0.8707f, 1.6159f, -0.2356f, 0.9444f, 0.5937f, + 0.7203f, 0.5061f, 1.5192f, -0.4897f, 0.9231f, + 0.2654f, -0.1441f, 0.5407f, -1.5476f, 0.6455f, + -1.1382f, 0.4640f, -0.4986f, 0.1289f, 2.7631f, + 0.1405f, 1.1191f, 2.1134f, -0.9754f, 0.1757f, + -0.1319f, -0.2735f, 0.3355f, -0.6008f, -1.1164f, + 0.2577f, -0.7226f, -0.9244f, 1.8737f, 0.6052f, + 1.1904f, 1.2195f, -0.0470f, -1.0914f, 1.0223f, + 0.3152f, 1.7528f, -0.7650f, 1.8299f, -0.2784f, + -0.2719f, 0.1885f, 2.1432f, 0.8527f, 0.0965f, + -0.0625f, 0.8269f, 1.0122f, -1.4482f, -0.0644f, + 0.3215f, 0.5908f, -1.4197f, 0.2113f, 0.0306f, + 0.3604f, 0.3166f, -0.8975f, -0.6393f, -1.2944f, + -0.0243f, -0.2354f, -0.7087f, 1.1566f, 0.4296f, + 0.5599f, -0.7776f, 0.3339f, 0.1759f, 2.1108f, + 1.0702f, 0.8279f, -0.2969f, 0.7120f, -0.2068f, + -0.1548f, 0.1553f, 0.6207f, -0.1690f, -0.5816f, + 1.2632f, 0.0695f, 1.1862f, -1.1874f, -0.7468f, + -0.9320f, -0.8579f, -0.9647f, -0.0991f, 0.0195f, + 1.1213f, -1.4873f, -0.2043f, -1.0466f, -1.5772f, + -0.0489f, 0.3430f, 0.1264f, 0.1519f, -1.3639f, + -1.6593f, 1.8127f, -1.4459f, -0.2158f, -0.9792f, + -1.4392f, 0.6508f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -0.5996f, -1.0962f, 1.6327f, 1.3951f, + 0.8784f, 0.3389f, 1.2907f, 0.3124f, 0.7299f, + 1.4220f, 0.3375f, 0.0438f, 1.8698f, -0.2635f, + -2.0799f, -0.6313f, 0.4090f, -1.1458f, 0.0784f, + -1.8848f, -1.6165f, 0.6179f, 0.9905f, -0.0729f, + 0.5054f, -0.6681f, -1.4382f, 1.7547f, -0.9605f, + -0.4558f, -1.6105f, 0.2979f, 1.1537f, -1.5604f, + 1.2779f, -1.2514f, 0.6056f, 0.5763f, -3.3558f, + 0.2836f, 0.6909f, -0.7631f, 2.4451f, -0.3500f, + 1.3289f, -0.6494f, 0.3478f, 1.0038f, -0.2937f, + 0.9238f, -1.2185f, 0.4138f, 0.5033f, 0.9174f, + 1.8131f, 1.4436f, -0.4207f, 0.0220f, -0.6807f, + -1.3306f, 1.5646f, 0.3338f, 0.7105f, 0.4683f, + -0.6179f, 0.0818f, -0.0488f, -0.9810f, -1.3632f, + 0.0929f, -1.7926f, -0.2921f, -0.4792f, 0.6756f, + -0.3413f, -0.2242f, -0.2111f, 0.6282f, 0.1667f, + -1.4055f, 1.5895f, 1.0838f, -0.9077f, -0.8060f, + 0.7967f, -2.9351f, 2.4179f, -0.4026f, 0.6451f, + 1.6845f, -0.0901f, 0.6106f, 2.3603f, 1.3908f, + -0.7917f, -0.6734f, -0.1213f, -1.1116f, -0.7401f, + -0.7879f, 0.0606f, -2.3337f, -1.2603f, -1.7245f, + -0.3533f, -0.9421f, -0.1776f, 0.3992f, -1.7142f, + -0.5319f, -0.8848f, 0.6513f, 1.0002f, -1.4699f, + -1.4254f, 0.7013f, 0.2414f, 0.2551f, -0.7457f, + 0.3133f, -1.0941f, -0.3682f, -0.0163f, -0.0645f, + -0.8101f, 0.1415f, 0.0551f, 0.5873f, -0.5887f, + -1.4733f, -0.8565f, 0.7400f, -0.5033f, 0.0553f, + 0.9265f, -0.8652f, -0.0288f, -0.2209f, 0.0610f, + 0.6776f, 0.4361f, -0.8052f, 0.3955f, 0.8988f, + 0.8238f, 0.2262f, 1.2912f, 0.6488f, 1.2114f, + 1.3569f, 0.2983f, 0.4718f, -1.1936f, 0.7928f, + -0.8665f, 0.9468f, 1.1629f, 0.0616f, -1.3136f, + -0.2764f, 0.0277f, -0.1126f, 0.2342f, -0.5866f, + -1.8219f, 1.1079f, 0.5795f, -1.4249f}; + + std::vector position_ids = {0}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f, -0.6536f, 0.9828f, 1.0000f, 0.2837f, + 0.9732f, 0.9999f, 0.9602f, 0.9615f, 0.9999f, 0.7539f, 0.9477f, 0.9999f, + -0.1455f, 0.9318f, 0.9999f, -0.9111f, 0.9140f, 0.9998f, -0.8391f, 0.8942f, + 0.9998f, 0.0044f, 0.8725f, 0.9997f, 0.8439f, 0.8488f, 0.9997f, 0.9074f, + 0.8234f, 0.9996f, 0.1367f, 0.7962f, 0.9995f, -0.7597f, 0.7673f, 0.9995f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, + 0.0043f, 0.1411f, 0.1388f, 0.0065f, -0.7568f, 0.1846f, 0.0086f, -0.9589f, + 0.2300f, 0.0108f, -0.2794f, 0.2749f, 0.0129f, 0.6570f, 0.3192f, 0.0151f, + 0.9894f, 0.3629f, 0.0172f, 0.4121f, 0.4057f, 0.0194f, -0.5440f, 0.4477f, + 0.0215f, -1.0000f, 0.4887f, 0.0237f, -0.5366f, 0.5286f, 0.0259f, 0.4202f, + 0.5675f, 0.0280f, 0.9906f, 0.6050f, 0.0302f, 0.6503f, 0.6413f, 0.0323f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, + 1.1676f, -1.0190f, 0.3157f, -1.6036f, 1.8493f, + 0.0447f, 1.5853f, 0.1036f, -0.3514f, 0.2421f, + 0.6463f, 0.8730f, -0.9276f, 1.0311f, -1.9557f, + -0.1482f, 1.7376f, 2.2039f, -0.6589f, -0.8618f, + -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + 0.6923f, 1.1571f, 0.7572f, -1.1471f, -0.5302f, + -0.4391f, 0.5516f, 1.0461f, -0.4812f, -0.1443f, + -0.4862f, -0.6423f, 0.6740f, -0.4614f, 0.5475f, + 1.1495f, 0.2389f, 0.8582f, -0.0259f, -0.6099f, + -0.2230f, 1.0963f, -1.5704f, -0.4595f, 0.9507f, + 0.6696f, -0.7721f, -1.7415f, 1.2087f, -0.6387f, + -1.1052f, -0.5243f, -0.0400f, -0.4671f, 0.4909f, + -0.1931f, -0.1937f, -0.0447f, -0.3171f, 2.6839f, + -0.0076f, 1.5185f, 0.8465f, 0.3737f, 0.0242f, + -0.0703f, 1.1279f, 0.8862f, 1.2275f, -0.1786f, + -0.8767f, -1.8072f, -0.2630f, 0.9387f, -0.8021f, + 0.7813f, 0.5001f, -1.4202f, -0.3850f, 0.9263f, + -0.0443f, -0.2323f, 0.5480f, 1.5696f, 0.6193f, + -1.1346f, 1.7878f, -0.5160f, 0.1192f, -2.1572f, + 0.0460f, 1.1202f, -1.4812f, -0.9082f, 0.1728f, + -1.5132f, -0.4489f, 0.3370f, -0.1541f, -0.9266f, + 0.2416f, 0.9270f, -1.1146f, 1.8758f, -0.4312f, + 1.3714f, 1.2106f, -0.4272f, -0.8529f, 1.0328f, + 1.8441f, 1.7698f, -0.7620f, 0.2168f, 0.1322f, + -0.2802f, 0.1460f, 2.1002f, 0.8437f, -0.1534f, + 0.4321f, 0.8360f, 0.5955f, -1.5452f, -0.0491f, + -0.8794f, 0.2418f, -1.4203f, 0.3635f, 0.2362f, + 0.3672f, -0.1128f, -0.8664f, -0.6354f, -1.4409f, + -0.3413f, -0.2409f, -0.3188f, 1.1054f, 0.4265f, + 0.5867f, -1.3279f, 0.3201f, 0.0125f, 1.8157f, + 1.0745f, 0.7372f, -0.2429f, 0.7100f, -0.4299f, + -0.2304f, 0.1645f, 0.9489f, -0.1816f, -0.5968f, + 1.0394f, 0.0204f, 1.1786f, -0.3315f, -0.3997f, + -0.9304f, -1.4268f, -1.1526f, -0.1132f, 0.1490f, + 1.3967f, -1.4634f, -0.1412f, -0.6339f, -1.5995f, + -0.1366f, 0.7604f, 0.1514f, 0.0824f, -1.1830f, + -1.6572f, 2.0099f, -0.9108f, -0.2256f, 0.4527f, + -1.8254f, 0.6475f, 0.8964f, 0.5717f, -0.2390f, + 0.6983f, -1.3416f, 0.2715f, -0.2852f, 0.6051f, + 0.2167f, -0.2181f, -1.6306f, 1.4788f, 0.2754f, + -0.0261f, -0.4618f, -0.5646f, -1.0389f, 0.5819f, + 1.3697f, 0.0002f, 1.5333f, -1.0556f, -0.1254f, + 0.1527f, -1.4979f, -1.1358f, 1.6320f, 0.2493f, + 0.8266f, 0.3424f, -0.4992f, 0.2964f, 0.7298f, + 1.8544f, 0.3516f, 0.0454f, 1.5415f, -0.2822f, + -2.0774f, 1.2323f, 0.3963f, -1.1503f, -0.4775f, + -1.9287f, -1.6164f, 0.3998f, 0.9020f, -0.0764f, + -1.8059f, -0.5762f, -1.4362f, -0.2706f, -1.0183f, + -0.4620f, 2.0891f, 0.1782f, 1.1591f, -0.8151f, + 1.3000f, -1.2464f, -0.5099f, 0.5098f, -3.3525f, + 0.4326f, 0.7414f, -0.7775f, -0.4271f, -0.3807f, + 1.3245f, 2.4936f, 0.3139f, 1.0095f, 0.2323f, + 0.8450f, -1.2244f, -0.4511f, 0.6266f, 0.9095f, + -1.7981f, 1.5241f, -0.4121f, 0.2341f, -0.4737f, + -1.3333f, -1.6150f, 0.4164f, 0.7100f, -0.2429f, + -0.5656f, 0.0863f, 0.0352f, -0.7227f, -1.3613f, + -0.0988f, -1.9114f, -0.3009f, 0.1435f, 0.7029f, + -0.3467f, 0.5092f, -0.0828f, 0.6253f, 0.7113f, + -1.2138f, 1.5964f, -0.8346f, -1.1515f, -0.7923f, + -0.8254f, -3.0038f, 2.4033f, -0.3398f, 0.0922f, + 1.7053f, 1.1114f, 0.7462f, 2.3660f, -0.8409f, + -0.6654f, -0.6530f, -0.7899f, -1.0957f, -0.7149f, + -0.1072f, -0.1967f, -2.3416f, -1.2609f, -1.6375f, + -0.3576f, 0.9413f, -0.5694f, 0.3954f, 0.1383f, + -0.7477f, -0.8689f, 1.8286f, 0.8510f, -1.4793f, + -0.1597f, 0.8541f, 0.2380f, 1.4392f, -0.5644f, + 0.3158f, -1.0686f, -0.1313f, -0.0181f, 0.2438f, + -0.8801f, 0.1413f, -0.3587f, 0.8002f, -0.5982f, + -1.4301f, -0.6620f, 0.7324f, -0.7250f, 0.0610f, + 0.9293f, -0.6902f, -0.0125f, -0.2089f, -0.1664f, + 0.5428f, 0.4245f, -0.7901f, 0.5665f, 0.9044f, + 0.1948f, -0.1723f, 1.2705f, 1.0303f, 1.2202f, + 1.3762f, -0.2959f, 0.7237f, -1.2077f, 0.7937f, + -0.6705f, 0.9287f, 1.0583f, 0.0496f, -1.3118f, + 0.5556f, 0.0459f, -0.1324f, -0.5513f, -0.7409f, + -1.8002f, 0.9892f, 0.3619f, -1.4522f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +// Interleaved = false, pos ids shape = (batch_size, sequence_length) +TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { + int batch_size = 1; + int sequence_length = 2; + int num_heads = 3; + int head_size = 6; + int max_sequence_length = 4; + int64_t interleaved = 0; // false + + std::vector input_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -1.0574f, -0.1188f, -0.9078f, 0.3452f, -0.5713f, -0.2351f, + -0.8480f, 0.5266f, -1.2944f, -0.0243f, -0.2354f, -0.7087f, -0.9647f, -0.0991f, + -0.2994f, -0.0650f, -1.5720f, -1.3211f}; + + std::vector position_ids = {0, 1}; + + std::vector cos_cache = { + 1.0000f, 1.0000f, 1.0000f, 0.5403f, 0.9989f, 1.0000f, -0.4161f, 0.9957f, + 1.0000f, -0.9900f, 0.9903f, 1.0000f}; + + std::vector sin_cache = { + 0.0000f, 0.0000f, 0.0000f, 0.8415f, 0.0464f, 0.0022f, 0.9093f, 0.0927f, 0.0043f, + 0.1411f, 0.1388f, 0.0065f}; + + std::vector output_data = { + -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, + -0.2250f, -0.4327f, -1.5071f, -0.4586f, -0.8663f, -0.2656f, 0.1665f, 0.7911f, + -0.9320f, -0.8579f, -0.8618f, -0.0922f, -0.9073f, -0.7032f, -0.5762f, -0.2371f, + -0.4377f, 0.5370f, -1.2929f, -0.7267f, -0.2107f, -0.7115f, -0.4666f, -0.0261f, + -0.2965f, -0.8469f, -1.5749f, -1.3217f}; + + RunTests(input_data, + position_ids, + cos_cache, + sin_cache, + output_data, + batch_size, + sequence_length, + head_size, + num_heads, + max_sequence_length, + interleaved); +} + +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/framework/function_test.cc b/onnxruntime/test/framework/function_test.cc index 6e745776ab6b0..41274ee0dedfa 100644 --- a/onnxruntime/test/framework/function_test.cc +++ b/onnxruntime/test/framework/function_test.cc @@ -6,36 +6,47 @@ #include "onnx/defs/parser.h" #include "core/common/span_utils.h" -#include "core/framework/float8.h" +#include "core/framework/customregistry.h" +#include "core/framework/op_kernel.h" #include "core/graph/model.h" #include "core/providers/cpu/cpu_execution_provider.h" #include "core/session/inference_session.h" #include "test/test_environment.h" #include "test/framework/test_utils.h" +#include "inference_session_wrapper.h" #include "test/common/tensor_op_test_utils.h" #include "test/util/include/asserts.h" +#include "test/providers/internal_testing/internal_testing_execution_provider.h" + // Unit tests to check the implementation of functions, model-local functions, // function-inlining etc. namespace onnxruntime { namespace test { -static void Check(const char* source, - const char* input_name, std::vector input_values, - const char* output_name, std::vector output_values) { - // Convert source-representation of model to ModelProto: +// Convert source-representation of model to ModelProto: +static void ParseOnnxSource(const char* source, std::string& result) { ONNX_NAMESPACE::OnnxParser parser(source); ONNX_NAMESPACE::ModelProto model; auto parse_status = parser.Parse(model); ASSERT_TRUE(parse_status.IsOK()) << parse_status.ErrorMessage(); ASSERT_TRUE(parser.EndOfInput()) << "Extra unparsed input unexpected."; - // Serialize and then load model: + // Serialize std::string serialized_model; const bool serialization_status = model.SerializeToString(&serialized_model); ASSERT_TRUE(serialization_status) << "Failed to serialize proto to string"; + result = std::move(serialized_model); +} + +static void Check(const char* source, + const char* input_name, std::vector input_values, + const char* output_name, std::vector output_values) { + // Serialize and then load model: + std::string serialized_model; + ParseOnnxSource(source, serialized_model); SessionOptions session_options; InferenceSession session_object{session_options, GetEnvironment()}; @@ -76,8 +87,8 @@ static void Check(const char* source, } } -TEST(FunctionTest, Basic) { - const char* code = R"( +namespace { +const char* basic_code = R"( < ir_version: 8, opset_import: [ "" : 16, "local" : 1 ] @@ -96,8 +107,10 @@ TEST(FunctionTest, Basic) { ly = Mul (lx, two) } )"; +} - Check(code, "x", {1.0, 2.0, 3.0}, "y", {2.0, 4.0, 6.0}); +TEST(FunctionTest, Basic) { + Check(basic_code, "x", {1.0, 2.0, 3.0}, "y", {2.0, 4.0, 6.0}); } // Check that variables are renamed to avoid conflicts when multiple @@ -521,5 +534,56 @@ TEST(FunctionTest, ConstantFoldingInSubGraph) { Check(code, "X", {1.0, 2.0, 3.0}, "Y", {3.0, 4.0, 5.0, 3.0, 4.0, 5.0, 3.0, 4.0, 5.0}); } +TEST(FunctionTest, TestInlinedLocalFunctionRemoved) { + std::string serialized_model; + ParseOnnxSource(basic_code, serialized_model); + + // Default is to do AOT Function inlining + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + std::stringstream sstr(serialized_model); + auto status = session_object.Load(sstr); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + auto model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); + + status = session_object.Initialize(); + ASSERT_TRUE(status.IsOK()) << status.ErrorMessage(); + + // All functions removed + model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(0, model_proto.functions_size()); +} + +TEST(FunctionTest, TestInlinedLocalFunctionNotRemoved) { + std::string serialized_model; + ParseOnnxSource(basic_code, serialized_model); + + // Default is to do AOT Function inlining + SessionOptions session_options; + InferenceSessionWrapper session_object{session_options, GetEnvironment()}; + + using InternalTestingEP = onnxruntime::internal_testing_ep::InternalTestingExecutionProvider; + const std::unordered_set empty_set; + auto internal_testing_ep = std::make_unique(empty_set, empty_set, DataLayout::NCHW); + internal_testing_ep->EnableStaticKernels().TakeAllNodes(); + + ASSERT_STATUS_OK(session_object.RegisterExecutionProvider(std::move(internal_testing_ep))); + + std::stringstream sstr(serialized_model); + ASSERT_STATUS_OK(session_object.Load(sstr)); + + auto model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); + + ASSERT_STATUS_OK(session_object.Initialize()); + + // myfun is not removed because it was claimed by InternalTestingEP + model_proto = session_object.GetModel().ToProto(); + ASSERT_EQ(1, model_proto.functions_size()); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/onnx/TestCase.cc b/onnxruntime/test/onnx/TestCase.cc index bc88f69fa990f..47c3798721679 100644 --- a/onnxruntime/test/onnx/TestCase.cc +++ b/onnxruntime/test/onnx/TestCase.cc @@ -944,6 +944,20 @@ std::unique_ptr> GetBrokenTests(const std::string& provider {"simple_rnn_batchwise", "type error", {}}, {"mod_float_mixed_sign_example", "fmod attribute must be true for floating point types", {}}, {"col2im_pads", "result mismatch", {"opset18"}}, + {"gridsample_volumetric_nearest_align_corners_0", "result differs", {}}, + {"gridsample_volumetric_nearest_align_corners_1", "result differs", {}}, + {"reduce_l1_empty_set", "unknown version", {}}, + {"reduce_l1_empty_set_expanded", "unknown version", {}}, + {"reduce_l2_empty_set", "unknown version", {}}, + {"reduce_l2_empty_set_expanded", "unknown version", {}}, + {"reduce_log_sum_empty_set", "unknown version", {}}, + {"reduce_log_sum_empty_set_expanded", "unknown version", {}}, + {"reduce_log_sum_exp_empty_set", "unknown version", {}}, + {"reduce_log_sum_exp_empty_set_expanded", "unknown version", {}}, + {"reduce_prod_empty_set", "unknown version", {}}, + {"reduce_sum_empty_set", "unknown version", {}}, + {"reduce_sum_square_empty_set", "unknown version", {}}, + {"reduce_sum_square_empty_set_expanded", "unknown version", {}}, #ifdef ENABLE_TRAINING_CORE {"adagrad", "not a registered function/op", {}}, // Op not registered. {"adagrad_multiple", "not a registered function/op", {}}, // Op not registered. @@ -1339,6 +1353,7 @@ std::unique_ptr> GetBrokenTests(const std::string& provider broken_tests->insert({"gridsample_reflection_padding", "result differs"}); broken_tests->insert({"spacetodepth", "result differs"}); } + #ifdef DISABLE_CONTRIB_OPS broken_tests->insert({"coreml_SqueezeNet_ImageNet", "This model uses contrib ops."}); broken_tests->insert({"keras2coreml_Permute_ImageNet", "This model uses contrib ops."}); diff --git a/onnxruntime/test/optimizer/graph_transform_test.cc b/onnxruntime/test/optimizer/graph_transform_test.cc index 6acf631d53cd9..46b95a127b75c 100755 --- a/onnxruntime/test/optimizer/graph_transform_test.cc +++ b/onnxruntime/test/optimizer/graph_transform_test.cc @@ -31,6 +31,7 @@ #include "core/optimizer/conv_add_act_fusion.h" #include "core/optimizer/conv_add_fusion.h" #include "core/optimizer/conv_bn_fusion.h" +#include "core/optimizer/matmul_bn_fusion.h" #include "core/optimizer/conv_mul_fusion.h" #include "core/optimizer/div_mul_fusion.h" #include "core/optimizer/dropout_elimination.h" @@ -1079,6 +1080,268 @@ TEST_F(GraphTransformationTests, FuseConvBNNoBias) { } } +TEST_F(GraphTransformationTests, FuseMatmulBNWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithEmptyOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } else if (node.OpType() == "BatchNormalization") { + node.MutableOutputDefs().push_back(&graph.GetOrCreateNodeArg("", nullptr)); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithOptionalOutputWithInBetweenNodes) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-with-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + // additional non-empty output to batchNormalization + ONNX_NAMESPACE::TypeProto optional_output_tensor_type; + optional_output_tensor_type.mutable_tensor_type()->set_elem_type(ONNX_NAMESPACE::TypeProto::kTensorType); + auto& arg = graph.GetOrCreateNodeArg("bn_optional_output", &optional_output_tensor_type); + node.MutableOutputDefs().push_back(&arg); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + +TEST_F(GraphTransformationTests, FuseMatmulBNDirectly) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-directly.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the last node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyReshape) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-reshape.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithOnlyTranspose) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + std::string expected_output_name; + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "MatMul") { + expected_output_name = node.OutputDefs()[0]->Name(); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 0); + ASSERT_EQ(op_to_count["MatMul"], 0); + ASSERT_EQ(op_to_count["Gemm"], 1); + + for (auto& node : graph.Nodes()) { + if (node.OpType() == "Gemm") { + ASSERT_EQ(node.OutputDefs()[0]->Name(), expected_output_name) + << "fusion should produce the same output name as the MatMul node"; + } + } +} + +TEST_F(GraphTransformationTests, FuseMatmulBNWithoutBatchNormalization) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-only-transpose.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + GraphViewer graphViewer(graph); + for (auto& node_index : graphViewer.GetNodesInTopologicalOrder()) { + auto& node = *graph.GetNode(node_index); + if (node.OpType() == "BatchNormalization") { + graph_utils::RemoveNode(graph, node); + } + } + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["MatMul"], 1); +} + +// should not fuse +TEST_F(GraphTransformationTests, FuseMatmulBNWithNonIgnorableNode) { + constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-matmul-bn-non-ignorable-node.onnx"; + + std::shared_ptr p_model; + ASSERT_STATUS_OK(Model::Load(model_uri, p_model, nullptr, *logger_)); + Graph& graph = p_model->MainGraph(); + + onnxruntime::GraphTransformerManager graph_transformation_mgr{5}; + auto rule_transformer_L1 = std::make_unique("RuleTransformerL1"); + ASSERT_STATUS_OK(rule_transformer_L1->Register(std::make_unique())); + ASSERT_STATUS_OK(graph_transformation_mgr.Register(std::move(rule_transformer_L1), TransformerLevel::Level1)); + + ASSERT_STATUS_OK(graph_transformation_mgr.ApplyTransformers(graph, TransformerLevel::Level1, *logger_)); + + std::map op_to_count = CountOpsInGraph(graph); + ASSERT_EQ(op_to_count["BatchNormalization"], 1); + ASSERT_EQ(op_to_count["MatMul"], 1); + ASSERT_EQ(op_to_count["Gemm"], 0); +} + TEST_F(GraphTransformationTests, DontFuseConvWithBNWithOptionalOutputs) { constexpr const ORTCHAR_T* model_uri = MODEL_FOLDER "fusion/fuse-conv-bn-no-bias.onnx"; diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc new file mode 100644 index 0000000000000..e37e784f28930 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test.cc @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/util/math.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { +TEST(AffineGridTest, 2d) { + OpTester test("AffineGrid", 20); + test.AddInput("theta", {1, 2, 3}, {1.0f, 0.0, 0.0f, 0.0f, 1.0, 0.0f}); + test.AddInput("size", {4}, {1, 1, 2, 3}); + test.AddOutput("grid", {1, 2, 3, 2}, + {-0.6667f, -0.5000f, 0.0000f, -0.5000f, 0.6667f, -0.5000f, -0.6667f, 0.5000f, 0.0000f, 0.5000f, 0.6667f, 0.5000f}); + test.Run(); +} + +// following tests code is generated with: +// python onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py +TEST(AffineGridTest, test_2d_0) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-0.3228f, -0.9151f, 1.1544f, -0.7414f, -0.4386f, -0.5868f, 1.0386f, -0.4132f, -0.5544f, -0.2586f, 0.9228f, -0.0849f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_1) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f, -0.5980f, -0.8620f, 0.3868f, -0.7462f, 1.3716f, -0.6304f, -0.7716f, -0.3696f, 0.2132f, -0.2538f, 1.1980f, -0.1380f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_2) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-0.6726f, -2.7663f, 0.8274f, -1.9003f, -1.2500f, -0.9330f, 0.2500f, -0.0670f, -1.8274f, 0.9003f, -0.3274f, 1.7663f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_3) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f, -1.0670f, -2.4524f, -0.0670f, -1.8750f, 0.9330f, -1.2976f, -1.9330f, 0.2976f, -0.9330f, 0.8750f, 0.0670f, 1.4524f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_4) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-1.0036f, -1.1661f, 1.9509f, -0.8188f, -1.1772f, -0.6736f, 1.7772f, -0.3264f, -1.3509f, -0.1812f, 1.6036f, 0.1661f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_5) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 2, 3}, {1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f, 1.477212f, -0.173648f, 0.300000f, 0.173648f, 0.492404f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f, -1.0036f, -1.1661f, 0.4736f, -0.9924f, 1.9509f, -0.8188f, -1.3509f, -0.1812f, 0.1264f, -0.0076f, 1.6036f, 0.1661f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_6) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {1, 1, 3, 2}); + test.AddOutput("grid", {1, 3, 2, 2}, {-1.1340f, -4.1160f, 1.8660f, -2.3840f, -2.0000f, -1.3660f, 1.0000f, 0.3660f, -2.8660f, 1.3840f, 0.1340f, 3.1160f}); + test.Run(); +} + +TEST(AffineGridTest, test_2d_7) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 2, 3}, {1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f, 1.500000f, -0.866025f, -0.500000f, 0.866025f, 2.750000f, -0.500000f}); + test.AddInput("size", {4}, {2, 10, 2, 3}); + test.AddOutput("grid", {2, 2, 3, 2}, {-1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f, -1.1340f, -4.1160f, 0.3660f, -3.2500f, 1.8660f, -2.3840f, -2.8660f, 1.3840f, -1.3660f, 2.2500f, 0.1340f, 3.1160f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_0) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.7468f, -1.3266f, 1.5323f, 0.6627f, -1.2078f, 1.3639f, -0.7468f, 0.6430f, 1.6191f, 0.6627f, 0.7618f, 1.4507f, -0.4048f, -1.5442f, 1.8408f, 1.0048f, -1.4254f, 1.6724f, -0.4048f, 0.4254f, 1.9276f, 1.0048f, 0.5442f, 1.7592f, -0.0627f, -1.7618f, 2.1493f, 1.3468f, -1.6430f, 1.9809f, -0.0627f, 0.2078f, 2.2361f, 1.3468f, 0.3266f, 2.0677f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_1) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f, -0.8962f, -1.4008f, 1.6375f, 0.0435f, -1.3216f, 1.5252f, 0.9832f, -1.2424f, 1.4130f, -0.8962f, 0.5688f, 1.7243f, 0.0435f, 0.6480f, 1.6121f, 0.9832f, 0.7272f, 1.4998f, -0.3832f, -1.7272f, 2.1002f, 0.5565f, -1.6480f, 1.9879f, 1.4962f, -1.5688f, 1.8757f, -0.3832f, 0.2424f, 2.1870f, 0.5565f, 0.3216f, 2.0748f, 1.4962f, 0.4008f, 1.9625f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_2) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.5299f, 0.8995f, -4.3568f, -0.2701f, -0.3995f, -2.9818f, -0.5299f, 2.3995f, 0.4064f, -0.2701f, 1.1005f, 1.7814f, -0.6299f, -0.6005f, -2.7691f, -0.3701f, -1.8995f, -1.3941f, -0.6299f, 0.8995f, 1.9941f, -0.3701f, -0.3995f, 3.3691f, -0.7299f, -2.1005f, -1.1814f, -0.4701f, -3.3995f, 0.1936f, -0.7299f, -0.6005f, 3.5818f, -0.4701f, -1.8995f, 4.9568f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_3) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)0); + test.AddInput("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f, -0.5982f, 0.7410f, -4.1890f, -0.4250f, -0.1250f, -3.2724f, -0.2518f, -0.9910f, -2.3557f, -0.5982f, 2.2410f, 0.5741f, -0.4250f, 1.3750f, 1.4908f, -0.2518f, 0.5090f, 2.4075f, -0.7482f, -1.5090f, -1.8075f, -0.5750f, -2.3750f, -0.8908f, -0.4018f, -3.2410f, 0.0259f, -0.7482f, -0.0090f, 2.9557f, -0.5750f, -0.8750f, 3.8724f, -0.4018f, -1.7410f, 4.7890f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_4) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-1.6226f, -2.2620f, 1.4189f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, 1.1965f, 1.9147f, 1.2557f, -1.1095f, -2.5884f, 1.8816f, 1.7095f, -2.3508f, 1.5448f, -1.1095f, 1.3508f, 2.0552f, 1.7095f, 1.5884f, 1.7184f, -0.5965f, -2.9147f, 2.3443f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 2.2226f, 1.2620f, 2.1811f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_5) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 3, 4}, {1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f, 1.409539f, 0.000000f, 0.513030f, 0.300000f, 0.118782f, 1.969615f, -0.326352f, -0.500000f, -0.168412f, 0.086824f, 0.462708f, 1.800000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f, -1.6226f, -2.2620f, 1.4189f, -0.2130f, -2.1433f, 1.2505f, 1.1965f, -2.0245f, 1.0821f, -1.6226f, 1.6772f, 1.5925f, -0.2130f, 1.7960f, 1.4241f, 1.1965f, 1.9147f, 1.2557f, -0.5965f, -2.9147f, 2.3443f, 0.8130f, -2.7960f, 2.1759f, 2.2226f, -2.6772f, 2.0075f, -0.5965f, 1.0245f, 2.5179f, 0.8130f, 1.1433f, 2.3495f, 2.2226f, 1.2620f, 2.1811f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_6) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {1, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {1, 1, 3, 2, 2}); + test.AddOutput("grid", {1, 3, 2, 2, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.0902f, 1.9510f, 4.0566f, -0.7598f, -0.7010f, -5.8381f, -0.2402f, -3.2990f, -3.0881f, -0.7598f, 2.2990f, 3.6881f, -0.2402f, -0.2990f, 6.4381f, -0.9098f, -2.9510f, -3.4566f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.3902f, -2.5490f, 8.8197f}); + test.Run(); +} + +TEST(AffineGridTest, test_3d_7) { + OpTester test("AffineGrid", 20); + test.AddAttribute("align_corners", (int64_t)1); + test.AddInput("theta", {2, 3, 4}, {0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f, 0.259808f, 0.000000f, -0.150000f, -0.500000f, -1.299038f, 1.500000f, -2.250000f, -0.500000f, 1.375000f, 4.763140f, 2.381570f, 0.300000f}); + test.AddInput("size", {5}, {2, 10, 2, 2, 3}); + test.AddOutput("grid", {2, 2, 2, 3, 3}, {-0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f, -0.6098f, 1.5490f, -8.2197f, -0.3500f, 0.2500f, -6.8447f, -0.0902f, -1.0490f, -5.4697f, -0.6098f, 4.5490f, 1.3066f, -0.3500f, 3.2500f, 2.6816f, -0.0902f, 1.9510f, 4.0566f, -0.9098f, -2.9510f, -3.4566f, -0.6500f, -4.2500f, -2.0816f, -0.3902f, -5.5490f, -0.7066f, -0.9098f, 0.0490f, 6.0697f, -0.6500f, -1.2500f, 7.4447f, -0.3902f, -2.5490f, 8.8197f}); + test.Run(); +} +} // namespace test +} // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py new file mode 100644 index 0000000000000..22bad6f1be534 --- /dev/null +++ b/onnxruntime/test/providers/cpu/tensor/affine_grid_test_gen.py @@ -0,0 +1,111 @@ +import argparse + +import numpy as np +import torch +from torch.nn.functional import affine_grid + +opset_version = 20 +parser = argparse.ArgumentParser(description="Generate test cases for the AffineGrid operator.") +parser.add_argument("--dim", type=int, choices=[2, 3], help="Dimension of the test cases (2 or 3)") +args = parser.parse_args() + +if args.dim is None or args.dim == 2: + align_corners_options = [False, True] + angles = [10, 60] + translations = [np.array([0.3, -0.5]), np.array([-0.5, -0.5])] + scales = [np.array([1.5, 0.5]), np.array([3.0, 5.5])] + sizes = [[1, 1, 3, 2], [2, 10, 2, 3]] + test_count = 0 + + for align_corners in align_corners_options: + for angle, translation, scale in zip(angles, translations, scales): + for size in sizes: + theta = np.array([], dtype=np.float32) + for _ in range(size[0]): + angle_radian = (angle / 180.0) * np.pi + theta = np.append( + theta, + [ + np.cos(angle_radian) * scale[0], + -np.sin(angle_radian), + translation[0], + np.sin(angle_radian), + np.cos(angle_radian) * scale[1], + translation[1], + ], + ) + theta = theta.reshape(size[0], 2, 3) + theta = torch.Tensor(theta) + grid = affine_grid(theta, size, align_corners=align_corners) + + # Print the C++ code for the test case + print(f"TEST(AffineGridTest, test_2d_{test_count}) {{") + print(f' OpTester test("AffineGrid", {opset_version});') + print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});') + print( + f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});" + ) + print( + f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}}});' + ) + print( + f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, 2}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});" + ) + print(" test.Run();") + print("}\n") + test_count += 1 + + +if args.dim is None or args.dim == 3: + align_corners_options = [False, True] + angles = [[10, 20], [60, -30]] + translations = [np.array([0.3, -0.5, 1.8]), np.array([-0.5, -0.5, 0.3])] + scales = [np.array([1.5, 2.0, 0.5]), np.array([0.3, 3.0, 5.5])] + sizes = [[1, 1, 3, 2, 2], [2, 10, 2, 2, 3]] + test_count = 0 + + for align_corners in align_corners_options: + for angle, translation, scale in zip(angles, translations, scales): + for size in sizes: + theta = np.array([], dtype=np.float32) + for _ in range(size[0]): + angle_radian_x = (angle[0] / 180.0) * np.pi + angle_radian_y = (angle[1] / 180.0) * np.pi + rot_matrix_x = np.array( + [ + [1, 0, 0], + [0, np.cos(angle_radian_x), -np.sin(angle_radian_x)], + [0, np.sin(angle_radian_x), np.cos(angle_radian_x)], + ] + ) + rot_matrix_y = np.array( + [ + [np.cos(angle_radian_y), 0, np.sin(angle_radian_y)], + [0, 1, 0], + [-np.sin(angle_radian_y), 0, np.cos(angle_radian_y)], + ] + ) + rot_matrix = np.matmul(rot_matrix_x, rot_matrix_y) + rot_matrix = rot_matrix * scale.reshape(3, 1) + rot_matrix = np.append(rot_matrix, np.reshape(translation, (3, 1)), axis=1) + theta = np.append(theta, rot_matrix.flatten()) + theta = theta.reshape(size[0], 3, 4) + theta = torch.Tensor(theta) + grid = affine_grid(theta, size, align_corners=align_corners) + + # Print the C++ code for the test case + print(f"TEST(AffineGridTest, test_3d_{test_count}) {{") + print(f' OpTester test("AffineGrid", {opset_version});') + print(f' test.AddAttribute("align_corners", (int64_t){1 if align_corners else 0});') + print( + f" test.AddInput(\"theta\", {{{theta.shape[0]}, {theta.shape[1]}, {theta.shape[2]}}}, {{{', '.join([f'{x:.6f}f' for x in theta.flatten()])}}});" + ) + print( + f' test.AddInput("size", {{{len(size)}}}, {{{size[0]}, {size[1]}, {size[2]}, {size[3]}, {size[4]}}});' + ) + print( + f" test.AddOutput(\"grid\", {{{size[0]}, {size[2]}, {size[3]}, {size[4]}, 3}}, {{{', '.join([f'{x:.4f}f' for x in grid.flatten()])}}});" + ) + print(" test.Run();") + print("}\n") + test_count += 1 diff --git a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc index ddb392eb82e13..2e583c5d2547b 100644 --- a/onnxruntime/test/providers/cpu/tensor/isinf_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isinf_test.cc @@ -17,85 +17,137 @@ constexpr double DOUBLE_INF = std::numeric_limits::infinity(); constexpr double DOUBLE_NINF = -std::numeric_limits::infinity(); constexpr double DOUBLE_NAN = std::numeric_limits::quiet_NaN(); -TEST(IsInfTest, test_isinf_float) { - // Defaults for detect_negative = 1 - // detect_positive = 1 - OpTester test("IsInf", 10); +template +void run_is_inf_test(int opset, int64_t detect_positive, int64_t detect_negative, const std::initializer_list& input, const std::initializer_list& output) { + OpTester test("IsInf", opset); + test.AddAttribute("detect_positive", detect_positive); + test.AddAttribute("detect_negative", detect_negative); + test.AddInput("X", {onnxruntime::narrow(input.size())}, input); + test.AddOutput("Y", {onnxruntime::narrow(output.size())}, output); + test.Run(); +} - std::vector input_dim{6}; - std::vector input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_float10) { + std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(10, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, true, true}); - test.Run(); +TEST(IsInfTest, test_isinf_float20) { + std::initializer_list input = {-1.2f, FLOAT_NAN, FLOAT_INF, 2.8f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); } -TEST(IsInfTest, test_isinf_double) { - // Defaults for detect_negative = 1 - // detect_positive = 1 - OpTester test("IsInf", 10); +TEST(IsInfTest, test_isinf_double10) { + std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(10, 1, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_double20) { + std::initializer_list input = {-1.2, DOUBLE_NAN, DOUBLE_INF, 2.8, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, true, true}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, true, true}); - test.Run(); +TEST(IsInfTest, test_isinf_positive_float10) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(10, 1, 0, input, output); } -TEST(IsInfTest, test_isinf_positive_float) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_negative", 0); +TEST(IsInfTest, test_isinf_positive_float20) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_positive_double10) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(10, 1, 0, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, false, true}); - test.Run(); +TEST(IsInfTest, test_isinf_positive_double20) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, true, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} + +TEST(IsInfTest, test_isinf_negative_float10) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(10, 0, 1, input, output); } -TEST(IsInfTest, test_isinf_positive_double) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_negative", 0); +TEST(IsInfTest, test_isinf_negative_float20) { + std::initializer_list input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_isinf_negative_double10) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(10, 0, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, true, false, false, true}); - test.Run(); +TEST(IsInfTest, test_isinf_negative_double20) { + std::initializer_list input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; + std::initializer_list output = {false, false, false, false, true, false}; + run_is_inf_test(20, 0, 1, input, output); } -TEST(IsInfTest, test_isinf_negative_float) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_positive", 0); +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(IsInfTest, test_Float8E4M3FN) { + std::initializer_list input = { + Float8E4M3FN(-1.0f), Float8E4M3FN(FLOAT_NAN, false), Float8E4M3FN(1.0f), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_NINF, false), Float8E4M3FN(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7f, FLOAT_NAN, FLOAT_INF, 3.6f, FLOAT_NINF, FLOAT_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_Float8E4M3FNUZ) { + std::initializer_list input = { + Float8E4M3FNUZ(-1.0f), Float8E4M3FNUZ(FLOAT_NAN, false), Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_NINF, false), Float8E4M3FNUZ(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, false, false, true, false}); - test.Run(); +TEST(IsInfTest, test_Float8E5M2_detect_both) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, true, false, true, false, true}; + run_is_inf_test(20, 1, 1, input, output); } -TEST(IsInfTest, test_isinf_negative_double) { - OpTester test("IsInf", 10); - test.AddAttribute("detect_positive", 0); +TEST(IsInfTest, test_Float8E5M2_detect_positive) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, true}; + run_is_inf_test(20, 1, 0, input, output); +} - std::vector input_dim{6}; - std::vector input = {-1.7, DOUBLE_NAN, DOUBLE_INF, 3.6, DOUBLE_NINF, DOUBLE_INF}; - test.AddInput("X", input_dim, input); +TEST(IsInfTest, test_Float8E5M2_detect_negative) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, true, false, true, false, false}; + run_is_inf_test(20, 0, 1, input, output); +} - std::vector output_dim(input_dim); - test.AddOutput("Y", output_dim, {false, false, false, false, true, false}); - test.Run(); +TEST(IsInfTest, test_Float8E5M2_none) { + std::initializer_list input = { + Float8E5M2(-1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(1.0f), Float8E5M2(FLOAT_NINF, false), Float8E5M2(FLOAT_NAN, false), Float8E5M2(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 0, 0, input, output); } +TEST(IsInfTest, test_Float8E5M2FNUZ) { + std::initializer_list input = { + Float8E5M2FNUZ(-1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(FLOAT_NINF, false), Float8E5M2FNUZ(FLOAT_NAN, false), Float8E5M2FNUZ(FLOAT_INF, false)}; + std::initializer_list output = {false, false, false, false, false, false}; + run_is_inf_test(20, 1, 1, input, output); +} +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc index 0dffc452b519d..0f1e5c07cdd9b 100644 --- a/onnxruntime/test/providers/cpu/tensor/isnan_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/isnan_test.cc @@ -9,29 +9,84 @@ namespace onnxruntime { namespace test { -TEST(IsNaNOpTest, IsNaNFloat) { - OpTester test("IsNaN", 9, kOnnxDomain); - std::vector dims{2, 2}; - test.AddInput("X", dims, {1.0f, NAN, 2.0f, NAN}); - test.AddOutput("Y", dims, {false, true, false, true}); +template +void run_is_nan_test(int opset, const std::vector& dims, const std::initializer_list& input, const std::initializer_list& output) { + OpTester test("IsNaN", opset, kOnnxDomain); + test.AddInput("X", dims, input); + test.AddOutput("Y", dims, output); test.Run(); } -TEST(IsNaNOpTest, IsNaNFloat16) { - OpTester test("IsNaN", 9, kOnnxDomain); +TEST(IsNaNOpTest, IsNaNFloat9) { std::vector dims{2, 2}; - test.AddInput("X", dims, std::initializer_list({MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN})); - test.AddOutput("Y", dims, {false, true, false, true}); - test.Run(); + std::initializer_list input = {1.0f, NAN, 2.0f, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); } -TEST(IsNaNOpTest, IsNaNDouble) { - OpTester test("IsNaN", 9, kOnnxDomain); +TEST(IsNaNOpTest, IsNaNFloat20) { std::vector dims{2, 2}; - test.AddInput("X", dims, {1.0, NAN, 2.0, NAN}); - test.AddOutput("Y", dims, {false, true, false, true}); - test.Run(); + std::initializer_list input = {1.0f, NAN, 2.0f, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat16_9) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat16_20) { + std::vector dims{2, 2}; + std::initializer_list input = {MLFloat16(1.0f), MLFloat16::NaN, MLFloat16(2.0f), MLFloat16::NaN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNDouble9) { + std::vector dims{2, 2}; + std::initializer_list input = {1.0, NAN, 2.0, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(9, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNDouble20) { + std::vector dims{2, 2}; + std::initializer_list input = {1.0, NAN, 2.0, NAN}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); } +#if !defined(DISABLE_FLOAT8_TYPES) +TEST(IsNaNOpTest, IsNaNFloat8E4M3FN) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E4M3FN(1.0f), Float8E4M3FN(-NAN), Float8E4M3FN(2.0f), Float8E4M3FN(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaN_Float8E4M3FNUZ) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E4M3FNUZ(1.0f), Float8E4M3FNUZ(-NAN), Float8E4M3FNUZ(2.0f), Float8E4M3FNUZ(-NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaNFloat8E5M2) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E5M2(1.0f), Float8E5M2(-NAN), Float8E5M2(2.0f), Float8E5M2(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} + +TEST(IsNaNOpTest, IsNaN_Float8E5M2FNUZ) { + std::vector dims{2, 2}; + std::initializer_list input = {Float8E5M2FNUZ(1.0f), Float8E5M2FNUZ(-NAN), Float8E5M2FNUZ(2.0f), Float8E5M2FNUZ(NAN)}; + std::initializer_list output = {false, true, false, true}; + run_is_nan_test(20, dims, input, output); +} +#endif } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/qnn/matmul_test.cpp b/onnxruntime/test/providers/qnn/matmul_test.cpp index e721ccbcb45a9..3073dde9d8e4c 100644 --- a/onnxruntime/test/providers/qnn/matmul_test.cpp +++ b/onnxruntime/test/providers/qnn/matmul_test.cpp @@ -112,12 +112,13 @@ TEST_F(QnnCPUBackendTests, MatMulOp) { } // Test MatMul broadcasting -// Note slight inaccuracy in CPU backend: +// Failed randomly on Linux +// Value of: expected_tensor.DataAsSpan() // Expected: contains 896 values, where each value and its corresponding value in 16-byte object -// <80-03 00-00 00-00 00-00 40-00 34-DD F7-01 00-00> are an almost-equal pair -// Actual: 16-byte object <80-03 00-00 00-00 00-00 40-00 23-DD F7-01 00-00>, -// where the value pair (73.68116, 73.680809) at index #80 don't match, which is -0.000350952 from 73.6812 -TEST_F(QnnCPUBackendTests, MatMulOp_Broadcast) { +// <80-03 00-00 00-00 00-00 40-B8 53-08 CC-7F 00-00> are an almost-equal pair +// Actual: 16-byte object <80-03 00-00 00-00 00-00 C0-B7 43-08 CC-7F 00-00>, where the value pair +// (-5.19657087, 0) at index #29 don't match, which is 5.19657 from -5.19657 +TEST_F(QnnCPUBackendTests, DISABLED_MatMulOp_Broadcast) { // Create two matrices with element values in the range [-10.0, 10.0]. std::vector input_a = GetFloatDataInRange(-10.0f, 10.0f, 28 * 64); std::vector input_b = GetFloatDataInRange(-10.0f, 10.0f, 64 * 32); diff --git a/onnxruntime/test/python/onnxruntime_test_python.py b/onnxruntime/test/python/onnxruntime_test_python.py index 1d954fe4370ad..d8628c4288206 100644 --- a/onnxruntime/test/python/onnxruntime_test_python.py +++ b/onnxruntime/test/python/onnxruntime_test_python.py @@ -298,6 +298,20 @@ def test_set_providers_with_options(self): self.assertEqual(option["trt_engine_cache_path"], str(engine_cache_path)) self.assertEqual(option["trt_force_sequential_engine_build"], "1") + from onnxruntime.capi import _pybind_state as C + + session_options = C.get_default_session_options() + + # TRT plugins registered as custom op domain should only be added once in session option regardless of number of session creation + sess1 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + sess2 = onnxrt.InferenceSession( + get_name("mul_1.onnx"), session_options, providers=["TensorrtExecutionProvider"] + ) + self.assertIn("TensorrtExecutionProvider", sess1.get_providers()) + self.assertIn("TensorrtExecutionProvider", sess2.get_providers()) + # We currently disable following test code since that not all test machines/GPUs have nvidia int8 capability """ diff --git a/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py new file mode 100644 index 0000000000000..88432d75c653e --- /dev/null +++ b/onnxruntime/test/python/quantization/test_op_matmul_bnb4.py @@ -0,0 +1,186 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import tempfile +import unittest +from importlib.util import find_spec +from pathlib import Path +from typing import Dict, Tuple, Union + +import numpy as np +import onnx +from onnx import TensorProto, helper +from op_test_utils import TestDataFeeds, check_model_correctness, check_op_type_count + +from onnxruntime.quantization import quant_utils + +quant_maps = { + 0: [ + 0.00000000, + 5.208333333e-03, + 0.66666667, + 1.00000000, + 0.33333333, + 0.50000000, + 0.16666667, + 0.25000000, + -0.00000000, + -5.208333333e-03, + -0.66666667, + -1.00000000, + -0.33333333, + -0.50000000, + -0.16666667, + -0.25000000, + ], + 1: [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], +} + + +class TestOpMatMulBnb4(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls._tmp_model_dir = tempfile.TemporaryDirectory(prefix="test_matmulbnb4.") + + @classmethod + def tearDownClass(cls): + cls._tmp_model_dir.cleanup() + + def fill_bnb4_data(self, shape: Tuple[int, int], quant_type: int) -> np.ndarray: + rows, cols = shape + line = np.zeros(shape) + line = line.reshape(-1) + quant_map = np.array(quant_maps[quant_type], dtype=np.float32) + + v = 0 + for i in range(line.shape[0]): + line[i] = quant_map[v] + v += 1 + if v >= 16: + v = 0 + + # bnb quantization quantizes weight.T after flattening + line = line.reshape(cols, rows).transpose() + return line.reshape(shape) + + def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds: + input_data_list = [] + for _i in range(n): + inputs = {} + for name, shape in name2shape.items(): + inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)}) + input_data_list.extend([inputs]) + dr = TestDataFeeds(input_data_list) + return dr + + def construct_model_matmul(self, output_model_path: str, quant_type: int) -> None: + # (input) + # | + # MatMul + # | + # (output) + input_name = "input" + output_name = "output" + initializers = [] + + def make_matmul(input_name, weight_shape: Union[int, Tuple[int, ...]], weight_name: str, output_name: str): + weight_data = self.fill_bnb4_data(weight_shape, quant_type).astype(np.float32) + initializers.append(onnx.numpy_helper.from_array(weight_data, name=weight_name)) + return onnx.helper.make_node( + "MatMul", + [input_name, weight_name], + [output_name], + ) + + # for this to work (in_features * out_features) % block_size == 0 + in_features = 52 + out_features = 288 + # make MatMul node + matmul_node = make_matmul( + input_name, + [in_features, out_features], + "linear1.weight", + output_name, + ) + + # make graph + input_tensor = helper.make_tensor_value_info(input_name, TensorProto.FLOAT, [-1, in_features]) + output_tensor = helper.make_tensor_value_info(output_name, TensorProto.FLOAT, [-1, out_features]) + graph_name = "matmul_bnb4_test" + graph = helper.make_graph( + [matmul_node], + graph_name, + [input_tensor], + [output_tensor], + initializer=initializers, + ) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + model.ir_version = 7 # use stable onnx ir version + + onnx.save(model, output_model_path) + + def quant_test(self, quant_type: int, block_size: int): + model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath(f"matmul_fp32_{quant_type}.onnx").absolute()) + self.construct_model_matmul(model_fp32_path, quant_type) + data_reader = self.input_feeds(1, {"input": [100, 52]}) + + model_bnb4_path = str( + Path(self._tmp_model_dir.name).joinpath(f"MatMulBnb4_{quant_type}_{block_size}.onnx").absolute() + ) + + # Quantize fp32 model to bnb4 model + from onnxruntime.quantization import matmul_bnb4_quantizer + + model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path)) + quant = matmul_bnb4_quantizer.MatMulBnb4Quantizer(model, quant_type, block_size) + quant.process() + quant.model.save_model_to_file(model_bnb4_path, False) + + quant_nodes = {"MatMulBnb4": 1} + check_op_type_count(self, model_bnb4_path, **quant_nodes) + + data_reader.rewind() + + try: + check_model_correctness(self, model_fp32_path, model_bnb4_path, data_reader.get_next()) + except Exception as exception: + raise exception + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_fp4(self): + np.random.seed(13) + self.quant_test(0, 64) + + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_matmul_bnb4_nf4(self): + np.random.seed(13) + self.quant_test(1, 64) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py new file mode 100644 index 0000000000000..9e9d05fae027d --- /dev/null +++ b/onnxruntime/test/python/quantization/test_quantizeblockwise_bnb4.py @@ -0,0 +1,139 @@ +#!/usr/bin/env python +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import unittest +from importlib.util import find_spec + +import numpy as np +import numpy.typing as npt + +quant_enums = {"FP4": 0, "NF4": 1} + + +def quantize_block_fp4(block: npt.ArrayLike): + # quantize a block of float32 values to uint8 by simulating a binary search using pivots + # could have used (block[:,None] - quant_map).argmin(axis=1) but there are some mismatches due to + # floating point precision + # block: 1-D array of normalized [-1,1] float32 values, len(block) % 2 == 0 + + # pivots to find the quantization index + # only half of the pivots are needed since the other half is symmetric + pivots = np.array( + [0.00260417, 0.0859375, 0.20833333, 0.29166667, 0.4166667, 0.583333, 0.8333333, 1], dtype=np.float32 + ) + # indices are not 0,1,2,3,4,5,6,7 because it is a floating point data type + pivot_indices = np.array([0, 1, 6, 7, 4, 5, 2, 3], dtype=np.uint8) + + # signs of the block + signs = (block < 0).astype(np.uint8) * 8 + + # find the uint8 quantization index + # argmax finds the first occurrence of True + quant_indices = pivot_indices[(np.abs(block)[:, None] <= pivots).argmax(axis=1)] + signs + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_block_nf4(block: npt.ArrayLike): + pivots = np.array( + [ + -0.8480964004993439, + -0.6106329262256622, + -0.4599952697753906, + -0.33967943489551544, + -0.23460740596055984, + -0.13791173323988914, + -0.045525018125772476, + 0.03979014977812767, + 0.1202552504837513, + 0.2035212516784668, + 0.2920137718319893, + 0.3893125355243683, + 0.5016634166240692, + 0.6427869200706482, + 0.8614784181118011, + 1.0, + ], + dtype=np.float32, + ) + + quant_indices = (block[:, None] <= pivots).argmax(axis=1).astype(np.uint8) + + return np.bitwise_or(np.left_shift(quant_indices[::2], 4), quant_indices[1::2]) + + +def quantize_blockwise_bnb4_ref(matrix_float: npt.ArrayLike, block_size: int, quant_type: str, target=None): + if len(matrix_float.shape) != 2: + raise ValueError("Current bnb4 block quantization only supports 2D tensors!") + + numel = matrix_float.size + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype=np.uint8) + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + + flattened_matrix_float = matrix_float.flatten() + for block_idx in range(num_blocks): + block_len = min(block_size, numel - block_idx * block_size) + block = np.float32(flattened_matrix_float[block_idx * block_size : block_idx * block_size + block_len]) + + block_absmax = np.max(np.abs(block)) + reciprocal_absmax = 1.0 / block_absmax if block_absmax != 0 else 0.0 + absmax[block_idx] = block_absmax + + if block_len % 2 != 0: + block = np.append(block, 0.0) + block_len += 1 + + block *= reciprocal_absmax + start = block_idx * block_size // 2 + end = start + block_len // 2 + if quant_type == "FP4": + packed[start:end] = quantize_block_fp4(block) + else: + packed[start:end] = quantize_block_nf4(block) + + return (packed, absmax) + + +def quantize_blockwise_bnb4_target(matrix_float: npt.ArrayLike, block_size: int, quant_type: str): + if len(matrix_float.shape) != 2: + raise ValueError("Current int4 block quantization only supports 2D tensors!") + quant_type_enum = quant_enums[quant_type] + + n, k = matrix_float.shape # already transposed + numel = n * k + num_blocks = (numel + block_size - 1) // block_size + quantized_numel = (numel + 1) // 2 + + packed = np.zeros(quantized_numel, dtype="uint8") + absmax = np.zeros(num_blocks, dtype=matrix_float.dtype) + from onnxruntime.capi._pybind_state import quantize_matmul_bnb4 + + quantize_matmul_bnb4(packed, matrix_float, absmax, block_size, quant_type_enum, n, k) + return (packed, absmax) + + +class TestQuantizeBlockwiseBnb4(unittest.TestCase): + @unittest.skipIf( + find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_bnb4" + ) + def test_quantize_blockwise_bnb4(self): + for quant_type in ["FP4", "NF4"]: + for k, n in [(128, 128), (32, 128), (128, 32), (52, 128), (128, 52), (73, 123)]: + for block_size in [16, 32, 64, 128]: + for type in [np.float32, np.float16]: + matrix_float = np.random.uniform(-1, 1, (k, n)).astype(type) + quant_value_ref, absmax_ref = quantize_blockwise_bnb4_ref(matrix_float, block_size, quant_type) + quant_value, absmax = quantize_blockwise_bnb4_target(matrix_float, block_size, quant_type) + assert np.allclose(quant_value_ref, quant_value) + assert np.allclose(absmax_ref, absmax) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py new file mode 100644 index 0000000000000..b17ae5f69aff5 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_parity_rotary_embedding.py @@ -0,0 +1,450 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + + +# Notes +# 1) The test cases in this file are for the following LLaMA-2 scenarios: +# - Microsoft rotary embeddings with interleaved = True +# - Prompt generation +# - Token generation +# - Hugging Face rotary embeddings (equal to Microsoft rotary embeddings with interleaved = False) +# - Prompt generation +# - Token generation +# +# 2) Shapes of position ids in ORT and `interleaved` for LLaMA-2 scenarios: +# - Microsoft model: When shape of position ids == (1), interleaved = True +# - Hugging Face model: When shape of position ids == (batch_size, sequence_length), interleaved = False + + +import unittest +from copy import deepcopy + +import numpy as np +import torch +import torch.nn as nn +from onnx import TensorProto, helper + +import onnxruntime as ort + + +class SampleInputConfig: + def __init__( + self, + batch_size=2, + sequence_length=8, + num_heads=4, + head_size=6, + max_sequence_length=16, + ): + self.batch_size = batch_size + self.sequence_length = sequence_length + self.num_heads = num_heads + self.head_size = head_size + self.hidden_size = self.num_heads * self.head_size + self.max_sequence_length = max_sequence_length + + +# LLaMA Hugging Face model +class LlamaHFRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device="cpu"): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype) + + freqs = torch.einsum("i,j->ij", t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) + + def get_cos_sin_cache(self, seq_len=None, device=torch.device("cpu"), dtype=torch.float32): # noqa: B008 + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=device, dtype=dtype) + + return ( + self.cos_cached[:, :, :seq_len, ...].to(dtype=dtype), + self.sin_cached[:, :, :seq_len, ...].to(dtype=dtype), + ) + + def rotate_half(self, x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def apply_rope_bnsh(self, x, cos, sin, position_ids): + # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] + sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] + cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def apply_rope_bsnh(self, x, cos, sin, position_ids): + # Two dimensions of cos and sin are always 1, so we can `squeeze` them. + cos = cos.squeeze() # [seq_len, dim] + sin = sin.squeeze() # [seq_len, dim] + cos = cos[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + sin = sin[position_ids].unsqueeze(2) # [bs, seq_len, 1, dim] + x_embed = (x * cos) + (self.rotate_half(x) * sin) + return x_embed + + def forward(self, x, cos, sin, pos_ids, x_format="bnsh"): + if x_format == "bnsh": + return self.apply_rope_bnsh(x, cos, sin, pos_ids) + return self.apply_rope_bsnh(x, cos, sin, pos_ids) + + +# LLaMA Microsoft model +class LlamaMSRotaryEmbedding(nn.Module): + def __init__(self, hidden_size, num_heads, max_sequence_length): + super().__init__() + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.max_sequence_length = max_sequence_length + + def get_cos_sin_cache(self, theta=10000.0, head_scale=1.0, device="cpu", dtype=torch.float32): + hidden_size = self.hidden_size + n_heads = self.num_heads + max_seq_len = self.max_sequence_length + + # Precalculate rotary matrices for the sequence + # According to "Attention Is All You Need", theta_i = 10000 ^ (2 * (i - 1)/dim), i in [1, 2, ..., dim//2] + head_dim = head_scale * hidden_size / n_heads + + pos = torch.arange(0, 2 * (head_dim // 2), step=2, device=device, dtype=dtype) + freqs = 1.0 / (theta ** (pos / head_dim)) + + idx = torch.arange(max_seq_len, device=freqs.device) + freqs = torch.outer(idx, freqs) + + cos = torch.reshape(torch.cos(freqs), [1, max_seq_len, 1, -1]) + sin = torch.reshape(torch.sin(freqs), [1, max_seq_len, 1, -1]) + dtype = torch.get_default_dtype() + + return cos.to(dtype), sin.to(dtype) + + def rotate_tensor( + self, + x: torch.Tensor, # BxSxNxH + cos: torch.Tensor, # 1xSx1x(H/2) + sin: torch.Tensor, # 1xSx1x(H/2) + pos: int, + interleaved: bool, + ): + # Dimension of x is [batch_size, seq_len, n_heads, head_dim] + rot_dim = 2 * cos.shape[3] + + # Dolly requires partial rotation + x_rot = x[:, :, :, :rot_dim] + + if interleaved: + x1 = x_rot[:, :, :, 0::2] + x2 = x_rot[:, :, :, 1::2] + else: + half = x_rot.shape[-1] // 2 + x1 = x[:, :, :, 0:half] + x2 = x[:, :, :, half : 2 * half] + + seq_len = x.shape[1] + cos_x = cos[:, pos : pos + seq_len, :, :] + sin_x = sin[:, pos : pos + seq_len, :, :] + + # cos_x: (1, S, 1, H/2) + # sin_x: (1, S, 1, H/2) + # x1: (B, S, N, H/2) + # x2: (B, S, N, H/2) + real = cos_x * x1 - sin_x * x2 + imag = sin_x * x1 + cos_x * x2 + + if interleaved: + x_rot[:, :, :, 0::2] = real + x_rot[:, :, :, 1::2] = imag + else: + x_rot = torch.cat((real, imag), dim=-1) + + return torch.cat((x_rot, x[:, :, :, rot_dim:]), dim=-1) + + def forward(self, x, cos, sin, pos, interleaved): + return self.rotate_tensor(x, cos, sin, pos, interleaved) + + +class TestLlamaRotaryEmbedding(unittest.TestCase): + def setUp(self): + self.config = SampleInputConfig() + self.llama_hf = LlamaHFRotaryEmbedding(self.config.head_size, self.config.max_sequence_length) + self.llama_ms = LlamaMSRotaryEmbedding( + self.config.hidden_size, self.config.num_heads, self.config.max_sequence_length + ) + + seed = 2 + np.random.seed(seed) + torch.manual_seed(seed) + torch.set_printoptions(sci_mode=False) + + def create_onnx_graph(self, x_shape, pos_shape, cos, sin, interleaved): + inputs = [ + helper.make_tensor_value_info( + name="input", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + helper.make_tensor_value_info( + name="position_ids", + elem_type=TensorProto.INT64, + shape=list(pos_shape), + ), + ] + outputs = [ + helper.make_tensor_value_info( + name="output", + elem_type=TensorProto.FLOAT, + shape=list(x_shape), + ), + ] + + initializers = [ + helper.make_tensor( + name="cos_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(cos).shape), + vals=cos.flatten().tolist(), + ), + helper.make_tensor( + name="sin_cache", + data_type=TensorProto.FLOAT, + dims=list(torch.squeeze(sin).shape), + vals=sin.flatten().tolist(), + ), + ] + nodes = [ + helper.make_node( + op_type="RotaryEmbedding", + inputs=["input", "position_ids", "cos_cache", "sin_cache"], + outputs=["output"], + interleaved=interleaved, + name="RotaryEmbedding_0", + domain="com.microsoft", + ), + ] + + graph = helper.make_graph( + nodes=nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model.SerializeToString() + + def get_eps(self): + eps = ["CPUExecutionProvider", "CUDAExecutionProvider"] + return list(filter(lambda ep: ep in ort.get_available_providers(), eps)) + + def run_ort_ep_tests(self, onnx_graph, inputs_ort, expected_output_bsnh): + eps = self.get_eps() + for ep in eps: + sess = ort.InferenceSession(onnx_graph, providers=[ep]) + output_ort = sess.run(None, inputs_ort)[0] + output_ort = output_ort.reshape( + (self.config.batch_size, inputs_ort["input"].shape[1], self.config.num_heads, self.config.head_size) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(expected_output_bsnh, output_ort)) + + # apply_rope(x_bnsh) == apply_rope(x_bsnh).transpose(1,2) + def test_hf_bnsh_and_hf_bsnh(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + + x_bnsh_after_rope = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + x_bsnh_after_rope = self.llama_hf( + x_bnsh.transpose(1, 2), cos_hf.transpose(1, 2), sin_hf.transpose(1, 2), pos_hf, "bsnh" + ) # output is BxSxNxH + + self.assertTrue(torch.allclose(x_bnsh_after_rope, x_bsnh_after_rope.transpose(1, 2))) + + # HF rotary == MSFT rotary non-interleaved + def test_hf_rotary_and_msft_rotary_noninterleaved(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = ( + self.llama_ms(x_bsd, cos_ms, sin_ms, pos_ms, interleaved=False).detach().cpu().numpy() # output is BxSxNxH + ) + + # Compare caches as Mx(H/2) + self.assertTrue( + torch.allclose(self.llama_hf.cos_cached.squeeze()[:, : (self.config.head_size // 2)], cos_ms.squeeze()) + ) + self.assertTrue( + torch.allclose(self.llama_hf.sin_cached.squeeze()[:, : (self.config.head_size // 2)], sin_ms.squeeze()) + ) + + # Compare outputs as BxSxNxH + self.assertTrue(np.allclose(output_hf.transpose(1, 2).detach().cpu().numpy(), output_ms)) + + # Prompt step, interleaved = true, pos ids shape = (1) + def test_msft_prompt_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 0 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Token generation step, interleaved = true, pos ids shape = (1) + def test_msft_token_rotary_interleaved(self): + # Calculated this way to match the data in rotary_embedding_op_test.cc + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = deepcopy(x_bsnh) # deepcopy to avoid changes made by self.llama_ms forward pass + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = 2 + output_ms = self.llama_ms(deepcopy(x_bsnh), cos_ms, sin_ms, pos_ms, interleaved=True).detach().cpu().numpy() + + x_bsd = x_bsd.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + pos_ms = torch.tensor([pos_ms]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=True) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare inputs/outputs as BxSxNxH + self.assertTrue(np.allclose(x_bsnh.flatten(), x_bsd.flatten())) + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_ms) + + # Prompt step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_prompt_rotary_batched_pos_ids(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Token generation step, interleaved = false, pos ids shape = (batch_size, sequence_length) + def test_hf_token_rotary_batched_pos_ids(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ids.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ids.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Prompt step, interleaved = false, pos ids shape = (1) + def test_hf_prompt_rotary_one_pos_id(self): + x_bnsh = torch.randn( + self.config.batch_size, self.config.num_heads, self.config.sequence_length, self.config.head_size + ) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_hf = torch.stack([torch.arange(0, self.config.sequence_length) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_hf) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, self.config.sequence_length, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([0]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + # Bonus test: Token generation step, interleaved = false, pos ids shape = (1) + def test_hf_token_rotary_one_pos_id(self): + x_bnsh = torch.randn(self.config.batch_size, self.config.num_heads, 1, self.config.head_size) + cos_hf, sin_hf = self.llama_hf.get_cos_sin_cache(self.config.sequence_length) + pos_ids = torch.stack([torch.tensor([2]) for _ in range(self.config.batch_size)]) + output_hf = self.llama_hf(x_bnsh, cos_hf, sin_hf, pos_ids) # output is BxNxSxH + + x_bsnh = x_bnsh.transpose(1, 2) + x_bsd = x_bsnh.reshape(self.config.batch_size, 1, self.config.hidden_size) + cos_ms, sin_ms = self.llama_ms.get_cos_sin_cache() + pos_ms = torch.tensor([2]) + onnx_graph = self.create_onnx_graph(x_bsd.shape, pos_ms.shape, cos_ms, sin_ms, interleaved=False) + inputs_ort = { + "input": x_bsd.detach().cpu().numpy(), + "position_ids": pos_ms.detach().cpu().numpy(), + } + + # Compare outputs as BxSxNxH + self.run_ort_ep_tests(onnx_graph, inputs_ort, output_hf.transpose(1, 2).detach().cpu().numpy()) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py new file mode 100644 index 0000000000000..7bca48c29019e --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_embedding_fusion.py @@ -0,0 +1,447 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryEmbeddingFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options, opt_level=0) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size]), + helper.make_tensor( + "pos_ids_new_shape", + TensorProto.FLOAT, + [2], + np.array([self.batch_size, self.sequence_length], dtype=np.int64), + ), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + helper.make_tensor("int_max", TensorProto.FLOAT, [1], np.array([sys.maxsize], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str = ""): + inputs = [ + helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + ] + if model_type in {"past", "merged"}: + # Input will be removed in fused model since it's not used in RotaryEmbedding. + # We create this input so that we can check the `past_seq_len` path during + # RotaryEmbedding fusion. + inputs.append( + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ) + ) + # Dummy input to test nodes for `curr_seq_len` path + if model_type != "": + inputs.append( + helper.make_tensor_value_info( + "curr_key", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.num_heads, self.head_size], + ) + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.sequence_length, self.head_size], + ) + ] + if model_type in {"merged"}: + # Dummy output to test that nodes for `past_seq_len` path are not removed for merged model + outputs.append(helper.make_tensor_value_info("past_seq_len_plus_zero", TensorProto.FLOAT, [1])) + return inputs, outputs + + def create_fused_model(self, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs() + + rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[inputs[0].name, inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[outputs[0].name], + name="RotaryEmbedding_0", + interleaved=int(interleaved), + ) + + graph = helper.make_graph( + nodes=[rope_node], + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_cache_path(self, model_type: str, use_redundant_squeeze_ops: bool): + # Create position ids path + reshape_node = helper.make_node( + "Reshape", + inputs=["position_ids", "pos_ids_new_shape"], + outputs=["pos_ids_reshaped"], + name="Reshape_0", + ) + pos_ids_nodes = [reshape_node] + + # Create cos path + cos_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["cos_unsqueeze"], + name="Unsqueeze_2", + ) + cos_slice_node = helper.make_node( + "Slice", + inputs=["cos_cache", "zero", "cos_unsqueeze", "two", "one"], + outputs=["cos_sliced"], + name="Slice_2", + ) + cos_nodes = [cos_init_unsqueeze_node, cos_slice_node] + + if use_redundant_squeeze_ops: + # These two nodes are eliminated by this transformers PR: https://github.com/huggingface/transformers/pull/26162 + cos_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["cos_sliced", "zero"], + outputs=["cos_squeeze_1"], + name="Squeeze_0", + ) + cos_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["cos_squeeze_1", "zero"], + outputs=["cos_squeeze_2"], + name="Squeeze_1", + ) + cos_nodes.extend([cos_squeeze_1_node, cos_squeeze_2_node]) + + cos_gather_node = helper.make_node( + "Gather", + inputs=["cos_squeeze_2" if use_redundant_squeeze_ops else "cos_sliced", "pos_ids_reshaped"], + outputs=["cos_indexed"], + name="Gather_1", + ) + cos_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["cos_indexed", "one"], + outputs=["cos"], + name="Unsqueeze_3", + ) + cos_nodes.extend([cos_gather_node, cos_end_unsqueeze_node]) + + # Create sin path + sin_init_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["sin_unsqueeze"], + name="Unsqueeze_4", + ) + sin_slice_node = helper.make_node( + "Slice", + inputs=["sin_cache", "zero", "sin_unsqueeze", "two", "one"], + outputs=["sin_sliced"], + name="Slice_3", + ) + sin_nodes = [sin_init_unsqueeze_node, sin_slice_node] + + if use_redundant_squeeze_ops: + sin_squeeze_1_node = helper.make_node( + "Squeeze", + inputs=["sin_sliced", "zero"], + outputs=["sin_squeeze_1"], + name="Squeeze_2", + ) + sin_squeeze_2_node = helper.make_node( + "Squeeze", + inputs=["sin_squeeze_1", "zero"], + outputs=["sin_squeeze_2"], + name="Squeeze_3", + ) + sin_nodes.extend([sin_squeeze_1_node, sin_squeeze_2_node]) + + sin_gather_node = helper.make_node( + "Gather", + inputs=["sin_squeeze_2" if use_redundant_squeeze_ops else "sin_sliced", "pos_ids_reshaped"], + outputs=["sin_indexed"], + name="Gather_2", + ) + sin_end_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["sin_indexed", "one"], + outputs=["sin"], + name="Unsqueeze_5", + ) + sin_nodes.extend([sin_gather_node, sin_end_unsqueeze_node]) + + # Create beginning nodes before cos and sin paths + + # Create curr seq len path + curr_transpose_node = helper.make_node( + "Transpose", + inputs=["curr_key"], + outputs=["curr_key_transposed"], + name="Transpose_curr", + perm=[0, 2, 1, 3], + ) + curr_shape_node = helper.make_node( + "Shape", + inputs=["curr_key_transposed"], + outputs=["curr_shape"], + name="Shape_curr", + ) + curr_gather_node = helper.make_node( + "Gather", + inputs=["curr_shape", "two"], + outputs=["curr_seq_len" if model_type in {"past", "merged"} else "new_seq_len"], + name="Gather_curr", + ) + beginning_nodes = [curr_transpose_node, curr_shape_node, curr_gather_node] + + if model_type in {"past", "merged"}: + # Create past seq len path + past_shape_node = helper.make_node( + "Shape", + inputs=["past_key"], + outputs=["past_shape"], + name="Shape_past", + ) + past_gather_node = helper.make_node( + "Gather", + inputs=["past_shape", "two"], + outputs=["past_seq_len"], + name="Gather_past", + ) + add_node = helper.make_node( + "Add", + inputs=["curr_seq_len", "past_seq_len"], + outputs=["new_seq_len"], + name="Add_1", + ) + beginning_nodes.extend([past_shape_node, past_gather_node, add_node]) + + if model_type == "merged": + dummy_node = helper.make_node( + "Add", + inputs=["past_seq_len", "zero"], + outputs=["past_seq_len_plus_zero"], + name="Add_dummy_node", + ) + beginning_nodes.append(dummy_node) + + return pos_ids_nodes + cos_nodes + sin_nodes + beginning_nodes + + def create_apply_rope_path(self): + start_node = helper.make_node( + "Transpose", + inputs=["input_0"], + outputs=["x"], + name="Transpose_0", + perm=[0, 2, 1, 3], + ) + + # Calculate x_half_shape + shape_node = helper.make_node( + "Shape", + inputs=["x"], + outputs=["x_shape"], + name="Shape_0", + ) + gather_node = helper.make_node( + "Gather", + inputs=["x_shape", "three"], + outputs=["x_last_idx_shape"], + name="Gather_0", + axis=0, + ) + div_node = helper.make_node( + "Div", + inputs=["x_last_idx_shape", "two"], + outputs=["x_half_shape"], + name="Div_0", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_0"], + name="Unsqueeze_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["x_half_shape", "zero"], + outputs=["x_half_shape_1"], + name="Unsqueeze_1", + ) + x_half_shape_nodes = [shape_node, gather_node, div_node, unsqueeze_0_node, unsqueeze_1_node] + + # Calculate rotate_half + x1_node = helper.make_node( + "Slice", + inputs=["x", "zero", "x_half_shape_0", "three", "one"], + outputs=["x1"], + name="Slice_0", + ) + x2_node = helper.make_node( + "Slice", + inputs=["x", "x_half_shape_1", "int_max", "three", "one"], + outputs=["x2"], + name="Slice_1", + ) + neg_node = helper.make_node( + "Neg", + inputs=["x2"], + outputs=["x2_neg"], + name="Neg_0", + ) + x_rotate_half_node = helper.make_node( + "Concat", + inputs=["x2_neg", "x1"], + outputs=["x_rotate_half"], + name="Concat_0", + axis=-1, + ) + rotate_half_nodes = [x1_node, x2_node, neg_node, x_rotate_half_node] + + # Calculate x_embed + x_cos_node = helper.make_node( + "Mul", + inputs=["x", "cos"], + outputs=["x_cos"], + name="Mul_0", + ) + x_sin_node = helper.make_node( + "Mul", + inputs=["x_rotate_half", "sin"], + outputs=["x_rotate_half_sin"], + name="Mul_1", + ) + end_node = helper.make_node( + "Add", + inputs=["x_cos", "x_rotate_half_sin"], + outputs=["output_0"], + name="Add_0", + ) + x_embed_nodes = [start_node, x_cos_node, x_sin_node, end_node] + + return x_half_shape_nodes + rotate_half_nodes + x_embed_nodes + + def create_test_model(self, model_type: str, use_redundant_squeeze_ops: bool, initializers: List[TensorProto]): + apply_rope_nodes = self.create_apply_rope_path() + cache_nodes = self.create_cache_path(model_type, use_redundant_squeeze_ops) + inputs, outputs = self.create_inputs_and_outputs(model_type) + + graph = helper.make_graph( + nodes=apply_rope_nodes + cache_nodes, + name="RotaryEmbedding_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=13) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, interleaved: bool, model_type: str): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + use_redundant_squeeze_ops = True + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(original_model_filename) + + use_redundant_squeeze_ops = False + original_model = self.create_test_model(model_type, use_redundant_squeeze_ops, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # Hugging Face's `decoder_model.onnx` + def test_hf_decoder_model(self): + interleaved = False # HF model does not use interleaving + model_type = "no_past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_with_past_model.onnx` + def test_hf_decoder_with_past_model(self): + interleaved = False # HF model does not use interleaving + model_type = "past" + self.check_models(interleaved, model_type) + + # Hugging Face's `decoder_merged.onnx` + def test_hf_decoder_merged_model(self): + interleaved = False # HF model does not use interleaving + model_type = "merged" + self.check_models(interleaved, model_type) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py new file mode 100644 index 0000000000000..fedba2a25dfc2 --- /dev/null +++ b/onnxruntime/test/python/transformers/test_rotary_mha_fusion.py @@ -0,0 +1,795 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import sys +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import NodeProto, TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestRotaryAttentionFusion(unittest.TestCase): + def setUp(self): + self.batch_size = 2 + self.sequence_length = 8 + self.num_heads = 4 + self.head_size = 6 + self.hidden_size = self.num_heads * self.head_size + + self.past_sequence_length = 2 + self.max_sequence_length = 12 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + model_type = "gpt2" + options = FusionOptions(model_type) + optimized_model = optimize_model( + original_model_path, + model_type, + self.num_heads, + self.hidden_size, + optimization_options=options, + opt_level=0, + ) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, fused_model: bool = False): + initializers = [ + float_tensor("cos_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("sin_cache", [self.max_sequence_length, self.head_size // 2]), + float_tensor("q_weight", [self.hidden_size, self.hidden_size]), + float_tensor("k_weight", [self.hidden_size, self.hidden_size]), + float_tensor("v_weight", [self.hidden_size, self.hidden_size]), + float_tensor("o_weight", [self.hidden_size, self.hidden_size]), + helper.make_tensor( + "sqrt_head_size", TensorProto.FLOAT, [1], np.array([np.sqrt(self.head_size)], dtype=np.float32) + ), + helper.make_tensor("neg_int_max", TensorProto.FLOAT, [1], np.array([-sys.maxsize - 1], dtype=np.int64)), + helper.make_tensor("num_heads", TensorProto.FLOAT, [1], np.array([self.num_heads], dtype=np.float32)), + helper.make_tensor("head_size", TensorProto.FLOAT, [1], np.array([self.head_size], dtype=np.float32)), + helper.make_tensor("hidden_size", TensorProto.FLOAT, [1], np.array([self.hidden_size], dtype=np.float32)), + helper.make_tensor("zero", TensorProto.FLOAT, [1], np.array([0], dtype=np.int64)), + helper.make_tensor("one", TensorProto.FLOAT, [1], np.array([1], dtype=np.int64)), + helper.make_tensor("two", TensorProto.FLOAT, [1], np.array([2], dtype=np.int64)), + helper.make_tensor("three", TensorProto.FLOAT, [1], np.array([3], dtype=np.int64)), + ] + return initializers + + def create_inputs_and_outputs(self, model_type: str): + attn_mask_size = [self.batch_size, self.sequence_length] + if model_type == "llama2_msft": + attn_mask_size.append(self.sequence_length) + + inputs = [ + helper.make_tensor_value_info( + "input_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info("position_ids", TensorProto.INT64, [self.batch_size, self.sequence_length]), + helper.make_tensor_value_info("attn_mask", TensorProto.INT64, attn_mask_size), + ] + if model_type in {"past", "merged", "llama2_msft"}: + inputs.extend( + [ + helper.make_tensor_value_info( + "past_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + helper.make_tensor_value_info( + "past_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length, self.head_size], + ), + ] + ) + outputs = [ + helper.make_tensor_value_info( + "output_0", TensorProto.FLOAT, [self.batch_size, self.sequence_length, self.hidden_size] + ), + helper.make_tensor_value_info( + "present_key", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + helper.make_tensor_value_info( + "present_value", + TensorProto.FLOAT, + [self.batch_size, self.num_heads, self.past_sequence_length + 1, self.head_size], + ), + ] + return inputs, outputs + + def create_matmul_nodes(self, is_fused: bool, model_type: str): + q_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "q_weight"], + outputs=["q_out" if is_fused or model_type == "llama2_msft" else "q_matmul_out"], + name="Q_MatMul", + ) + + k_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "k_weight"], + outputs=["k_out" if is_fused or model_type == "llama2_msft" else "k_matmul_out"], + name="K_MatMul", + ) + + v_matmul_node = helper.make_node( + "MatMul", + inputs=["input_0", "v_weight"], + outputs=["v_out"], + name="V_MatMul", + ) + + return [q_matmul_node, k_matmul_node, v_matmul_node] + + def create_rotary_embeddings( + self, + is_fused: bool, + model_type: str, + interleaved: bool, + inputs: List[TensorProto], + initializers: List[TensorProto], + ): + def get_first_rope_input(node_type: str): + if is_fused or model_type == "llama2_msft": + # q_out/k_out + return f"{node_type}_out" + if model_type in {"no_past", "past", "merged"}: + if node_type == "k": + return "k_before_rope" + return "q_before_rope" + return "" + + def get_first_rope_output(node_type: str): + if is_fused or model_type in {"llama2_msft", "past", "merged"}: + if node_type == "q": + return "q_rope" + return "k_rope" + if model_type in {"no_past"}: + if node_type == "k": + return "present_key" + return "q_rope" + return "" + + q_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("q"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("q")], + name="Q_RotaryEmbedding", + interleaved=int(interleaved), + ) + + k_rope_node = helper.make_node( + "RotaryEmbedding", + inputs=[get_first_rope_input("k"), inputs[1].name, initializers[0].name, initializers[1].name], + outputs=[get_first_rope_output("k")], + name="K_RotaryEmbedding", + interleaved=int(interleaved), + ) + + return [q_rope_node, k_rope_node] + + def create_q_path(self, model_type: str): + if model_type == "llama2_msft": + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_rope"], + outputs=["q_transposed"], + name="Transpose_q", + perm=[0, 2, 1, 3], + ) + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_transposed", "concat_q_extra_out"], + outputs=["q"], + name="Reshape_q", + ) + return [transpose_q_node, reshape_q_node] + + reshape_q_node = helper.make_node( + "Reshape", + inputs=["q_matmul_out", "concat_q_extra_out"], + outputs=["q_reshaped"], + name="Reshape_q", + ) + transpose_q_node = helper.make_node( + "Transpose", + inputs=["q_reshaped"], + outputs=["q_before_rope"], + name="Transpose_q", + ) + return [reshape_q_node, transpose_q_node] + + def create_k_path_llama2_msft(self): + # Create k cache slicing path + k_cache_unsqueeze_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["k_pos_id"], + ) + k_cache_slice_node = helper.make_node( + "Slice", + inputs=["past_key", "zero", "k_pos_id", "two", "one"], + outputs=["k_cache_sliced"], + ) + # Create k path + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_rope"], + outputs=["k_rope_transposed"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + concat_k_node = helper.make_node( + "Concat", + inputs=["k_cache_sliced", "k_rope_transposed"], + outputs=["present_key"], + name="Concat_k", + axis=2, + ) + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["present_key_transposed"], + name="Transpose_k_2", + perm=[0, 2, 3, 1], + ) + reshape_k_node = helper.make_node( + "Reshape", + inputs=["present_key_transposed", "concat_k_extra_out"], + outputs=["k"], + name="Reshape_k", + ) + return [ + k_cache_unsqueeze_node, + k_cache_slice_node, + transpose_k_1_node, + concat_k_node, + transpose_k_2_node, + reshape_k_node, + ] + + def create_k_path_hf(self, model_type: str): + reshape_k_node = helper.make_node( + "Reshape", + inputs=["k_matmul_out", "concat_k_extra_out"], + outputs=["k_reshaped"], + name="Reshape_k", + ) + transpose_k_1_node = helper.make_node( + "Transpose", + inputs=["k_reshaped"], + outputs=["k_before_rope"], + name="Transpose_k_1", + perm=[0, 2, 1, 3], + ) + k_nodes = [reshape_k_node, transpose_k_1_node] + + if model_type in {"past", "merged"}: + concat_k_node = helper.make_node( + "Concat", + inputs=["past_key", "k_rope"], + outputs=["present_key"], + axis=2, + ) + k_nodes.append(concat_k_node) + + transpose_k_2_node = helper.make_node( + "Transpose", + inputs=["present_key"], + outputs=["k"], + name="Transpose_k_2", + perm=[0, 1, 3, 2], + ) + return k_nodes + [transpose_k_2_node] # noqa: RUF005 + + def create_k_path(self, model_type: str): + if model_type == "llama2_msft": + return self.create_k_path_llama2_msft() + return self.create_k_path_hf(model_type) + + def create_attn_mask_path_llama2_msft(self): + x_shape_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape"], + name="Shape_input", + ) + x_get_seq_len_node = helper.make_node( + "Gather", + inputs=["input_0_shape", "one"], + outputs=["input_0_seq_len"], + name="Gather_input", + axis=0, + ) + x_new_seq_len_node = helper.make_node( + "Add", + inputs=["position_ids", "input_0_seq_len"], + outputs=["new_seq_len"], + name="Add_mask", + ) + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_mask_0_out"], + name="Unsqueeze_mask_0", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_1_out"], + name="Unsqueeze_mask_1", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["new_seq_len", "zero"], + outputs=["unsqueeze_mask_2_out"], + name="Unsqueeze_mask_2", + ) + slice_mask_1_node = helper.make_node( + "Slice", + inputs=["attn_mask", "unsqueeze_mask_0_out", "unsqueeze_mask_1_out", "one", "one"], + outputs=["slice_mask_1_out"], + name="Slice_mask_1", + ) + slice_mask_2_node = helper.make_node( + "Slice", + inputs=["slice_mask_1_out", "zero", "unsqueeze_mask_2_out", "two", "one"], + outputs=["slice_mask_2_out"], + name="Slice_mask_2", + ) + concat_mask_node = helper.make_node( + "Concat", + inputs=["slice_mask_2_out" for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + return [ + x_shape_node, + x_get_seq_len_node, + x_new_seq_len_node, + unsqueeze_0_node, + unsqueeze_1_node, + unsqueeze_2_node, + slice_mask_1_node, + slice_mask_2_node, + concat_mask_node, + ] + + def create_attn_mask_path_hf(self, model_type: str): + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=["attn_mask", "one"], + outputs=["unsqueeze_1_mask_out"], + name="Unsqueeze_1_mask", + ) + unsqueeze_2_node = helper.make_node( + "Unsqueeze", + inputs=["unsqueeze_1_mask_out", "two"], + outputs=["unsqueeze_2_mask_out"], + name="Unsqueeze_2_mask", + ) + expand_node = helper.make_node( + "Expand", + inputs=["unsqueeze_2_mask_out", "zero"], + outputs=["expand_out"], + name="Expand_mask", + ) + cast_node = helper.make_node( + "Cast", + inputs=["expand_out"], + outputs=["cast_out"], + name="Cast_mask", + to=TensorProto.FLOAT, + ) + sub_node = helper.make_node( + "Sub", + inputs=["one", "cast_out"], + outputs=["sub_out"], + name="Sub_mask", + ) + where_node = helper.make_node( + "Where", + inputs=["zero", "neg_int_max", "sub_out"], + outputs=["where_out" if model_type != "past" else "attn_mask_out"], + name="Where_mask", + ) + attn_mask_nodes = [unsqueeze_1_node, unsqueeze_2_node, expand_node, cast_node, sub_node, where_node] + + if model_type == "past": + return attn_mask_nodes + + add_node = helper.make_node( + "Add", + inputs=["where_out", "zero"], + outputs=["attn_mask_out"], + name="Add_mask", + ) + return attn_mask_nodes + [add_node] # noqa: RUF005 + + def create_attn_mask_path(self, is_fused: bool, model_type: str): + if model_type == "llama2_msft": + attn_mask_nodes = self.create_attn_mask_path_llama2_msft() + if is_fused: + attn_mask_nodes.pop() + attn_mask_nodes[-1].output[0] = "attn_mask_out" + return attn_mask_nodes + + attn_mask_nodes = self.create_attn_mask_path_hf(model_type) + if is_fused: + new_output_name = "attn_mask_out_mask" + attn_mask_nodes[-1].output[0] = new_output_name + concat_mask_node = helper.make_node( + "Concat", + inputs=[new_output_name for _ in range(self.num_heads)], + outputs=["attn_mask_out"], + name="Concat_mask", + axis=0, + ) + attn_mask_nodes.append(concat_mask_node) + return attn_mask_nodes + + def create_qk_path(self, model_type: str): + matmul_qk_node = helper.make_node( + "MatMul", + inputs=["q" if model_type == "llama2_msft" else "q_rope", "k"], + outputs=["qk"], + name="MatMul_q_k", + ) + div_node = helper.make_node( + "Div", + inputs=["qk", "sqrt_head_size"], + outputs=["qk_div"], + name="Div_0", + ) + add_node = helper.make_node( + "Add", + inputs=["qk_div", "attn_mask_out"], + outputs=["qk_plus_mask"], + name="Add_0", + ) + softmax_node = helper.make_node( + "Softmax", + inputs=["qk_plus_mask"], + outputs=["softmax_out"], + name="Softmax_0", + ) + return [matmul_qk_node, div_node, add_node, softmax_node] + + def create_v_path(self, model_type: str): + reshape_v_1_node = helper.make_node( + "Reshape", + inputs=["v_out", "concat_v_1_extra_out"], + outputs=["reshape_v_1_out"], + name="Reshape_v_1", + ) + transpose_v_1_node = helper.make_node( + "Transpose", + inputs=["reshape_v_1_out"], + outputs=["transpose_v_1_out" if model_type != "no_past" else "present_value"], + name="Transpose_v_1", + ) + v_nodes = [reshape_v_1_node, transpose_v_1_node] + + if model_type == "no_past": + return v_nodes + + if model_type in {"past", "merged"}: + concat_v_node = helper.make_node( + "Concat", + inputs=["past_value", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + return v_nodes + [concat_v_node] # noqa: RUF005 + + # Create extra nodes for `position_ids` + unsqueeze_v_node = helper.make_node( + "Unsqueeze", + inputs=["position_ids", "zero"], + outputs=["unsqueeze_v_out"], + name="Unsqueeze_v", + ) + slice_v_node = helper.make_node( + "Slice", + inputs=["past_value", "zero", "unsqueeze_v_out", "two", "one"], + outputs=["v_cache_sliced_out"], + name="Slice_v", + ) + concat_v_node = helper.make_node( + "Concat", + inputs=["v_cache_sliced_out", "transpose_v_1_out"], + outputs=["present_value"], + name="Concat_v", + axis=2, + ) + v_nodes.extend([unsqueeze_v_node, slice_v_node, concat_v_node]) + + # Create remaining nodes for v path + transpose_v_2_node = helper.make_node( + "Transpose", + inputs=["present_value"], + outputs=["transpose_v_2_out"], + name="Transpose_v_2", + ) + reshape_v_2_node = helper.make_node( + "Reshape", + inputs=["transpose_v_2_out", "concat_v_2_extra_out"], + outputs=["v"], + name="Reshape_v_2", + ) + return v_nodes + [transpose_v_2_node, reshape_v_2_node] # noqa: RUF005 + + def create_qkv_path(self, model_type: str): + matmul_qkv_node = helper.make_node( + "MatMul", + inputs=["softmax_out", "v" if model_type == "llama2_msft" else "present_value"], + outputs=["softmax_v_out"], + name="MatMul_softmax_v", + ) + qkv_nodes = [matmul_qkv_node] + + if model_type == "llama2_msft": + reshape_qkv_1_node = helper.make_node( + "Reshape", + inputs=["softmax_v_out", "concat_qkv_1_extra_out"], + outputs=["reshape_qkv_1_out"], + name="Reshape_qkv_1", + ) + qkv_nodes.append(reshape_qkv_1_node) + + transpose_qkv_node = helper.make_node( + "Transpose", + inputs=["reshape_qkv_1_out" if model_type == "llama2_msft" else "softmax_v_out"], + outputs=["transpose_qkv_out"], + name="Transpose_qkv", + ) + reshape_qkv_2_node = helper.make_node( + "Reshape", + inputs=["transpose_qkv_out", "concat_qkv_2_extra_out"], + outputs=["attn_output"], + name="Reshape_qkv_2", + ) + + return qkv_nodes + [transpose_qkv_node, reshape_qkv_2_node] # noqa: RUF005 + + def create_concat_unsqueeze_paths(self, model_type: str, reshape_nodes: List[NodeProto]): + # Create initial shape paths + shape_0_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_0"], + name="Shape_0", + ) + gather_0_node = helper.make_node( + "Gather", + inputs=["input_0_shape_0", "zero"], + outputs=["input_0_shape_0_indexed"], + name="Gather_0", + axis=0, + ) + shape_1_node = helper.make_node( + "Shape", + inputs=["input_0"], + outputs=["input_0_shape_1"], + name="Shape_1", + ) + gather_1_node = helper.make_node( + "Gather", + inputs=["input_0_shape_1", "one"], + outputs=["input_0_shape_1_indexed"], + name="Gather_1", + axis=0, + ) + extra_nodes = [shape_0_node, gather_0_node, shape_1_node, gather_1_node] + + if model_type == "llama2_msft": + mul_node = helper.make_node( + "Mul", + inputs=[gather_0_node.output[0], "num_heads"], + outputs=["mul_extra_out"], + name="Mul_extra_0", + ) + add_node = helper.make_node( + "Add", + inputs=[gather_1_node.output[0], "position_ids"], + outputs=["add_extra_out"], + name="Add_extra_0", + ) + extra_nodes.extend([mul_node, add_node]) + + for i, reshape_node in enumerate(reshape_nodes): + use_mul_and_add_nodes_0 = model_type == "llama2_msft" and reshape_node.output[0] in {"q", "k", "v"} + use_mul_and_add_nodes_1 = model_type == "llama2_msft" and reshape_node.output[0] in {"k", "v"} + + unsqueeze_0_node = helper.make_node( + "Unsqueeze", + inputs=[gather_0_node.output[0] if not use_mul_and_add_nodes_0 else "mul_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i}"], + name=f"Unsqueeze_extra_{2*i}", + ) + unsqueeze_1_node = helper.make_node( + "Unsqueeze", + inputs=[gather_1_node.output[0] if not use_mul_and_add_nodes_1 else "add_extra_out", "zero"], + outputs=[f"unsqueeze_extra_{2*i + 1}"], + name=f"Unsqueeze_extra_{2*i + 1}", + ) + + reshape_name = reshape_node.name + if reshape_name == "Reshape_qkv_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "hidden_size"] + elif reshape_name == "Reshape_qkv_1": + concat_node_inputs = [unsqueeze_0_node.output[0], "num_heads", unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_2": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + elif reshape_name == "Reshape_v_1": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "num_heads", "head_size"] + elif reshape_name == "Reshape_k": + concat_node_inputs = [unsqueeze_0_node.output[0], "head_size", unsqueeze_1_node.output[0]] + elif reshape_name == "Reshape_q": + concat_node_inputs = [unsqueeze_0_node.output[0], unsqueeze_1_node.output[0], "head_size"] + + concat_node = helper.make_node( + "Concat", + inputs=concat_node_inputs, + outputs=[reshape_nodes[i].input[1]], + name=f"Concat_extra_{i}", + axis=0, + ) + extra_nodes.extend([unsqueeze_0_node, unsqueeze_1_node, concat_node]) + + return extra_nodes + + def create_end_nodes(self): + matmul_o_node = helper.make_node( + "MatMul", + inputs=["attn_output", "o_weight"], + outputs=["output_proj"], + name="MatMul_o_proj", + ) + end_node = helper.make_node( + "Add", + inputs=["zero", "output_proj"], + outputs=["output_0"], + name="Add_normalize_node", + ) + return [matmul_o_node, end_node] + + def create_fused_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(True, model_type=model_type) + rope_nodes = self.create_rotary_embeddings(True, model_type, interleaved, inputs, initializers) + attn_mask_nodes = self.create_attn_mask_path(True, model_type) + + mha_inputs = [ + rope_nodes[0].output[0], # q + rope_nodes[1].output[0], # k + matmul_nodes[-1].output[0], # v + "", # bias + "attn_mask_out" if model_type == "llama2_msft" else "", # attn_mask + "attn_mask_out" if model_type != "llama2_msft" else "", # add_qk + "past_key" if model_type != "no_past" else "", # past_key + "past_value" if model_type != "no_past" else "", # past_value + ] + mha_node = helper.make_node( + "MultiHeadAttention", + inputs=mha_inputs, + outputs=["attn_output", "present_key", "present_value"], + name="MultiHeadAttention_0", + num_heads=self.num_heads, + ) + + end_nodes = self.create_end_nodes() + + graph = helper.make_graph( + nodes=matmul_nodes + rope_nodes + attn_mask_nodes + [mha_node] + end_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def create_test_model(self, model_type: str, interleaved: bool, initializers: List[TensorProto]): + inputs, outputs = self.create_inputs_and_outputs(model_type) + matmul_nodes = self.create_matmul_nodes(False, model_type) + rope_nodes = self.create_rotary_embeddings(False, model_type, interleaved, inputs, initializers) + + # Create main paths + q_nodes = self.create_q_path(model_type) + k_nodes = self.create_k_path(model_type) + attn_mask_nodes = self.create_attn_mask_path(False, model_type) + qk_nodes = self.create_qk_path(model_type) + v_nodes = self.create_v_path(model_type) + qkv_nodes = self.create_qkv_path(model_type) + + reshape_nodes = list(filter(lambda node: node.op_type == "Reshape", q_nodes + k_nodes + v_nodes + qkv_nodes)) + extra_nodes = self.create_concat_unsqueeze_paths(model_type, reshape_nodes) + + end_nodes = self.create_end_nodes() + + first_set_of_nodes = matmul_nodes + rope_nodes + q_nodes + k_nodes + attn_mask_nodes + second_set_of_nodes = qk_nodes + v_nodes + qkv_nodes + extra_nodes + end_nodes + graph = helper.make_graph( + nodes=first_set_of_nodes + second_set_of_nodes, + name="RotaryAttention_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="ai.onnx", version=17) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, model_type: str, interleaved: bool): + initializers = self.create_initializers() + + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(model_type, interleaved, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(model_type, interleaved, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + def test_llama2_msft_model(self): + model_type = "llama2_msft" + interleaved = True + self.check_models(model_type, interleaved) + + def test_hf_decoder_model(self): + model_type = "no_past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_with_past_model(self): + model_type = "past" + interleaved = False + self.check_models(model_type, interleaved) + + def test_hf_decoder_merged_model(self): + model_type = "merged" + interleaved = False + self.check_models(model_type, interleaved) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py new file mode 100644 index 0000000000000..e86bdda7baffb --- /dev/null +++ b/onnxruntime/test/python/transformers/test_simplified_layernorm_fusion.py @@ -0,0 +1,243 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import os +import unittest +from typing import List + +import numpy as np +import onnx +from onnx import TensorProto, helper +from parity_utilities import find_transformers_source + +if find_transformers_source(): + from fusion_options import FusionOptions + from onnx_model import OnnxModel + from optimizer import optimize_model +else: + from onnxruntime.transformers.fusion_options import FusionOptions + from onnxruntime.transformers.onnx_model import OnnxModel + from onnxruntime.transformers.optimizer import optimize_model + + +def float_tensor(name: str, shape: List[int], random=False): + low = 0.0 + high = 1.0 + total_elements = 1 + for x in shape: + total_elements *= x + weights = [np.random.uniform(low, high) for _ in range(total_elements)] if random else [1.0] * total_elements + return helper.make_tensor(name, TensorProto.FLOAT, shape, weights) + + +class TestSimplifiedLayerNormFusion(unittest.TestCase): + def setUp(self): + self.vocab_size = 5 + self.batch_size = 2 + self.sequence_length = 8 + self.hidden_size = 16 + self.epsilon = 0.000009999999747378752 + + def verify_fusion(self, expected_model_path, original_model_path): + expected_model = OnnxModel(onnx.load(expected_model_path)) + expected_model.topological_sort(is_deterministic=True) + + options = FusionOptions("gpt2") + optimized_model = optimize_model(original_model_path, optimization_options=options) + optimized_model.topological_sort(is_deterministic=True) + + self.assertTrue(str(expected_model.model.graph), str(optimized_model.model.graph)) + + def create_initializers(self, use_embed_weight: bool = False): + initializers = [ + helper.make_tensor("Two", TensorProto.FLOAT, [1], np.array([2], dtype=np.float32)), + helper.make_tensor("epsilon", TensorProto.FLOAT, [1], np.array([self.epsilon], dtype=np.float32)), + helper.make_tensor("One", TensorProto.FLOAT, [1], np.array([1], dtype=np.float32)), + float_tensor("scale", [self.hidden_size]), + ] + if use_embed_weight: + initializers = [ # noqa: RUF005 + float_tensor("embed_weight", [self.vocab_size, self.hidden_size]) + ] + initializers + return initializers + + def create_inputs_and_outputs(self, start_node_type: str): + inputs, start_node = None, None + if start_node_type == "Add": + start_node = helper.make_node( + "Add", + inputs=["input_0", "input_1"], + outputs=["D"], + name="Add_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + input_1 = helper.make_tensor_value_info( + "input_1", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0, input_1] + elif start_node_type == "Gather": + start_node = helper.make_node( + "Gather", + inputs=["embed_weight", "input_0"], + outputs=["D"], + name="Gather_0", + ) + input_0 = helper.make_tensor_value_info( + "input_0", + TensorProto.INT64, + [self.batch_size, self.sequence_length], + ) + inputs = [input_0] + else: + # start_node_type is a graph input + assert start_node_type == "GraphInput" + input_0 = helper.make_tensor_value_info( + "D", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + inputs = [input_0] + + outputs = [ + helper.make_tensor_value_info( + "output_0", + TensorProto.FLOAT, + [self.batch_size, self.sequence_length, self.hidden_size], + ) + ] + return inputs, outputs, start_node + + def create_fused_model(self, start_node_type: str, initializers: List[TensorProto]): + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + sln_node = helper.make_node( + "SimplifiedLayerNormalization", + inputs=[start_node.output[0] if start_node is not None else "D", initializers[0].name], + outputs=[outputs[0].name], + axis=-1, + epsilon=initializers[2].float_data[0], + stash_type=1, + ) + + graph = helper.make_graph( + nodes=[sln_node] + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + # Notation follows https://onnx.ai/onnx/operators/onnx__LayerNormalization.html#summary + def create_test_model(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + end_node = helper.make_node( + "Mul", + inputs=["scale", "Normalized"] if first_parent_idx == 1 else ["Normalized", "scale"], + outputs=["output_0"], + name="Mul_1", + ) + mul_node = helper.make_node( + "Mul", + inputs=["D", "InvStdDev"], + outputs=["Normalized"], + name="Mul_0", + ) + div_node = helper.make_node( + "Div", + inputs=["One", "StdDev"], + outputs=["InvStdDev"], + name="Div_0", + ) + sqrt_node = helper.make_node( + "Sqrt", + inputs=["VarEps"], + outputs=["StdDev"], + name="Sqrt_0", + ) + add_node = helper.make_node( + "Add", + inputs=["Var", "epsilon"], + outputs=["VarEps"], + name="Add_1", + ) + reducemean_node = helper.make_node( + "ReduceMean", + inputs=["DD"], + outputs=["Var"], + name="ReduceMean_0", + ) + pow_node = helper.make_node( + "Pow", + inputs=["D", "Two"], + outputs=["DD"], + name="Pow_0", + ) + + inputs, outputs, start_node = self.create_inputs_and_outputs(start_node_type) + + main_nodes = [pow_node, reducemean_node, add_node, sqrt_node, div_node, mul_node, end_node] + graph = helper.make_graph( + nodes=main_nodes + ([] if start_node is None else [start_node]), + name="SimplifiedLayerNorm_Graph", + inputs=inputs, + outputs=outputs, + initializer=initializers, + ) + opset_import = helper.make_opsetid(domain="com.microsoft", version=1) + model = helper.make_model(graph, opset_imports=[opset_import]) + return model + + def check_models(self, start_node_type: str, first_parent_idx: int, initializers: List[TensorProto]): + expected_model_filename = "expected_model.onnx" + expected_model = self.create_fused_model(start_node_type, initializers) + onnx.save(expected_model, expected_model_filename) + + original_model_filename = "original_model.onnx" + original_model = self.create_test_model(start_node_type, first_parent_idx, initializers) + onnx.save(original_model, original_model_filename) + + self.verify_fusion(expected_model_filename, original_model_filename) + os.remove(expected_model_filename) + os.remove(original_model_filename) + + # sim_ln_nodes_1 + def test_simplified_layernorm_add_idx1(self): + start_node_type = "Add" + first_parent_idx = 1 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_2 + def test_simplified_layernorm_gather_idx1(self): + start_node_type = "Gather" + first_parent_idx = 1 + initializers = self.create_initializers(use_embed_weight=True) + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_3 + def test_simplified_layernorm_add_idx0(self): + start_node_type = "Add" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + # sim_ln_nodes_4 + def test_simplified_layernorm_gather_graph_input(self): + start_node_type = "GraphInput" + first_parent_idx = 0 + initializers = self.create_initializers() + self.check_models(start_node_type, first_parent_idx, initializers) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxruntime/test/python/transformers/test_whisper.py b/onnxruntime/test/python/transformers/test_whisper.py index ebda0bccaadcf..ceda5a88c3925 100644 --- a/onnxruntime/test/python/transformers/test_whisper.py +++ b/onnxruntime/test/python/transformers/test_whisper.py @@ -50,7 +50,7 @@ def verify_fusion(self, optimized_model, expected_model_filename): ) ) - # Attention type #1 in onnx_model_bart.py + # Attention type #1 in fusion_bart_attention.py def test_encoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -67,7 +67,7 @@ def test_encoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "encoder_attention_with_sln_fused.onnx") - # Attention type #2 in onnx_model_bart.py + # Attention type #2 in fusion_bart_attention.py def test_decoder_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -84,7 +84,7 @@ def test_decoder_attention_fusion_with_skiplayernorm(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_attention_with_sln_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -100,7 +100,7 @@ def test_decoder_multihead_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -118,7 +118,7 @@ def test_decoder_with_past_multihead_self_attention_fusion_with_skiplayernorm(se os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_fusion(self): num_heads = 4 hidden_size = 64 @@ -134,7 +134,7 @@ def test_decoder_with_past_multihead_cross_attention_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_cross_mha_fused.onnx") - # Attention type #4 in onnx_model_bart.py + # Attention type #4 in fusion_bart_attention.py def test_decoder_multihead_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64 @@ -151,7 +151,7 @@ def test_decoder_multihead_attention_split_bias_fusion(self): os.remove(model_path) self.verify_fusion(optimized_model, "decoder_mha_split_bias_fused.onnx") - # Attention type #3 in onnx_model_bart.py + # Attention type #3 in fusion_bart_attention.py def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skiplayernorm(self): num_heads = 4 hidden_size = 64 @@ -171,7 +171,7 @@ def test_decoder_with_past_multihead_self_attention_split_bias_fusion_with_skipl os.remove(model_path) self.verify_fusion(optimized_model, "decoder_with_past_self_mha_split_bias_fused.onnx") - # Attention type #5 in onnx_model_bart.py + # Attention type #5 in fusion_bart_attention.py def test_decoder_with_past_multihead_cross_attention_split_bias_fusion(self): num_heads = 4 hidden_size = 64 diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 565bf068e7abd..ba282193c5ca6 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -2832,6 +2832,132 @@ TEST(CApiTest, ConfigureCudaArenaAndDemonstrateMemoryArenaShrinkage) { #endif #ifdef USE_TENSORRT +TEST(CApiTest, TestExternalCUDAStreamWithIOBinding) { + const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + OrtTensorRTProviderOptionsV2* trt_options; + ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); + std::unique_ptr + rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); + + // updating provider option with user provided compute stream + cudaStream_t compute_stream = nullptr; + void* user_compute_stream = nullptr; + cudaStreamCreate(&compute_stream); + ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr); + ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); + ASSERT_TRUE(user_compute_stream == (void*)compute_stream); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( + static_cast(session_options), + rel_trt_options.get()) == nullptr); + + Ort::Session session(*ort_env, MODEL_URI, session_options); + Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + + /* + * Use cudaMallocHost() (pinned memory allocation) to create input/output tensors + */ + float* input_data; + cudaMallocHost(&input_data, 3 * 2 * sizeof(float)); + ASSERT_NE(input_data, nullptr); + cudaMemcpy(input_data, x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + + std::cout << "pinned memory allocation" << std::endl; + std::cout << "input tesnor:" << std::endl; + for (int i = 0; i < 6; i++) { + std::cout << input_data[i] << std::endl; + } + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data), x_values.size(), + x_shape.data(), x_shape.size()); + + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + float* output_data; + cudaMallocHost(&output_data, 3 * 2 * sizeof(float)); + ASSERT_NE(output_data, nullptr); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data), + expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); + + // Create IoBinding for inputs and outputs. + Ort::IoBinding binding(session); + binding.BindInput("X", bound_x); + binding.BindOutput("Y", bound_y); + + /* + * Use cudaMalloc() (pageable memory allocation first and then implicit pinned memory allocation) to create input/output tensors + */ + float* input_data_2; + cudaMalloc(&input_data_2, 3 * 2 * sizeof(float)); + ASSERT_NE(input_data_2, nullptr); + cudaMemcpy(input_data_2, x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_x_2 = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data_2), x_values.size(), + x_shape.data(), x_shape.size()); + + float* output_data_2; + cudaMalloc(&output_data_2, 3 * 2 * sizeof(float)); + ASSERT_NE(output_data_2, nullptr); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_y_2 = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data_2), + expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); + + // Create IoBinding for inputs and outputs. + Ort::IoBinding binding_2(session); + binding_2.BindInput("X", bound_x_2); + binding_2.BindOutput("Y", bound_y_2); + + // Run with first iobindings + session.Run(Ort::RunOptions(), binding); + + // Check the values against the bound raw memory (needs copying from device to host first) + std::array y_values; + cudaMemcpy(y_values.data(), output_data, sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + + std::cout << "pinned memory allocation" << std::endl; + std::cout << "output: " << std::endl; + for (auto y : y_values) { + std::cout << y << std::endl; + } + ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); + + // Run with second iobindings + session.Run(Ort::RunOptions(), binding_2); + + // Check the values against the bound raw memory (needs copying from device to host first) + cudaMemcpy(y_values.data(), output_data_2, sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + + std::cout << "pageable memory allocation" << std::endl; + std::cout << "output: " << std::endl; + for (auto y : y_values) { + std::cout << y << std::endl; + } + ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); + + // Clean up + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); + binding_2.ClearBoundInputs(); + binding_2.ClearBoundOutputs(); + + cudaFreeHost(input_data); + cudaFreeHost(output_data); + cudaFree(input_data_2); + cudaFree(output_data_2); + cudaStreamDestroy(compute_stream); +} + class CApiTensorRTTest : public testing::Test, public ::testing::WithParamInterface {}; // This test uses CreateTensorRTProviderOptions/UpdateTensorRTProviderOptions APIs to configure and create a TensorRT Execution Provider @@ -2849,15 +2975,6 @@ TEST_P(CApiTensorRTTest, TestConfigureTensorRTProviderOptions) { ASSERT_TRUE(api.CreateTensorRTProviderOptions(&trt_options) == nullptr); std::unique_ptr rel_trt_options(trt_options, api.ReleaseTensorRTProviderOptions); - // Only test updating provider option with user provided compute stream - cudaStream_t compute_stream = nullptr; - void* user_compute_stream = nullptr; - cudaStreamCreateWithFlags(&compute_stream, cudaStreamNonBlocking); - ASSERT_TRUE(api.UpdateTensorRTProviderOptionsWithValue(rel_trt_options.get(), "user_compute_stream", compute_stream) == nullptr); - ASSERT_TRUE(api.GetTensorRTProviderOptionsByName(rel_trt_options.get(), "user_compute_stream", &user_compute_stream) == nullptr); - ASSERT_TRUE(user_compute_stream == (void*)compute_stream); - cudaStreamDestroy(compute_stream); - const char* engine_cache_path = "./trt_engine_folder"; std::vector keys{"device_id", "has_user_compute_stream", "trt_fp16_enable", "trt_int8_enable", "trt_engine_cache_enable", diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc index aba35b33b75c6..3d561d378cb8c 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.cc @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#ifdef USE_CUDA +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) #define ORT_API_MANUAL_INIT #include "onnxruntime_cxx_api.h" @@ -32,6 +32,9 @@ void KernelOne(const Ort::Custom::CudaContext& cuda_ctx, CUSTOM_ENFORCE(cuda_ctx.cuda_stream, "failed to fetch cuda stream"); CUSTOM_ENFORCE(cuda_ctx.cudnn_handle, "failed to fetch cudnn handle"); CUSTOM_ENFORCE(cuda_ctx.cublas_handle, "failed to fetch cublas handle"); + void* deferred_cpu_mem = cuda_ctx.AllocDeferredCpuMem(sizeof(int32_t)); + CUSTOM_ENFORCE(deferred_cpu_mem, "failed to allocate deferred cpu allocator"); + cuda_ctx.FreeDeferredCpuMem(deferred_cpu_mem); auto z_raw = Z.Allocate(input_shape); cuda_add(Z.NumberOfElement(), z_raw, X.Data(), Y.Data(), cuda_ctx.cuda_stream); } @@ -43,8 +46,4 @@ void RegisterOps(Ort::CustomOpDomain& domain) { } // namespace Cuda -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} - #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h index c0287c4932c98..35cd36fcd4cb7 100644 --- a/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/cuda/cuda_ops.h @@ -5,6 +5,14 @@ namespace Cuda { +#if defined(USE_CUDA) && !defined(ENABLE_TRAINING) + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Cuda \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc index 40fb127eb0b8f..2d5ffc3c81b0f 100644 --- a/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc +++ b/onnxruntime/test/testdata/custom_op_library/custom_op_library.cc @@ -13,6 +13,8 @@ #include "core/framework/ortdevice.h" #include "core/framework/ortmemoryinfo.h" #include "cpu/cpu_ops.h" +#include "cuda/cuda_ops.h" +#include "rocm/rocm_ops.h" #include "onnxruntime_lite_custom_op.h" static const char* c_OpDomain = "test.customop"; @@ -31,10 +33,15 @@ OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtA ORT_TRY { Ort::CustomOpDomain domain{c_OpDomain}; Cpu::RegisterOps(domain); - Ort::CustomOpDomain domain_v2{"v2"}; Cpu::RegisterOps(domain_v2); + Cuda::RegisterOps(domain); + Cuda::RegisterOps(domain_v2); + + Rocm::RegisterOps(domain); + Rocm::RegisterOps(domain_v2); + Ort::UnownedSessionOptions session_options(options); session_options.Add(domain); session_options.Add(domain_v2); diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc index 113bfb85454a2..069246b4201e7 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.cc @@ -19,7 +19,7 @@ using namespace Ort::Custom; throw std::runtime_error(msg); \ } -namespace Cuda { +namespace Rocm { void KernelOne(const Ort::Custom::RocmContext& rocm_ctx, const Ort::Custom::Tensor& X, @@ -38,10 +38,6 @@ void RegisterOps(Ort::CustomOpDomain& domain) { domain.Add(c_CustomOpOne.get()); } -} // namespace Cuda - -#else - -void Cuda::RegisterOps(Ort::CustomOpDomain& domain) {} +} // namespace Rocm #endif \ No newline at end of file diff --git a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h index 4e8958cd9dae0..d3e9e4040a5c3 100644 --- a/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h +++ b/onnxruntime/test/testdata/custom_op_library/rocm/rocm_ops.h @@ -5,6 +5,14 @@ namespace Rocm { +#ifdef USE_ROCM + void RegisterOps(Ort::CustomOpDomain& domain); -} \ No newline at end of file +#else + +inline void RegisterOps(Ort::CustomOpDomain&) {} + +#endif + +} // namespace Rocm diff --git a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc index c142106ed506c..44db7c0078cfc 100644 --- a/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc +++ b/onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc @@ -283,14 +283,22 @@ "^test_dft_axis", "^test_dft", "^test_dft_inverse", - "^test_isinf", - "^test_isinf_float16", - "^test_isinf_negative", - "^test_isinf_positive", - "^test_isnan", - "^test_isnan_float16", "^test_reduce_max_bool_inputs", - "^test_reduce_min_bool_inputs" + "^test_reduce_min_bool_inputs", + "^test_reduce_min_empty_set", + "^test_reduce_l1_empty_set", + "^test_reduce_l1_empty_set_expanded", + "^test_reduce_l2_empty_set", + "^test_reduce_l2_empty_set_expanded", + "^test_reduce_log_sum_empty_set", + "^test_reduce_log_sum_empty_set_expanded", + "^test_reduce_log_sum_exp_empty_set", + "^test_reduce_log_sum_exp_empty_set_expanded", + "^test_reduce_prod_empty_set", + "^test_reduce_sum_empty_set", + "^test_reduce_sum_empty_set_non_reduced_axis_zero", + "^test_reduce_sum_square_empty_set", + "^test_reduce_sum_square_empty_set_expanded" ], "current_failing_tests_x86": [ "^test_vgg19", diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx new file mode 100644 index 0000000000000..fa11adaac8d95 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-directly.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx new file mode 100644 index 0000000000000..1050a7285b4a6 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-non-ignorable-node.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx new file mode 100644 index 0000000000000..c361a42700a30 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-reshape.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx new file mode 100644 index 0000000000000..f70ae2e6229e7 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-only-transpose.onnx differ diff --git a/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx new file mode 100644 index 0000000000000..8e4bc49514548 Binary files /dev/null and b/onnxruntime/test/testdata/transform/fusion/fuse-matmul-bn-with-reshape.onnx differ diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 28af61e15b2b5..e224507bc740e 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -268,7 +268,7 @@ std::unique_ptr DefaultCannExecutionProvider() { std::unique_ptr DefaultDmlExecutionProvider() { #ifdef USE_DML - if (auto factory = DMLProviderFactoryCreator::Create(0)) + if (auto factory = DMLProviderFactoryCreator::Create(0, false, false, false)) return factory->CreateProvider(); #endif return nullptr; diff --git a/onnxruntime/test/util/include/inference_session_wrapper.h b/onnxruntime/test/util/include/inference_session_wrapper.h index eab83c26b681f..757caf7987d35 100644 --- a/onnxruntime/test/util/include/inference_session_wrapper.h +++ b/onnxruntime/test/util/include/inference_session_wrapper.h @@ -12,9 +12,8 @@ namespace test { // InferenceSession wrapper class for use in tests where we need access to the Graph and SessionState class InferenceSessionWrapper : public InferenceSession { public: - explicit InferenceSessionWrapper(const SessionOptions& session_options, - const Environment& env) : InferenceSession(session_options, env) { - } + // Expose the constructors from InferenceSession + using InferenceSession::InferenceSession; const Graph& GetGraph() const { return model_->MainGraph(); diff --git a/onnxruntime/wasm/api.cc b/onnxruntime/wasm/api.cc index 968eece361724..0e58bb4f93f7f 100644 --- a/onnxruntime/wasm/api.cc +++ b/onnxruntime/wasm/api.cc @@ -493,6 +493,14 @@ char* OrtEndProfiling(ort_session_handle_t session) { #define CHECK_TRAINING_STATUS(ORT_API_NAME, ...) \ CheckStatus(Ort::GetTrainingApi().ORT_API_NAME(__VA_ARGS__)) +#define RETURN_TRAINING_ERROR_CODE_IF_ERROR(ORT_API_NAME, ...) \ + do { \ + int error_code = CHECK_TRAINING_STATUS(ORT_API_NAME, __VA_ARGS__); \ + if (error_code != ORT_OK) { \ + return error_code; \ + } \ + } while (false) + ort_training_checkpoint_handle_t EMSCRIPTEN_KEEPALIVE OrtTrainingLoadCheckpoint(void* checkpoint_data_buffer, size_t checkpoint_size) { OrtCheckpointState* checkpoint_state = nullptr; @@ -571,6 +579,57 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio return CHECK_TRAINING_STATUS(CopyBufferToParameters, training_handle, parameters_buffer, trainable_only); } +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count, + bool isEvalModel) { + if (isEvalModel) { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetEvalModelOutputCount, training_handle, output_count); + return ORT_OK; + } else { + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelInputCount, training_handle, input_count); + RETURN_TRAINING_ERROR_CODE_IF_ERROR(TrainingSessionGetTrainingModelOutputCount, training_handle, output_count); + return ORT_OK; + } +} + +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel) { + OrtAllocator* allocator = nullptr; + RETURN_NULLPTR_IF_ERROR(GetAllocatorWithDefaultOptions, &allocator); + + char* name = nullptr; + + if (isEvalModel) { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetEvalModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } + } else { + if (isInput) { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelInputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } else { + return (CHECK_TRAINING_STATUS(TrainingSessionGetTrainingModelOutputName, training_handle, index, + allocator, &name) == ORT_OK) + ? name + : nullptr; + } + } +} + void EMSCRIPTEN_KEEPALIVE OrtTrainingReleaseSession(ort_training_session_handle_t training_handle) { Ort::GetTrainingApi().ReleaseTrainingSession(training_handle); } diff --git a/onnxruntime/wasm/api.h b/onnxruntime/wasm/api.h index 9a0664697f0ff..2cd1515d191c8 100644 --- a/onnxruntime/wasm/api.h +++ b/onnxruntime/wasm/api.h @@ -432,6 +432,35 @@ int EMSCRIPTEN_KEEPALIVE OrtTrainingCopyParametersFromBuffer(ort_training_sessio size_t parameter_count, bool trainable_only); +/** + * Gets the input count and output count of the training or eval model associated with the given training handle. + * @param traning_handle handle of the traning session + * @param input_count [out] a pointer to a size_t variable to accept input_count + * @param output_count [out] a pointer to a size_t variable to accept output_count + * @param isEvalModel when false, returns input & output count of the training model. When true, returns input & output + * count of the eval model. + * @returns ORT error code. If not zero, call OrtGetLastError() to get a detailed error message. + */ +int EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputCount(ort_training_session_handle_t training_handle, + size_t* input_count, + size_t* output_count, + bool isEvalModel); + +/** + * Gets the input or output name at the specified index associated with the training or eval model from the + * given training session. + * @param traning_handle handle of the traning session + * @param index the input or output index + * @param isInput if true, this method retrieves an input name. If false, this method retrieves an output name. + * @param isEvalModel when false, returns input & output names of the training model. When true, returns input & output + * names of the eval model. + * @returns a pointer to a buffer which contains C-style string. Caller must release the C style string after use by + */ +char* EMSCRIPTEN_KEEPALIVE OrtTrainingGetModelInputOutputName(ort_training_session_handle_t training_handle, + size_t index, + bool isInput, + bool isEvalModel); + /** * @brief Release the specified ORT training session. * diff --git a/orttraining/orttraining/core/graph/gradient_builder.cc b/orttraining/orttraining/core/graph/gradient_builder.cc index 133cab71f2b1c..6547f53a3c2ae 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.cc +++ b/orttraining/orttraining/core/graph/gradient_builder.cc @@ -2147,5 +2147,13 @@ IMPLEMENT_GRADIENT_BUILDER(GetScaledSumGradient) { ORT_THROW("ScaledSum gradient builder does not support ", input_count, " inputs"); } +IMPLEMENT_GRADIENT_BUILDER(GetResizeGradient) { + return std::vector{ + NodeDef(OpDef{"ResizeGrad", kMSDomain, 1}, + {GO(0), I(0), I(1), I(2)}, + {GI(0)}, + SrcNodeAttributes())}; +} + } // namespace training } // namespace onnxruntime diff --git a/orttraining/orttraining/core/graph/gradient_builder.h b/orttraining/orttraining/core/graph/gradient_builder.h index a517e8af13fcc..28a316261e2f6 100755 --- a/orttraining/orttraining/core/graph/gradient_builder.h +++ b/orttraining/orttraining/core/graph/gradient_builder.h @@ -90,6 +90,7 @@ DECLARE_GRADIENT_BUILDER(GetGRUGradient) DECLARE_GRADIENT_BUILDER(GetReciprocalGradient) DECLARE_GRADIENT_BUILDER(GetLeakyReluGradient) DECLARE_GRADIENT_BUILDER(GetConvTransposeGradient) +DECLARE_GRADIENT_BUILDER(GetResizeGradient) DECLARE_GRADIENT_BUILDER(GetExternalGradient) diff --git a/orttraining/orttraining/core/graph/gradient_builder_registry.cc b/orttraining/orttraining/core/graph/gradient_builder_registry.cc index 4062b5d097394..4b8c68aef078a 100755 --- a/orttraining/orttraining/core/graph/gradient_builder_registry.cc +++ b/orttraining/orttraining/core/graph/gradient_builder_registry.cc @@ -122,6 +122,7 @@ void GradientBuilderRegistry::RegisterGradientBuilders() { REGISTER_GRADIENT_BUILDER("Reciprocal", GetReciprocalGradient); REGISTER_GRADIENT_BUILDER("LeakyRelu", GetLeakyReluGradient); REGISTER_GRADIENT_BUILDER("ConvTranspose", GetConvTransposeGradient); + REGISTER_GRADIENT_BUILDER("Resize", GetResizeGradient); REGISTER_GRADIENT_BUILDER("ExternalGradient", GetExternalGradient); }; diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index cfc79455c43ed..c90acfdb7bb78 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -5001,6 +5001,26 @@ Return true if all elements are true and false otherwise. "T", {"tensor(float16)", "tensor(float)", "tensor(double)"}, "Constrain input and output types to float tensors."); + + ONNX_CONTRIB_OPERATOR_SCHEMA(ResizeGrad) + .SetDomain(kMSDomain) + .SinceVersion(1) + .Input(0, "dY", "Gradient of output Y.", "T") + .Input(1, "X", "Input tensor to the Resize operator.", "T") + .Input(2, "roi", "The roi input to the Resize operator.", "T", OpSchema::Optional) + .Input(3, "scales", "The scales input to the Resize operator.", "tensor(float)", OpSchema::Optional) + .Output(0, "dX", "Gradient of the input X.", "T") + .AllowUncheckedAttributes() + .TypeConstraint( + "T", + {"tensor(float16)", "tensor(float)", "tensor(double)"}, + "Constrain input and output types to float tensors.") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromInputToOutput(ctx, 1, 0); + if (hasInputShape(ctx, 1)) { + propagateShapeFromInputToOutput(ctx, 1, 0); + } + }); } } // namespace training diff --git a/orttraining/orttraining/python/training/optim/_ds_code_store.py b/orttraining/orttraining/python/training/optim/_ds_code_store.py new file mode 100644 index 0000000000000..dc1e20bc3dcff --- /dev/null +++ b/orttraining/orttraining/python/training/optim/_ds_code_store.py @@ -0,0 +1,81 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# +# Copyright 2020 The Microsoft DeepSpeed Team +# +# !!!IMPORTANT: This file is a copy of the original one in DeepSpeed repo at given version, +# It is used to compare with the source code of current installed DeepSpeed during runtime. +# Please don't modify it or do any code formatting for it. +# 'orttraining/orttraining/python/training/optim/_ds_code_store.py' is removed from lintrunner config by intention. +# -------------------------------------------------------------------------- + +# Wrap code in this to make sure the indentation is correct compared with raw DeepSpeed. + +class Stage1And2_DeepSpeedZeroOptimizer_0_9_2: + + def has_overflow_serial(self, params, is_grad_list=False): + for p in params: + if p.grad is not None and self._has_inf_or_nan(p.grad.data): + return True + + return False + + + def get_grad_norm_direct(self, gradients, params, norm_type=2): + """Clips gradient norm of an iterable of parameters. + + This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and + added functionality to handle model parallel parameters. Note that + the gradients are modified in place. + + Arguments: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + """ + norm_type = float(norm_type) + if norm_type == inf: + total_norm = max(g.data.abs().max() for g in gradients) + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=self.dp_process_group) + + # Take max across all GPUs. + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.MAX) + total_norm = total_norm_cuda[0].item() + else: + total_norm = 0.0 + # if dist.get_rank() == 0: + # logger.info(f"Total Norm beginning {total_norm}") + for g, p in zip(gradients, params): + # Pipeline parallelism may replicate parameters. Avoid multi-counting. + if hasattr(p, PIPE_REPLICATED) and p.ds_pipe_replicated: + continue + if is_model_parallel_parameter(p) or (self.model_parallel_rank == 0): + param_norm = g.data.double().norm(2) + total_norm += param_norm.item()**2 + # Sum across all model parallel GPUs. + total_norm_cuda = get_accelerator().FloatTensor([float(total_norm)]) + dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.SUM, group=self.dp_process_group) + + self._model_parallel_all_reduce(tensor=total_norm_cuda, op=dist.ReduceOp.SUM) + + total_norm = total_norm_cuda[0].item()**(1. / norm_type) + + if total_norm == float('inf') or total_norm == -float('inf') or total_norm != total_norm: + total_norm = -1 + + return total_norm + + + def has_overflow_partitioned_grads_serial(self): + for i in range(len(self.bit16_groups)): + for j, grad in enumerate(self.averaged_gradients[i]): + if grad is not None and self._has_inf_or_nan(grad.data, j): + return True + return False diff --git a/orttraining/orttraining/python/training/optim/_ds_modifier.py b/orttraining/orttraining/python/training/optim/_ds_modifier.py index 6b1c98cc02a52..20f4f814e5476 100644 --- a/orttraining/orttraining/python/training/optim/_ds_modifier.py +++ b/orttraining/orttraining/python/training/optim/_ds_modifier.py @@ -10,6 +10,9 @@ # - has_overflow_partitioned_grads_serial : https://github.com/microsoft/DeepSpeed/blob/d8e9ef6f99e27bb95e10bd146d145b3372b4cfda/deepspeed/runtime/zero/stage2.py#L1799 # -------------------------------------------------------------------------- +from __future__ import annotations + +import inspect import types import warnings @@ -17,12 +20,69 @@ from numpy import inf from packaging.version import Version +from ._ds_code_store import Stage1And2_DeepSpeedZeroOptimizer_0_9_2 from ._modifier import FP16OptimizerModifier, check_overflow, check_overflow_for_grads from ._multi_tensor_apply import MultiTensorApply multi_tensor_applier = MultiTensorApply(2048 * 32) +def _get_normalized_str(function) -> str: + return inspect.getsource(function) + + +def _dynamic_checks(cur_ds_version: Version, optimizer) -> bool: + _functions_to_override = ["has_overflow_serial", "get_grad_norm_direct", "has_overflow_partitioned_grads_serial"] + + _version_to_source_code_map = {"0.9.2": Stage1And2_DeepSpeedZeroOptimizer_0_9_2} + + # Try to find the biggest version that is smaller than or equal to cur_ds_version. + # then compare the source code (in case the found version is the latest version supported); + # If current code does not match the found version, return False, and raise a warning to + # add the new version to the list. + versions = [Version(v) for v in _version_to_source_code_map] + sorted_versions = sorted(versions, reverse=True) + version_to_compare = None + for sv in sorted_versions: + if cur_ds_version >= sv: + version_to_compare = sv + break + + if version_to_compare is None: + warnings.warn( + "Unable to find a DeepSpeed version that is smaller than or equal to the current version " + f"{cur_ds_version}. Skip modifying optimizer.", + UserWarning, + ) + return False + + v_optimizer_cls = _version_to_source_code_map[str(version_to_compare)] + all_match = True + for func_name in _functions_to_override: + if not getattr(optimizer, func_name): + warnings.warn( + f"DeepSpeed function {func_name} is not found in optimizer. Skip modifying optimizer.", UserWarning + ) + all_match = False + cur_code_str = _get_normalized_str(getattr(optimizer, func_name)) + v_code_str = _get_normalized_str(getattr(v_optimizer_cls, func_name)) + if cur_code_str != v_code_str: + warnings.warn( + f"DeepSpeed function {func_name} has changed after version {version_to_compare}. " + f"Please append new version {cur_ds_version} in _version_to_source_code_map and _ds_code_store.py.\n" + f"---[{func_name}] Old Source Code Start----\n" + f"{v_code_str}\n" + f"---{func_name} Old Source Code End----\n" + f"---[{func_name}] New Source Code Start----\n" + f"{cur_code_str}\n" + f"---{func_name} New Source Code End----", + UserWarning, + ) + all_match = False + + return all_match + + class DeepSpeedZeROModifier(FP16OptimizerModifier): def __init__(self, optimizer, **kwargs) -> None: super().__init__(optimizer) @@ -30,19 +90,32 @@ def __init__(self, optimizer, **kwargs) -> None: def can_be_modified(self): import deepspeed + # Note 1: # This modifier relies on the implementation of has_overflow_serial, get_grad_norm_direct, # and has_overflow_partitioned_grads_serial # in https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/runtime/zero/stage_1_and_2.py. - # Everytime if we want to update this version supporting list to a newer version, - # we need to check if the implementation of these functions are changed. - # An easy way to check is to check the history of this file, if there is no change during the update, + # The minimum version supported is 0.4.0, all versions in between [0.4.0, 0.9.1] + # are manually checked to make sure the implementation of these functions are "logically" not changed. + # The way we did the check is to check the history of this file, if there is no change during the update, # it's safe to update the version supporting list. Otherwise, or the file is moved or renamed, # we need to check the implementation of these functions in detail. + # + # Note 2: + # Since version 0.9.2, we added dynamic source code check, by comparing installed version of code with + # the source code in our code store. If the source code is changed, we will raise a warning to ask user + # to add the new version to the code store. Otherwise, we will override the functions. + ds_version = Version(deepspeed.__version__) - if ds_version > Version("0.9.1") or ds_version < Version("0.4.0"): + if ds_version < Version("0.4.0"): + warnings.warn( + f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}, " + "minimum supported version: 0.4.0, current version", + UserWarning, + ) + return False + if ds_version > Version("0.9.1") and not _dynamic_checks(ds_version, self._optimizer): warnings.warn( - "Skip modifying optimizer because of unsupported DeepSpeed version {}, " - "supported version: 0.4.0 - 0.9.1.".format(deepspeed.__version__), + f"Skip modifying optimizer because of unsupported DeepSpeed version {ds_version}.", UserWarning, ) return False diff --git a/orttraining/orttraining/python/training/optim/_modifier_registry.py b/orttraining/orttraining/python/training/optim/_modifier_registry.py index 4a3a33ecc0513..a88740dac60b7 100644 --- a/orttraining/orttraining/python/training/optim/_modifier_registry.py +++ b/orttraining/orttraining/python/training/optim/_modifier_registry.py @@ -3,13 +3,59 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- +from __future__ import annotations + +import warnings +from typing import ClassVar + from ._apex_amp_modifier import ApexAMPModifier from ._ds_modifier import DeepSpeedZeROModifier from ._megatron_modifier import LegacyMegatronLMModifier +from ._modifier import FP16OptimizerModifier + + +class _AccelerateDeepSpeedZeROModifier(DeepSpeedZeROModifier): + """ + Modifier for wrapper of DeepSpeed Optimizer in accelerator. + https://github.com/huggingface/accelerate/blob/7843286f2e1c50735d259fbc0084a7f1c85e00e3/src/accelerate/utils/deepspeed.py#L182C19-L182C19 + """ + + def __init__(self, accelerator_optimizer, **kwargs) -> None: + super().__init__(accelerator_optimizer.optimizer) + + +def get_full_qualified_type_name(o): + klass = o.__class__ + module = klass.__module__ + if module == "builtins": + return klass.__qualname__ + return module + "." + klass.__qualname__ + + +class OptimizerModifierTypeRegistry: + _MAP: ClassVar[dict[str, FP16OptimizerModifier]] = { + "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, + "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, + "apex.amp.optimizer.unique_name_as_id": ApexAMPModifier, + } + + @staticmethod + def create_modifier(optimizer_full_qualified_name: str, optimizer, **kwargs) -> FP16OptimizerModifier | None: + """Create modifier for optimizer.""" + if optimizer_full_qualified_name in OptimizerModifierTypeRegistry._MAP: + return OptimizerModifierTypeRegistry._MAP[optimizer_full_qualified_name](optimizer, **kwargs) + + if optimizer_full_qualified_name == "accelerate.utils.deepspeed.DeepSpeedOptimizerWrapper": + if ( + hasattr(optimizer, "optimizer") + and get_full_qualified_type_name(optimizer.optimizer) in OptimizerModifierTypeRegistry._MAP + ): + return _AccelerateDeepSpeedZeROModifier(optimizer, **kwargs) -OptimizerModifierTypeRegistry = { - "megatron.fp16.fp16.FP16_Optimizer": LegacyMegatronLMModifier, - "deepspeed.runtime.zero.stage2.FP16_DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, - "deepspeed.runtime.zero.stage_1_and_2.DeepSpeedZeroOptimizer": DeepSpeedZeROModifier, - "apex.amp.optimizer.unique_name_as_id": ApexAMPModifier, -} + warnings.warn( + "Skip modifying optimizer because of optimizer name not found in the registry: " + f"{optimizer_full_qualified_name}", + UserWarning, + ) + return None diff --git a/orttraining/orttraining/python/training/optim/fp16_optimizer.py b/orttraining/orttraining/python/training/optim/fp16_optimizer.py index 2a5dfbc2189d3..fc93eadc32112 100644 --- a/orttraining/orttraining/python/training/optim/fp16_optimizer.py +++ b/orttraining/orttraining/python/training/optim/fp16_optimizer.py @@ -3,9 +3,8 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -import warnings -from ._modifier_registry import OptimizerModifierTypeRegistry +from ._modifier_registry import OptimizerModifierTypeRegistry, get_full_qualified_type_name def FP16_Optimizer(optimizer, **kwargs): # noqa: N802 @@ -80,22 +79,13 @@ def FP16_Optimizer(optimizer, **kwargs): # noqa: N802 """ - def get_full_qualified_type_name(o): - if hasattr(optimizer, "_amp_stash"): - return "apex.amp.optimizer.unique_name_as_id" - - klass = o.__class__ - module = klass.__module__ - if module == "builtins": - return klass.__qualname__ - return module + "." + klass.__qualname__ - - optimizer_full_qualified_name = get_full_qualified_type_name(optimizer) - if optimizer_full_qualified_name not in OptimizerModifierTypeRegistry: - warnings.warn("Skip modifying optimizer because of optimizer name not found in registry.", UserWarning) - return optimizer - - modifier = OptimizerModifierTypeRegistry[optimizer_full_qualified_name](optimizer, **kwargs) - modifier.apply() + optimizer_full_qualified_name = ( + "apex.amp.optimizer.unique_name_as_id" + if hasattr(optimizer, "_amp_stash") + else get_full_qualified_type_name(optimizer) + ) + modifier = OptimizerModifierTypeRegistry.create_modifier(optimizer_full_qualified_name, optimizer, **kwargs) + if modifier is not None: + modifier.apply() return optimizer diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py index b9318033a3d53..dd32e2aced561 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_autograd_function_runner.py @@ -245,6 +245,8 @@ def _process_inplace_outputs( if not copied: # Only need a copy once. + # Inplace copy only happens for non-leaf variables, so we have to set requires_grad to False. + raw_input_tensor.requires_grad = False raw_input_tensor.copy_(all_outputs_of_kernel_run[output_index]) _log_warning( f"{log_prefix}Copy output tensor {output_index} to raw input tensor {raw_tensor_input_index}. " @@ -449,7 +451,8 @@ def call_python_forward_function( try: func_name = func_name.decode("utf-8") if isinstance(func_name, bytes) else func_name # If this is the first time run, collect runtime tensor reuse mapping. - if kernel_invoke_id not in _GlobalOpKernelInfoMap: + is_first_time_run = kernel_invoke_id not in _GlobalOpKernelInfoMap + if is_first_time_run: kernel_info = CustomFuncOpKernelInfo(kernel_invoke_id) _GlobalOpKernelInfoMap[kernel_invoke_id] = kernel_info @@ -473,6 +476,11 @@ def call_python_forward_function( if tensor_input_index in inplace_map: raw_input_tensors_used_inplace[tensor_input_index] = wrapped_arg + # Only requires gradient when running under training mode + # and the associated tensor has grad_flag=True (i.e., + # "requires_grad=True" in the original PyTorch script). + wrapped_arg.requires_grad = is_training_mode and grad_flag + # Note1: # If it's first-time kernel invocation, tensor_input_indices_to_save_in_ctx is None, we do the # copy for all tensors. Otherwise, we only copy the tensors whose indices are in @@ -480,29 +488,30 @@ def call_python_forward_function( # Note2: # For inference mode, we don't need to do the copy because ctx will be None, # so nothing will be saved for ctx. - if is_training_mode and ( - tensor_input_indices_to_save_in_ctx is None - or tensor_input_index in tensor_input_indices_to_save_in_ctx - ): - wrapped_arg = wrapped_arg.detach().clone() - - # Only requires gradient when running under training mode - # and the associated tensor has grad_flag=True (i.e., - # "requires_grad=True" in the original PyTorch script). - wrapped_arg.requires_grad = is_training_mode and grad_flag - # Note3: - # If it's not first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the - # mul for all tensors. Otherwise, we only mul by one for the tensors whose indices are in - # tensor_input_indices_for_mark_dirty. - if is_training_mode and ( - tensor_input_indices_for_mark_dirty is None - or tensor_input_index in tensor_input_indices_for_mark_dirty - ): - # To fix this issue: - # "a leaf Variable that requires grad has been used in an in-place operation." - with torch.set_grad_enabled(True): - wrapped_arg = wrapped_arg.clone() + # To fix this issue: + # "a leaf Variable that requires grad has been used in an in-place operation." + # If it's first-time kernel invocation, tensor_input_indices_for_mark_dirty is None, we do the + # copy for all tensors to generate grad for it. Otherwise, we only clone (to generate grad) for + # the tensors whose indices are in tensor_input_indices_for_mark_dirty. + if is_training_mode: + if is_first_time_run: + with torch.set_grad_enabled(True): + wrapped_arg = wrapped_arg.clone() + else: + is_input_index_saved_in_ctx = ( + tensor_input_indices_to_save_in_ctx is None + or tensor_input_index in tensor_input_indices_to_save_in_ctx + ) + is_input_index_marked_dirty = ( + tensor_input_indices_for_mark_dirty is None + or tensor_input_index in tensor_input_indices_for_mark_dirty + ) + if is_input_index_saved_in_ctx or is_input_index_marked_dirty: + # when with grad, the leaf tensor after clone will not be leaf. + with torch.set_grad_enabled(is_input_index_marked_dirty): + wrapped_arg = wrapped_arg.clone() + wrapped_arg.requires_grad = is_training_mode and grad_flag wrapped_args.append(wrapped_arg) input_tensors_used_for_fw_run[tensor_input_index] = wrapped_arg diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py index 156c3e001d88f..77317242727b4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_gradient_registry.py @@ -271,8 +271,3 @@ def upsample_nearest2d_gradient(): @register_gradient("org.pytorch.aten", "ATen", "upsample_nearest3d", "vec") def upsample_nearest3d_gradient(): return _upsample_gradient("upsample_nearest3d_backward", 3) - - -@register_gradient("org.pytorch.aten", "ATen", "upsample_bilinear2d", "vec") -def upsample_bilinear2d_gradient(): - return _upsample_gradient("upsample_bilinear2d_backward", 2) diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index 64c7abe1c9386..6e694dcdf2e39 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -808,16 +808,3 @@ def upsample_nearest2d(g, input, output_size, scale_factors): @register_symbolic("upsample_nearest3d") def upsample_nearest3d(g, input, output_size, scale_factors): return _upsample_nearest(g, input, output_size, scale_factors, "upsample_nearest3d") - - -@register_symbolic("upsample_bilinear2d") -def upsample_bilinear2d(g, input, output_size, align_corners, scale_factors): - return g.op( - "org.pytorch.aten::ATen", - input, - output_size, - align_corners, - scale_factors, - operator_s="upsample_bilinear2d", - overload_name_s="vec", - ) diff --git a/orttraining/orttraining/test/gradient/gradient_ops_test.cc b/orttraining/orttraining/test/gradient/gradient_ops_test.cc index 597801f4030c1..890a1bbccbc92 100644 --- a/orttraining/orttraining/test/gradient/gradient_ops_test.cc +++ b/orttraining/orttraining/test/gradient/gradient_ops_test.cc @@ -3298,6 +3298,41 @@ TEST(GradientCheckerTest, ConvTransposeGrad) { execution_providers.push_back(DefaultCudaExecutionProvider()); ConvTransposeGradientCheckerTest(&execution_providers); } + +// TODO: Enable test for ROCM +TEST(GradientCheckerTest, ResizeGrad) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + const std::vector attributes = { + MakeAttribute("coordinate_transformation_mode", "half_pixel"), + MakeAttribute("cubic_coeff_a", -0.75f), + MakeAttribute("exclude_outside", static_cast(0)), + MakeAttribute("extrapolation_value", 0.0f), + MakeAttribute("mode", "linear"), + MakeAttribute("nearest_mode", "floor")}; + + float max_error; + GradientChecker gradient_checker; + OpDef op_def{"Resize", kOnnxDomain, 18}; + + TensorInfo x_info({1, 2, 4, 4}, true); + TensorInfo roi_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + TensorInfo scales_info({4}, false, nullptr, DataTypeImpl::GetTensorType()); + + TensorInfo y_info({1, 2, 8, 8}, true); + + std::vector> x_datas = {{0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f, + 0.2f, 0.4f, 0.6f, 0.8f, 0.2f, 0.4f, 0.6f, 0.8f}, + {1.0f, 1.0f, 1.0f, 1.0f}, + {1.0f, 1.0f, 2.0f, 2.0f}}; + + ASSERT_STATUS_OK(gradient_checker.ComputeGradientError(op_def, {x_info, roi_info, scales_info}, + {y_info}, &max_error, x_datas, attributes, true, false, &execution_providers)); + EXPECT_IS_TINY(max_error); +} + #endif // USE_CUDA } // namespace test diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 643d47b0d043e..c8ec2e52f3078 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -1773,13 +1773,17 @@ def run_step(model, input): _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) -def test_aten_upsample_bilinear(): +@pytest.mark.parametrize("interpolate_size_scale", ({"size": (8, 12)}, {"scale_factor": 4.7})) +@pytest.mark.parametrize("align_corners", (True, False)) +def test_resize_grad_correctness_bilinear_2d(interpolate_size_scale, align_corners): class _NeuralNetUpsampleBilinear(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input): - return torch.nn.functional.interpolate(input, size=(8, 12), mode="bilinear") + return torch.nn.functional.interpolate( + input, align_corners=align_corners, mode="bilinear", **interpolate_size_scale + ) device = "cuda" pt_model = _NeuralNetUpsampleBilinear().to(device) diff --git a/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc new file mode 100644 index 0000000000000..8fc13af8816be --- /dev/null +++ b/orttraining/orttraining/test/training_ops/cuda/resize_grad_test.cc @@ -0,0 +1,227 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "test/providers/compare_provider_test_utils.h" +#include "test/providers/provider_test_utils.h" +#include "test/util/include/default_providers.h" + +namespace onnxruntime::test { + +#if defined(USE_CUDA) || defined(USE_ROCM) + +namespace { + +void AddResizeGradAttributes(OpTester& test, const std::string& coordinate_transformation_mode) { + test.AddAttribute("mode", "linear"); + test.AddAttribute("coordinate_transformation_mode", coordinate_transformation_mode); +} + +} // namespace + +TEST(ResizeGradTest, ResizeGradWithSizes) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(128, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX(32, 4.0f); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithSizesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(128, 1.0f); + std::vector dY_shape = {1, 2, 8, 8}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f, + 2.9388f, 3.9184f, 3.9184f, 2.9388f, 3.9184f, 5.2245f, 5.2245f, 3.9184f, + 3.9184f, 5.2245f, 5.2245f, 3.9184f, 2.9388f, 3.9184f, 3.9184f, 2.9388f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScales) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesHalf) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "half_pixel"); + + std::vector dY(72, 1.0f); + std::vector dY_half(dY.size()); + ConvertFloatToMLFloat16(dY.data(), dY_half.data(), static_cast(dY.size())); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_half(X.size()); + ConvertFloatToMLFloat16(X.data(), X_half.data(), static_cast(X.size())); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f, + 2.7128f, 2.9550f, 2.7612f, 1.4533f, 2.9550f, 3.2189f, 3.0078f, 1.5830f, + 2.7612f, 3.0078f, 2.8106f, 1.4792f, 1.4533f, 1.5830f, 1.4792f, 0.7785f}); + std::vector dX_half(dX.size()); + ConvertFloatToMLFloat16(dX.data(), dX_half.data(), static_cast(dX.size())); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY_half); + test.AddInput("X", X_shape, X_half); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX_half); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +TEST(ResizeGradTest, ResizeGradWithScalesAndAlignCorners) { + std::vector> providers; +#ifdef USE_CUDA + providers.emplace_back(DefaultCudaExecutionProvider()); +#elif USE_ROCM + providers.emplace_back(DefaultRocmExecutionProvider()); +#endif + + OpTester test("ResizeGrad", 1, onnxruntime::kMSDomain); + + AddResizeGradAttributes(test, "align_corners"); + + std::vector dY(72, 1.0f); + std::vector dY_shape = {1, 2, 6, 6}; + + std::vector X(32, 1.0f); + std::vector X_shape = {1, 2, 4, 4}; + + std::vector dX({1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f, + 1.9600f, 2.2400f, 2.2400f, 1.9600f, 2.2400f, 2.5600f, 2.5600f, 2.2400f, + 2.2400f, 2.5600f, 2.5600f, 2.2400f, 1.9600f, 2.2400f, 2.2400f, 1.9600f}); + std::vector dX_shape = X_shape; + + test.AddInput("dY", dY_shape, dY); + test.AddInput("X", X_shape, X); + test.AddInput("", {0}, {}); + test.AddInput("scales", {4}, {1.0f, 1.0f, 1.7f, 1.7f}); + + test.AddOutput("dX", dX_shape, dX); + + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &providers); +} + +#endif // defined(USE_CUDA) || defined(USE_ROCM) + +} // namespace onnxruntime::test diff --git a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc index 8e61dbee506f2..ae4f48b6b49a2 100644 --- a/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc +++ b/orttraining/orttraining/training_ops/cuda/cuda_training_kernels.cc @@ -207,6 +207,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BatchScale); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, PadAndUnflatten); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, ScaledSum); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, double, ResizeGrad); // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training @@ -453,13 +456,14 @@ Status RegisterCudaTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, - BuildKernelCreateInfo, - BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // the kernels within the following ifdef are not included in a build with // --enable_training_ops but without --enable_training #ifdef ENABLE_TRAINING diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc new file mode 100644 index 0000000000000..a5e8f7cd35d88 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.cc @@ -0,0 +1,81 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include + +#include "orttraining/training_ops/cuda/tensor/resize_grad.h" +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" + +namespace onnxruntime::cuda { + +#define REGISTER_RESIZEGRAD_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + ResizeGrad, \ + kMSDomain, \ + 1, \ + T, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .InputMemoryType(OrtMemTypeCPUInput, 2) /* Keep roi on CPU */ \ + .InputMemoryType(OrtMemTypeCPUInput, 3) /* Keep scales on CPU */ \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + ResizeGrad); + +REGISTER_RESIZEGRAD_KERNEL_TYPED(MLFloat16) +REGISTER_RESIZEGRAD_KERNEL_TYPED(float) +REGISTER_RESIZEGRAD_KERNEL_TYPED(double) + +template +Status ResizeGrad::ComputeInternal(OpKernelContext* context) const { + typedef typename ToCudaType::MappedType CudaT; + + const Tensor* dY = context->Input(0); + const Tensor* X = context->Input(1); + const Tensor* scales = context->Input(3); + + ORT_ENFORCE(X->Shape().NumDimensions() == 4, "Expected input tensor to have 4 dimensions. Actual: ", + X->Shape().NumDimensions()); + + const auto get_scales_from_input = [](const Tensor* scales) { + if (nullptr == scales) { + return std::make_pair(std::optional{}, std::optional{}); + } + + ORT_ENFORCE(scales->Shape().Size() == 4, "There must be a scale for each dimension."); + + const auto* scales_data = scales->Data(); + return std::make_pair(std::optional{scales_data[2]}, std::optional{scales_data[3]}); + }; + + std::pair, std::optional> scale_factors = get_scales_from_input(scales); + + Tensor* dX = context->Output(0, X->Shape()); + + const int64_t batch_size = X->Shape()[0]; + const int64_t num_channels = X->Shape()[1]; + const int64_t output_height = dY->Shape()[2]; + const int64_t output_width = dY->Shape()[3]; + const int64_t input_height = X->Shape()[2]; + const int64_t input_width = X->Shape()[3]; + + if (dX->Shape() == dY->Shape()) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(dX->MutableDataRaw(), dY->DataRaw(), dY->SizeInBytes(), cudaMemcpyDeviceToDevice)); + return Status::OK(); + } + + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(dX->MutableDataRaw(), 0, dX->SizeInBytes(), Stream(context))); + + const bool align_corners = coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS; + const CudaT* dy_data = reinterpret_cast(dY->Data()); + CudaT* dx_data = reinterpret_cast(dX->MutableData()); + + ResizeGradImpl(Stream(context), input_height, input_width, output_height, + output_width, batch_size, num_channels, align_corners, + scale_factors.first, scale_factors.second, + dy_data, dx_data); + + return Status::OK(); +} + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h new file mode 100644 index 0000000000000..53f8d5f0d71f5 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include "core/common/common.h" +#include "core/providers/cuda/cuda_kernel.h" +#include "core/providers/cpu/tensor/upsamplebase.h" + +namespace onnxruntime::cuda { + +template +class ResizeGrad final : public UpsampleBase, public CudaKernel { + public: + ResizeGrad(const OpKernelInfo& info) : UpsampleBase(info), CudaKernel(info) { + ORT_ENFORCE(!antialias_, "Antialiasing is not supported in ResizeGrad yet."); + + ORT_ENFORCE(axes_.empty(), "ReizeGrad does not support the `axes` attribute yet."); + + std::string coordinate_transform_mode = + info.GetAttrOrDefault("coordinate_transformation_mode", "half_pixel"); + coordinate_transform_mode_ = StringToCoordinateTransformationMode(coordinate_transform_mode); + ORT_ENFORCE(coordinate_transform_mode_ == ResizeCoordinateTransformationMode::HALF_PIXEL || + coordinate_transform_mode_ == ResizeCoordinateTransformationMode::ALIGN_CORNERS, + "ReizeGrad only supports the `HALF_PIXEL` and `ALIGN_CORNERS` coordinate_transform_mode ", + coordinate_transform_mode, " is not supported yet."); + + ORT_ENFORCE(keep_aspect_ratio_policy_ == AspectRatioPolicy::STRETCH, + "ReizeGrad only supports the `STRETCH` policy."); + + std::string mode; + ORT_ENFORCE(info.GetAttr("mode", &mode).IsOK()); + ORT_ENFORCE((UpsampleMode::LINEAR == mode_), + "ReizeGrad only supports the `LINEAR` mode. ", mode, " mode is not supported yet."); + } + + Status ComputeInternal(OpKernelContext* context) const override; +}; + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu new file mode 100644 index 0000000000000..0507cda62390b --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.cu @@ -0,0 +1,151 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Contents of this file are derived from the pytorch cuda implementation of +// the upsample_bilinear2d_backward implementation at: +// https://github.com/pytorch/pytorch/blob/ce50132748f652ed6079c3db8008a6817594dbae/aten/src/ATen/native/cuda/UpSampleBilinear2d.cu + +#include "orttraining/training_ops/cuda/tensor/resize_grad_impl.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/atomic/common.cuh" + +namespace onnxruntime::cuda { + +namespace { + +constexpr int NumThreadsPerBlock = GridDim::maxThreadsPerBlock; + +} // namespace + +__device__ __forceinline__ size_t +idx(const size_t nc, + const size_t height, + const size_t width, + const size_t h, + const size_t w) { + return (nc * height + h) * width + w; +} + +template +__device__ __forceinline__ static T AreaPixelComputeSourceIndex( + T scale, + int dst_index, + bool align_corners, + bool cubic) { + if (align_corners) { + return scale * dst_index; + } else { + T src_idx = scale * (dst_index + static_cast(0.5)) - + static_cast(0.5); + return (!cubic && src_idx < static_cast(0)) + ? static_cast(0) + : src_idx; + } +} + +template +__global__ void UpsampleGrad(const int64_t nc, const int64_t input_height, + const int64_t input_width, const int64_t output_height, + const int64_t output_width, const AccT rheight, + const AccT rwidth, const bool align_corners, + const T* dY_data, T* dX_data) { + const size_t dy_numel = nc * output_width * output_height; + const size_t dx_numel = nc * input_width * input_height; + for (size_t index = blockDim.x * blockIdx.x + threadIdx.x; + index < dy_numel; + index += blockDim.x * gridDim.x) { + size_t index_temp = index; + const int w2 = index_temp % output_width; // 0:width2-1 + index_temp /= output_width; + const int h2 = index_temp % output_height; // 0:height2-1 + const size_t nc = index_temp / output_height; + + const AccT h1r = AreaPixelComputeSourceIndex( + rheight, h2, align_corners, /*cubic=*/false); + const int h1 = h1r; + const int h1p = (h1 < input_height - 1) ? 1 : 0; + const AccT h1lambda = h1r - h1; + const AccT h0lambda = static_cast(1) - h1lambda; + + const AccT w1r = AreaPixelComputeSourceIndex( + rwidth, w2, align_corners, /*cubic=*/false); + const int w1 = w1r; + const int w1p = (w1 < input_width - 1) ? 1 : 0; + const AccT w1lambda = w1r - w1; + const AccT w0lambda = static_cast(1) - w1lambda; + + const T d2val = dY_data[index]; + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1), + dx_numel, + static_cast(h0lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1, w1 + w1p), + dx_numel, + static_cast(h0lambda * w1lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1), + dx_numel, + static_cast(h1lambda * w0lambda) * d2val); + AtomicAdd( + dX_data, + idx(nc, input_height, input_width, h1 + h1p, w1 + w1p), + dx_numel, + static_cast(h1lambda * w1lambda) * d2val); + } +} + +template +T AreaPixelComputeScale(int64_t input_size, int64_t output_size, bool align_corners, + const std::optional& scale) { + if (align_corners) { + if (output_size <= 1) { + return T{0}; + } + return static_cast(input_size - 1) / static_cast(output_size - 1); + } else { + if (scale.has_value()) { + return static_cast(T{1.0} / *scale); + } else { + return static_cast(input_size) / static_cast(output_size); + } + } +} + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data) { + float rheight = AreaPixelComputeScale(input_height, output_height, align_corners, scale_height); + float rwidth = AreaPixelComputeScale(input_width, output_width, align_corners, scale_width); + + const size_t output_numel = batch_size * channels * output_height * output_width; + int blocks_per_grid = (int)(ceil(static_cast(output_numel) / NumThreadsPerBlock)); + UpsampleGrad<<>>( + batch_size * channels, input_height, input_width, output_height, output_width, + rheight, rwidth, align_corners, dY_data, dX_data); +} + +#define SPECIALIZED_RESIZEGRAD_IMPL(T) \ + template void ResizeGradImpl(cudaStream_t stream, int64_t input_height, \ + int64_t input_width, int64_t output_height, \ + int64_t output_width, int64_t batch_size, \ + int64_t channels, bool align_corners, \ + const std::optional& scale_height, \ + const std::optional& scale_width, \ + const T* dY_data, T* dX_data); + +SPECIALIZED_RESIZEGRAD_IMPL(half) +SPECIALIZED_RESIZEGRAD_IMPL(float) +SPECIALIZED_RESIZEGRAD_IMPL(double) + +#undef SPECIALIZED_RESIZEGRAD_IMPL + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h new file mode 100644 index 0000000000000..3e917f9071e30 --- /dev/null +++ b/orttraining/orttraining/training_ops/cuda/tensor/resize_grad_impl.h @@ -0,0 +1,20 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include + +namespace onnxruntime::cuda { + +template +void ResizeGradImpl(cudaStream_t stream, int64_t input_height, + int64_t input_width, int64_t output_height, + int64_t output_width, int64_t batch_size, + int64_t channels, bool align_corners, + const std::optional& scale_height, + const std::optional& scale_width, + const T* dY_data, T* dX_data); + +} // namespace onnxruntime::cuda diff --git a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc index 2321aa23dd6eb..e0749c2fb4d0d 100644 --- a/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc +++ b/orttraining/orttraining/training_ops/rocm/rocm_training_kernels.cc @@ -187,6 +187,9 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float_BFloat16, ReduceAllL2); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, BFloat16_BFloat16, ReduceAllL2); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, PadAndUnflatten); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, MLFloat16, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, float, ResizeGrad); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kMSDomain, 1, double, ResizeGrad); #if defined(ORT_USE_NCCL) || defined(USE_MPI) // P2P communication operators. @@ -387,6 +390,9 @@ Status RegisterRocmTrainingKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo, // P2P communication operators. #if defined(ORT_USE_NCCL) || defined(USE_MPI) diff --git a/setup.py b/setup.py index 2eb8f212d730c..b71836e0ee6e4 100644 --- a/setup.py +++ b/setup.py @@ -192,11 +192,15 @@ def run(self): cuda_dependencies = [ "libcublas.so.11", + "libcublas.so.12", "libcublasLt.so.11", - "libcudnn.so.8", + "libcublasLt.so.12", "libcudart.so.11.0", - "libcurand.so.10", + "libcudart.so.12.0", + "libcudnn.so.8", "libcufft.so.10", + "libcufft.so.11", + "libcurand.so.10", ] rocm_dependencies = [ "librccl.so.1", diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index fdd8c09333737..3696c41c196de 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -193,7 +193,7 @@ stages: DoCompliance: ${{ parameters.DoCompliance }} DoEsrp: ${{ parameters.DoEsrp }} stage_name_suffix: gpu - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 msbuildPlatform: x64 packageName: x64-cuda @@ -376,7 +376,7 @@ stages: - task: BatchScript@1 displayName: 'setup env' inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_cuda_11.bat' + filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\setup_env_cuda.bat' modifyEnvironment: true workingFolder: '$(Build.BinariesDirectory)' @@ -488,13 +488,13 @@ stages: Steps: - script: | tools/ci_build/get_docker_image.py \ - --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 \ + --dockerfile tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda \ --context tools/ci_build/github/linux/docker \ - --docker-build-args "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg BUILD_UID=$( id -u )" \ + --docker-build-args "--network=host --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 --build-arg INSTALL_CUDNN=true --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 --build-arg BUILD_UID=$( id -u )" \ --container-registry onnxruntimebuildcache \ --multiple_repos \ --repository onnxruntimecuda118xtrt86build - displayName: "Get onnxruntimecuda118xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6" + displayName: "Get onnxruntimecuda118xtrt86build image for tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda" workingDirectory: $(Build.SourcesDirectory)/onnxruntime ContainerRegistry: onnxruntimebuildcache diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 71a580f348f6f..1d4681d064387 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -58,9 +58,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda11build - task: Cache@2 @@ -154,9 +160,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda11build - task: CmdLine@2 diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml index 9450395f3cf79..16d4457c45eb6 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-tensorrt-ci-pipeline.yml @@ -57,9 +57,15 @@ jobs: - template: templates/get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimetensorrt86gpubuild - template: templates/linux-build-step-with-cache.yml diff --git a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml index d24b0e0539631..2a94499c7a268 100644 --- a/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml +++ b/tools/ci_build/github/azure-pipelines/post-merge-jobs.yml @@ -67,7 +67,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 diff --git a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml index 2161a9205f22d..c8aac6e8b130d 100644 --- a/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-package-test-pipeline.yml @@ -30,7 +30,7 @@ stages: - template: templates/py-packaging-linux-test-cpu.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' base_image: 'arm64v8/almalinux:8' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index cf73691a5eecc..9ca4a45ffcec4 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.90 + version: 1.0.95 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.90 + version: 1.0.95 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml index ca5a52fa61ed3..0c8fb91a24a31 100644 --- a/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml +++ b/tools/ci_build/github/azure-pipelines/templates/jobs/set-winenv.yml @@ -1,19 +1,27 @@ parameters: -- name: EnvSetupScript - type: string - -- name: DownloadCUDA - type: boolean - default: false + - name: EnvSetupScript + type: string + - name: DownloadCUDA + type: boolean + default: false + - name: PrimaryCUDAVersion + type: string + default: '11.8' + - name: SecondaryCUDAVersion + type: string + default: '12.2' steps: -- ${{ if eq(parameters.DownloadCUDA, 'true') }}: - - powershell: | - azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v11.8" $(Agent.TempDirectory) - -- task: BatchScript@1 - displayName: 'setup env' - inputs: - filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\${{ parameters.EnvSetupScript }}' - modifyEnvironment: true - workingFolder: '$(Build.BinariesDirectory)' + - ${{ if eq(parameters.DownloadCUDA, 'true') }}: + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.PrimaryCUDAVersion }}" $(Agent.TempDirectory) + displayName: 'Download Primary CUDA SDK v${{ parameters.PrimaryCUDAVersion }}' + - powershell: | + azcopy.exe cp --recursive "https://lotusscus.blob.core.windows.net/models/cuda_sdk/v${{ parameters.SecondaryCUDAVersion }}" $(Agent.TempDirectory) + displayName: 'Download Secondary CUDA SDK v${{ parameters.SecondaryCUDAVersion }}' + - task: BatchScript@1 + displayName: 'setup env' + inputs: + filename: '$(Build.SourcesDirectory)\tools\ci_build\github\windows\${{ parameters.EnvSetupScript }}' + modifyEnvironment: true + workingFolder: '$(Build.BinariesDirectory)' diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml index 51d3a9ebc2187..1cc5c48c5513c 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-cpu-packaging-pipeline.yml @@ -47,7 +47,7 @@ stages: OnnxruntimeCFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeCXXFlags: '-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -O3 -Wl,--strip-all' OnnxruntimeNodejsBindingArch: 'arm64' - PoolName: 'aiinfra-linux-ARM64-CPU-2019' + PoolName: 'onnxruntime-linux-ARM64-CPU-2019' ArtifactNamePrefix: ${{ parameters.ArtifactNamePrefix }} PackageJava: ${{ parameters.PackageJava }} PackageNodeJS: ${{ parameters.PackageNodeJS }} diff --git a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml index 445f739e81c45..0d58f6cee4003 100644 --- a/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/templates/linux-gpu-tensorrt-packaging-pipeline.yml @@ -44,9 +44,15 @@ stages: submodules: recursive - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u )" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + " Repository: onnxruntimecuda118xtrt86build - template: set-version-number-variables-step.yml diff --git a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml index 3d5a71284fa6f..33c82b5e8965a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-linux-gpu.yml @@ -36,9 +36,16 @@ jobs: - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + --build-arg PLATFORM=${{ parameters.arch }} + " Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml index 43ed0172825bc..a70e0c01e52f1 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-linux-test-cuda.yml @@ -81,9 +81,16 @@ jobs: - template: get-docker-image-steps.yml parameters: - Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda Context: tools/ci_build/github/linux/docker - DockerBuildArgs: "--network=host --build-arg POLICY=manylinux_2_28 --build-arg PLATFORM=x86_64 --build-arg PREPEND_PATH=/usr/local/cuda/bin --build-arg LD_LIBRARY_PATH_ARG=/usr/local/lib64 --build-arg DEVTOOLSET_ROOTPATH=/usr --build-arg BUILD_UID=$( id -u ) --build-arg PLATFORM=${{ parameters.arch }}" + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 + --build-arg TRT_VERSION=8.6.1.6-1.cuda11.8 + --build-arg INSTALL_CUDNN=true + --build-arg BUILD_UID=$( id -u ) + --build-arg PLATFORM=${{ parameters.arch }} + " Repository: onnxruntimecuda118xtrt86build${{ parameters.arch }} - task: Bash@3 diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 1e28ad08a5bdc..1a67ace5e85fa 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -484,7 +484,7 @@ stages: - template: py-linux.yml parameters: arch: 'aarch64' - machine_pool: 'aiinfra-linux-ARM64-CPU-2019' + machine_pool: 'onnxruntime-linux-ARM64-CPU-2019' base_image: 'arm64v8/almalinux:8' devtoolset_rootpath: /opt/rh/gcc-toolset-12/root ld_library_path_arg: /opt/rh/gcc-toolset-12/root/usr/lib64:/opt/rh/gcc-toolset-12/root/usr/lib:/opt/rh/gcc-toolset-12/root/usr/lib64/dyninst:/opt/rh/gcc-toolset-12/root/usr/lib/dyninst:/usr/local/lib64 diff --git a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml index 84c910ba58787..a5925d16564fe 100644 --- a/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml +++ b/tools/ci_build/github/azure-pipelines/templates/win-wasm-ci.yml @@ -20,7 +20,7 @@ parameters: default: false - name: TimeoutInMinutes - default: 180 + default: 240 - name: BuildJsep type: boolean diff --git a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml index 2a5622faf2905..ed010b5619db5 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-pipeline.yml @@ -132,7 +132,7 @@ stages: isTraining: false ORT_EP_NAME: CPU GenerateDocumentation: false - WITH_CACHE: true + WITH_CACHE: false MachinePool: 'onnxruntime-Win-CPU-2022' - stage: x86_release diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml index 07b5388ea5cd2..ae2a4b4cead3d 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-ci-pipeline.yml @@ -40,7 +40,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --build_java --build_nodejs --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --enable_cuda_profiling --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 msbuildPlatform: x64 @@ -57,7 +57,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 additionalBuildFlags: --enable_pybind --enable_training --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --skip_onnx_tests --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75 msbuildPlatform: x64 @@ -76,7 +76,7 @@ stages: - template: templates/jobs/win-ci-vs-2022-job.yml parameters: BuildConfig: 'RelWithDebInfo' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 # note: need to specify `--gen_doc` when creating the build config so it has to be in additionalBuildFlags additionalBuildFlags: --gen_doc validate --skip_tests --enable_pybind --use_dml --use_cuda --cuda_home="$(Agent.TempDirectory)\v11.8" --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86 --cmake_extra_defines onnxruntime_BUILD_UNIT_TESTS=OFF diff --git a/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml index b5db8a5201405..d0f9772da7adc 100644 --- a/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/win-gpu-reduce-op-ci-pipeline.yml @@ -10,7 +10,7 @@ jobs: BuildConfig: 'MinSizeRel' variables: MsbuildArguments: '-detailedsummary -maxcpucount -consoleloggerparameters:PerformanceSummary' - EnvSetupScript: setup_env_cuda_11.bat + EnvSetupScript: setup_env_cuda.bat buildArch: x64 TODAY: $[format('{0:dd}{0:MM}{0:yyyy}', pipeline.startTime)] timeoutInMinutes: 120 diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda index 4d9c676674a09..7b2cada736488 100644 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda +++ b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda @@ -5,11 +5,10 @@ ARG POLICY=manylinux_2_28 ARG PLATFORM=x86_64 ARG BASEIMAGE=nvidia/cuda:12.2.0-devel-ubi8 -ARG TRT_VERSION=8.6.1.6-1.cuda12.0 ARG DEVTOOLSET_ROOTPATH=/usr ARG LD_LIBRARY_PATH_ARG=/usr/local/lib64 ARG PREPEND_PATH=/usr/local/cuda/binet - +ARG INSTALL_CUDNN=false #Build manylinux docker image begin FROM $BASEIMAGE AS runtime_base @@ -118,7 +117,7 @@ RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 FROM build_cpython AS build_cpython311 COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.0b5 +RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 FROM build_cpython AS all_python COPY build_scripts/install-pypy.sh \ @@ -155,23 +154,35 @@ CMD ["/bin/bash"] #Build manylinux docker image end -#Install TensorRT 8.6.1.6 -RUN CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') &&\ - dnf -y install\ - libcudnn8-devel-*cuda${CUDA_VERSION}*\ - libcudnn8-*cuda${CUDA_VERSION}*\ - libnvinfer8-${TRT_VERSION}\ - libnvparsers8-${TRT_VERSION}\ - libnvonnxparsers8-${TRT_VERSION}\ - libnvinfer-plugin8-${TRT_VERSION}\ - libnvinfer-vc-plugin8-${TRT_VERSION}\ - libnvinfer-devel-${TRT_VERSION}\ - libnvparsers-devel-${TRT_VERSION}\ - libnvonnxparsers-devel-${TRT_VERSION}\ - libnvinfer-plugin-devel-${TRT_VERSION}\ - libnvinfer-vc-plugin-devel-${TRT_VERSION}\ - libnvinfer-headers-devel-${TRT_VERSION}\ - libnvinfer-headers-plugin-devel-${TRT_VERSION} + +#Install optinal Cudnn +RUN if [ "$INSTALL_CUDNN" = true ]; then \ + CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') && \ + dnf -y install \ + libcudnn8-devel-*cuda${CUDA_VERSION}* \ + libcudnn8-*cuda${CUDA_VERSION}* ; \ +fi + +#Install TensorRT only if TRT_VERSION is not empty +RUN if [ -n "$TRT_VERSION" ]; then \ + echo "TRT_VERSION is $TRT_VERSION" && \ + dnf -y install \ + libnvinfer8-${TRT_VERSION} \ + libnvparsers8-${TRT_VERSION} \ + libnvonnxparsers8-${TRT_VERSION} \ + libnvinfer-plugin8-${TRT_VERSION} \ + libnvinfer-vc-plugin8-${TRT_VERSION} \ + libnvinfer-devel-${TRT_VERSION} \ + libnvparsers-devel-${TRT_VERSION} \ + libnvonnxparsers-devel-${TRT_VERSION} \ + libnvinfer-plugin-devel-${TRT_VERSION} \ + libnvinfer-vc-plugin-devel-${TRT_VERSION} \ + libnvinfer-headers-devel-${TRT_VERSION} \ + libnvinfer-headers-plugin-devel-${TRT_VERSION}; \ +else \ + echo "TRT_VERSION is none skipping Tensor RT Installation" ; \ +fi + ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 #Add our own dependencies ADD scripts /tmp/scripts diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 deleted file mode 100644 index 933b0211b0e6c..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11 +++ /dev/null @@ -1,166 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.8.0-cudnn8-devel-ubi8 -ARG POLICY=manylinux_2_28 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need both CUDA and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 deleted file mode 100644 index 003bb2324c049..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_4 +++ /dev/null @@ -1,173 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.6.1-cudnn8-devel-centos7 -ARG POLICY=manylinux2014 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.4.1.5 -#RUN yum install -y wget -RUN v="8.4.1-1.cuda11.6" &&\ - yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ - yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 deleted file mode 100644 index 0337ffc5e00a0..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_6_tensorrt8_5 +++ /dev/null @@ -1,173 +0,0 @@ -ARG BASEIMAGE=nvidia/cuda:11.6.1-cudnn8-devel-centos7 -ARG POLICY=manylinux2014 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.2 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.5.1.7 -#RUN yum install -y wget -RUN v="8.5.1-1.cuda11.8" &&\ - yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo &&\ - yum -y install libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} \ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH diff --git a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 b/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 deleted file mode 100644 index 70765c667ab8e..0000000000000 --- a/tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda11_8_tensorrt8_6 +++ /dev/null @@ -1,181 +0,0 @@ -# This file is deprecated and will be replaced by tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda -ARG BASEIMAGE=nvidia/cuda:11.8.0-devel-ubi8 -ARG POLICY=manylinux_2_28 -ARG PLATFORM=x86_64 -ARG DEVTOOLSET_ROOTPATH= -ARG LD_LIBRARY_PATH_ARG= -ARG PREPEND_PATH= - -#We need CUDA, TensorRT and manylinux. But the CUDA Toolkit End User License Agreement says NVIDIA CUDA Driver Libraries(libcuda.so, libnvidia-ptxjitcompiler.so) are only distributable in applications that meet this criteria: -#1. The application was developed starting from a NVIDIA CUDA container obtained from Docker Hub or the NVIDIA GPU Cloud, and -#2. The resulting application is packaged as a Docker container and distributed to users on Docker Hub or the NVIDIA GPU Cloud only. -#So we use CUDA as the base image then add manylinux and TensorRT on top of it. - -#Build manylinux2014 docker image begin -FROM $BASEIMAGE AS runtime_base -ARG POLICY -ARG PLATFORM -ARG DEVTOOLSET_ROOTPATH -ARG LD_LIBRARY_PATH_ARG -ARG PREPEND_PATH -LABEL maintainer="The ManyLinux project" - -ENV AUDITWHEEL_POLICY=${POLICY} AUDITWHEEL_ARCH=${PLATFORM} AUDITWHEEL_PLAT=${POLICY}_${PLATFORM} -ENV LC_ALL=en_US.UTF-8 LANG=en_US.UTF-8 LANGUAGE=en_US.UTF-8 -ENV DEVTOOLSET_ROOTPATH=${DEVTOOLSET_ROOTPATH} -ENV LD_LIBRARY_PATH=${LD_LIBRARY_PATH_ARG} -ENV PATH=${PREPEND_PATH}${PATH} -ENV PKG_CONFIG_PATH=/usr/local/lib/pkgconfig - -# first copy the fixup mirrors script, keep the script around -COPY build_scripts/fixup-mirrors.sh /usr/local/sbin/fixup-mirrors - -# setup entrypoint, this will wrap commands with `linux32` with i686 images -COPY build_scripts/install-entrypoint.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ - -RUN /build_scripts/install-entrypoint.sh && rm -rf /build_scripts -COPY manylinux-entrypoint /usr/local/bin/manylinux-entrypoint -ENTRYPOINT ["manylinux-entrypoint"] - -COPY build_scripts/install-runtime-packages.sh \ - build_scripts/build_utils.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-runtime-packages.sh && rm -rf /build_scripts/ - -COPY build_scripts/build_utils.sh /build_scripts/ - -COPY build_scripts/install-autoconf.sh /build_scripts/ -RUN export AUTOCONF_ROOT=autoconf-2.71 && \ - export AUTOCONF_HASH=431075ad0bf529ef13cb41e9042c542381103e80015686222b8a9d4abef42a1c && \ - export AUTOCONF_DOWNLOAD_URL=http://ftp.gnu.org/gnu/autoconf && \ - manylinux-entrypoint /build_scripts/install-autoconf.sh - -COPY build_scripts/install-automake.sh /build_scripts/ -RUN export AUTOMAKE_ROOT=automake-1.16.5 && \ - export AUTOMAKE_HASH=07bd24ad08a64bc17250ce09ec56e921d6343903943e99ccf63bbf0705e34605 && \ - export AUTOMAKE_DOWNLOAD_URL=http://ftp.gnu.org/gnu/automake && \ - manylinux-entrypoint /build_scripts/install-automake.sh - -COPY build_scripts/install-libtool.sh /build_scripts/ -RUN export LIBTOOL_ROOT=libtool-2.4.7 && \ - export LIBTOOL_HASH=04e96c2404ea70c590c546eba4202a4e12722c640016c12b9b2f1ce3d481e9a8 && \ - export LIBTOOL_DOWNLOAD_URL=http://ftp.gnu.org/gnu/libtool && \ - manylinux-entrypoint /build_scripts/install-libtool.sh - -COPY build_scripts/install-libxcrypt.sh /build_scripts/ -RUN export LIBXCRYPT_VERSION=4.4.28 && \ - export LIBXCRYPT_HASH=db7e37901969cb1d1e8020cb73a991ef81e48e31ea5b76a101862c806426b457 && \ - export LIBXCRYPT_DOWNLOAD_URL=https://github.com/besser82/libxcrypt/archive && \ - export PERL_ROOT=perl-5.34.0 && \ - export PERL_HASH=551efc818b968b05216024fb0b727ef2ad4c100f8cb6b43fab615fa78ae5be9a && \ - export PERL_DOWNLOAD_URL=https://www.cpan.org/src/5.0 && \ - manylinux-entrypoint /build_scripts/install-libxcrypt.sh - -FROM runtime_base AS build_base -COPY build_scripts/install-build-packages.sh /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-build-packages.sh - - -FROM build_base AS build_git -COPY build_scripts/build-git.sh /build_scripts/ -RUN export GIT_ROOT=git-2.36.2 && \ - export GIT_HASH=6dc2cdea5fb23d823ba4871cc23222c1db31dfbb6d6c6ff74c4128700df57c68 && \ - export GIT_DOWNLOAD_URL=https://www.kernel.org/pub/software/scm/git && \ - manylinux-entrypoint /build_scripts/build-git.sh - - -FROM build_base AS build_cpython -COPY build_scripts/build-sqlite3.sh /build_scripts/ -RUN export SQLITE_AUTOCONF_ROOT=sqlite-autoconf-3390200 && \ - export SQLITE_AUTOCONF_HASH=852be8a6183a17ba47cee0bbff7400b7aa5affd283bf3beefc34fcd088a239de && \ - export SQLITE_AUTOCONF_DOWNLOAD_URL=https://www.sqlite.org/2022 && \ - manylinux-entrypoint /build_scripts/build-sqlite3.sh - -COPY build_scripts/build-openssl.sh /build_scripts/ -RUN export OPENSSL_ROOT=openssl-1.1.1q && \ - export OPENSSL_HASH=d7939ce614029cdff0b6c20f0e2e5703158a489a72b2507b8bd51bf8c8fd10ca && \ - export OPENSSL_DOWNLOAD_URL=https://www.openssl.org/source && \ - manylinux-entrypoint /build_scripts/build-openssl.sh - -COPY build_scripts/build-cpython.sh /build_scripts/ - - -FROM build_cpython AS build_cpython37 -COPY build_scripts/cpython-pubkeys.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.7.13 - - -FROM build_cpython AS build_cpython38 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.8.13 - - -FROM build_cpython AS build_cpython39 -COPY build_scripts/ambv-pubkey.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.9.13 - - -FROM build_cpython AS build_cpython310 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.10.5 - -FROM build_cpython AS build_cpython311 -COPY build_scripts/cpython-pubkey-310-311.txt /build_scripts/cpython-pubkeys.txt -RUN manylinux-entrypoint /build_scripts/build-cpython.sh 3.11.0b5 - -FROM build_cpython AS all_python -COPY build_scripts/install-pypy.sh \ - build_scripts/pypy.sha256 \ - build_scripts/finalize-python.sh \ - /build_scripts/ -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.7 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.8 7.3.9 -RUN manylinux-entrypoint /build_scripts/install-pypy.sh 3.9 7.3.9 -COPY --from=build_cpython37 /opt/_internal /opt/_internal/ -COPY --from=build_cpython38 /opt/_internal /opt/_internal/ -COPY --from=build_cpython39 /opt/_internal /opt/_internal/ -COPY --from=build_cpython310 /opt/_internal /opt/_internal/ -COPY --from=build_cpython311 /opt/_internal /opt/_internal/ -RUN manylinux-entrypoint /build_scripts/finalize-python.sh - - -FROM runtime_base -COPY --from=build_git /manylinux-rootfs / -COPY --from=build_cpython /manylinux-rootfs / -COPY --from=all_python /opt/_internal /opt/_internal/ -COPY build_scripts/finalize.sh \ - build_scripts/python-tag-abi-tag.py \ - build_scripts/requirements3.7.txt \ - build_scripts/requirements3.8.txt \ - build_scripts/requirements3.9.txt \ - build_scripts/requirements3.10.txt \ - build_scripts/requirements3.11.txt \ - build_scripts/requirements-base-tools.txt \ - /build_scripts/ -COPY build_scripts/requirements-tools/* /build_scripts/requirements-tools/ -RUN manylinux-entrypoint /build_scripts/finalize.sh && rm -rf /build_scripts - -ENV SSL_CERT_FILE=/opt/_internal/certs.pem - -CMD ["/bin/bash"] - -#Build manylinux2014 docker image end - -#Install TensorRT 8.6.1.6 -RUN v="8.6.1.6-1.cuda11.8" && CUDA_VERSION=$(nvcc --version | sed -n 's/^.*release \([0-9]\+\.[0-9]\+\).*$/\1/p') \ - && dnf -y install libcudnn8-devel-*cuda$CUDA_VERSION* libcudnn8-*cuda$CUDA_VERSION* libnvinfer8-${v} libnvparsers8-${v} libnvonnxparsers8-${v} libnvinfer-plugin8-${v} libnvinfer-vc-plugin8-${v}\ - libnvinfer-devel-${v} libnvparsers-devel-${v} libnvonnxparsers-devel-${v} libnvinfer-plugin-devel-${v} libnvinfer-vc-plugin-devel-${v} libnvinfer-headers-devel-${v} libnvinfer-headers-plugin-devel-${v} -ENV JAVA_HOME=/usr/lib/jvm/msopenjdk-11 -#Add our own dependencies -ADD scripts /tmp/scripts -RUN cd /tmp/scripts && /tmp/scripts/manylinux/install_centos.sh && /tmp/scripts/manylinux/install_deps.sh && rm -rf /tmp/scripts - -ARG BUILD_UID=1001 -ARG BUILD_USER=onnxruntimedev -RUN adduser --uid $BUILD_UID $BUILD_USER -WORKDIR /home/$BUILD_USER -USER $BUILD_USER -ENV PATH /usr/local/dotnet:$PATH -ENV CUDA_MODULE_LOADING "LAZY" diff --git a/tools/ci_build/github/windows/setup_env_cuda_11.bat b/tools/ci_build/github/windows/setup_env_cuda.bat similarity index 53% rename from tools/ci_build/github/windows/setup_env_cuda_11.bat rename to tools/ci_build/github/windows/setup_env_cuda.bat index 1308e43a4f6db..96569cbe0f648 100644 --- a/tools/ci_build/github/windows/setup_env_cuda_11.bat +++ b/tools/ci_build/github/windows/setup_env_cuda.bat @@ -6,4 +6,10 @@ if exist PATH=%AGENT_TEMPDIRECTORY%\v11.8\ { } else { set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.8\extras\CUPTI\lib64;%PATH% } +@REM The default version is still cuda v11.8, because set cuda v12.2 after it +if exist PATH=%AGENT_TEMPDIRECTORY%\v12.2\ { + set PATH=%PATH%;%AGENT_TEMPDIRECTORY%\v12.2\bin;%AGENT_TEMPDIRECTORY%\v12.2\extras\CUPTI\lib64 +} else { + set PATH=%PATH%;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin;C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\extras\CUPTI\lib64 +} set GRADLE_OPTS=-Dorg.gradle.daemon=false