diff --git a/cgmanifests/generated/cgmanifest.json b/cgmanifests/generated/cgmanifest.json index bcd0b2a92a5c3..03e3f84547a68 100644 --- a/cgmanifests/generated/cgmanifest.json +++ b/cgmanifests/generated/cgmanifest.json @@ -36,7 +36,7 @@ "component": { "type": "git", "git": { - "commitHash": "dcd5bd5fd593e31465af3d9ef291d26c646b0a4f", + "commitHash": "4a2c63365eff8823a5221db86ef490e828306f9d", "repositoryUrl": "https://github.com/abseil/abseil-cpp.git" }, "comments": "abseil_cpp" @@ -192,6 +192,16 @@ "comments": "mp11" } }, + { + "component": { + "type": "git", + "git": { + "commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a", + "repositoryUrl": "https://github.com/intel/neural-speed.git" + }, + "comments": "neural_speed" + } + }, { "component": { "type": "git", diff --git a/cmake/deps.txt b/cmake/deps.txt index fda27e5e93797..ba9c2bb73cf7a 100644 --- a/cmake/deps.txt +++ b/cmake/deps.txt @@ -12,7 +12,7 @@ # NOTE: You must run deps_update_and_upload.py and generate_cgmanifest.py when ready to test your changes in a CI. # See https://microsoft.sharepoint.com/teams/ONNX2/_layouts/OneNote.aspx?id=%2Fteams%2FONNX2%2FShared%20Documents%2FNotebooks%2FONNX%20Ecosystem%20Team%20Notebook&wd=target%28Development.one%7C63D3AB47-51D1-4A62-9965-66882234BD44%2FAdd%20or%20update%20a%20dependency%20in%20deps.txt%7C0E9ED71D-89D5-40FA-B05F-C0123289C591%2F%29 # -abseil_cpp;https://github.com/abseil/abseil-cpp/archive/dcd5bd5fd593e31465af3d9ef291d26c646b0a4f.zip;6cc204586014e189f5c0fe3274f83162fa7c700c +abseil_cpp;https://github.com/abseil/abseil-cpp/archive/refs/tags/20240116.0.zip;bc2cec6baaad67fcb6c0c38972b687d4797927e9 cxxopts;https://github.com/jarro2783/cxxopts/archive/3c73d91c0b04e2b59462f0a741be8c07024c1bc0.zip;6c6ca7f8480b26c8d00476e0e24b7184717fe4f0 date;https://github.com/HowardHinnant/date/archive/refs/tags/v3.0.1.zip;2dac0c81dc54ebdd8f8d073a75c053b04b56e159 dlpack;https://github.com/dmlc/dlpack/archive/refs/tags/v0.6.zip;4d565dd2e5b31321e5549591d78aa7f377173445 @@ -34,6 +34,7 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36 microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5 mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41 mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063 +neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11 #use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459) onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035 diff --git a/cmake/external/abseil-cpp.cmake b/cmake/external/abseil-cpp.cmake index 3bcd4109e2888..57cfbee4644ef 100644 --- a/cmake/external/abseil-cpp.cmake +++ b/cmake/external/abseil-cpp.cmake @@ -19,7 +19,7 @@ if(WIN32 AND NOT Patch_FOUND) set(ABSL_ENABLE_INSTALL ON) endif() # NB! Advancing Abseil version changes its internal namespace, -# currently absl::lts_20230125 which affects abseil-cpp.natvis debugger +# currently absl::lts_20240116 which affects abseil-cpp.natvis debugger # visualization file, that must be adjusted accordingly, unless we eliminate # that namespace at build time. FetchContent_Declare( diff --git a/cmake/external/abseil-cpp.natvis b/cmake/external/abseil-cpp.natvis index 1e5a36fb9efb9..a4fb63b6a8377 100644 --- a/cmake/external/abseil-cpp.natvis +++ b/cmake/external/abseil-cpp.natvis @@ -1,6 +1,6 @@ - + @@ -24,7 +24,7 @@ - + @@ -51,7 +51,7 @@ - + *($T1 *){value} (*($T1 *){value}) @@ -60,7 +60,7 @@ - + *($T1 *)this (*($T1 *)this) @@ -68,7 +68,7 @@ - + {value.first}, {value.second} ({value.first}, {value.second}) diff --git a/cmake/external/neural_speed.cmake b/cmake/external/neural_speed.cmake index e66e2acfb209a..ed711351403a7 100644 --- a/cmake/external/neural_speed.cmake +++ b/cmake/external/neural_speed.cmake @@ -7,12 +7,9 @@ endif() if(USE_NEURAL_SPEED) FetchContent_Declare( neural_speed - URL https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip - URL_HASH SHA1=65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939 + URL ${DEP_URL_neural_speed} + URL_HASH SHA1=${DEP_SHA1_neural_speed} ) set(BTLA_USE_OPENMP OFF) - FetchContent_MakeAvailable(neural_speed) - if(NOT neural_speed_POPULATED) - FetchContent_Populate(neural_speed) - endif() + onnxruntime_fetchcontent_makeavailable(neural_speed) endif() diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake index f89d2150a6830..17de2aa4aaea6 100644 --- a/cmake/onnxruntime_mlas.cmake +++ b/cmake/onnxruntime_mlas.cmake @@ -355,19 +355,23 @@ else() ${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S ${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S ${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S + ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S ${MLAS_SRC_DIR}/activate_fp16.cpp ${MLAS_SRC_DIR}/dwconv.cpp ${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp ${MLAS_SRC_DIR}/pooling_fp16.cpp ${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp ${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp + ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp ) set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ") + set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ") + set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ") endif() if(ONNXRUNTIME_MLAS_MULTI_ARCH) diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index fa395802d95ff..0987d6d164dbd 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) if (onnxruntime_USE_CUDA) list(APPEND onnxruntime_shared_lib_test_LIBS cudart) endif() + if (onnxruntime_USE_ROCM) + list(APPEND onnxruntime_shared_lib_test_LIBS hip::host) + endif() if (onnxruntime_USE_TENSORRT) list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER}) endif() @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP) target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu) endif() + if (onnxruntime_USE_ROCM) + target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include) + target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__) + endif() if (CMAKE_SYSTEM_NAME STREQUAL "Android") target_sources(onnxruntime_shared_lib_test PRIVATE "${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..e7b537d6894c8 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr #### Attributes
+
do_rotary : int
+
Whether to use rotary position embedding. Default value is 0.
kv_num_heads : int (required)
Number of attention heads for k and v
local_window_size : int
left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
num_heads : int (required)
Number of attention heads for q
+
rotary_interleaved : int
+
Rotate using interleaved pattern. Default value is 0 (False).
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs +#### Inputs (7 - 9)
query : T
-
Query with shape (batch_size, sequence_length, hidden_size)
-
key : T
+
Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+
key (optional) : T
Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-
value : T
+
value (optional) : T
Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
past_key (optional) : T
past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
total_sequence_length : M
Scalar tensor of total sequence length (past + new).
+
cos_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
+
sin_cache (optional) : T
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs @@ -5761,7 +5769,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14) +#### Inputs (5 - 15)
input_ids : F
@@ -5792,6 +5800,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
extra_decoding_ids (optional) : I
Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+
temperature (optional) : T
+
Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9ecc58bee0725..31cca232fde34 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -499,7 +499,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)| |Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)| |WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)| | | | | @@ -843,7 +843,7 @@ Do not modify directly.* |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| -|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| +|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)| |Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| |LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| @@ -876,7 +876,7 @@ Do not modify directly.* |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| +|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)| | | | | @@ -922,10 +922,12 @@ Do not modify directly.* |BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| |Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)| @@ -952,7 +954,8 @@ Do not modify directly.* |DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| +|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**

or

*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| +|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)| |Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -961,7 +964,8 @@ Do not modify directly.* |DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)| |Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)| |Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)| -|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| +|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)| |||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)| |Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| @@ -1004,7 +1008,8 @@ Do not modify directly.* |Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)| |||11+|**T** = tensor(float), tensor(float16)| |||1+|**T** = tensor(float), tensor(float16)| -|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Identity|*in* input:**T**
*out* output:**T**

or

*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| @@ -1099,7 +1104,8 @@ Do not modify directly.* |||7+|**T** = tensor(float), tensor(float16)| |QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)| |QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)| -|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**

or

*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| +|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)| |RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)| |||7+|**T** = tensor(float), tensor(float16)| @@ -1150,7 +1156,8 @@ Do not modify directly.* |Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)| |||13+|**T** = tensor(float), tensor(float16)| |||6+|**T** = tensor(float), tensor(float16)| -|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**

or

*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**

or

*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)| @@ -1178,7 +1185,8 @@ Do not modify directly.* |SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| -|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)| @@ -1188,7 +1196,8 @@ Do not modify directly.* |||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)| |Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)| -|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| +|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)| |Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**

or

*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| |||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)| diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index aca9f4896fbdb..2ce9d361e8e56 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions { has_user_compute_stream{}, user_compute_stream{}, default_memory_arena_cfg{}, + enable_hip_graph{false}, tunable_op_enable{false}, tunable_op_tuning_enable{false}, tunable_op_max_tuning_duration_ms{} {} @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions { */ OrtArenaCfg* default_memory_arena_cfg; + int enable_hip_graph; + /** \brief Enable TunableOp for using. * Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default. * This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE. @@ -3608,6 +3611,14 @@ struct OrtApi { * - "1": Faster preparation time, less optimal graph. * - "2": Longer preparation time, more optimal graph. * - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details. + * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown). + * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options: + * - "0": Default (none). + * - "68" + * - "69" + * - "73" + * - "75" + * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device). * * SNPE supported keys: * "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16", diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h index 8fd51962bf087..b282438795eb5 100644 --- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h @@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p // Flag to specify whether to dump the EP context into the Onnx model. // "0": dump the EP context into separate file, keep the file name in the Onnx model. // "1": dump the EP context into the Onnx model. (default). -static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; \ No newline at end of file +static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode"; + +// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul. +// Option values: +// - "0": Gemm FastMath mode is not enabled. [DEFAULT] +// - "1": Gemm FastMath mode is enabled. +static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16"; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 3a1c0d1bb8fa1..4a5e2b7ef3b1e 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -8,7 +8,7 @@ #include "onnxruntime/core/session/onnxruntime_c_api.h" #include "OrtJniUtil.h" #include "ai_onnxruntime_OrtSession_SessionOptions.h" -#ifdef WIN32 +#ifdef _WIN32 #include #else #include @@ -318,7 +318,7 @@ JNIEXPORT void JNICALL Java_ai_onnxruntime_OrtSession_00024SessionOptions_closeC // Iterate the handles, calling the appropriate close function for (jint i = 0; i < numHandles; i++) { -#ifdef WIN32 +#ifdef _WIN32 FreeLibrary((void*)handles[i]); #else dlclose((void*)handles[i]); diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c index 9f7b8d3a3dcfc..464234c34798a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c +++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c @@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes } } wchar_t* optimizerStr = NULL; - if (optimizerPath == NULL) { + if (optimizerPath != NULL) { optimizerStr = copyAndPad(jniEnv, optimizerPath); if (optimizerStr == NULL) { // exception has been thrown in Java, go to cleanup and return null. diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index 2f510308d9306..2557971eb4ded 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -52,6 +52,7 @@ Do not modify directly.* | GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | | | Greater | ai.onnx(7-8,9-12,13+) | | | GreaterOrEqual | ai.onnx(12-15,16+) | | +| HardSigmoid | ai.onnx(6+) | | | If | ai.onnx(1-10,11-12,13-18,19+) | | | InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | | | LayerNormalization | ai.onnx(17+) | | diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts index 9d4d5875310b7..68054210e79a7 100644 --- a/js/web/lib/wasm/binding/ort-wasm.d.ts +++ b/js/web/lib/wasm/binding/ort-wasm.d.ts @@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule { jsepCreateDownloader: (gpuBuffer: GPUBuffer, size: number, type: Tensor.GpuBufferDataTypes) => () => Promise; + /** + * [exported from js_internal_api.js] Called when InferenceSession.run started. + */ + jsepOnRunStart: () => void; // #endregion } diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts index 2956ec1cad4da..8ca025d66550c 100644 --- a/js/web/lib/wasm/jsep/backend-webgpu.ts +++ b/js/web/lib/wasm/jsep/backend-webgpu.ts @@ -208,7 +208,7 @@ export class WebGpuBackend { Object.defineProperty(this.env.webgpu, 'device', {value: this.device}); - // init queryType, which is necessary for createKernel + // init queryType, which is necessary for InferenceSession.create this.setQueryType(); } @@ -222,18 +222,6 @@ export class WebGpuBackend { getCommandEncoder(): GPUCommandEncoder { if (!this.commandEncoder) { this.commandEncoder = this.device.createCommandEncoder(); - - // refresh queryType, as sometimes we only need to enable query for a specific run - this.setQueryType(); - if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { - this.querySet = this.device.createQuerySet({ - type: 'timestamp', - count: this.maxDispatchNumber * 2, - }); - this.queryResolveBuffer = this.device.createBuffer( - // eslint-disable-next-line no-bitwise - {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); - } } return this.commandEncoder; } @@ -639,6 +627,7 @@ export class WebGpuBackend { return createView(data.buffer, type); }; } + // #endregion writeTimestamp(index: number): void { if (this.queryType !== 'inside-passes') { return; @@ -655,7 +644,19 @@ export class WebGpuBackend { } else if (this.device.features.has('timestamp-query')) { this.queryType = 'at-passes'; } + + if (this.queryType !== 'none' && typeof this.querySet === 'undefined') { + this.querySet = this.device.createQuerySet({ + type: 'timestamp', + count: this.maxDispatchNumber * 2, + }); + this.queryResolveBuffer = this.device.createBuffer( + // eslint-disable-next-line no-bitwise + {size: this.maxDispatchNumber * 2 * 8, usage: GPUBufferUsage.COPY_SRC | GPUBufferUsage.QUERY_RESOLVE}); + } } } - // #endregion + onRunStart(): void { + this.setQueryType(); + } } diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts index 90e02da986b8f..d737a28654220 100644 --- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts +++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts @@ -25,7 +25,7 @@ import * as pool from './ops/pool'; import {range} from './ops/range'; import {reduceL1, reduceL2, reduceLogSum, reduceLogSumExp, reduceMax, reduceMean, reduceMin, reduceProd, reduceSum, reduceSumSquare} from './ops/reduce'; import {parseResizeAttributes, resize} from './ops/resize'; -import {parseSkipLayerNormAttributes, skipLayerNorm} from './ops/skip-layer-norm'; +import {skipLayerNorm} from './ops/skip-layer-norm'; import {parseSliceAttributes, slice} from './ops/slice'; import {parseSoftmaxAttributes, softmax} from './ops/softmax'; import {parseSplitAttributes, split} from './ops/split'; @@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]], ['Greater', [binaryOps.greater]], ['GreaterOrEqual', [binaryOps.greaterOrEqual]], + ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]], ['InstanceNormalization', [instanceNorm]], ['LayerNormalization', [layerNorm]], ['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]], @@ -115,7 +116,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new ['Sin', [unaryOps.sin]], ['Sinh', [unaryOps.sinh]], ['Slice', [slice, parseSliceAttributes]], - ['SkipLayerNormalization', [skipLayerNorm, parseSkipLayerNormAttributes]], + ['SkipLayerNormalization', [skipLayerNorm]], ['Split', [split, parseSplitAttributes]], ['Sqrt', [unaryOps.sqrt]], ['Softmax', [softmax, parseSoftmaxAttributes]], diff --git a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts index a2fda9f07d09f..509a722f4b52a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/skip-layer-norm.ts @@ -4,10 +4,10 @@ import {DataType} from '../../../wasm-common'; import {TensorView} from '../../tensor-view'; import {ShapeUtil} from '../../util'; -import {AttributeWithCacheKey, createAttributeWithCacheKey} from '../attribute-with-cache-key'; -import {ComputeContext, ProgramInfo} from '../types'; +import {AttributeWithCacheKey} from '../attribute-with-cache-key'; +import {ComputeContext, ProgramInfo, ProgramUniform} from '../types'; -import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType,} from './common'; +import {castToF32, fillVector, getMaxComponents, inputVariable, outputVariable, ShaderHelper, sumVector, tensorTypeToWsglStorageType, UniformsArrayType} from './common'; export interface SkipLayerNormAttributes extends AttributeWithCacheKey { epsilon: number; @@ -86,60 +86,74 @@ const createSkipLayerNormProgramInfo = const hasInputSkipBiasSumOutput = outputCount > 3; const components = getMaxComponents(hiddenSize); - const variables = [ - inputVariable('x', inputs[0].dataType, inputs[0].dims, components), - inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), - inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), - ]; - if (hasBetaInput) { - variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); - } - if (hasBiasInput) { - variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); - } - variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); - if (hasMeanOutput) { - variables.push(outputVariable('meanOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInvStdDevOutput) { - variables.push(outputVariable('invStdOutput', DataType.float, meanInvStdDevDim)); - } - if (hasInputSkipBiasSumOutput) { - variables.push(outputVariable('inputSkipBiasSum', inputs[0].dataType, outputShape, components)); - } - const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - const hiddenSize: f32 = ${hiddenSize}; - const hiddenSizeVectorized: u32 = ${hiddenSize / components}; - const epsilon: f32 = ${attributes.epsilon}; - ${shaderHelper.declareVariables(...variables)} + const programUniforms: ProgramUniform[] = [ + {type: 'uint32', data: outputSize}, + {type: 'uint32', data: components}, + {type: 'uint32', data: hiddenSize}, + {type: 'float32', data: attributes.epsilon}, + ]; + const getShaderSource = (shaderHelper: ShaderHelper) => { + const uniformsArray: UniformsArrayType = [ + {name: 'output_size', type: 'u32'}, + {name: 'components', type: 'u32'}, + {name: 'hidden_size', type: 'u32'}, + {name: 'epsilon', type: 'f32'}, + ]; + const variables = [ + inputVariable('x', inputs[0].dataType, inputs[0].dims, components), + inputVariable('skip', inputs[1].dataType, inputs[1].dims, components), + inputVariable('gamma', inputs[2].dataType, inputs[2].dims, components), + ]; + if (hasBetaInput) { + variables.push(inputVariable('beta', inputs[3].dataType, inputs[3].dims, components)); + } + if (hasBiasInput) { + variables.push(inputVariable('bias', inputs[4].dataType, inputs[4].dims, components)); + } + variables.push(outputVariable('output', inputs[0].dataType, outputShape, components)); + if (hasMeanOutput) { + variables.push(outputVariable('mean_output', DataType.float, meanInvStdDevDim)); + } + if (hasInvStdDevOutput) { + variables.push(outputVariable('inv_std_output', DataType.float, meanInvStdDevDim)); + } + if (hasInputSkipBiasSumOutput) { + variables.push(outputVariable('input_skip_bias_sum', inputs[0].dataType, outputShape, components)); + } + const dataType = tensorTypeToWsglStorageType(inputs[0].dataType); + return ` + + ${shaderHelper.registerUniforms(uniformsArray).declareVariables(...variables)} ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize / hiddenSize)} - let offset = global_idx * hiddenSizeVectorized; + ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes('uniforms.output_size / uniforms.hidden_size')} + let hidden_size_vectorized: u32 = uniforms.hidden_size / uniforms.components; + let offset = global_idx * hidden_size_vectorized; var sum = ${fillVector('f32', components)}; var squareSum = ${fillVector('f32', components)}; - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - let skipValue = skip[offset + i]; - let biasValue = ${hasBiasInput ? 'bias[i]' : '0.0'}; - let inputValue = x[offset + i]; - let value = inputValue + skipValue + biasValue; - ${hasInputSkipBiasSumOutput ? 'inputSkipBiasSum[offset + i] = value;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + let skip_value = skip[offset + i]; + let bias_value = ${hasBiasInput ? 'bias[i]' : '0.0'}; + let input_value = x[offset + i]; + let value = input_value + skip_value + bias_value; + ${hasInputSkipBiasSumOutput ? 'input_skip_bias_sum[offset + i] = value;' : ''} output[offset + i] = value; - let f32Value = ${castToF32(dataType, components, 'value')}; - sum += f32Value; - squareSum += f32Value * f32Value; + let f32_value = ${castToF32(dataType, components, 'value')}; + sum += f32_value; + squareSum += f32_value * f32_value; } - let mean = ${sumVector('sum', components)} / hiddenSize; - let invStdDev = inverseSqrt(${sumVector('squareSum', components)} / hiddenSize - mean * mean + epsilon); - ${hasMeanOutput ? 'meanOutput[global_idx] = mean;' : ''} - ${hasInvStdDevOutput ? 'invStdOutput[global_idx] = invStdDev;' : ''} - for (var i: u32 = 0; i < hiddenSizeVectorized; i++) { - output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(invStdDev) * gamma[i] - + ${hasBetaInput ? 'beta[i]' : '0.0'}; + let mean = ${sumVector('sum', components)} / f32(uniforms.hidden_size); + let inv_std_dev = inverseSqrt(${ + sumVector('squareSum', components)} / f32(uniforms.hidden_size) - mean * mean + uniforms.epsilon); + ${hasMeanOutput ? 'mean_output[global_idx] = mean;' : ''} + ${hasInvStdDevOutput ? 'inv_std_output[global_idx] = inv_std_dev;' : ''} + for (var i: u32 = 0; i < hidden_size_vectorized; i++) { + output[offset + i] = (output[offset + i] - ${dataType}(mean)) * ${dataType}(inv_std_dev) * gamma[i] + ${ + hasBetaInput ? 'beta[i]' : '0.0'}; } }`; + }; const outputs = [{dims: outputShape, dataType: inputs[0].dataType}]; if (outputCount > 1) { outputs.push({dims: meanInvStdDevDim, dataType: DataType.float}); @@ -150,12 +164,14 @@ const createSkipLayerNormProgramInfo = if (outputCount > 3) { outputs.push({dims: inputShape, dataType: inputs[0].dataType}); } - return { name: 'SkipLayerNormalization', - shaderCache: {hint: attributes.cacheKey}, + shaderCache: { + hint: `${components};${hasMeanOutput};${hasInvStdDevOutput};${hasInputSkipBiasSumOutput}`, + inputDependencies: inputs.map((_input, _index) => 'type') + }, getShaderSource, - getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}}), + getRunData: () => ({outputs, dispatchGroup: {x: Math.ceil(outputSize / hiddenSize / 64)}, programUniforms}), }; }; @@ -178,8 +194,3 @@ export const skipLayerNorm = (context: ComputeContext, attributes: SkipLayerNorm context.compute( createSkipLayerNormProgramInfo(context.inputs, attributes, context.outputCount, isTraining), {outputs}); }; - -export const parseSkipLayerNormAttributes = (attributes: Record): SkipLayerNormAttributes => { - const epsilon = attributes.epsilon as number; - return createAttributeWithCacheKey({epsilon}); -}; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts index a25e7fe4229b4..82311d72e58b9 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts @@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`)); }; +export interface HardSigmoidAttributes extends AttributeWithCacheKey { + readonly alpha: number; + readonly beta: number; +} + +export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes => + createAttributeWithCacheKey(attributes as { + alpha: number; + beta: number; + }); + +export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => { + const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType); + context.compute(createElementwiseProgramInfo( + context.inputs[0], 'HardSigmoid', + a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${ + attributes.beta})))`, + undefined, attributes.cacheKey)); +}; + export const sin = (context: ComputeContext): void => { context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin')); }; diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts index 5821fac3c468f..8768643fa7257 100644 --- a/js/web/lib/wasm/wasm-core-impl.ts +++ b/js/web/lib/wasm/wasm-core-impl.ts @@ -488,8 +488,8 @@ export const run = async( } } + wasm.jsepOnRunStart?.(); let errorCode: number; - if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) { errorCode = await wasm._OrtRunWithBinding( sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle); diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc index 812e9d7c2def0..ad1c0a72c11d3 100644 --- a/js/web/test/data/ops/fused-conv.jsonc +++ b/js/web/test/data/ops/fused-conv.jsonc @@ -108,5 +108,39 @@ ] } ] + }, + { + "name": "fused conv with clip", + "operator": "FusedConv", + "attributes": [ + { "name": "activation", "data": "Clip", "type": "string" }, + { "name": "kernel_shape", "data": [2, 2], "type": "ints" }, + { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" } + ], + "opset": { "domain": "com.microsoft", "version": 1 }, + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [10, 20, 30, 40, 50, 60, 70, 80, 90], + "dims": [1, 1, 3, 3], + "type": "float32" + }, + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ], + "outputs": [ + { + "data": [400, 470, 600, 600], + "dims": [1, 1, 2, 2], + "type": "float32" + } + ] + } + ] } ] diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc index 033b3b3f4b0f5..373b3c645df57 100644 --- a/js/web/test/suite-test-list.jsonc +++ b/js/web/test/suite-test-list.jsonc @@ -597,9 +597,9 @@ // // "test_hardmax_example", // // "test_hardmax_negative_axis", // // "test_hardmax_one_hot", - // // "test_hardsigmoid_default", - // // "test_hardsigmoid_example", - // // "test_hardsigmoid", + "test_hardsigmoid_default", + "test_hardsigmoid_example", + "test_hardsigmoid", // // "test_hardswish_expanded", // // "test_hardswish", "test_if", diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index da489a6901512..8afeb874750b4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters { bool is_unidirectional; // causal int local_window_size; bool kv_share_buffer; + bool is_packed_qkv; bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor + bool do_rotary; + bool rotary_interleaved; float scale; AttentionQkvFormat qkv_format; AttentionQkvFormat past_kv_format; + int zeros_count; + int* zero_ptr; }; namespace attention { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 3962486d5b5eb..bb6885c3216bc 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0; ORT_ENFORCE(logits_processor >= 0, "logits_processor shall be a non-negative integer, got ", logits_processor); -} + if (this->model_type == IGenerationParameters::kModelTypeWhisper) { + auto* temperature_tensor = context->Input(14); + if (temperature_tensor) { + if (temperature_tensor->IsDataType()) { + temperature = *temperature_tensor->Data(); + } else { + temperature = static_cast(*temperature_tensor->Data()); + } + } else { + temperature = 1.0f; + } + } +} void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc index d6eb87228bb4a..2c296bf4f8483 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc @@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in Status mha_fwd_kvcache(const cudaDeviceProp& dprops, cudaStream_t stream, void* q, // batch_size x seqlen_q x num_heads x head_size - void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size - void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size - void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size + void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size + void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits, void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size) { - // if (seqlen_q == 1) { - // is_causal = false; - // } // causal=true is the same as causal=false in this case - + int local_window_size, + bool is_rotary_interleaved, + bool is_packed_qkv) { auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; const int head_size_rounded = round_multiple(head_size, 32); const int seqlen_q_rounded = round_multiple(seqlen_q, 128); const int seqlen_k_rounded = round_multiple(seqlen_k, 128); + // In kv-cache case, seqlen_k_max as kv sequence length Flash_fwd_params params; set_params_fprop(params, batch_size, @@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, is_causal ? 0 : -1); params.dprops = &dprops; - if (k != nullptr && v != nullptr) { + if (k_new != nullptr && v_new != nullptr) { params.seqlen_knew = seqlen_k_new; - params.knew_ptr = k; - params.vnew_ptr = v; + params.knew_ptr = k_new; + params.vnew_ptr = v_new; // All stride are in elements, not bytes. - params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; - params.knew_row_stride = num_heads_k * head_size; - params.vnew_row_stride = num_heads_k * head_size; + if (is_packed_qkv) { + params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size); + params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size); + } else { + params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size; + params.knew_row_stride = num_heads_k * head_size; + params.vnew_row_stride = num_heads_k * head_size; + } params.knew_head_stride = head_size; params.vnew_head_stride = head_size; } else { @@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, params.cu_seqlens_k = static_cast(seqlens_k_); } + if (rotary_cos != nullptr) { + params.rotary_cos_ptr = rotary_cos; + params.rotary_sin_ptr = rotary_sin; + params.is_rotary_interleaved = is_rotary_interleaved; + params.rotary_dim = (head_size / 16) * 16; + } + params.num_splits = num_splits; if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) { params.softmax_lseaccum_ptr = softmax_lse_accum; @@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, } // Only split kernel supports appending to KV cache - run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr); + run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr); return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h index 3d75d6834b8e0..387d1cf9d84fe 100644 --- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h +++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h @@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, void* out, // batch_size x seqlen_q x num_heads x head_size void* softmax_lse, // batch_size x num_heads x seqlen_q void* seqlens_k_, // batch_size + void* rotary_sin, // seqlen_ro x (rotary_dim / 2) + void* rotary_cos, // seqlen_ro x (rotary_dim / 2) int batch_size, int num_heads, int num_heads_k, @@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops, int num_splits = 0, void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded - int local_window_size = -1); + int local_window_size = -1, + bool is_rotary_interleaved = false, + bool is_packed_qkv = false); size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc index fd6fb79742cac..fe56f84f0a886 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc @@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) kv_num_heads_ = static_cast(kv_num_heads); is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1; local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1)); + do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1; scale_ = info.GetAttrOrDefault("scale", 0.0f); #if USE_FLASH_ATTENTION @@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info) #else disable_memory_efficient_attention_ = true; #endif + if (!disable_flash_attention_) { + zeros_ = this->GetScratchBuffer(kZerosCount, nullptr); + } } template @@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { const Tensor* past_value = context->Input(4); const Tensor* seqlens_k = context->Input(5); const Tensor* total_seqlen = context->Input(6); + const Tensor* cos_cache = context->Input(7); + const Tensor* sin_cache = context->Input(8); auto& device_prop = GetDeviceProp(); GroupQueryAttentionParameters parameters; @@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { value, past_key, past_value, + cos_cache, + sin_cache, ¶meters, num_heads_, kv_num_heads_, @@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { scale_, device_prop.maxThreadsPerBlock)); parameters.local_window_size = local_window_size_; + parameters.is_unidirectional = is_unidirectional_; + parameters.zeros_count = kZerosCount; + parameters.zero_ptr = zeros_.get(); + // parameters.left_padding = left_padding_; int sequence_length = parameters.sequence_length; + parameters.do_rotary = do_rotary_; + parameters.rotary_interleaved = rotary_interleaved_; TensorShapeVector output_shape(3); output_shape[0] = static_cast(parameters.batch_size); @@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { !use_flash_attention && !disable_memory_efficient_attention_ && local_window_size_ == -1 && + do_rotary_ == false && + key != nullptr && (parameters.head_size & 7) == 0 && parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length && (sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) && @@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { Tensor* present_value = context->Output(2, present_shape); data.query = reinterpret_cast(query->Data()); - data.key = reinterpret_cast(key->Data()); - data.value = reinterpret_cast(value->Data()); + data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data()); + data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data()); data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data()); data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data()); data.output = reinterpret_cast(output->MutableData()); @@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const { if (fmha_buffer != nullptr) { data.fmha_buffer = reinterpret_cast(fmha_buffer.get()); } + // Rotary + if (parameters.do_rotary) { + data.cos_cache = reinterpret_cast(cos_cache->Data()); + data.sin_cache = reinterpret_cast(sin_cache->Data()); + } cublasHandle_t cublas = GetCublasHandle(context); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h index 54a8127e29e7b..15573ece166fc 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h @@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel { int num_heads_; // number of attention heads int kv_num_heads_; // different for k and v for group query attention int local_window_size_; + bool is_unidirectional_; bool is_past_bsnh_; + bool do_rotary_; + bool rotary_interleaved_; float scale_; bool disable_flash_attention_; bool disable_memory_efficient_attention_; + static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256) + IAllocatorUniquePtr zeros_; }; } // namespace cuda diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h index 2cb9955807f26..853e1a710cb24 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h @@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query, bool is_past_bsnh, float scale) { // Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length - // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) - // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) + // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr + // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr // no packing for q/k/v: - // query (Q) : (B, S, D) - // key (K) : (B, S, D_kv) - // value (V) : (B, S, D_kv) + // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv)) + // key (K) : (B, S, D_kv) or nullptr + // value (V) : (B, S, D_kv) or nullptr ORT_UNUSED_PARAMETER(value); AttentionQkvFormat qkv_format = Q_K_V_BSNH; AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH; - + const bool is_packed_qkv = key == nullptr; const auto& query_dims = query->Shape().GetDims(); - const auto& key_dims = key->Shape().GetDims(); if (query_dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ", @@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query, int batch_size = static_cast(query_dims[0]); int sequence_length = static_cast(query_dims[1]); int q_hidden_size = static_cast(query_dims[2]); - int head_size = static_cast(q_hidden_size) / num_heads; + int head_size = 0; + + if (num_heads % kv_num_heads != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", + num_heads % kv_num_heads); + } - int kv_hidden_size = static_cast(key_dims[2]); + int kv_hidden_size = 0; + // Check key and value when not packed + if (!is_packed_qkv) { + head_size = static_cast(q_hidden_size) / num_heads; + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + const auto& key_dims = key->Shape().GetDims(); + if (key_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", + key_dims.size()); + } else if (query_dims[0] != key_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != key_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'key' shall have same dim 1 (sequence length)"); + } + kv_hidden_size = static_cast(key_dims[2]); + const auto& value_dims = value->Shape().GetDims(); + if (value_dims.size() != 3) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", + value_dims.size()); + } else if (query_dims[0] != value_dims[0]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 0 (batch size)"); + } else if (query_dims[1] != value_dims[1]) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'query' and 'value' shall have same dim 1 (sequence length)"); + } else if (value_dims[2] != kv_hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); + } + } else { + // Check packed qkv + head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads); + if (head_size % 8 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size must be a multiple of 8. Got head_size % 8 == ", + head_size % 8); + } + if (value != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv."); + } + q_hidden_size = head_size * num_heads; + kv_hidden_size = head_size * kv_num_heads; + } + // Check past-present KV int32_t past_sequence_length = 0; if (past_key != nullptr && past_value != nullptr) { const auto& past_key_dims = past_key->Shape().GetDims(); @@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query, "Input 'past_key' and 'past_value' shall be both present or both absent."); } - if (key_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ", - key_dims.size()); - } - if (query_dims[0] != key_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'key' shall have same dim 0 (batch size)"); - } - - if (num_heads % kv_num_heads != 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ", - num_heads % kv_num_heads); - } - - const auto& value_dims = value->Shape().GetDims(); - if (value_dims.size() != 3) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ", - value_dims.size()); - } - - if (query_dims[0] != value_dims[0]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query' and 'value' shall have same dim 0 (batch_size)"); - } - - if (static_cast(sequence_length) != value_dims[1]) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)"); - } - - if (value_dims[2] != kv_hidden_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key."); - } - // Check seqlens_k tensor (holding past seqlen for token gen) const auto& seqlens_dim = seqlens_k->Shape().GetDims(); if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) { @@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query, int total_sequence_length = *((*total_seqlen).template Data()); int present_sequence_length = std::max(total_sequence_length, past_sequence_length); + if (cos_cache != nullptr && sin_cache != nullptr) { + const auto& cos_dims = cos_cache->Shape().GetDims(); + const auto& sin_dims = sin_cache->Shape().GetDims(); + + if (head_size % 16 != 0) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "head_size shall be a multiple of 16. Got head_size % 16 == ", + head_size % 16); + } + if (cos_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 0 must be of present_sequence_length."); + } + if (sin_dims[0] != present_sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 0 must be of present_sequence_length."); + } + if (cos_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + if (sin_dims[1] != (head_size / 16) * 8) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8."); + } + } else if (cos_cache != nullptr || sin_cache != nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'cos_cache' and 'sin_cache' shall be both present or both absent."); + } + bool is_prompt = sequence_length != 1; if (parameters != nullptr) { @@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query, output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors output_parameters->hidden_size = q_hidden_size; output_parameters->num_heads = num_heads; - output_parameters->head_size = q_hidden_size / num_heads; + output_parameters->head_size = head_size; output_parameters->kv_hidden_size = kv_hidden_size; output_parameters->kv_num_heads = kv_num_heads; + output_parameters->is_packed_qkv = is_packed_qkv; output_parameters->is_unidirectional = true; output_parameters->is_prompt = is_prompt; output_parameters->scale = scale; @@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query, const Tensor* value, const Tensor* past_key, const Tensor* past_value, + const Tensor* cos_cache, + const Tensor* sin_cache, void* parameters, int num_heads, int kv_num_heads, @@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block); } - return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); + return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale); } } // namespace group_query_attention_helper diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu index 5b0f5d0cfe601..d88e9a49fb5ee 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu @@ -151,9 +151,10 @@ template Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters, GroupQueryAttentionData& data, cudaStream_t stream, - const int max_threads_per_block) { + const int max_threads_per_block, + const bool past_only = false) { const int batch_size = parameters.batch_size; - const int kv_sequence_length = parameters.sequence_length; + const int kv_sequence_length = past_only ? 0 : parameters.sequence_length; const int past_sequence_length = parameters.seqlen_past_kv_cache; const int present_sequence_length = parameters.seqlen_present_kv_cache; const int kv_num_heads = parameters.kv_num_heads; @@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters, return CUDA_CALL(cudaGetLastError()); } - __global__ void PastToTotalSeqlen(int32_t* seqlens_k, int32_t* seqlens_k_buff, const int add_seqlen) { @@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k, // Convert Past to Total sequence length tensor Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k, int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream, - const int threads_per_block) { + const int threads_per_block) { if (parameters.is_prompt) { return Status::OK(); } @@ -482,91 +482,63 @@ Status FlashAttention( const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int kv_sequence_length = parameters.sequence_length; - const int present_sequence_length = parameters.seqlen_present_kv_cache; const int num_heads = parameters.num_heads; const int kv_num_heads = parameters.kv_num_heads; const int head_size = parameters.head_size; AttentionQkvFormat past_kv_format = parameters.past_kv_format; - - void* query = reinterpret_cast(const_cast(data.query)); - void* key = reinterpret_cast(const_cast(data.key)); - void* value = reinterpret_cast(const_cast(data.value)); - bool is_causal = true; - bool is_bf16 = std::is_same::value; - // Note: seqlens_k is past sequence length for flash - if (parameters.is_prompt) { - // Launch kernel to copy seqlen - constexpr int thr_per_blk = 256; - int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk; - repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size); - } - - void* seqlens_k = reinterpret_cast(data.seqlens_k); - - if (parameters.kv_share_buffer) { - // Share buffer case - if (data.past_key == nullptr || data.past_key != data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv shall share the same tensor when kv_share_buffer is on."); - } - - if (parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block)); - key = nullptr; - value = nullptr; - seqlens_k = reinterpret_cast(data.seqlens_k_total); - } - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); + void* query = reinterpret_cast(const_cast(data.query)); + void* key; + void* value; - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, kv_sequence_length, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + if (!parameters.is_packed_qkv) { + key = reinterpret_cast(const_cast(data.key)); + value = reinterpret_cast(const_cast(data.value)); } else { - // Not share buffer case - // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient - if (data.past_key != nullptr && data.past_key == data.present_key) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, - "Past and present kv share the same tensor but kv_share_buffer is not on."); - } - - ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block)); + const size_t key_offset = static_cast(num_heads * head_size); + const size_t value_offset = static_cast(kv_num_heads * head_size); + key = reinterpret_cast(query) + key_offset; + value = reinterpret_cast(key) + value_offset; + } - if (!parameters.is_prompt) { - ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256)); + void* seqlens_k = reinterpret_cast(data.seqlens_k); + if (parameters.is_prompt) { + // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value + // user should use seqlens_k to index into output to get new tokens + if (batch_size <= parameters.zeros_count) { + seqlens_k = parameters.zero_ptr; + } else { + // Launch kernel to create larger seqlen tensor when batch_size > 256 + constexpr int thr_per_blk = 256; + int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk; + repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size); + seqlens_k = data.seqlens_k_total; } - - seqlens_k = reinterpret_cast(data.seqlens_k_total); - - void* present_key = reinterpret_cast(const_cast(data.present_key)); - void* present_value = reinterpret_cast(const_cast(data.present_value)); - - DUMP_TENSOR_INIT(); - DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1); - DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size); - DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size); - DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size); - - bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; - ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( - device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse), - seqlens_k, batch_size, num_heads, kv_num_heads, - head_size, sequence_length, present_sequence_length, 0, - scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), - reinterpret_cast(data.out_accum), parameters.local_window_size)); + } else if (!parameters.kv_share_buffer) { // copy past kv to present kv + ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true)); } + void* present_key = reinterpret_cast(const_cast(data.present_key)); + void* present_value = reinterpret_cast(const_cast(data.present_value)); + void* cos_cache = reinterpret_cast(const_cast(data.cos_cache)); + void* sin_cache = reinterpret_cast(const_cast(data.sin_cache)); + + bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH; + ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache( + device_prop, stream, query, present_key, present_value, key, value, data.output, + reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache, + batch_size, num_heads, kv_num_heads, head_size, sequence_length, + parameters.seqlen_present_kv_cache, kv_sequence_length, + scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum), + reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved, + parameters.is_packed_qkv)); + + // if (parameters.left_padding && parameters.is_prompt) { + // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock)); + // } + DUMP_TENSOR_INIT(); DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size); @@ -672,7 +644,6 @@ Status EfficientAttention( p.has_custom_right_padding = true; run_memory_efficient_attention(p); - DUMP_TENSOR_INIT(); DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size); return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h index de32d7ea93163..1bf91f9c875eb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h @@ -21,6 +21,8 @@ struct GroupQueryAttentionData { const T* past_key = nullptr; const T* past_value = nullptr; int* seqlens_k = nullptr; + const T* cos_cache = nullptr; + const T* sin_cache = nullptr; // Flash buffers T* softmax_lse = nullptr; T* softmax_lse_accum = nullptr; diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 2a90e4911f286..08cbb145a6f65 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -49,6 +49,7 @@ ONNX_OPERATOR_KERNEL_EX( .InputMemoryType(OrtMemTypeCPUInput, 9) // 'attention_mask' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 10) // 'decoder_input_ids' needs to be on CPU .InputMemoryType(OrtMemTypeCPUInput, 11) // 'logits_processor' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 14) // 'temperature' needs to be on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'sequences_scores' output on CPU .TypeConstraint("T", {DataTypeImpl::GetTensorType(), diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067d..09e7d61b71db9 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; @@ -121,8 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfebe..1d676d7a7bcac 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. diff --git a/onnxruntime/core/common/cpuid_info.cc b/onnxruntime/core/common/cpuid_info.cc index fcf9c2b03dea5..711fd595e90fd 100644 --- a/onnxruntime/core/common/cpuid_info.cc +++ b/onnxruntime/core/common/cpuid_info.cc @@ -30,6 +30,10 @@ #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #endif // ARM #endif // Linux @@ -148,6 +152,7 @@ void CPUIDInfo::ArmLinuxInit() { has_fp16_ = cpuinfo_has_arm_neon_fp16_arith(); has_arm_neon_i8mm_ = cpuinfo_has_arm_i8mm(); has_arm_sve_i8mm_ = cpuinfo_has_arm_sve() && cpuinfo_has_arm_i8mm(); + has_arm_neon_bf16_ = cpuinfo_has_arm_neon_bf16(); const uint32_t core_cnt = cpuinfo_get_cores_count(); core_uarchs_.resize(core_cnt, cpuinfo_uarch_unknown); @@ -177,6 +182,7 @@ void CPUIDInfo::ArmLinuxInit() { has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); #endif } @@ -278,6 +284,7 @@ void CPUIDInfo::ArmWindowsInit() { /* TODO: implement them when hw+sw is available for testing these features */ has_arm_neon_i8mm_ = false; has_arm_sve_i8mm_ = false; + has_arm_neon_bf16_ = false; } #endif /* (arm or arm64) and windows */ diff --git a/onnxruntime/core/common/cpuid_info.h b/onnxruntime/core/common/cpuid_info.h index a15c75104b83a..2f8041e39f680 100644 --- a/onnxruntime/core/common/cpuid_info.h +++ b/onnxruntime/core/common/cpuid_info.h @@ -30,6 +30,7 @@ class CPUIDInfo { bool HasArmNeonDot() const { return has_arm_neon_dot_; } bool HasArmNeon_I8MM() const { return has_arm_neon_i8mm_; } bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } uint32_t GetCurrentCoreIdx() const; @@ -125,6 +126,7 @@ class CPUIDInfo { bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; #ifdef CPUIDINFO_ARCH_X86 diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..8583474a1e391 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -259,13 +259,13 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& *output_shape.add_dim() = query_dims[1]; *output_shape.add_dim() = query_dims[2]; updateOutputShape(ctx, 0, output_shape); - } else { - fail_shape_inference("Missing input 2 (value)"); } } if (ctx.getNumOutputs() > 1) { // has present output if (hasInputShape(ctx, past_key_index)) { + // auto& query_shape = getInputShape(ctx, 0); + // auto& query_dims = query_shape.dim(); auto& past_shape = getInputShape(ctx, past_key_index); auto& past_dims = past_shape.dim(); if (past_dims.size() != 4) { @@ -273,8 +273,7 @@ void GroupQueryAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& } ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, past_key_index, 1); ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, past_key_index, 1); - ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, static_cast(past_key_index) + 1, 2); + // TODO(aciddelgado): propagate output shapes depending if kv-share buffer is on or not } } } @@ -1015,18 +1014,29 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "left_window_size for local attention (like Mistral). Default value is -1 meaning unused.", AttributeProto::INT, static_cast(-1)) + .Attr("do_rotary", + "Whether to use rotary position embedding. Default value is 0.", + AttributeProto::INT, + OPTIONAL_VALUE) + .Attr("rotary_interleaved", + "Rotate using interleaved pattern. Default value is 0 (False).", + AttributeProto::INT, + OPTIONAL_VALUE) .Input(0, "query", - "Query with shape (batch_size, sequence_length, hidden_size)", + "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape" + "(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).", "T") .Input(1, "key", "Key with shape (batch_size, kv_sequence_length, kv_hidden_size) ", - "T") + "T", + OpSchema::Optional) .Input(2, "value", "Value with shape (batch_size, kv_sequence_length, kv_hidden_size)", - "T") + "T", + OpSchema::Optional) .Input(3, "past_key", "past state key with support for format BNSH. When past_key uses same tensor as present_key" @@ -1047,6 +1057,16 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "total_sequence_length", "Scalar tensor of total sequence length (past + new).", "M") + .Input(7, + "cos_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) + .Input(8, + "sin_cache", + "2D tensor with shape (max_sequence_length, head_size / 2).", + "T", + OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 982e8fd834b76..27c968a59eb91 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -1231,6 +1231,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(WhisperBeamSearch, 1, "In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) " "are treated as stop of the extra_decoding_ids for corresponding batch.", "I", OpSchema::Optional) + .Input(14, "temperature", "Temperature value to apply to logits processing during this execution's decoding. Shape is (1)", "T", OpSchema::Optional) .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I") .Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional) .Output(2, "scores", diff --git a/onnxruntime/core/mlas/inc/mlas.h b/onnxruntime/core/mlas/inc/mlas.h index bdd4dba521eba..ce7838556fbf0 100644 --- a/onnxruntime/core/mlas/inc/mlas.h +++ b/onnxruntime/core/mlas/inc/mlas.h @@ -1614,6 +1614,119 @@ MlasHalfGemmConvertPackB( void* PackedB ); +#if defined(__aarch64__) && defined(__linux__) +/** + * @brief Whether current CPU supports Bfloat16(bf16) acceleration. + */ +bool MLASCALL +MlasBf16AccelerationSupported(); + +/** + * @brief Interface for bf16 gemm post processors. + * + * Example implementation of this interface includes activations, + * conversion from single precision to precision, etc. + * + * SBGEMM is computed tile by tile. When a tile of result matrix + * is produced, the method Process() is called to process this tile. + * Parameters of this method describe the location and shape of the + * tile. + */ +class MLAS_SBGEMM_POSTPROCESSOR +{ + public: + virtual void Process(float*, /**< the address of matrix to process */ + size_t, /**< the start row index of matrix */ + size_t, /**< the start col index of matrix */ + size_t, /**< the element count per row to process */ + size_t, /**< the element count per col to process */ + size_t /**< the leading dimension of matrix */ + ) const = 0; + + virtual ~MLAS_SBGEMM_POSTPROCESSOR() {} +}; + +/** + * @brief bfloat16 precision activation functions, with optional sum tensor. + * Supplied sum tensor must be the same layout as the GEMM output tensor. + * And the supplied sum tensor will be added to the tensor before activation. + */ +class MLAS_SBGEMM_ACTIVATION_PROCESSOR : public MLAS_SBGEMM_POSTPROCESSOR +{ + public: + MLAS_SBGEMM_ACTIVATION_PROCESSOR(const MLAS_ACTIVATION& Activation, const float* SumBuf = nullptr) + : Activation_(Activation), SumBuf_(SumBuf) + { + } + + void Process(float* C, size_t StartM, size_t StartN, size_t CountM, size_t CountN, size_t ldc) + const override; + + private: + const MLAS_ACTIVATION& Activation_; + const float* SumBuf_; +}; + +/** + * @brief Data parameters for bfloat16 precision GEMM routine + * All except C are [in] parameters + */ +struct MLAS_SBGEMM_DATA_PARAMS { + const void* A = nullptr; /**< address of A */ + const void* B = nullptr; /**< address of B */ + const float* Bias = nullptr; /**< address of Bias, vector size N */ + float* C = nullptr; /**< address of result matrix */ + size_t lda = 0; /**< leading dimension of A */ + size_t ldb = 0; /**< leading dimension of B, 0 when B is pre-packed*/ + size_t ldc = 0; /**< leading dimension of C*/ + const MLAS_SBGEMM_POSTPROCESSOR* OutputProcessor = nullptr; + bool AIsfp32 = false; /**< matrix A is fp32, needs to be converted to bf16*/ + bool BIsfp32 = false; /**< matrix B is fp32, needs to be converted to bf16*/ +}; + +/** + * @brief Bfloat16 precision Batched GEMM: C = A * B + Bias + * Either B can be either fp32 or bf16 + * + * Note: We only support uniform batching, so shapes and types of the + * input must be same across all parameter blocks. + * + * @param[in] M row size of matrix A and C + * @param[in] N column size of matrix B and C + * @param[in] K column size of matrix A and row size of matrix B + * @param[in] BatchN number of batches + * @param[inout] DataParams An array (size BatchN) of parameter blocks + * @param[in] ThreadPool + * @return + */ +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* DataParams, MLAS_THREADPOOL* ThreadPool = nullptr); + +/** + * @brief For bfloat16 precision GEMM, returns size of the + * packing buffer needed for right hand side + * @param[in] N Number of columns + * @param[in] K Number of rows + * @return size of the packing buffer, + * 0 if operation not supported + */ +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K); + +/** + * @brief For bfloat16 precision GEMM, convert the float matrix B + * to blfoat16 precision and pack it into a packing buffer + * + * @param[in] N Number of columns + * @param[in] K Number of rows + * @param[in] B Address of matrix B + * @param[in] ldb leading dimension of input matrix B + * @param[out] PackedB Address of the packed matrix + */ +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB); +#endif + /** * @brief Indirect Depthwise convolution for fp16 * @param Input Supplies the indirect buffer for NHWC input diff --git a/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S new file mode 100644 index 0000000000000..e424c30515e9f --- /dev/null +++ b/onnxruntime/core/mlas/lib/aarch64/SbgemmKernelNeon.S @@ -0,0 +1,907 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + SbgemmKernelNeon.s + +Abstract: + + This module implements the kernels for the bfloat16 half precision matrix/matrix + multiply operation (SBGEMM). + +--*/ + +#include "asmmacro.h" + + .text + +// +// Stack frame layout for the sbgemm kernel. d8-d15, x19-x30 need save +// + .equ .LMlasSbgemmKernel_backup_x19_x20, 0 + .equ .LMlasSbgemmKernel_backup_x21_x22, 16 + .equ .LMlasSbgemmKernel_backup_x23_x24, 32 + .equ .LMlasSbgemmKernel_backup_x25_x26, 48 + .equ .LMlasSbgemmKernel_backup_x27_x28, 64 + .equ .LMlasSbgemmKernel_backup_d8_d9, 80 + .equ .LMlasSbgemmKernel_backup_d10_d11, 96 + .equ .LMlasSbgemmKernel_backup_d12_d13, 112 + .equ .LMlasSbgemmKernel_backup_d14_d15, 128 + .equ .LMlasSbgemmKernel_SavedRegisters, 144 + .equ .LMlasSbgemmKernel_SavedRegisters_Neg, -144 + + +// +// ClearRowAccumulators +// +// Generates the code to clear the accumulators for a single row of the output +// block. +// + + .macro InitRowAccumulators Columns, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + mov v\Vec1Reg\().16b,v0.16b +.if \Columns\() > 2 + mov v\Vec2Reg\().16b,v1.16b +.endif +.if \Columns\() > 4 + mov v\Vec3Reg\().16b,v2.16b +.endif +.if \Columns\() > 6 + mov v\Vec4Reg\().16b,v3.16b +.endif + + .endm + +// +// InitBlockAccumulators +// +// Generates the code to init the accumulators for a single row of the output +// block. +// + + .macro InitBlockAccumulators Mode, Columns, Rows + + //check if the Bias != nullptr + cbz x8,.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd + + ld1 {v14.4s},[x8],#16 // load Bias[0] + // v4~v7 will be set to matrixB after this, so, they can used now + dup v4.4s,v14.s[0] // broadcast Bias + dup v5.4s,v14.s[1] + dup v6.4s,v14.s[2] + dup v7.4s,v14.s[3] + + zip1 v0.4s, v4.4s, v5.4s + zip2 v1.4s, v6.4s, v7.4s +.if \Columns\() > 4 + ld1 {v15.4s},[x8],#16 // load Bias[4] + dup v4.4s,v15.s[0] // broadcast Bias + dup v5.4s,v15.s[1] + dup v6.4s,v15.s[2] + dup v7.4s,v15.s[3] + + zip1 v2.4s, v4.4s, v5.4s + zip2 v3.4s, v6.4s, v7.4s +.endif + + b .L\Mode\().PopulateAccumulators\Columns\().x\Rows\() + +.L\Mode\().InitBlock\Columns\().x\Rows\().SkipBiasAdd: + eor v0.16b,v0.16b,v0.16b // No bias, reset regs + eor v1.16b,v1.16b,v1.16b + eor v2.16b,v2.16b,v2.16b + eor v3.16b,v3.16b,v3.16b + +.L\Mode\().PopulateAccumulators\Columns\().x\Rows\(): + InitRowAccumulators \Columns\(),16,17,18,19 +.if \Rows\() > 2 + InitRowAccumulators \Columns\(),20,21,22,23 +.endif +.if \Rows\() > 4 + InitRowAccumulators \Columns\(),24,25,26,27 +.endif +.if \Rows\() > 6 + InitRowAccumulators \Columns\(),28,29,30,31 +.endif + + .endm + +// LoadMatrixAElementsBy8 +// +// Generates the code to load 4 or 8 elements from matrix A. +// + .macro LoadMatrixAElementsBy8 Rows + + ldr q8,[x0],#16 + bfcvtn v8.4h, v8.4s +.if \Rows\() > 1 + ldr q1,[x10],#16 + bfcvtn2 v8.8h, v1.4s +.endif + +.if \Rows\() > 2 + ldr q9,[x11],#16 + bfcvtn v9.4h, v9.4s +.endif +.if \Rows\() > 3 + ldr q1,[x12],#16 + bfcvtn2 v9.8h, v1.4s +.endif + +.if \Rows\() > 4 + ldr q10,[x20],#16 + bfcvtn v10.4h, v10.4s +.endif +.if \Rows\() > 5 + ldr q1,[x21],#16 + bfcvtn2 v10.8h, v1.4s +.endif + +.if \Rows\() > 6 + ldr q11,[x22],#16 + bfcvtn v11.4h, v11.4s +.endif +.if \Rows\() > 7 + ldr q1,[x23],#16 + bfcvtn2 v11.8h, v1.4s +.endif + + .endm + + +// +// MultiplyAccumulateRow +// +// Generates the code to multiply and accumulate a single row of the output +// block. +// + + .macro MultiplyAccumulateRow Columns, MatrixAReg, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg + + bfmmla v\Vec1Reg\().4s, \MatrixAReg\().8h, v4.8h +.if \Columns\() > 2 + bfmmla v\Vec2Reg\().4s, \MatrixAReg\().8h, v5.8h +.endif +.if \Columns\() > 4 + bfmmla v\Vec3Reg\().4s, \MatrixAReg\().8h, v6.8h +.endif +.if \Columns\() > 6 + bfmmla v\Vec4Reg\().4s, \MatrixAReg\().8h, v7.8h +.endif + + .endm + +// +// MultiplyAccumulateBlock +// +// Generates the code to multiply and accumulate into the output block. +// + + .macro MultiplyAccumulateBlock Columns, Rows + + MultiplyAccumulateRow \Columns\(),v8,16,17,18,19 +.if \Rows\() > 2 + MultiplyAccumulateRow \Columns\(),v9,20,21,22,23 +.endif +.if \Rows\() > 4 + MultiplyAccumulateRow \Columns\(),v10,24,25,26,27 +.endif +.if \Rows\() > 6 + MultiplyAccumulateRow \Columns\(),v11,28,29,30,31 +.endif + + .endm + +// +// ComputeBlockLoop +// +// Generates the code to loop over K entries of the input matrices to produce +// the output block. +// + + .macro ComputeBlockLoop Mode, Columns, Rows + + InitBlockAccumulators \Mode\(),\Columns\(),\Rows\() + + add x10,x0,x6,lsl #2 // compute matrix A plus 1 row +.if \Rows\() > 2 + add x11,x10,x6,lsl #2 // compute matrix A plus 2 rows + add x12,x11,x6,lsl #2 // compute matrix A plus 3 rows +.endif +.if \Rows\() > 4 + add x20,x12,x6,lsl #2 // compute matrix A plus 4 rows + add x21,x20,x6,lsl #2 // compute matrix A plus 5 rows +.endif +.if \Rows\() > 6 + add x22,x21,x6,lsl #2 // compute matrix A plus 6 rows + add x23,x22,x6,lsl #2 // compute matrix A plus 7 rows +.endif + sub x9,x3,#4 // block count to process + tbnz x9,#63,.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop: + + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + + sub x9,x9,#4 + tbz x9,#63,.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4Loop +.L\Mode\().ProcessRemaining\Columns\().x\Rows\().Blocks: + add x9,x9,#4 // correct for over-subtract above + cbz x9,.L\Mode\().Output\Columns\().x\Rows\().Block + +.L\Mode\().Compute\Columns\().x\Rows\().BlockBy4PaddedLoop: + LoadMatrixAElementsBy8 \Rows\() + ldr q4, [x1],#16 +.if \Columns\() > 2 + ldr q5,[x1],#16 +.endif +.if \Columns\() > 4 + ldr q6,[x1],#16 +.endif +.if \Columns\() > 6 + ldr q7,[x1],#16 +.endif + MultiplyAccumulateBlock \Columns\(),\Rows\() + +.L\Mode\().Output\Columns\().x\Rows\().Block: + + .endm + + +// +// OutputRow2Element +// OutputRow4Element +// OutputRow6Element +// OutputRow8Element +// OutputRow10Element +// OutputRow12Element +// OutputRow14Element +// OutputRow16Element +// +// Generates the code to store elements to the output block. +// + + .macro OutputRow2Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr s8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr s9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + mov v8.S[2], v9.S[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov w27, v8.S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v8.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + mov w27, v\Vec1Reg\().S[0] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov w27, v\Vec1Reg\().S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow4Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + + mov v8.D[1], v9.D[0] + + fadd v8.4s,v8.4s,v\Vec1Reg\().4s + + mov x27, v8.D[0] + mov x28, v8.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + mov x27, v\Vec1Reg\().D[0] + mov x28, v\Vec1Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.endif + + .endm + + + .macro OutputRow6Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr d8,[\AddrReg1\()],#8 + ldr w28,[\AddrReg1\()],#-8 + mov v8.S[2], w28 +.if \last_row\() == 0 + ldr d9,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-8 + mov v9.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + mov x27, v8.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v8.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v9.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v9.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + mov x27, v4.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v4.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + mov x27, v5.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v5.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.endif + + .endm + + + .macro OutputRow8Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + +.endif + + .endm + + + .macro OutputRow10Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr w28, [\AddrReg1\()],#-16 + +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr w27,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + mov v8.S[0], w28 + mov v8.S[2], w27 + + fadd v8.4s,v8.4s,v\Vec3Reg\().4s + + mov w27, v8.S[0] + mov w28, v8.S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov w27, v\Vec3Reg\().S[0] + mov w28, v\Vec3Reg\().S[2] + + str w27, [\AddrReg1\()],#4 +.if \last_row\() == 0 + str w28, [\AddrReg2\()],#4 +.endif +.endif + +.endm + + + .macro OutputRow12Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#-16 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#-16 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + mov v11.D[0],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + + str q8,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 +.endif + + mov v10.D[1], v11.D[0] + + fadd v10.4s,v10.4s,v\Vec3Reg\().4s + + mov x27, v10.D[0] + mov x28, v10.D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + + str q4,[\AddrReg1\()],#16 +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 +.endif + mov x27, v\Vec3Reg\().D[0] + mov x28, v\Vec3Reg\().D[1] + + str x27, [\AddrReg1\()],#8 +.if \last_row\() == 0 + str x28, [\AddrReg2\()],#8 +.endif +.endif + + .endm + + .macro OutputRow14Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldr q8,[\AddrReg1\()],#16 + ldr d10,[\AddrReg1\()],#8 + ldr w28, [\AddrReg1\()],#-24 + mov v10.S[2], w28 +.if \last_row\() == 0 + ldr q9,[\AddrReg2\()],#16 + ldr d11,[\AddrReg2\()],#8 + ldr w27,[\AddrReg2\()],#-24 + mov v11.S[2], w27 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + str q8,[\AddrReg1\()],#16 + + mov x27, v10.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v10.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q9,[\AddrReg2\()],#16 + mov x27, v11.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v11.S[2] + str w27, [\AddrReg2\()],#4 +.endif + +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + str q4,[\AddrReg1\()],#16 + mov x27, v6.D[0] + str x27, [\AddrReg1\()],#8 + mov w27, v6.S[2] + str w27, [\AddrReg1\()],#4 + +.if \last_row\() == 0 + str q5,[\AddrReg2\()],#16 + mov x27, v7.D[0] + str x27, [\AddrReg2\()],#8 + mov w27, v7.S[2] + str w27, [\AddrReg2\()],#4 +.endif +.endif + + .endm + + + .macro OutputRow16Element Mode, AddrReg1, AddrReg2, Vec1Reg, Vec2Reg, Vec3Reg, Vec4Reg, last_row + +.ifeqs "\Mode\()","Add" + ldp q8,q10,[\AddrReg1\()],#0 +.if \last_row\() == 0 + ldp q9,q11,[\AddrReg2\()],#0 +.else + mov x27,#0 + mov v9.D[0],x27 + mov v9.D[1],x27 + + mov v11.D[0],x27 + mov v11.D[1],x27 +.endif + uzp1 v4.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d,v\Vec1Reg\().2d,v\Vec2Reg\().2d + + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + fadd v8.4s,v8.4s,v4.4s + fadd v9.4s,v9.4s,v5.4s + fadd v10.4s,v10.4s,v6.4s + fadd v11.4s,v11.4s,v7.4s + + stp q8,q10,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q9,q11,[\AddrReg2\()],#32 +.endif +.else + uzp1 v4.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp2 v5.2d, v\Vec1Reg\().2d,v\Vec2Reg\().2d + uzp1 v6.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + uzp2 v7.2d, v\Vec3Reg\().2d,v\Vec4Reg\().2d + + stp q4,q6,[\AddrReg1\()],#32 +.if \last_row\() == 0 + stp q5,q7,[\AddrReg2\()],#32 +.endif +.endif + + .endm + +// +// OutputBlock +// +// Generates the code to store the output block. +// + + .macro OutputBlock Mode, Columns, Rows + + OutputRow\Columns\()Element \Mode\(),x2,x13,16,17,18,19,(\Rows\() == 1) + +.if \Rows\() > 2 + OutputRow\Columns\()Element \Mode\(),x14,x15,20,21,22,23,(\Rows\() == 3) +.endif + +.if \Rows\() > 4 + OutputRow\Columns\()Element \Mode\(),x16,x17,24,25,26,27,(\Rows\() == 5) +.endif + +.if \Rows\() > 6 + OutputRow\Columns\()Element \Mode\(),x18,x19,28,29,30,31,(\Rows\() == 7) +.endif + + .endm +// +// ProcessRows +// +// Generates the code to process a compute and store the output block for a +// fixed number of rows. +// + + .macro ProcessRows Mode, Rows + mov x4,#\Rows\() // return number of rows handled + cmp x5,#6 + ble .L\Mode\().ProcessNextColumnLoop6x\Rows\() + +.L\Mode\().ProcessNextColumnLoop8x\Rows\(): + ComputeBlockLoop \Mode\(),8,\Rows\() + + sub x5,x5,#8 + cmp x5,#0 + blt .L\Mode\().Output14ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),16,\Rows\() + mov x0,x26 // reload matrix A + cmp x5,#6 + bgt .L\Mode\().ProcessNextColumnLoop8x\Rows\() + cbz x5,.L\Mode\().ExitKernel + + +.L\Mode\().ProcessNextColumnLoop6x\Rows\(): + + cmp x5,#4 + ble .L\Mode\().ProcessNextColumnLoop4x\Rows\() + ComputeBlockLoop \Mode\(),6,\Rows\() + sub x5,x5,#6 + cmp x5,#0 + blt .L\Mode\().Output10ElementsOnlyFor\Rows\() + OutputBlock \Mode\(),12,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#4 + bgt .L\Mode\().ProcessNextColumnLoop6x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop4x\Rows\(): + cmp x5,#2 + ble .L\Mode\().ProcessNextColumnLoop2x\Rows\() + ComputeBlockLoop \Mode\(),4,\Rows\() + sub x5,x5,#4 + cmp x5,#0 + blt .L\Mode\().Output6ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),8,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + bgt .L\Mode\().ProcessNextColumnLoop4x\Rows\() + b .L\Mode\().ExitKernel + +.L\Mode\().ProcessNextColumnLoop2x\Rows\(): + ComputeBlockLoop \Mode\(),2,\Rows\() + sub x5,x5,#2 + cmp x5,#0 + blt .L\Mode\().Output2ElementsOnlyFor\Rows\() + + OutputBlock \Mode\(),4,\Rows\() + + mov x0,x26 // reload matrix A + cmp x5,#2 + b .L\Mode\().ExitKernel + +.L\Mode\().Output14ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),14,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output10ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),10,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output6ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),6,\Rows\() + b .L\Mode\().ExitKernel + + +.L\Mode\().Output2ElementsOnlyFor\Rows\(): + OutputBlock \Mode\(),2,\Rows\() + b .L\Mode\().ExitKernel + + .endm + + +/*++ + +Routine Description: + + This routine is an inner kernel to compute matrix multiplication for a + set of rows. + +Arguments: + + A (x0) - Supplies the address of matrix A. + + B (x1) - Supplies the address of matrix B. The matrix data has been packed + using MlasSbgemmCopyPackB or MlasSbgemmTransposePackB. + + C (x2) - Supplies the address of matrix C. + + CountK (x3) - Supplies the number of columns from matrix A and the number + of rows from matrix B to iterate over. + + CountM (x4) - Supplies the maximum number of rows that can be processed for + matrix A and matrix C. The actual number of rows handled for this + invocation depends on the kernel implementation. + + CountN (x5) - Supplies the number of columns from matrix B and matrix C to + iterate over. + + lda (x6) - Supplies the first dimension of matrix A. + + ldc (x7) - Supplies the first dimension of matrix C. + + Bias - Supplies the address of Bias Vector [1xn] + + +Return Value: + + Returns the number of rows handled. + +--*/ + .macro SbgemmKernelNeonFunction Mode + + FUNCTION_ENTRY MlasSbgemmKernel\Mode\() + + ldr x8, [sp, #0] //Bias vector + + stp x19, x20, [sp, #.LMlasSbgemmKernel_SavedRegisters_Neg]! + stp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + stp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + stp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + stp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + stp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + stp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + stp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + stp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + + add x13,x2,x7,lsl #2 // compute matrix C plus 1 row + add x14,x13,x7,lsl #2 // compute matrix C plus 2 rows + add x15,x14,x7,lsl #2 // compute matrix C plus 3 rows + add x16,x15,x7,lsl #2 // compute matrix C plus 4 rows + add x17,x16,x7,lsl #2 // compute matrix C plus 5 rows + add x18,x17,x7,lsl #2 // compute matrix C plus 6 rows + add x19,x18,x7,lsl #2 // compute matrix C plus 7 rows + + mov x26,x0 // save matrix A +// +// Process 8 rows of the matrices. +// + cmp x4,#8 + blt .L\Mode\().ProcessCountMLessThan8 + ProcessRows \Mode\(),8 + +// +// Restore non-volatile registers and return. +// + +.L\Mode\().ExitKernel: + mov x0,x4 + + ldp d14, d15, [sp, #.LMlasSbgemmKernel_backup_d14_d15] + ldp d12, d13, [sp, #.LMlasSbgemmKernel_backup_d12_d13] + ldp d10, d11, [sp, #.LMlasSbgemmKernel_backup_d10_d11] + ldp d8, d9, [sp, #.LMlasSbgemmKernel_backup_d8_d9] + ldp x27, x28, [sp, #.LMlasSbgemmKernel_backup_x27_x28] + ldp x25, x26, [sp, #.LMlasSbgemmKernel_backup_x25_x26] + ldp x23, x24, [sp, #.LMlasSbgemmKernel_backup_x23_x24] + ldp x21, x22, [sp, #.LMlasSbgemmKernel_backup_x21_x22] + ldp x19, x20, [sp], #.LMlasSbgemmKernel_SavedRegisters + + ret + +// +// Process 4 rows of the matrix. +// + +.L\Mode\().ProcessCountMLessThan8: + cmp x4,#4 + blt .L\Mode\().ProcessCountMLessThan4 + ProcessRows \Mode\(),4 + b .L\Mode\().ExitKernel + +// +// Process 2 row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan4: + cmp x4,#2 + blt .L\Mode\().ProcessCountMLessThan2 + + ProcessRows \Mode\(),2 + b .L\Mode\().ExitKernel + + +// +// Process the last row of the matrix. +// + +.L\Mode\().ProcessCountMLessThan2: + ProcessRows \Mode\(),1 + b .L\Mode\().ExitKernel + + + .endm + + SbgemmKernelNeonFunction Zero + SbgemmKernelNeonFunction Add diff --git a/onnxruntime/core/mlas/lib/amx_common.h b/onnxruntime/core/mlas/lib/amx_common.h index 3eb0700932faa..caf94af02362d 100644 --- a/onnxruntime/core/mlas/lib/amx_common.h +++ b/onnxruntime/core/mlas/lib/amx_common.h @@ -18,7 +18,7 @@ Module Name: #include "mlasi.h" -#ifdef WIN32 +#ifdef _WIN32 #define tile_dpbssd(dst, src1, src2) _tile_dpbssd(dst, src1, src2) #define tile_dpbsud(dst, src1, src2) _tile_dpbsud(dst, src1, src2) diff --git a/onnxruntime/core/mlas/lib/mlasi.h b/onnxruntime/core/mlas/lib/mlasi.h index 7bb8b17031a84..624eb913d5c9e 100644 --- a/onnxruntime/core/mlas/lib/mlasi.h +++ b/onnxruntime/core/mlas/lib/mlasi.h @@ -193,6 +193,8 @@ class MLASCPUIDInfo bool HasArmSVE_I8MM() const { return has_arm_sve_i8mm_; } + bool HasArmNeon_BF16() const { return has_arm_neon_bf16_; } + private: MLASCPUIDInfo(); @@ -200,6 +202,7 @@ class MLASCPUIDInfo bool has_fp16_{false}; bool has_arm_neon_i8mm_{false}; bool has_arm_sve_i8mm_{false}; + bool has_arm_neon_bf16_{false}; }; using MLAS_CPUIDINFO = MLASCPUIDInfo; @@ -357,6 +360,20 @@ size_t #else +#if defined(__aarch64__) && defined(__linux__) +typedef size_t(MLASCALL MLAS_SBGEMM_FLOAT_KERNEL)( + const float* A, + const bfloat16_t* B, + float* C, + size_t CountK, + size_t CountM, + size_t CountN, + size_t lda, + size_t ldc, + const float* Bias +); +#endif + typedef size_t (MLASCALL MLAS_GEMM_FLOAT_KERNEL)( @@ -727,6 +744,10 @@ extern "C" { #else MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelZero; MLAS_GEMM_FLOAT_KERNEL MlasSgemmKernelAdd; +#if defined(__aarch64__) && defined(__linux__) + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelZero; + MLAS_SBGEMM_FLOAT_KERNEL MlasSbgemmKernelAdd; +#endif MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelZero; MLAS_GEMM_DOUBLE_KERNEL MlasDgemmKernelAdd; #endif @@ -856,6 +877,10 @@ extern "C" { #define MLAS_DGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) #define MLAS_QGEMM_THREAD_COMPLEXITY 65536 +#if defined(__aarch64__) && defined(__linux__) +#define MLAS_SBGEMM_THREAD_COMPLEXITY (size_t(64) * size_t(1024)) +#endif + // // Single-threaded single precision matrix/matrix multiply operation. // diff --git a/onnxruntime/core/mlas/lib/platform.cpp b/onnxruntime/core/mlas/lib/platform.cpp index 1310ed3f384b9..de092f7d1d350 100644 --- a/onnxruntime/core/mlas/lib/platform.cpp +++ b/onnxruntime/core/mlas/lib/platform.cpp @@ -60,6 +60,10 @@ MLASCPUIDInfo::MLASCPUIDInfo() #define HWCAP2_SVEI8MM (1 << 9) #endif +#ifndef HWCAP2_BF16 +#define HWCAP2_BF16 (1 << 14) +#endif + #if defined(BUILD_MLAS_NO_ONNXRUNTIME) MLASCPUIDInfo::MLASCPUIDInfo() { @@ -70,6 +74,8 @@ MLASCPUIDInfo::MLASCPUIDInfo() has_arm_neon_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_I8MM) != 0); has_arm_sve_i8mm_ = ((getauxval(AT_HWCAP2) & HWCAP2_SVEI8MM) != 0); + + has_arm_neon_bf16_ = ((getauxval(AT_HWCAP2) & HWCAP2_BF16) != 0); } #endif diff --git a/onnxruntime/core/mlas/lib/sbgemm.h b/onnxruntime/core/mlas/lib/sbgemm.h new file mode 100644 index 0000000000000..de7fd72fad45a --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm.h @@ -0,0 +1,399 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm.h + +Abstract: + + This module defines the set of template functions to implement bfloat16 + precision matrix/matrix multiply operation (SBGEMM). + + To implement a new kernel, template functions below need to be specialized: + MlasSBGemmConvertPackB + MlasSBGemmPackedBOffset + MlasSBGemmPackedBLeadingDim + MlasSBGemmKernel + + MlasSBGemmOperation is the shared kernel driver. + + A kernel type should define the following constants: + bool PackNeeded; Whether B needs to be packed + size_t KernelMaxM; Max # rows the vectorized kernel can process + size_t PackedK; Packed alignment on the K dim (power of 2) + size_t PackedN; Packed alignment on the n dim (power of 2) + MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include +#include + +#include "mlasi.h" + +/** + * @brief Define the default striding parameters for + * the bfloat16 precision gemm operation + */ +struct MLAS_SBGEMM_STRIDES { + size_t M; + size_t N; + size_t K; +}; + +/** + * @brief Convert fp32 matrix B to bf16 and pack the data + * + * @tparam KernelType + * @param[out] D Address of packing buffer + * @param[in] B Address of source matrix B in fp32 + * @param[in] ldb Leading dimension of B + * @param[in] CountN # of column to pack + * @param[in] CountK # of rows to pack + */ +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Find the location of PackedB[StartK, StartN] + * + * @tparam KernelType + * @param PackedB + * @param DimN Total columns of the packing buffer + * @param DimK Total rows of the packing buffer + * @param StartN + * @param StartK + * @return Address of PackedB[StartK, StartN] + */ +template +MLAS_FORCEINLINE const bfloat16_t* +MlasSBGemmPackedBOffset( + const bfloat16_t* PackedB, size_t DimN, size_t DimK, size_t StartN, size_t StartK +) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return PackedB + StartK * DimN + StartN; +} + +/** + * @brief leading dimension of the packed B buffer + * Related to how B is packed + * @tparam KernelType + * @param DimN + * @param DimK + * @return leading dimension of the packed B buffer + */ +template +MLAS_FORCEINLINE size_t +MlasSBGemmPackedBLeadingDim(size_t DimN, size_t DimK) +{ + // By default the packed buffer is just a row major + // K row by N column buffer + MLAS_UNREFERENCED_PARAMETER(DimK); + return DimN; +} + +template +void +MlasSBGemmKernel(const size_t CountM, const size_t CountN, const size_t CountK, const float* A, const size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode); + +template +MLAS_FORCEINLINE void +MlasSBGemmPackedOperation(size_t M, size_t RangeStartN, size_t RangeCountN, size_t AlignedN, size_t K, const float* A, size_t lda, const void* PackedB, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t PackedStrideN = Strides.N; + size_t PackedStrideK = Strides.K; + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < RangeCountN; n += CountN) { + const size_t SliceStartN = RangeStartN + n; + CountN = std::min(RangeCountN - n, PackedStrideN); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + bool ZeroMode = (k == 0); + CountK = std::min(K - k, PackedStrideK); + + const bfloat16_t* pb = (const bfloat16_t*)PackedB + AlignedN * k + CountK * SliceStartN; + float* c = C + n; + const float* pbias = ((nullptr == Bias) ? nullptr : Bias + RangeStartN + n); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, pb, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor) + ->Process(C + n, M, SliceStartN, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmNonPackedOperation(size_t M, size_t N, size_t K, const float* A, size_t lda, const float* B, size_t ldb, float* C, size_t ldc, const float* Bias, void* PostProcessor) +{ + // + // Compute the strides to step through slices of the input matrices. + // + // Expand the N stride if K is small or expand the K stride if N is small + // for better utilization of the B panel. Avoid changing the K stride if + // the A panel needs to be used for transposing. + // + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + size_t StrideN = Strides.N; + size_t StrideK = Strides.K; + + if (N >= K) { + while (StrideK / 2 >= K) { + StrideN *= 2; + StrideK /= 2; + } + } else { + while (StrideN > 16 && StrideN / 2 >= N) { + StrideK *= 2; + StrideN /= 2; + } + } + + constexpr size_t packBSize = UpAlignSize(Strides.N * Strides.K * sizeof(bfloat16_t)); + MlasThreadedBufAlloc(packBSize); + uint8_t* p = ThreadedBufHolder.get(); + auto* PanelB = reinterpret_cast(p); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountN; + for (size_t n = 0; n < N; n += CountN) { + CountN = std::min(N - n, StrideN); + + // + // Step through each slice of matrix B along the N dimension. + // + size_t CountK; + for (size_t k = 0; k < K; k += CountK) { + CountK = std::min(K - k, StrideK); + + // + // Copy a panel of matrix B to a local packed buffer. + // + MlasSBGemmConvertPackB(PanelB, B + n + k * ldb, ldb, CountN, CountK); + + auto* c = C + n; + const float* pbias = + ((nullptr == Bias) ? nullptr : Bias + n); // TODO: check the SliceNStart + + bool ZeroMode = (k == 0); + MlasSBGemmKernel(M, CountN, CountK, A + k, lda, PanelB, c, ldc, ZeroMode ? pbias : nullptr, ZeroMode); + } + if (PostProcessor != nullptr) { + ((MLAS_SBGEMM_POSTPROCESSOR*)PostProcessor)->Process(C + n, M, N, M, CountN, ldc); + } + } +} + +template +void +MlasSBGemmOperation(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId) +{ + const ptrdiff_t ThreadIdM = ThreadId / ThreadCountN; + const ptrdiff_t ThreadIdN = ThreadId % ThreadCountN; + + // + // Partition the operation along the M dimension. + // + size_t RangeStartM; + size_t RangeCountM; + + MlasPartitionWork(ThreadIdM, ThreadCountM, M, &RangeStartM, &RangeCountM); + + // + // Partition the operation along the N dimension. + // + size_t RangeStartN; + size_t RangeCountN; + + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + MlasPartitionWork(ThreadIdN, ThreadCountN, BlockedN, &RangeStartN, &RangeCountN); + + RangeStartN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + RangeCountN *= MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + RangeCountN = std::min(N - RangeStartN, RangeCountN); + + // + // Dispatch the partitioned operation. + // + const size_t lda = DataParams->lda; + const size_t ldc = DataParams->ldc; + const float* A = (const float*)DataParams->A + RangeStartM * lda; + float* C = DataParams->C + RangeStartM * ldc + RangeStartN; + const float* bias = DataParams->Bias; + + if (!DataParams->BIsfp32) { + MlasSBGemmPackedOperation( + RangeCountM, RangeStartN, RangeCountN, BlockedN * MLAS_SGEMM_STRIDEN_THREAD_ALIGN, K, A, + lda, DataParams->B, C, ldc, bias, (void*)DataParams->OutputProcessor + ); + } else { + const size_t ldb = DataParams->ldb; + const float* B = (const float*)DataParams->B + RangeStartN; + MlasSBGemmNonPackedOperation(RangeCountM, RangeCountN, K, A, lda, B, ldb, C, ldc, bias, (void*)DataParams->OutputProcessor); + } +} + +// +// dispatch structure. +// +typedef void(MLAS_SBGEMM_OPERATION)(const ptrdiff_t ThreadCountM, const ptrdiff_t ThreadCountN, const size_t M, const size_t N, const size_t K, const MLAS_SBGEMM_DATA_PARAMS* DataParams, ptrdiff_t ThreadId); + +typedef void(MLAS_SBGEMM_CONVERTPACKB_ROUTINE)( + bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK +); + +/** + * @brief Hardware dependent dispatch for half precision GEMM + */ +struct MLAS_SBGEMM_DISPATCH { + MLAS_SBGEMM_OPERATION* Operation; /**< HalfGemm driver */ + MLAS_SBGEMM_CONVERTPACKB_ROUTINE* ConvertPackBRoutine; /**< Convert and pack function for B */ + size_t PackedK; + size_t PackedN; + size_t StrideM; + size_t BufOverRead; +}; + +extern const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon; + +MLAS_FORCEINLINE +const MLAS_SBGEMM_DISPATCH* +MlasSBGemmGetDispatch() +{ +#if defined(MLAS_TARGET_ARM64) + return &MlasSBGemmDispatchNeon; +#else + std::cerr << "SBGemm Kernel is supported only on ARM64 platform."; + exit(1); +#endif +} + +size_t MLASCALL +MlasSBGemmPackBSize(size_t N, size_t K) +{ + // + // Compute the number of bytes required to hold the packed buffer. + // + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return 0; + + const auto padding = dispatch->BufOverRead; + const auto PackedK = dispatch->PackedK; + const auto PackedN = dispatch->PackedN; + + const size_t AlignedK = (K + PackedK - 1) & ~(PackedK - 1); + const size_t AlignedN = (N + PackedN - 1) & ~(PackedN - 1); + const size_t BytesRequired = AlignedN * AlignedK * sizeof(bfloat16_t) + padding; + const size_t BufferAlignment = MlasGetPreferredBufferAlignment(); + const size_t AlignedBytesRequired = + (BytesRequired + BufferAlignment - 1) & ~(BufferAlignment - 1); + + return AlignedBytesRequired; +} + +void MLASCALL +MlasSBGemmConvertPackB(size_t N, size_t K, const float* B, size_t ldb, void* PackedB) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + dispatch->ConvertPackBRoutine((bfloat16_t*)PackedB, B, ldb, N, K); +} + +void MLASCALL +MlasSBGemmBatch(const size_t M, const size_t N, const size_t K, const size_t BatchN, const MLAS_SBGEMM_DATA_PARAMS* Data, MLAS_THREADPOOL* ThreadPool) +{ + const MLAS_SBGEMM_DISPATCH* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + MLAS_SBGEMM_OPERATION* operation = dispatch->Operation; + + // + // Compute the number of target threads given the complexity of the SGEMM + // operation. Small requests should run using the single threaded path. + // + + const double Complexity = double(M) * double(N) * double(K); + + ptrdiff_t TargetThreadCount; + + if (Complexity < double(MLAS_SBGEMM_THREAD_COMPLEXITY * GetMlasPlatform().MaximumThreadCount)) { + TargetThreadCount = ptrdiff_t(Complexity / double(MLAS_SGEMM_THREAD_COMPLEXITY)) + 1; + } else { + TargetThreadCount = GetMlasPlatform().MaximumThreadCount; + } + + ptrdiff_t MaximumThreadCount = MlasGetMaximumThreadCount(ThreadPool); + + if (TargetThreadCount >= MaximumThreadCount) { + TargetThreadCount = MaximumThreadCount; + } + + // + // Segment the operation across multiple threads. + // + // N.B. Currently, the operation is segmented as a 1D partition, which + // works okay for operations involving skinny matrices. + // + ptrdiff_t ThreadsPerGemm = (TargetThreadCount + BatchN - 1) / BatchN; + ptrdiff_t ThreadCountM; + ptrdiff_t ThreadCountN; + + if (N > M) { + const size_t BlockedN = + (N + MLAS_SGEMM_STRIDEN_THREAD_ALIGN - 1) / MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + + if (size_t(ThreadsPerGemm) > BlockedN) { + ThreadsPerGemm = ptrdiff_t(BlockedN); + } + + ThreadCountM = 1; + ThreadCountN = ThreadsPerGemm; + + } else { + if (size_t(ThreadsPerGemm) > M) { + ThreadsPerGemm = ptrdiff_t(M); + } + + ThreadCountM = ThreadsPerGemm; + ThreadCountN = 1; + } + + MlasTrySimpleParallel( + ThreadPool, ThreadsPerGemm * static_cast(BatchN), [=](ptrdiff_t tid) { + ptrdiff_t GemmIdx = tid / ThreadsPerGemm; + ptrdiff_t ThreadIdx = tid % ThreadsPerGemm; + operation(ThreadCountM, ThreadCountN, M, N, K, &(Data[GemmIdx]), ThreadIdx); + } + ); +} +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp new file mode 100644 index 0000000000000..a6a73996c548b --- /dev/null +++ b/onnxruntime/core/mlas/lib/sbgemm_kernel_neon.cpp @@ -0,0 +1,362 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + sbgemm_kernel_neon.cpp + +Abstract: + + This module implements bfloat16 precision GEMM kernel for neon. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "arm_neon.h" +#include "mlasi.h" +#include "sbgemm.h" + +struct MLAS_SBGEMM_KERNEL_NEON { + static constexpr bool PackNeeded = true; + static constexpr size_t KernelMaxM = 8; // max # rows the vectorized kernel can process + static constexpr size_t PackedK = 4; + static constexpr size_t PackedN = MLAS_SGEMM_STRIDEN_THREAD_ALIGN; + static constexpr MLAS_SBGEMM_STRIDES Strides{128, 128, 256}; // M:N:K +}; + +bool MLASCALL +MlasBf16AccelerationSupported() +{ +#if defined(MLAS_TARGET_ARM64) + return MLAS_CPUIDINFO::GetCPUIDInfo().HasArmNeon_BF16(); +#else + return false; +#endif +} + +/* + This routine converts fp32 to bf16 and copies elements from the source + matrix to the destination packed buffer. + + 4x2 elements from the source matrix are unrolled to be physically + contiguous for better locality inside the SBGEMM kernels. The remaining + rows and columns are padded to 4 and 2 alignment. +*/ +MLAS_FORCEINLINE +void +MlasSBGemmConvertCopyPackB(bfloat16_t* D, const float* B, size_t ldb, size_t CountN, size_t CountK) +{ + // + // Copy data from matrix B into the destination buffer 4x2 blocks at a + // time. + // + // + while (CountN >= 8) { + const float* b = B; + int y = static_cast(CountK); + + while (y > 0) { + MLAS_FLOAT32X4 t0_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t0_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2_h = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_l = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3_h = MlasZeroFloat32x4(); + + if (y >= 4) { + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + t3_l = MlasLoadFloat32x4(&b[ldb * 3]); + t3_h = MlasLoadFloat32x4(&b[ldb * 3 + 4]); + } else { + switch (y) { + case 3: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + t2_l = MlasLoadFloat32x4(&b[ldb * 2]); + t2_h = MlasLoadFloat32x4(&b[ldb * 2 + 4]); + break; + case 2: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + t1_l = MlasLoadFloat32x4(&b[ldb * 1]); + t1_h = MlasLoadFloat32x4(&b[ldb * 1 + 4]); + break; + case 1: + t0_l = MlasLoadFloat32x4(&b[ldb * 0]); + t0_h = MlasLoadFloat32x4(&b[ldb * 0 + 4]); + break; + } + } + + float32x4x2_t z0_l = vzipq_f32(t0_l, t2_l); + float32x4x2_t z1_l = vzipq_f32(t1_l, t3_l); + float32x4x2_t o0_l = vzipq_f32(z0_l.val[0], z1_l.val[0]); + float32x4x2_t o1_l = vzipq_f32(z0_l.val[1], z1_l.val[1]); + t0_l = o0_l.val[0]; + t1_l = o0_l.val[1]; + t2_l = o1_l.val[0]; + t3_l = o1_l.val[1]; + + bfloat16x8_t t0t1_l_4h = vcvtq_low_bf16_f32(t0_l); + bfloat16x8_t t0t1_l_8h = vcvtq_high_bf16_f32(t0t1_l_4h, t1_l); + + bfloat16x8_t t2t3_l_4h = vcvtq_low_bf16_f32(t2_l); + bfloat16x8_t t2t3_l_8h = vcvtq_high_bf16_f32(t2t3_l_4h, t3_l); + + vst1q_bf16(&D[0], t0t1_l_8h); + vst1q_bf16(&D[8], t2t3_l_8h); + + float32x4x2_t z0_h = vzipq_f32(t0_h, t2_h); + float32x4x2_t z1_h = vzipq_f32(t1_h, t3_h); + float32x4x2_t o0_h = vzipq_f32(z0_h.val[0], z1_h.val[0]); + float32x4x2_t o1_h = vzipq_f32(z0_h.val[1], z1_h.val[1]); + t0_h = o0_h.val[0]; + t1_h = o0_h.val[1]; + t2_h = o1_h.val[0]; + t3_h = o1_h.val[1]; + + bfloat16x8_t t0t1_h_4h = vcvtq_low_bf16_f32(t0_h); + bfloat16x8_t t0t1_h_8h = vcvtq_high_bf16_f32(t0t1_h_4h, t1_h); + + bfloat16x8_t t2t3_h_4h = vcvtq_low_bf16_f32(t2_h); + bfloat16x8_t t2t3_h_8h = vcvtq_high_bf16_f32(t2t3_h_4h, t3_h); + + vst1q_bf16(&D[16], t0t1_h_8h); + vst1q_bf16(&D[24], t2t3_h_8h); + + D += 32; + b += ldb * 4; + y -= 4; + }; + B += 8; + CountN -= 8; + } + + // + // Special case the handling of the remaining columns less than 8 elements + // wide. + // + if (CountN > 0) { + int y = static_cast(CountK); + while (y > 0) { + const float* b = B; + size_t b_inc = 0; + if ((CountN & 4) != 0) { + MLAS_FLOAT32X4 t0 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t1 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t2 = MlasZeroFloat32x4(); + MLAS_FLOAT32X4 t3 = MlasZeroFloat32x4(); + if (y >= 4) { + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + t3 = MlasLoadFloat32x4(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + t2 = MlasLoadFloat32x4(&b[ldb * 2]); + break; + case 2: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + t1 = MlasLoadFloat32x4(&b[ldb * 1]); + break; + case 1: + t0 = MlasLoadFloat32x4(&b[ldb * 0]); + break; + } + } + + float32x4x2_t z0 = vzipq_f32(t0, t2); + float32x4x2_t z1 = vzipq_f32(t1, t3); + float32x4x2_t o0 = vzipq_f32(z0.val[0], z1.val[0]); + float32x4x2_t o1 = vzipq_f32(z0.val[1], z1.val[1]); + + t0 = o0.val[0]; + t1 = o0.val[1]; + t2 = o1.val[0]; + t3 = o1.val[1]; + + bfloat16x8_t t0t1_4h = vcvtq_low_bf16_f32(t0); + bfloat16x8_t t0t1_8h = vcvtq_high_bf16_f32(t0t1_4h, t1); + + bfloat16x8_t t2t3_4h = vcvtq_low_bf16_f32(t2); + bfloat16x8_t t2t3_8h = vcvtq_high_bf16_f32(t2t3_4h, t3); + + vst1q_bf16(&D[0], t0t1_8h); + vst1q_bf16(&D[8], t2t3_8h); + + D += 16; + b += 4; + b_inc += 4; + } + + if ((CountN & 2) != 0) { + float32x2_t t0 = {0x0, 0x0}; + float32x2_t t1 = {0x0, 0x0}; + float32x2_t t2 = {0x0, 0x0}; + float32x2_t t3 = {0x0, 0x0}; + + if (y >= 4) { + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + t3 = vld1_f32(&b[ldb * 3]); + } else { + switch (y) { + case 3: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + t2 = vld1_f32(&b[ldb * 2]); + break; + case 2: + t0 = vld1_f32(&b[ldb * 0]); + t1 = vld1_f32(&b[ldb * 1]); + break; + case 1: + t0 = vld1_f32(&b[ldb * 0]); + break; + } + } + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 2; + b_inc += 2; + } + if ((CountN & 1) != 0) { + float a = 0.0f; + float b = 0.0f; + float c = 0.0f; + float d = 0.0f; + + if (y >= 4) { + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + d = *(float*)(&B[ldb * 3 + b_inc]); + } else { + switch (y) { + case 3: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + c = *(float*)(&B[ldb * 2 + b_inc]); + break; + case 2: + a = *(float*)(&B[ldb * 0 + b_inc]); + b = *(float*)(&B[ldb * 1 + b_inc]); + break; + case 1: + a = *(float*)(&B[ldb * 0 + b_inc]); + break; + } + } + + float32x2_t t0 = {a, 0x0}; + float32x2_t t1 = {b, 0x0}; + float32x2_t t2 = {c, 0x0}; + float32x2_t t3 = {d, 0x0}; + + float32x2x2_t z0 = vzip_f32(t0, t2); + float32x2x2_t z1 = vzip_f32(t1, t3); + float32x2x2_t o0 = vzip_f32(z0.val[0], z1.val[0]); + float32x2x2_t o1 = vzip_f32(z0.val[1], z1.val[1]); + + float32x4_t tt0 = vcombine_f32(o0.val[0], o0.val[1]); + float32x4_t tt1 = vcombine_f32(o1.val[0], o1.val[1]); + + bfloat16x8_t t_4h = vcvtq_low_bf16_f32(tt0); + bfloat16x8_t t_8h = vcvtq_high_bf16_f32(t_4h, tt1); + + vst1q_bf16(&D[0], t_8h); + + D += 8; + b += 1; + b_inc += 1; + } + B += 4 * ldb; + y -= 4; + } + } +} + +template +void +MlasSBGemmConvertPackB( + bfloat16_t* PackedB, const float* B, size_t ldb, size_t CountN, size_t CountK +) +{ + const auto* dispatch = MlasSBGemmGetDispatch(); + if (dispatch == nullptr) return; + + const auto PackedN = dispatch->PackedN; + + const size_t AlignedN = (CountN + PackedN - 1) & ~(PackedN - 1); + + // + // Step through each slice of matrix B along the K dimension. + // + size_t K_block_size; + constexpr MLAS_SBGEMM_STRIDES Strides = KernelType::Strides; + + for (size_t k = 0; k < CountK; k += K_block_size) { + K_block_size = std::min(CountK - k, Strides.K); + + MlasSBGemmConvertCopyPackB((bfloat16_t*)PackedB, B + k * ldb, ldb, CountN, K_block_size); + PackedB = (bfloat16_t*)PackedB + AlignedN * K_block_size; + } +} + +template <> +MLAS_FORCEINLINE void +MlasSBGemmKernel(size_t CountM, size_t CountN, size_t CountK, const float* A, size_t lda, const bfloat16_t* B, float* C, size_t ldc, const float* Bias, const bool ZeroMode) +{ + while (CountM > 0) { + size_t RowsHandled; + if (ZeroMode) { + RowsHandled = MlasSbgemmKernelZero(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } else { + RowsHandled = MlasSbgemmKernelAdd(A, B, C, CountK, CountM, CountN, lda, ldc, Bias); + } + C += ldc * RowsHandled; + A += lda * RowsHandled; + CountM -= RowsHandled; + } +} + +const MLAS_SBGEMM_DISPATCH MlasSBGemmDispatchNeon = { + MlasSBGemmOperation, + MlasSBGemmConvertPackB, + MLAS_SBGEMM_KERNEL_NEON::PackedK, + MLAS_SBGEMM_KERNEL_NEON::PackedN, + MLAS_SBGEMM_KERNEL_NEON::KernelMaxM, + 32 // kernel may read beyond buffer end by 32 bytes +}; +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/core/providers/cpu/math/matmul.cc b/onnxruntime/core/providers/cpu/math/matmul.cc index ec395cf018f5e..583ee759cc2e6 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.cc +++ b/onnxruntime/core/providers/cpu/math/matmul.cc @@ -6,7 +6,6 @@ #include "core/providers/cpu/math/matmul_helper.h" #include "core/util/math.h" #include "core/util/math_cpuonly.h" -#include "core/mlas/inc/mlas.h" namespace onnxruntime { @@ -125,6 +124,44 @@ Status MatMul::Compute(OpKernelContext* ctx) const { return Status::OK(); } +#if defined(__aarch64__) && defined(__linux__) +bool GemmPackBBfloat16(AllocatorPtr& alloc, + const Tensor& tensor_b, + bool trans_b, + IAllocatorUniquePtr& packed_b, + size_t& packed_b_size, + TensorShape& b_shape) { + // Only handle the common case of a 2D weight matrix. Additional matrices + // could be handled by stacking the packed buffers. + if (tensor_b.Shape().NumDimensions() != 2) { + return false; + } + + b_shape = tensor_b.Shape(); + + const size_t K = trans_b ? static_cast(b_shape[1]) : static_cast(b_shape[0]); + const size_t N = trans_b ? static_cast(b_shape[0]) : static_cast(b_shape[1]); + + packed_b_size = MlasSBGemmPackBSize(N, K); + if (packed_b_size == 0) { + return false; + } + + packed_b = IAllocator::MakeUniquePtr(alloc, packed_b_size, true); + auto* packed_b_data = packed_b.get(); + + // Initialize memory to 0 as there could be some padding associated with pre-packed + // buffer memory and we don not want it uninitialized and generate different hashes + // if and when we try to cache this pre-packed buffer for sharing between sessions. + memset(packed_b_data, 0, packed_b_size); + MlasSBGemmConvertPackB(N, + K, + tensor_b.Data(), + trans_b ? K : N, + packed_b_data); + return true; +} +#endif Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc, /*out*/ bool& is_packed, @@ -134,7 +171,24 @@ Status MatMul::PrePack(const Tensor& tensor, int input_idx, /*out*/ Alloc // only pack Matrix B if (input_idx == 1) { size_t packed_b_size; - is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); +#if defined(__aarch64__) && defined(__linux__) + size_t dim1 = 0; + size_t dim2 = 0; + TensorShape b_shape = tensor.Shape(); + + if (b_shape.NumDimensions() == 2) { + dim1 = static_cast(b_shape[0]); + dim2 = static_cast(b_shape[1]); + } + + if (use_fastmath_mode_ && (trans_b_attr_ == 0) && ((dim1 * dim2) >= kFastMathModeKernelsizeThreshold)) { + is_packed = GemmPackBBfloat16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } else +#endif + { + is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_); + } + bool share_prepacked_weights = (prepacked_weights != nullptr); if (is_packed && share_prepacked_weights) { prepacked_weights->buffers_.push_back(std::move(packed_b_)); @@ -186,22 +240,40 @@ Status MatMul::Compute(OpKernelContext* ctx) const { const size_t K = static_cast(helper.K()); const size_t lda = helper.Lda(trans_a); const size_t ldb = helper.Ldb(trans_b); - - std::vector data(max_len); - for (size_t i = 0; i < max_len; i++) { - data[i].BIsPacked = bool(packed_b_); - data[i].A = a_data + helper.LeftOffsets()[i]; - data[i].lda = lda; - data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; - data[i].ldb = ldb; - data[i].C = y_data + helper.OutputOffsets()[i]; - data[i].ldc = N; - data[i].alpha = alpha_attr_; - data[i].beta = 0.0f; +#if defined(__aarch64__) && defined(__linux__) + if (use_fastmath_mode_ && !trans_b && ((N * K) >= kFastMathModeKernelsizeThreshold)) { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsfp32 = !(bool(packed_b_)); + data[i].AIsfp32 = true; + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsfp32 ? b_data + helper.RightOffsets()[i] : (float*)packed_b_.get(); + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].Bias = nullptr; + data[i].OutputProcessor = nullptr; + } + MlasSBGemmBatch(M, N, K, max_len, data.data(), thread_pool); + } else +#endif + { + std::vector data(max_len); + for (size_t i = 0; i < max_len; i++) { + data[i].BIsPacked = bool(packed_b_); + data[i].A = a_data + helper.LeftOffsets()[i]; + data[i].lda = lda; + data[i].B = data[i].BIsPacked ? (float*)packed_b_.get() : b_data + helper.RightOffsets()[i]; + data[i].ldb = ldb; + data[i].C = y_data + helper.OutputOffsets()[i]; + data[i].ldc = N; + data[i].alpha = alpha_attr_; + data[i].beta = 0.0f; + } + MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, + M, N, K, data.data(), max_len, thread_pool); } - MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans, - M, N, K, data.data(), max_len, thread_pool); - return Status::OK(); } diff --git a/onnxruntime/core/providers/cpu/math/matmul.h b/onnxruntime/core/providers/cpu/math/matmul.h index b960fa4fb0587..b9bbe36583879 100644 --- a/onnxruntime/core/providers/cpu/math/matmul.h +++ b/onnxruntime/core/providers/cpu/math/matmul.h @@ -4,6 +4,8 @@ #pragma once #include "core/framework/op_kernel.h" +#include "core/mlas/inc/mlas.h" +#include "core/session/onnxruntime_session_options_config_keys.h" namespace onnxruntime { @@ -27,6 +29,11 @@ class MatMul final : public OpKernel { info.GetAttrOrDefault("transBatchB", &trans_batch_b_attr, 0); trans_batch_a_ = trans_batch_a_attr != 0; trans_batch_b_ = trans_batch_b_attr != 0; + +#if defined(__aarch64__) && defined(__linux__) + auto config_ops = info.GetConfigOptions().GetConfigEntry(kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16); + use_fastmath_mode_ = (config_ops == "1") && MlasBf16AccelerationSupported(); +#endif } Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc, @@ -48,6 +55,14 @@ class MatMul final : public OpKernel { int64_t trans_b_attr_; bool trans_batch_a_; bool trans_batch_b_; + +#if defined(__aarch64__) && defined(__linux__) + // fastmath mode state + bool use_fastmath_mode_; + // sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2 + // so a minimum of 32 elements is defined to outweigh the additional prepacking overhead + const size_t kFastMathModeKernelsizeThreshold = 32; +#endif }; } // namespace onnxruntime diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp index 76b9b308fe98f..45ff25c4fdd90 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorCast.cpp @@ -29,7 +29,7 @@ class DmlOperatorCast : public DmlOperator castDesc.OutputTensor = outputDescs.data(); DML_OPERATOR_DESC opDesc = { DML_OPERATOR_CAST, &castDesc }; - + SetDmlOperatorDesc(opDesc, kernelInfo); } @@ -49,5 +49,6 @@ class DmlOperatorCast : public DmlOperator DML_OP_DEFINE_CREATION_FUNCTION(Cast, DmlOperatorCast); DML_OP_DEFINE_CREATION_FUNCTION(CastLike15, DmlOperatorCast); +DML_OP_DEFINE_CREATION_FUNCTION(CastLike19, DmlOperatorCast); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp index ab8ddbfe91bf0..16bb10f004f91 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorElementWise.cpp @@ -487,7 +487,7 @@ class DmlOperatorElementwisePow : public DmlOperator Initialize(kernelInfo, kernelInputIndices, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_CONSTANT_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -497,11 +497,11 @@ class DmlOperatorElementwisePow : public DmlOperator SetDmlOperatorDesc({ DML_OPERATOR_ELEMENT_WISE_CONSTANT_POW, &opDesc}, kernelInfo); } else - { + { Initialize(kernelInfo, std::nullopt, std::nullopt, kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0)); std::vector inputDescs = GetDmlInputDescs(); - std::vector outputDescs = GetDmlOutputDescs(); + std::vector outputDescs = GetDmlOutputDescs(); DML_ELEMENT_WISE_POW_OPERATOR_DESC opDesc = {}; opDesc.InputTensor = &inputDescs[0]; @@ -519,13 +519,16 @@ class DmlOperatorElementwiseQLinear : public DmlOperator public: DmlOperatorElementwiseQLinear(const MLOperatorKernelCreationContext& kernelInfo) : DmlOperator(kernelInfo) { - ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() == 3); + + ML_CHECK_VALID_ARGUMENT(kernelInfo.GetInputCount() >= 2); ML_CHECK_VALID_ARGUMENT(kernelInfo.GetOutputCount() == 1); + Initialize(kernelInfo, std::nullopt, std::nullopt); + std::vector outputShape = kernelInfo.GetTensorShapeDescription().GetOutputTensorShape(0); const uint32_t outputShapeDimCount = gsl::narrow_cast(outputShape.size()); - - Initialize(kernelInfo, std::nullopt, std::nullopt); + const DML_TENSOR_DATA_TYPE inputDataType = m_inputTensorDescs[0].GetDmlDataType(); + bool hasZeroPointTensor = kernelInfo.IsInputValid(2); uint32_t axis = 0; @@ -541,9 +544,14 @@ class DmlOperatorElementwiseQLinear : public DmlOperator axis = Dml::HandleNegativeAxis(signedAxis, outputShapeDimCount, /*validateAxis*/ false); } - // Explicitly reshape each of the inputs after the first input (scale and zero point tensors). + // Explicitly reshape each of the inputs after the first input (scale tensor and optional zero point tensor). for (uint32_t index = 1, inputCount = gsl::narrow_cast(m_inputTensorDescs.size()); index < inputCount; ++index) { + if (!kernelInfo.IsInputValid(index)) + { + continue; + } + auto edgeDesc = kernelInfo.GetInputEdgeDescription(index); assert(edgeDesc.edgeType == MLOperatorEdgeType::Tensor); @@ -587,12 +595,8 @@ class DmlOperatorElementwiseQLinear : public DmlOperator TOperatorDesc opDesc = {}; opDesc.InputTensor = &inputDescs[0]; opDesc.ScaleTensor = &inputDescs[1]; - opDesc.ZeroPointTensor = &inputDescs[2]; + opDesc.ZeroPointTensor = hasZeroPointTensor ? &inputDescs[2] : nullptr; opDesc.OutputTensor = &outputDescs[0]; - - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ScaleTensor, 1); - TryConvertTensorToBroadcastScalar(kernelInfo, opDesc.ZeroPointTensor, 2); - SetDmlOperatorDesc({ApiTraits::OperatorDescTraits::Type, &opDesc}, kernelInfo); } }; diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp index a014db5adbe61..b243f7e741a70 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/DmlOperatorPadding.cpp @@ -51,6 +51,12 @@ class DmlOperatorPadding : public DmlOperator, public PaddingHelper { mode = DML_PADDING_MODE_REFLECTION; } +#if DML_TARGET_VERSION >= 0x6300 + else if (modeString == AttrValue::Wrap) + { + mode = DML_PADDING_MODE_WRAP; + } +#endif else { ML_INVALID_ARGUMENT("Unknown Pad mode attribute."); @@ -116,5 +122,6 @@ DML_OP_DEFINE_CREATION_FUNCTION(Pad7, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad11, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad13, VersionedKernel); DML_OP_DEFINE_CREATION_FUNCTION(Pad18, VersionedKernel); +DML_OP_DEFINE_CREATION_FUNCTION(Pad19, VersionedKernel); } // namespace Dml diff --git a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp index 15a8051953c79..7b53a1102c5a7 100644 --- a/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp +++ b/onnxruntime/core/providers/dml/DmlExecutionProvider/src/Operators/OperatorRegistration.cpp @@ -358,6 +358,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(Pad7); DML_OP_EXTERN_CREATION_FUNCTION(Pad11); DML_OP_EXTERN_CREATION_FUNCTION(Pad13); DML_OP_EXTERN_CREATION_FUNCTION(Pad18); +DML_OP_EXTERN_CREATION_FUNCTION(Pad19); DML_OP_EXTERN_CREATION_FUNCTION(SpaceToDepth); DML_OP_EXTERN_CREATION_FUNCTION(DepthToSpace); DML_OP_EXTERN_CREATION_FUNCTION(Sqrt); @@ -436,6 +437,7 @@ DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMul); DML_OP_EXTERN_CREATION_FUNCTION(FusedMatMulActivation); DML_OP_EXTERN_CREATION_FUNCTION(Cast); DML_OP_EXTERN_CREATION_FUNCTION(CastLike15); +DML_OP_EXTERN_CREATION_FUNCTION(CastLike19); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyFromHost); DML_OP_EXTERN_CREATION_FUNCTION(MemcpyToHost); DML_OP_EXTERN_CREATION_FUNCTION(TopK7); @@ -746,6 +748,11 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_VER( 11, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 13, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2) /*pads, value*/)}, // https://microsoft.visualstudio.com/OS/_workitems/edit/26007728 {REG_INFO_VER( 18, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, + +#if DML_TARGET_VERSION >= 0x6300 + {REG_INFO_VER( 19, Pad, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1, 2, 3) /*pads, value, axes*/)}, +#endif + {REG_INFO( 7, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 13, SpaceToDepth, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO( 7, DepthToSpace, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -785,6 +792,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY(13, Identity, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(14, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(16, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, + {REG_INFO_COPY(19, Identity, typeNameListDefaultV, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 7, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY( 9, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, {REG_INFO_COPY(11, Flatten, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported)}, @@ -798,6 +806,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO_COPY( 7, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(13, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, {REG_INFO_COPY(14, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, + {REG_INFO_COPY(19, Reshape, typeNameListDefault, supportedTypeListAllScalars, DmlGraphSupport::Supported, requiredConstantCpuInputs(1))}, // Elementwise {REG_INFO( 7, Sqrt, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, @@ -857,8 +866,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Affine, typeNameListDefault, supportedTypeListFloat16to32, DmlGraphSupport::Supported)}, {REG_INFO( 10, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 10, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, {REG_INFO( 13, DequantizeLinear, typeNameListDefault, supportedTypeListDequantizeLinear, DmlGraphSupport::Supported)}, + {REG_INFO( 19, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, QuantizeLinear, typeNameListTwo, supportedTypeListQuantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO_MS( 1, DequantizeLinear, typeNameListTwo, supportedTypeListDequantizeLinear19, DmlGraphSupport::Supported)}, {REG_INFO( 9, Sign, typeNameListDefault, supportedTypeListFloat16to32Ints8to64, DmlGraphSupport::Supported)}, @@ -943,6 +954,7 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison7, DmlGraphSupport::Supported)}, {REG_INFO( 11, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 13, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Equal, typeNameListLogicalComparison, supportedTypeListLogicalComparison9, DmlGraphSupport::Supported)}, {REG_INFO( 7, Not, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, And, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, {REG_INFO( 7, Or, typeNameListDefault, supportedTypeListBool, DmlGraphSupport::Supported)}, @@ -1004,7 +1016,9 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 9, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 13, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO( 19, Cast, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO_VER( 15, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, + {REG_INFO_VER( 19, CastLike, typeNameListTwo, supportedTypeListCast, DmlGraphSupport::Supported)}, {REG_INFO( 7, MemcpyFromHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO( 7, MemcpyToHost, typeNameListDefault, supportedTypeListAll)}, {REG_INFO_VER( 7, TopK, typeNameListTopK, supportedTypeListTopK, DmlGraphSupport::Supported)}, @@ -1015,8 +1029,10 @@ constexpr static OperatorRegistrationInformation operatorRegistrationInformation {REG_INFO( 7, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 15, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Shape, typeNameShape, supportedTypeListShape, DmlGraphSupport::NotSupported)}, {REG_INFO( 7, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO( 13, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, + {REG_INFO( 19, Size, typeNameSize, supportedTypeListSize, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS( 9, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, {REG_INFO_DYNAMIC_OUTPUTS(13, NonZero, typeNameListDefault, supportedTypeListNonZero, DmlGraphSupport::NotSupported)}, diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h index e3df1d00b3e8a..9c5d021f52b36 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/Attributes.h @@ -149,5 +149,6 @@ namespace AttrValue static constexpr const char* NearestNeighbor = "NN"; static constexpr const char* NotSet = "NOTSET"; static constexpr const char* Reflect = "reflect"; + static constexpr const char* Wrap = "wrap"; } // namespace AttrValue diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h index 0e0e6bb1eaf5c..d4b44f6fa8a9d 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.h @@ -1589,6 +1589,7 @@ using ShapeInferenceHelper_Pad7 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad11 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad13 = VersionedOpsetHelper; using ShapeInferenceHelper_Pad18 = VersionedOpsetHelper; +using ShapeInferenceHelper_Pad19 = VersionedOpsetHelper; using ShapeInferenceHelper_SpaceToDepth = SpaceToDepthHelper; using ShapeInferenceHelper_DepthToSpace = DepthToSpaceHelper; @@ -1606,6 +1607,7 @@ using ShapeInferenceHelper_Expand = ExpandHelper; using ShapeInferenceHelper_Reshape7 = ReshapeHelper; using ShapeInferenceHelper_Reshape13 = ReshapeHelper; using ShapeInferenceHelper_Reshape14 = ReshapeHelper; +using ShapeInferenceHelper_Reshape19 = ReshapeHelper; using ShapeInferenceHelper_ConstantOfShape = ConstantOfShapeHelper; using ShapeInferenceHelper_Tile = TileHelper; using ShapeInferenceHelper_Resize10 = VersionedOpsetHelper; @@ -1725,6 +1727,7 @@ using ShapeInferenceHelper_Identity7 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity13 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Identity16 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_Identity19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_MatMul = MatMulHelper; using ShapeInferenceHelper_MatMulInteger = MatMulHelper; using ShapeInferenceHelper_QLinearMatMul = QLinearMatMulHelper; @@ -1750,6 +1753,7 @@ using ShapeInferenceHelper_CumSum14 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_Range = RangeHelper; using ShapeInferenceHelper_CastLike15 = GetOutputShapeAsInputShapeHelper; +using ShapeInferenceHelper_CastLike19 = GetOutputShapeAsInputShapeHelper; using ShapeInferenceHelper_DmlFusedConv = ConvHelper; using ShapeInferenceHelper_DmlFusedConvTranspose = ConvTransposeHelper; diff --git a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h index 8438bc620712c..57cb009b72ebc 100644 --- a/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h +++ b/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorVersions.h @@ -413,6 +413,17 @@ namespace OperatorHelper namespace OnnxOperatorSet19 { static const int sc_sinceVer_AveragePool = 19; + static const int sc_sinceVer_Pad = 19; + static const int sc_sinceVer_Cast = 19; + static const int sc_sinceVer_CastLike = 19; + static const int sc_sinceVer_Constant = 19; + static const int sc_sinceVer_Equal = 19; + static const int sc_sinceVer_Identity = 19; + static const int sc_sinceVer_QuantizeLinear = 19; + static const int sc_sinceVer_DequantizeLinear = 19; + static const int sc_sinceVer_Reshape = 19; + static const int sc_sinceVer_Shape = 19; + static const int sc_sinceVer_Size = 19; } namespace MsftOperatorSet1 diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index c2ff2ebc39e13..af9658271d210 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -98,6 +98,7 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomai class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Erf); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Sigmoid); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Sigmoid); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, HardSigmoid); class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 6, 12, Log); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, Log); @@ -392,6 +393,7 @@ std::unique_ptr RegisterKernels() { KERNEL_CREATE_INFO(13, Erf), KERNEL_CREATE_INFO_VERSIONED(6, 12, Sigmoid), KERNEL_CREATE_INFO(13, Sigmoid), + KERNEL_CREATE_INFO(6, HardSigmoid), KERNEL_CREATE_INFO_VERSIONED(6, 12, Log), KERNEL_CREATE_INFO(13, Log), diff --git a/onnxruntime/core/providers/js/operators/unary.cc b/onnxruntime/core/providers/js/operators/unary.cc index 78563d30b0136..9082527e3a8d7 100644 --- a/onnxruntime/core/providers/js/operators/unary.cc +++ b/onnxruntime/core/providers/js/operators/unary.cc @@ -77,6 +77,9 @@ JSEP_KERNEL_IMPL(Sigmoid, Sigmoid) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Sigmoid, 6, 12, Sigmoid) JSEP_ELEMENTWISE_KERNEL(Sigmoid, 13, Sigmoid) +JSEP_CLASS_IMPL_ATTRIBUTE_FLOAT_2_DEFAULT(HardSigmoid, HardSigmoid, alpha, 0.2, beta, 0.5) +JSEP_ELEMENTWISE_KERNEL(HardSigmoid, 6, HardSigmoid) + JSEP_KERNEL_IMPL(Log, Log) JSEP_ELEMENTWISE_VERSIONED_KERNEL(Log, 6, 12, Log) JSEP_ELEMENTWISE_KERNEL(Log, 13, Log) diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc index 193e4f5ff2a31..973b81d337c81 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc @@ -17,6 +17,7 @@ #include "core/framework/endian_utils.h" #include "core/common/logging/capture.h" #include "core/providers/qnn/builder/onnx_ctx_model_helper.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" #ifdef _WIN32 #include @@ -329,9 +330,37 @@ Status QnnBackendManager::CreateDevice() { return Status::OK(); } + qnn::QnnConfigsBuilder device_configs_builder(QNN_DEVICE_CONFIG_INIT, + {}); + if (qnn_backend_type_ == QnnBackendType::HTP) { + // Set SoC Model. The *enum* Qnn_SocModel_t is deprecated and will not be updated in the future. Therefore, + // must use the latest SDK documentation to get the SoC model of the latest HW. + if (soc_model_ != QNN_SOC_MODEL_UNKNOWN) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_SOC; + custom_config.socModel = soc_model_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + + // Set the minimum HTP architecture. The driver will use ops that are compatible with this minimum architecture. + if (htp_arch_ != QNN_HTP_DEVICE_ARCH_NONE) { + QnnHtpDevice_CustomConfig_t& custom_config = device_configs_builder.PushCustomConfig(); + custom_config.option = QNN_HTP_DEVICE_CONFIG_OPTION_ARCH; + custom_config.arch.arch = htp_arch_; + custom_config.arch.deviceId = device_id_; + + QnnDevice_Config_t& device_config = device_configs_builder.PushConfig(); + device_config.option = QNN_DEVICE_CONFIG_OPTION_CUSTOM; + device_config.customConfig = &custom_config; + } + } + LOGS_DEFAULT(INFO) << "Create device."; if (nullptr != qnn_interface_.deviceCreate) { - auto result = qnn_interface_.deviceCreate(log_handle_, nullptr, &device_handle_); + auto result = qnn_interface_.deviceCreate(log_handle_, device_configs_builder.GetQnnConfigs(), &device_handle_); if (QNN_SUCCESS != result) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Failed to create device. Error: ", result); } diff --git a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h index 58f207efb9e95..f7b8947ab84bb 100644 --- a/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h +++ b/onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h @@ -17,6 +17,7 @@ #include #include "HTP/QnnHtpDevice.h" #include "QnnLog.h" +#include "QnnTypes.h" #include "System/QnnSystemInterface.h" #include "core/common/status.h" #include "core/common/logging/logging.h" @@ -35,13 +36,19 @@ class QnnBackendManager { uint32_t rpc_control_latency, HtpPerformanceMode htp_performance_mode, ContextPriority context_priority, - std::string&& qnn_saver_path) + std::string&& qnn_saver_path, + uint32_t device_id, + QnnHtpDevice_Arch_t htp_arch, + uint32_t soc_model) : backend_path_(backend_path), profiling_level_(profiling_level), rpc_control_latency_(rpc_control_latency), htp_performance_mode_(htp_performance_mode), context_priority_(context_priority), - qnn_saver_path_(qnn_saver_path) { + qnn_saver_path_(qnn_saver_path), + device_id_(device_id), + htp_arch_(htp_arch), + soc_model_(soc_model) { } ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnBackendManager); @@ -233,6 +240,9 @@ class QnnBackendManager { #endif const std::string qnn_saver_path_; uint32_t htp_power_config_client_id_ = 0; + uint32_t device_id_ = 0; + QnnHtpDevice_Arch_t htp_arch_ = QNN_HTP_DEVICE_ARCH_NONE; + uint32_t soc_model_ = QNN_SOC_MODEL_UNKNOWN; }; } // namespace qnn diff --git a/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h new file mode 100644 index 0000000000000..9dd9bbaa08d64 --- /dev/null +++ b/onnxruntime/core/providers/qnn/builder/qnn_configs_helper.h @@ -0,0 +1,90 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +namespace onnxruntime { +namespace qnn { + +/** + * Helper class for building a null-terminated list of QNN configurations. + * A QNN configuration consists of multiple objects with references to each other. This + * class ensures that all configuration objects have the same lifetime, so that they remain valid + * across calls to qnn_interface.xxxCreate(). + */ +template +class QnnConfigsBuilder { + public: + /** + * Initializes the config build. Provide the initial/default value for each config struct type. + * \param base_config_init The initial/default value for objects of type BaseConfigType. + * \param custom_config_init The initial/default value for objects of type CustomConfigType. + */ + QnnConfigsBuilder(BaseConfigType base_config_init, CustomConfigType custom_config_init) + : base_config_init_(std::move(base_config_init)), custom_config_init_(std::move(custom_config_init)) {} + + /** + * Returns a pointer to the beginning of a null-terminated array of QNN base configurations. + * This result is typically passed to QNN's xxxCreate() APIs. + * + * \return Pointer to null-terminated BaseConfigType* array. + */ + const BaseConfigType** GetQnnConfigs() { + if (config_ptrs_.empty()) { + return nullptr; + } + + if (!IsNullTerminated()) { + config_ptrs_.push_back(nullptr); + } + + return config_ptrs_.data(); + } + + /** + * Creates and returns a reference to a new custom QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default CustomConfigType object. + */ + CustomConfigType& PushCustomConfig() { + custom_configs_.push_back(custom_config_init_); + return custom_configs_.back(); + } + + /** + * Creates and returns a reference to a new QNN configuration object. The object is initialized to + * the QNN recommended default value. The caller is meant to override fields in this object. + * + * \return A reference to a default BaseConfigType object. + */ + BaseConfigType& PushConfig() { + configs_.push_back(base_config_init_); + BaseConfigType& config = configs_.back(); + + // Add pointer to this new config to the list of config pointers. + if (IsNullTerminated()) { + config_ptrs_.back() = &config; // Replace last nullptr entry. + } else { + config_ptrs_.push_back(&config); + } + + return config; + } + + private: + bool IsNullTerminated() const { + return !config_ptrs_.empty() && config_ptrs_.back() == nullptr; + } + + BaseConfigType base_config_init_; + CustomConfigType custom_config_init_; + InlinedVector custom_configs_; + InlinedVector configs_; + InlinedVector config_ptrs_; +}; + +} // namespace qnn +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc deleted file mode 100644 index 63aa01b48e7e2..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.cc +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -const QnnGraph_Config_t** QnnGraphConfigsBuilder::GetQnnGraphConfigs() { - if (graph_config_ptrs_.empty()) { - return nullptr; - } - - if (!IsNullTerminated()) { - graph_config_ptrs_.push_back(nullptr); - } - - return graph_config_ptrs_.data(); -} - -QnnHtpGraph_CustomConfig_t& QnnGraphConfigsBuilder::PushHtpGraphCustomConfig() { - htp_custom_graph_configs_.push_back(QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); - return htp_custom_graph_configs_.back(); -} - -QnnGraph_Config_t& QnnGraphConfigsBuilder::PushGraphConfig() { - graph_configs_.push_back(QNN_GRAPH_CONFIG_INIT); - QnnGraph_Config_t& config = graph_configs_.back(); - - // Add pointer to this new graph config to the list of graph config pointers. - if (IsNullTerminated()) { - graph_config_ptrs_.back() = &config; // Replace last nullptr entry. - } else { - graph_config_ptrs_.push_back(&config); - } - - return config; -} - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h b/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h deleted file mode 100644 index 8c4928fdacbc4..0000000000000 --- a/onnxruntime/core/providers/qnn/builder/qnn_graph_configs_helper.h +++ /dev/null @@ -1,56 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include - -#include "HTP/QnnHtpGraph.h" - -namespace onnxruntime { -namespace qnn { - -/** - * Helper class for building a null-terminated list of QNN Graph configurations. - * A QNN configuration consists of multiple objects with references to each other. This - * class ensures that all configuration objects have the same lifetime, so that they remain valid - * across the call to graphCreate(). - */ -class QnnGraphConfigsBuilder { - public: - /** - * Returns a pointer to the beginning of a null-terminated array of QNN Graph configurations. - * This result is passed QNN's graphCreate() API. - * - * \return Pointer to null-terminated QnnGraph_Config_t* array. - */ - const QnnGraph_Config_t** GetQnnGraphConfigs(); - - /** - * Creates and returns a reference to a new HTP graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnHtpGraph_CustomConfig_t object. - */ - QnnHtpGraph_CustomConfig_t& PushHtpGraphCustomConfig(); - - /** - * Creates and returns a reference to a new graph configuration object. The object is initialized to - * the QNN recommended default value. The caller is meant to override fields in this object. - * - * \return A reference to a default QnnGraph_Config_t object. - */ - QnnGraph_Config_t& PushGraphConfig(); - - private: - bool IsNullTerminated() const { - return !graph_config_ptrs_.empty() && graph_config_ptrs_.back() == nullptr; - } - - InlinedVector htp_custom_graph_configs_; - InlinedVector graph_configs_; - InlinedVector graph_config_ptrs_; -}; - -} // namespace qnn -} // namespace onnxruntime diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc index 56eb1f4f59f33..0310cc2bc8f26 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.cc +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.cc @@ -111,6 +111,22 @@ void QNNExecutionProvider::ParseHtpGraphFinalizationOptimizationMode(const std:: } } +static void ParseHtpArchitecture(const std::string& htp_arch_string, QnnHtpDevice_Arch_t& qnn_htp_arch) { + if (htp_arch_string.empty() || htp_arch_string == "0") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + } else if (htp_arch_string == "68") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V68; + } else if (htp_arch_string == "69") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V69; + } else if (htp_arch_string == "73") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V73; + } else if (htp_arch_string == "75") { + qnn_htp_arch = QNN_HTP_DEVICE_ARCH_V75; + } else { + LOGS_DEFAULT(WARNING) << "Invalid HTP architecture: " << htp_arch_string; + } +} + QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_options_map, const SessionOptions* session_options) : IExecutionProvider{onnxruntime::kQnnExecutionProvider, true} { @@ -223,13 +239,49 @@ QNNExecutionProvider::QNNExecutionProvider(const ProviderOptions& provider_optio } } + static const std::string QNN_DEVICE_ID = "device_id"; + uint32_t device_id = 0; + auto dev_id_pos = provider_options_map.find(QNN_DEVICE_ID); + if (dev_id_pos != provider_options_map.end()) { + int value = std::stoi(dev_id_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid device ID '" << value + << "', only >= 0 allowed. Set to " << device_id << "."; + } else { + device_id = static_cast(value); + } + } + + static const std::string QNN_HTP_ARCH = "htp_arch"; + QnnHtpDevice_Arch_t htp_arch = QNN_HTP_DEVICE_ARCH_NONE; + auto htp_arch_pos = provider_options_map.find(QNN_HTP_ARCH); + if (htp_arch_pos != provider_options_map.end()) { + ParseHtpArchitecture(htp_arch_pos->second, htp_arch); + } + + static const std::string QNN_SOC_MODEL = "soc_model"; + uint32_t soc_model = QNN_SOC_MODEL_UNKNOWN; + auto soc_model_pos = provider_options_map.find(QNN_SOC_MODEL); + if (soc_model_pos != provider_options_map.end()) { + int value = std::stoi(soc_model_pos->second); + if (value < 0) { + LOGS_DEFAULT(WARNING) << "Invalid SoC Model '" << value + << "', only >= 0 allowed. Set to " << soc_model << "."; + } else { + soc_model = static_cast(value); + } + } + qnn_backend_manager_ = std::make_unique( std::move(backend_path), profiling_level, rpc_control_latency, htp_performance_mode, context_priority, - std::move(qnn_saver_path)); + std::move(qnn_saver_path), + device_id, + htp_arch, + soc_model); } bool QNNExecutionProvider::IsNodeSupported(qnn::QnnModelWrapper& qnn_model_wrapper, const NodeUnit& node_unit, @@ -512,25 +564,25 @@ Status QNNExecutionProvider::CreateComputeFunc(std::vector& nod return Status::OK(); } -void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_builder) const { +void QNNExecutionProvider::InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const { if (qnn_backend_manager_->GetQnnBackendType() == qnn::QnnBackendType::HTP) { if (htp_graph_finalization_opt_mode_ != qnn::HtpGraphFinalizationOptimizationMode::kDefault) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config = configs_builder.PushCustomConfig(); htp_graph_opt_config.option = QNN_HTP_GRAPH_CONFIG_OPTION_OPTIMIZATION; htp_graph_opt_config.optimizationOption.type = QNN_HTP_GRAPH_OPTIMIZATION_TYPE_FINALIZE_OPTIMIZATION_FLAG; htp_graph_opt_config.optimizationOption.floatValue = static_cast(htp_graph_finalization_opt_mode_); - QnnGraph_Config_t& graph_opt_config = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config = configs_builder.PushConfig(); graph_opt_config.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config.customConfig = &htp_graph_opt_config; } if (vtcm_size_in_mb_ > 0) { - QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushHtpGraphCustomConfig(); + QnnHtpGraph_CustomConfig_t& htp_graph_opt_config_vtcm = configs_builder.PushCustomConfig(); htp_graph_opt_config_vtcm.option = QNN_HTP_GRAPH_CONFIG_OPTION_VTCM_SIZE; htp_graph_opt_config_vtcm.vtcmSizeInMB = static_cast(vtcm_size_in_mb_); - QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushGraphConfig(); + QnnGraph_Config_t& graph_opt_config_vtcm = configs_builder.PushConfig(); graph_opt_config_vtcm.option = QNN_GRAPH_CONFIG_OPTION_CUSTOM; graph_opt_config_vtcm.customConfig = &htp_graph_opt_config_vtcm; } @@ -547,10 +599,11 @@ Status QNNExecutionProvider::CompileFromOrtGraph(const std::vector qnn_model = std::make_unique(logger, qnn_backend_manager_.get()); - qnn::QnnGraphConfigsBuilder graph_configs_builder; + qnn::QnnConfigsBuilder graph_configs_builder(QNN_GRAPH_CONFIG_INIT, + QNN_HTP_GRAPH_CUSTOM_CONFIG_INIT); InitQnnGraphConfigs(graph_configs_builder); - ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnGraphConfigs())); + ORT_RETURN_IF_ERROR(qnn_model->ComposeGraph(graph_viewer, fused_node, graph_configs_builder.GetQnnConfigs())); ORT_RETURN_IF_ERROR(qnn_model->FinalizeGraphs()); ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput()); diff --git a/onnxruntime/core/providers/qnn/qnn_execution_provider.h b/onnxruntime/core/providers/qnn/qnn_execution_provider.h index d4927f3fa505e..3f75be0efebcd 100644 --- a/onnxruntime/core/providers/qnn/qnn_execution_provider.h +++ b/onnxruntime/core/providers/qnn/qnn_execution_provider.h @@ -5,11 +5,12 @@ #include "core/framework/execution_provider.h" #include "core/framework/session_options.h" +#include "core/graph/model.h" #include #include "core/providers/qnn/builder/qnn_backend_manager.h" #include "core/providers/qnn/builder/qnn_model.h" -#include "core/providers/qnn/builder/qnn_graph_configs_helper.h" -#include "core/graph/model.h" +#include "core/providers/qnn/builder/qnn_configs_helper.h" +#include "HTP/QnnHtpGraph.h" namespace onnxruntime { @@ -58,7 +59,7 @@ class QNNExecutionProvider : public IExecutionProvider { void ParseHtpGraphFinalizationOptimizationMode(const std::string& htp_graph_finalization_opt_mode_string); - void InitQnnGraphConfigs(qnn::QnnGraphConfigsBuilder& configs_holder) const; + void InitQnnGraphConfigs(qnn::QnnConfigsBuilder& configs_builder) const; private: qnn::HtpGraphFinalizationOptimizationMode htp_graph_finalization_opt_mode_ = qnn::HtpGraphFinalizationOptimizationMode::kDefault; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index d7c5098d9dbe4..d7bec337a6be4 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -170,6 +170,8 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_)); MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream)); + + hip_graph_.SetStream(stream); } ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { @@ -177,6 +179,33 @@ ROCMExecutionProvider::PerThreadContext::~PerThreadContext() { ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_))); } +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; +} + +void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { + hip_graph_.Reset(); + hip_graph_.CaptureBegin(); +} + +void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { + hip_graph_.CaptureEnd(); + is_graph_captured_ = true; +} + +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { + return is_graph_captured_; +} + +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { + ORT_ENFORCE(IsGraphCaptured()); + return hip_graph_.Replay(); +} + +void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { + ++regular_run_count_before_graph_capture_; +} + void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) { if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable( "ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead."); @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in if (info.external_allocator_info.UseExternalAllocator()) { use_ep_level_unified_stream_ = true; stream_ = nullptr; + } else if (info.enable_hip_graph) { + // current hip graph implementation only works with single stream + // use EP level unified stream for all the reqeust + HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking)); + use_ep_level_unified_stream_ = true; } else { stream_ = nullptr; } @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart() { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; + GetPerThreadContext().CaptureBegin(); + } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (GetPerThreadContext().IsGraphCaptureAllowed()) { + GetPerThreadContext().CaptureEnd(); + // HIP work issued to a capturing stream doesn’t actually run on the GPU, + // so run the captured graph here to actually execute the work. + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + } else { + GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); + } + } + if (sync_stream) { HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast(stream_))); } - // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), - // PerThreadContext won't be created and there is nothing to - // release. This didn't happen before because we always call - // GetPerThreadContext in OnRunStart. - if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { + // The reason of !IsGraphCaptureEnabled(): + // If hip graph is enabled, the per thread context will not be released + // because the per thread hip graph needs to be maintained and replayed for + // the next run. + // The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end(): + // In extreme cases (e.g., 1-op graph and that op fallbacks to CPU), + // PerThreadContext won't be created and there is nothing to + // release. This didn't happen before because we always call + // GetPerThreadContext in OnRunStart. + if (!IsGraphCaptureEnabled() && + PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) { ReleasePerThreadContext(); } return Status::OK(); } +bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { + return info_.enable_hip_graph; +} + +bool ROCMExecutionProvider::IsGraphCaptured() const { + return GetPerThreadContext().IsGraphCaptured(); +} + +Status ROCMExecutionProvider::ReplayGraph() { + return GetPerThreadContext().ReplayGraph(); +} + namespace rocm { // opset 1 to 9 class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index c4945b9ac2481..37d5f7b42210f 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -10,6 +10,7 @@ #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" #include "core/providers/rocm/rocm_execution_provider_info.h" +#include "core/providers/rocm/rocm_graph.h" #include "core/providers/rocm/rocm_pch.h" #include "core/providers/rocm/shared_inc/rocm_utils.h" #include "core/providers/rocm/shared_inc/rocm_call.h" @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; + bool IsGraphCaptureEnabled() const override; + bool IsGraphCaptured() const override; + Status ReplayGraph() override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider { ROCMExecutionProviderInfo info_; hipDeviceProp_t device_prop_; bool external_stream_ = false; + // only used when set user external stream or hip graph hipStream_t stream_ = nullptr; bool use_ep_level_unified_stream_ = false; @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider { } } + bool IsGraphCaptureAllowed() const; + void CaptureBegin(); + void CaptureEnd(); + bool IsGraphCaptured() const; + Status ReplayGraph(); + void IncrementRegularRunCountBeforeGraphCapture(); + private: rocblas_handle rocblas_handle_ = nullptr; miopenHandle_t miopen_handle_ = nullptr; @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr> constant_ones_double_; std::unique_ptr> constant_ones_half_; std::unique_ptr> constant_ones_bfloat16_; + + // Hip graph with multi threads will be supported in the future, so hip_graph_ + // is put under PerThreadContext. + ROCMGraph hip_graph_; + bool is_graph_captured_ = false; + int regular_run_count_before_graph_capture_ = 0; + + // There is chance that the second regular run allocates GPU memory for causes like: + // (1) memory pattern is enabled. (2) arena allocation for stream. + // Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs + // to allocate enough memory in Arena before graph capturing. + const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations. }; using PerThreadContextMap = std::unordered_map>; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc index 650635c153640..b557f92287f2b 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc"; constexpr const char* kGpuExternalFree = "gpu_external_free"; constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache"; constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace"; +constexpr const char* kEnableHipGraph = "enable_hip_graph"; constexpr const char* kTunableOpEnable = "tunable_op_enable"; constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable"; constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms"; @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P info.miopen_conv_exhaustive_search) .AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream) .AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace) + .AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph) .AddValueParser( rocm::provider_option_names::kTunableOpEnable, [&info](const std::string& value_str) -> Status { @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution {rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)}, {rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)}, {rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)}, + {rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)}, {rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)}, {rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)}, {rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)}, diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h index e35c0cc0afecc..2f549cc1ac143 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider_info.h @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo { // If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best. bool miopen_conv_use_max_workspace{true}; + bool enable_hip_graph{false}; + rocm::TunableOpInfo tunable_op{}; static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); diff --git a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc index 4d88c25469372..88ef666678b3e 100644 --- a/onnxruntime/core/providers/rocm/rocm_provider_factory.cc +++ b/onnxruntime/core/providers/rocm/rocm_provider_factory.cc @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider { info.has_user_compute_stream = params->has_user_compute_stream != 0; info.user_compute_stream = params->user_compute_stream; info.default_memory_arena_cfg = params->default_memory_arena_cfg; + info.enable_hip_graph = params->enable_hip_graph; info.tunable_op.enable = params->tunable_op_enable; info.tunable_op.tuning_enable = params->tunable_op_tuning_enable; info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms; @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider { rocm_options.user_compute_stream = internal_options.user_compute_stream; } rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg; + rocm_options.enable_hip_graph = internal_options.enable_hip_graph; rocm_options.tunable_op_enable = internal_options.tunable_op.enable; rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable; rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms; diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 6de53243f880f..c69d37a6ab35e 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -153,7 +153,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) { // Empty node provider means CPU EP if (!node_provider.empty() && - node_provider != kCudaExecutionProvider && + !(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) && node_provider != kCpuExecutionProvider) { nodes_on_cpu_and_cuda_eps_only = false; break; @@ -1717,7 +1717,8 @@ common::Status InferenceSession::Initialize() { // now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs. ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve()); - // Currently CUDA graph is only considered by CUDA EP and TRT EP. + // Currently CUDA graph is only considered by CUDA EP and TRT EP, and + // HIP graph is only considered by ROCM EP. // // Check for CUDA EP: // If the CUDA EP is part of the providers list for this session AND @@ -1730,47 +1731,58 @@ common::Status InferenceSession::Initialize() { // The TRT EP is configured to do a graph capture AND // All the graph nodes have been assigned to the TRT EP, // Then the TRT EP is cached for triggering a ReplayGraph() in Run(). - std::vector cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider}; + // + // Check for ROCM EP: + // If the ROCM EP is part of the providers list for this session AND + // The ROCM EP is configured to do a graph capture AND + // All the "compute" graph nodes have been assigned to the ROCM EP, + // Then the ROCM EP is cached for triggering a ReplayGraph() in Run(). + // + std::vector graph_support_ep_list = { + onnxruntime::kTensorrtExecutionProvider, + onnxruntime::kCudaExecutionProvider, + onnxruntime::kRocmExecutionProvider}; - for (auto& it : cuda_graph_support_ep_list) { + for (auto& it : graph_support_ep_list) { auto* target_ep = execution_providers_.Get(it); if (target_ep && target_ep->IsGraphCaptureEnabled()) { - // CUDA Graphs can't work with control flow nodes + // CUDA/HIP Graphs can't work with control flow nodes if (HasControlflowNodes(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << "as the model has control flow nodes which can't be supported by CUDA Graphs."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - "as the model has control flow nodes which can't be supported by CUDA Graphs.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.")); } - if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) { + if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 || + strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) { // Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes // The reasoning behind this logic is that certain shape nodes will be forced onto CPU // and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP // which is all we care about. if (!AreAllComputeNodesAssignedToCudaEp(graph)) { - LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user " - << " as all compute graph nodes have not been partitioned to the CUDA EP."; + LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user " + << " as all compute graph nodes have not been partitioned to the CUDA/HIP EP."; ORT_RETURN_IF_ERROR_SESSIONID_( ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, - "This session cannot use the CUDA Graph feature as requested by the user " - " as all compute graph nodes have not been partitioned to the CUDA EP.")); + "This session cannot use the CUDA/HIP Graph feature as requested by the user " + " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.")); } // Log a warning for the user to know that there are shape subgraphs that will execute on CPU if (HasShapeSubgraphNodes(graph)) { LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. " - << "Use the CUDA Graph feature with caution. " + << "Use the CUDA/HIP Graph feature with caution. " << "As long as the intermediate shapes produced in the model " - << "using the representative input used to capture the CUDA graph, " + << "using the representative input used to capture the CUDA/HIP graph, " << "will match the shapes produced in the model for other inputs " << "of the same shape as the representative input (common case), " - << "it is safe to use the CUDA Graph feature."; + << "it is safe to use the CUDA/HIP Graph feature."; } } else { // Following code path is for TRT EP currently. @@ -1789,7 +1801,7 @@ common::Status InferenceSession::Initialize() { } } - LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user."; + LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user."; cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep); break; // Make sure only one ep can run CUDA graph. } @@ -2479,7 +2491,9 @@ Status InferenceSession::Run(const RunOptions& run_options, // As N+1 inference runs (N for memory allocation and 1 for graph capturing) // are needed before replaying the captured graph, here run N inference runs recursively until graph captured, // so that users just need one session run to capture the graph. - // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP. + // N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, + // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, + // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9e8c215af39f0..2330f80eb84df 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -2346,6 +2346,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider options->has_user_compute_stream = 0; options->user_compute_stream = nullptr; options->default_memory_arena_cfg = nullptr; + options->enable_hip_graph = false; options->tunable_op_enable = 0; options->tunable_op_tuning_enable = 0; options->tunable_op_max_tuning_duration_ms = 0; diff --git a/onnxruntime/python/onnxruntime_inference_collection.py b/onnxruntime/python/onnxruntime_inference_collection.py index 1a3e22142f80e..09f768f53ea65 100644 --- a/onnxruntime/python/onnxruntime_inference_collection.py +++ b/onnxruntime/python/onnxruntime_inference_collection.py @@ -466,7 +466,7 @@ def _create_inference_session(self, providers, provider_options, disabled_optimi session_options = self._sess_options if self._sess_options else C.get_default_session_options() - self._register_ep_custom_ops(session_options, providers, provider_options) + self._register_ep_custom_ops(session_options, providers, provider_options, available_providers) if self._model_path: sess = C.InferenceSession(session_options, self._model_path, True, self._read_config_from_model) @@ -510,11 +510,15 @@ def _reset_session(self, providers, provider_options): self._sess_options = self._sess_options_initial self._create_inference_session(providers, provider_options) - def _register_ep_custom_ops(self, session_options, providers, provider_options): + def _register_ep_custom_ops(self, session_options, providers, provider_options, available_providers): for i in range(len(providers)): - if providers[i] == "TensorrtExecutionProvider": + if providers[i] in available_providers and providers[i] == "TensorrtExecutionProvider": C.register_tensorrt_plugins_as_custom_ops(session_options, provider_options[i]) - elif isinstance(providers[i], tuple) and providers[i][0] == "TensorrtExecutionProvider": + elif ( + isinstance(providers[i], tuple) + and providers[i][0] in available_providers + and providers[i][0] == "TensorrtExecutionProvider" + ): C.register_tensorrt_plugins_as_custom_ops(session_options, providers[i][1]) diff --git a/onnxruntime/python/tools/transformers/fusion_attention.py b/onnxruntime/python/tools/transformers/fusion_attention.py index d11cb91d98b0c..f48cabd25fc5c 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention.py +++ b/onnxruntime/python/tools/transformers/fusion_attention.py @@ -129,6 +129,9 @@ def __init__( self.num_heads_warning = True self.hidden_size_warning = True + self.shape_infer = None + self.shape_infer_done = True + def get_num_heads_and_hidden_size_from_concat(self, concat: NodeProto) -> Tuple[int, int]: """ Detect num_heads and hidden_size from Concat node in the following subgraph: @@ -202,12 +205,15 @@ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int] return num_heads, hidden_size def get_add_qk_str(self, add_qk: NodeProto): - shape_infer = self.model.infer_runtime_shape(update=True) - if shape_infer is None: + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is None: return None - input_0_shape = shape_infer.get_edge_shape(add_qk.input[0]) - input_1_shape = shape_infer.get_edge_shape(add_qk.input[1]) + input_0_shape = self.shape_infer.get_edge_shape(add_qk.input[0]) + input_1_shape = self.shape_infer.get_edge_shape(add_qk.input[1]) if input_0_shape is None or input_1_shape is None: logger.debug(f"one of the inputs of {add_qk} is None") diff --git a/onnxruntime/python/tools/transformers/fusion_attention_unet.py b/onnxruntime/python/tools/transformers/fusion_attention_unet.py index 250ec5f3eb159..9a353e7e2d675 100644 --- a/onnxruntime/python/tools/transformers/fusion_attention_unet.py +++ b/onnxruntime/python/tools/transformers/fusion_attention_unet.py @@ -28,10 +28,19 @@ def __init__( enable_packed_qkv: bool, enable_packed_kv: bool, ): - super().__init__(model, "MultiHeadAttention" if is_cross_attention else "Attention", ["LayerNormalization"]) + super().__init__( + model, + "Attention" if is_cross_attention and enable_packed_qkv else "MultiHeadAttention", + ["LayerNormalization"], + ) self.hidden_size = hidden_size self.num_heads = num_heads self.is_cross_attention = is_cross_attention + + # Note: pack Q/K/V or K/V weights into one tensor make it harder for updating initializers for LoRA. + # To support LoRA, it is better to use separated Q, K and V inputs in offline optimization, + # and CUDA operator pre-packs those tensors to preferred format based on available kernels. + # In this way, we can support LoRA and get optimal performance at same time. self.enable_packed_qkv = enable_packed_qkv self.enable_packed_kv = enable_packed_kv @@ -170,9 +179,7 @@ def create_attention_node( return None # Sometimes weights are stored in fp16 - if q_weight.data_type == 10: - logger.debug("weights are in fp16. Please run fp16 conversion after optimization") - return None + float_type = q_weight.data_type qw = NumpyHelper.to_array(q_weight) kw = NumpyHelper.to_array(k_weight) @@ -212,7 +219,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_QKV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_weight.shape[0], qkv_weight.shape[1]], vals=qkv_weight, ) @@ -235,8 +242,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[attention_node_name + "_input"], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_qkv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -251,7 +261,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qw_in_size, qkv_weight_dim], vals=qkv_weight, ) @@ -280,7 +290,7 @@ def create_attention_node( matmul_node_name = self.model.create_node_name("MatMul", name_prefix="MatMul_KV") self.add_initializer( name=matmul_node_name + "_weight", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[kv_weight.shape[0], kv_weight.shape[1]], vals=kv_weight, ) @@ -303,8 +313,11 @@ def create_attention_node( reshape_node = helper.make_node( "Reshape", - inputs=[matmul_node_name + "_out", matmul_node_name + "_reshape_shape"], - outputs=[k_matmul.output[0]], + inputs=[ + matmul_node_name + "_out", + matmul_node_name + "_reshape_shape", + ], + outputs=[attention_node_name + "_kv_input"], name=matmul_node_name + "_reshape", ) self.node_name_to_graph_name[reshape_node.name] = self.this_graph_name @@ -317,7 +330,7 @@ def create_attention_node( self.add_initializer( name=attention_node_name + "_qkv_bias", - data_type=TensorProto.FLOAT, + data_type=float_type, dims=[qkv_bias_dim], vals=qkv_bias, ) @@ -330,7 +343,7 @@ def create_attention_node( attention_node_name + "_qkv_bias", ] else: - attention_inputs = [attention_node_name + "_input"] + attention_inputs = [attention_node_name + "_qkv_input"] else: if not self.enable_packed_kv: attention_inputs = [ @@ -342,7 +355,7 @@ def create_attention_node( else: attention_inputs = [ q_matmul.output[0], - k_matmul.output[0], + attention_node_name + "_kv_input", ] attention_node = helper.make_node( @@ -839,6 +852,9 @@ def create_attention_node_lora( return attention_node def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node): + if self.fuse_a1111_fp16(normalize_node, input_name_to_nodes, output_name_to_node): + return + node_before_layernorm = self.model.match_parent(normalize_node, "Add", 0) # In SD 1.5, for self attention, LayerNorm has parent Reshape @@ -1168,3 +1184,125 @@ def match_lora_path( return (lora_mul_node, lora_matmul_1_node) return None + + def fuse_a1111_fp16(self, normalize_node, input_name_to_nodes, output_name_to_node): + """Fuse attention of fp16 UNet exported in A1111 (stable diffusion webui) extension""" + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Add"], [0, 0]) + if entry_path is None: + entry_path = self.model.match_parent_path(normalize_node, ["Cast", "Reshape"], [0, 0]) + if entry_path is None: + return False + _cast, node_before_layernorm = entry_path + + root_input = node_before_layernorm.output[0] + + children_nodes = input_name_to_nodes[root_input] + skip_add = None + for node in children_nodes: + if node.op_type == "Add": # SkipLayerNormalization fusion is not applied yet + skip_add = node + break + if skip_add is None: + return False + + match_qkv = self.match_qkv_a1111(root_input, skip_add) + if match_qkv is None: + return False + + ( + reshape_qkv, + transpose_qkv, + reshape_q, + matmul_q, + matmul_k, + matmul_v, + ) = match_qkv + + cast_q = self.model.match_parent(matmul_q, "Cast", 0) + cast_k = self.model.match_parent(matmul_k, "Cast", 0) + cast_v = self.model.match_parent(matmul_v, "Cast", 0) + if not ( + cast_q is not None + and cast_k is not None + and (cast_q == cast_k if not self.is_cross_attention else cast_q != cast_k) + and cast_k == cast_v + ): + return False + + if cast_q.input[0] != normalize_node.output[0]: + return False + + attention_last_node = reshape_qkv + + q_num_heads = self.get_num_heads(reshape_q, True) or self.get_num_heads(reshape_q, False) + if q_num_heads <= 0: + logger.debug("fuse_attention: failed to detect num_heads") + return False + + q_hidden_size = self.get_hidden_size(normalize_node) + + # number of heads are same for all the paths, hence to create attention node, we pass the q_num_heads + new_node = self.create_attention_node( + matmul_q, + matmul_k, + matmul_v, + q_num_heads, + q_hidden_size, + input=matmul_q.input[0], + output=attention_last_node.output[0], + ) + if new_node is None: + return False + + self.nodes_to_add.append(new_node) + self.node_name_to_graph_name[new_node.name] = self.this_graph_name + + self.nodes_to_remove.extend([attention_last_node, transpose_qkv]) + + # Use prune graph to remove nodes since they are shared by all attention nodes. + self.prune_graph = True + return True + + def match_qkv_a1111(self, root_input, skip_add): + """Match Q, K and V paths exported by A1111 (stable diffusion webui) extension""" + another_input = 1 if skip_add.input[0] == root_input else 0 + qkv_nodes = self.model.match_parent_path( + skip_add, + ["Add", "MatMul", "Reshape", "Transpose", "Reshape", "Einsum"], + [another_input, None, None, 0, 0, 0], + ) + + if qkv_nodes is None: + return None + + (_, _, reshape_qkv, transpose_qkv, reshape_einsum, einsum_qkv) = qkv_nodes + + v_nodes = self.model.match_parent_path(einsum_qkv, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if v_nodes is None: + logger.debug("fuse_attention: failed to match v path") + return None + (_, _, _, matmul_v) = v_nodes + + qk_nodes = self.model.match_parent_path( + einsum_qkv, ["Cast", "Cast", "Softmax", "Mul", "Einsum"], [0, 0, 0, 0, None] + ) + if qk_nodes is not None: + (_, _, _softmax_qk, _, einsum_qk) = qk_nodes + else: + logger.debug("fuse_attention: failed to match qk path") + return None + + q_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [0, 0, 0, 0]) + if q_nodes is None: + logger.debug("fuse_attention: failed to match q path") + return None + (_, _transpose_q, reshape_q, matmul_q) = q_nodes + + k_nodes = self.model.match_parent_path(einsum_qk, ["Reshape", "Transpose", "Reshape", "MatMul"], [1, 0, 0, 0]) + if k_nodes is None: + logger.debug("fuse_attention: failed to match k path") + return None + + (_, _, _, matmul_k) = k_nodes + + return reshape_qkv, transpose_qkv, reshape_q, matmul_q, matmul_k, matmul_v diff --git a/onnxruntime/python/tools/transformers/fusion_embedlayer.py b/onnxruntime/python/tools/transformers/fusion_embedlayer.py index bc38399e3cce5..42156d9123383 100644 --- a/onnxruntime/python/tools/transformers/fusion_embedlayer.py +++ b/onnxruntime/python/tools/transformers/fusion_embedlayer.py @@ -28,7 +28,9 @@ def __init__(self, model: OnnxModel, description: str = "no mask"): description, ) self.utils = FusionUtils(model) - self.shape_infer_helper = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = None + self.shape_infer_done = False + # The following will be reset in each fuse call of FusionEmbedLayerNormalization self.attention = None self.embed_node = None @@ -329,9 +331,13 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None position_ids = position_embedding_gather.input[1] - if self.shape_infer_helper is not None: - input_ids_shape = self.shape_infer_helper.get_edge_shape(input_ids) - position_ids_shape = self.shape_infer_helper.get_edge_shape(position_ids) + if not self.shape_infer_done: + self.shape_infer = self.model.infer_runtime_shape(update=True) + self.shape_infer_done = True + + if self.shape_infer is not None: + input_ids_shape = self.shape_infer.get_edge_shape(input_ids) + position_ids_shape = self.shape_infer.get_edge_shape(position_ids) assert input_ids_shape and position_ids_shape if not ( len(input_ids_shape) == 2 @@ -345,11 +351,11 @@ def check_embedding(self, word_embedding_gather, segment_embedding_gather, posit ) return False - if segment_ids and not self.shape_infer_helper.compare_shape(input_ids, segment_ids): + if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids): logger.info( "Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {} != {}".format( input_ids_shape, - self.shape_infer_helper.get_edge_shape(segment_ids), + self.shape_infer.get_edge_shape(segment_ids), ) ) return False diff --git a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py index f1d803a3cc082..4d9913f427b37 100644 --- a/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py +++ b/onnxruntime/python/tools/transformers/fusion_gemmfastgelu.py @@ -32,7 +32,7 @@ def get_dimensions(self, input_name: str) -> Union[int, None]: return self.get_dimensions_from_tensor_proto(graph_input) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py index 141ebb1f95a11..5233fdf272fbd 100644 --- a/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py +++ b/onnxruntime/python/tools/transformers/fusion_nhwc_conv.py @@ -7,7 +7,8 @@ from typing import List from fusion_base import Fusion -from onnx import TensorProto, helper, numpy_helper +from fusion_utils import FusionUtils +from onnx import helper, numpy_helper from onnx_model import OnnxModel logger = getLogger(__name__) @@ -19,6 +20,7 @@ class FusionNhwcConv(Fusion): def __init__(self, model: OnnxModel, update_weight=False): super().__init__(model, "NhwcConv", ["Conv"], "NhwcConv") self.update_weight = update_weight + self.fusion_utils = FusionUtils(model) def create_transpose_node(self, input_name: str, perm: List[int], output_name=None): """Append a Transpose node after an input""" @@ -49,6 +51,15 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): if len(weight.shape) != 4: return + dtype = self.model.get_dtype(nhwc_conv_input) + if not (dtype is not None and weight_tensor.data_type == dtype): + cast_node = self.fusion_utils.add_cast_node( + input_name=nhwc_conv_input, + to_type=weight_tensor.data_type, + output_name_to_node=output_name_to_node, + ) + nhwc_conv_input = cast_node.output[0] + if self.update_weight: # Transpose weights from NCHW to NHWC weight = weight.transpose(0, 2, 3, 1) @@ -56,7 +67,7 @@ def fuse(self, conv, input_name_to_nodes, output_name_to_node): weight_name = node_name + "_weight_NHWC" self.add_initializer( name=weight_name, - data_type=TensorProto.FLOAT, + data_type=weight_tensor.data_type, dims=list(weight.shape), vals=weight, ) diff --git a/onnxruntime/python/tools/transformers/fusion_shape.py b/onnxruntime/python/tools/transformers/fusion_shape.py index bc32d78eda66c..dfa77fc7d0221 100644 --- a/onnxruntime/python/tools/transformers/fusion_shape.py +++ b/onnxruntime/python/tools/transformers/fusion_shape.py @@ -29,12 +29,12 @@ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[i return None def get_dimensions(self, input_name: str) -> Union[int, None]: - graph_input = self.model.find_graph_input(input_name) - if graph_input: - return self.get_dimensions_from_tensor_proto(graph_input) + shape = self.model.get_shape(input_name) + if shape is not None: + return len(shape) if not self.shape_infer_done: - self.shape_infer = self.model.infer_runtime_shape({}, update=True) + self.shape_infer = self.model.infer_runtime_shape(update=True) self.shape_infer_done = True if self.shape_infer is not None: diff --git a/onnxruntime/python/tools/transformers/fusion_utils.py b/onnxruntime/python/tools/transformers/fusion_utils.py index afc968fab46c1..726c587ff7043 100644 --- a/onnxruntime/python/tools/transformers/fusion_utils.py +++ b/onnxruntime/python/tools/transformers/fusion_utils.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- from logging import getLogger -from typing import Tuple +from typing import Optional, Tuple import numpy from numpy import array_equal, ndarray @@ -29,17 +29,7 @@ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]: return False, input_name def cast_input(self, input_name: str, target_type="int32"): - cast_output = input_name + "_" + target_type - - # Avoid consequent Cast nodes. - inputs = [input_name] - output_name_to_node = self.model.output_name_to_node() - if input_name in output_name_to_node: - parent_node = output_name_to_node[input_name] - if parent_node and parent_node.op_type == "Cast": - inputs = [parent_node.input[0]] - - cast_node = helper.make_node("Cast", inputs=inputs, outputs=[cast_output]) + output_name = input_name + "_" + target_type if target_type == "int32": to_type = int(TensorProto.INT32) @@ -50,10 +40,36 @@ def cast_input(self, input_name: str, target_type="int32"): else: raise ValueError("Invalid target_type: {target_type}") + cast_node = self.add_cast_node(input_name, to_type, output_name) + + return output_name, cast_node + + def add_cast_node( + self, + input_name: str, + to_type: int, + output_name: Optional[str] = None, + output_name_to_node=None, + graph_name: Optional[str] = None, + ): + if output_name is None: + output_name = input_name + f"_cast_to_{to_type}" + + # Avoid consequent Cast nodes. + inputs = [input_name] + if output_name_to_node is None: + output_name_to_node = self.model.output_name_to_node() + if input_name in output_name_to_node: + parent_node = output_name_to_node[input_name] + if parent_node and parent_node.op_type == "Cast": + inputs = [parent_node.input[0]] + + cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name]) + cast_node.attribute.extend([helper.make_attribute("to", to_type)]) - self.model.add_node(cast_node) + self.model.add_node(cast_node, graph_name=graph_name) - return cast_output, cast_node + return cast_node def cast_input_to_int32(self, input_name: str): return self.cast_input(input_name, "int32") @@ -224,9 +240,10 @@ def check_node_input_value(self, node, input_index: int, expected_value): def remove_identity_nodes(self): """Remove Identity nodes, except those right before graph output.""" nodes_to_remove = [] + graph_output_names = self.model.get_graphs_output_names() for node in self.model.nodes(): if node.op_type == "Identity": - if node.output[0] not in self.model.get_graphs_output_names(): + if node.output[0] not in graph_output_names: self.model.replace_input_of_all_nodes(node.output[0], node.input[0]) nodes_to_remove.append(node) diff --git a/onnxruntime/python/tools/transformers/import_utils.py b/onnxruntime/python/tools/transformers/import_utils.py new file mode 100644 index 0000000000000..9755a26b7b004 --- /dev/null +++ b/onnxruntime/python/tools/transformers/import_utils.py @@ -0,0 +1,20 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +import importlib.metadata +import importlib.util + + +def is_installed(package): + try: + dist = importlib.metadata.distribution(package) + except importlib.metadata.PackageNotFoundError: + try: + spec = importlib.util.find_spec(package) + except ModuleNotFoundError: + return False + + return spec is not None + + return dist is not None diff --git a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md index b10c10c87ee57..8607485bc265b 100644 --- a/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +++ b/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md @@ -51,7 +51,7 @@ sh build.sh --config Release --build_shared_lib --parallel --use_cuda --cuda_ve --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=80 \ --allow_running_as_root python3 -m pip install --upgrade pip -python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-1.17.0-cp310-cp310-linux_x86_64.whl --force-reinstall +python3 -m pip install build/Linux/Release/dist/onnxruntime_gpu-*.whl --force-reinstall ``` If the GPU is not A100, change `CMAKE_CUDA_ARCHITECTURES=80` in the command line according to the GPU compute capacity (like 89 for RTX 4090, or 86 for RTX 3090). diff --git a/onnxruntime/python/tools/transformers/onnx_model.py b/onnxruntime/python/tools/transformers/onnx_model.py index 37b39c91b5c15..9d1066b6e372b 100644 --- a/onnxruntime/python/tools/transformers/onnx_model.py +++ b/onnxruntime/python/tools/transformers/onnx_model.py @@ -40,6 +40,12 @@ def initialize(self, model): self.enable_shape_infer: bool = True self.all_graphs: Optional[List[GraphProto]] = None + # Cache of shape and data type from onnx graph to speed up optimization. + # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes) + # Note that these do not cache the symbolic shape inference result. + self._dtype_dict: Optional[Dict[str, int]] = None + self._shape_dict: Optional[Dict[str, List]] = None + def disable_shape_inference(self): self.enable_shape_infer = False @@ -519,20 +525,60 @@ def tensor_shape_to_list(self, tensor_type): shape_list.append("?") # shall not happen return shape_list - def get_dtype(self, input_or_output: str): - """Try get data type given a name (could be initializer, graph input or output).""" - tensor_type_map = {obj.name: obj.type for obj in self.model.graph.value_info} + def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get data type given a name (could be initializer, input or output of graph or node).""" + + if self._dtype_dict is None: + self._dtype_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type + + for initializer in self.model.graph.initializer: + if initializer.name not in self._dtype_dict: + self._dtype_dict[initializer.name] = initializer.data_type - if input_or_output in tensor_type_map: - return tensor_type_map[input_or_output].tensor_type.elem_type + if name in self._dtype_dict: + return self._dtype_dict[name] - graph_input = self.find_graph_input(input_or_output) - if graph_input: - return graph_input.type.tensor_type.elem_type + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type + + return None - graph_output = self.find_graph_output(input_or_output) - if graph_output: - return graph_output.type.tensor_type.elem_type + def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None): + """Try get shape given a name (could be initializer, input or output of graph or node).""" + + if self._shape_dict is None: + self._shape_dict = {} + for value_info in itertools.chain( + self.model.graph.value_info, + self.model.graph.input, + self.model.graph.output, + ): + if value_info.type.tensor_type.HasField("shape"): + shape = [] + for dim in value_info.type.tensor_type.shape.dim: + if dim.dim_param: + shape.append(dim.dim_param) + else: + shape.append(dim.dim_value) + self._shape_dict[value_info.name] = shape + + for initializer in self.model.graph.initializer: + if initializer.name not in self._shape_dict: + self._shape_dict[initializer.name] = initializer.dims + + if name in self._shape_dict: + return self._shape_dict[name] + + if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_: + value_info = symbolic_shape_helper.known_vi_[name] + return value_info.type.tensor_type.elem_type return None @@ -566,23 +612,14 @@ def remove_cascaded_cast_nodes(self): def remove_useless_cast_nodes(self): """Remove cast nodes that are not needed: input and output has same data type.""" shape_infer = self.infer_runtime_shape(update=True) - if shape_infer is None: - logger.info("Skip removing useless cast nodes since shape inference failed.") - return - - def get_data_type(input_or_output_name): - dtype = self.get_dtype(input_or_output_name) - if dtype: - return dtype - if shape_infer.known_vi_[input_or_output_name].type.tensor_type.HasField("elem_type"): - return shape_infer.known_vi_[input_or_output_name].type.tensor_type.elem_type - return None + if self.enable_shape_infer and shape_infer is None: + logger.warning("shape inference failed which might impact useless cast node detection.") nodes_to_remove = [] for node in self.nodes(): if node.op_type == "Cast": - input_dtype = get_data_type(node.input[0]) - output_dtype = get_data_type(node.output[0]) + input_dtype = self.get_dtype(node.input[0], shape_infer) + output_dtype = self.get_dtype(node.output[0], shape_infer) if input_dtype and input_dtype == output_dtype: nodes_to_remove.append(node) @@ -601,7 +638,10 @@ def get_data_type(input_or_output_name): self.replace_input_of_all_nodes(node.output[0], node.input[0]) self.remove_node(node) - logger.info("Removed %d Cast nodes with output type same as input", len(nodes_to_remove)) + logger.info( + "Removed %d Cast nodes with output type same as input", + len(nodes_to_remove), + ) def convert_model_float32_to_float16(self, cast_input_output=True): logger.warning( @@ -1214,7 +1254,10 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): continue for j in range(i + 1, initializer_count): if OnnxModel.has_same_value( - self.model.graph.initializer[i], self.model.graph.initializer[j], cache, cache + self.model.graph.initializer[i], + self.model.graph.initializer[j], + cache, + cache, ): same[j] = i @@ -1223,7 +1266,8 @@ def remove_duplicated_initializer(self, cache: Optional[dict] = None): if same[i] >= 0: count += 1 self.replace_input_of_all_nodes( - self.model.graph.initializer[i].name, self.model.graph.initializer[same[i]].name + self.model.graph.initializer[i].name, + self.model.graph.initializer[same[i]].name, ) if count > 0: diff --git a/onnxruntime/python/tools/transformers/onnx_model_bert.py b/onnxruntime/python/tools/transformers/onnx_model_bert.py index 51deb67ce5bf3..431e64509e3cc 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_bert.py +++ b/onnxruntime/python/tools/transformers/onnx_model_bert.py @@ -126,7 +126,8 @@ def fuse_rotary_embeddings(self): # Remove non-MS domain functions rot_emb_nodes = list( filter( - lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", self.model.graph.node + lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft", + self.model.graph.node, ) ) non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes)) @@ -350,7 +351,11 @@ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bo self.attention_mask.set_mask_format(options.attention_mask_format) if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention): self.attention_fusion = FusionAttention( - self, self.hidden_size, self.num_heads, self.attention_mask, options.use_multi_head_attention + self, + self.hidden_size, + self.num_heads, + self.attention_mask, + options.use_multi_head_attention, ) if (options is None) or options.enable_attention: @@ -415,7 +420,12 @@ def get_fused_operator_statistics(self): "SkipSimplifiedLayerNormalization", "RotaryEmbedding", ] - q_ops = ["QOrderedAttention", "QOrderedGelu", "QOrderedLayerNormalization", "QOrderedMatMul"] + q_ops = [ + "QOrderedAttention", + "QOrderedGelu", + "QOrderedLayerNormalization", + "QOrderedMatMul", + ] for op in ops + q_ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/python/tools/transformers/onnx_model_unet.py b/onnxruntime/python/tools/transformers/onnx_model_unet.py index 4d15b9288e7b6..01298b3576eb1 100644 --- a/onnxruntime/python/tools/transformers/onnx_model_unet.py +++ b/onnxruntime/python/tools/transformers/onnx_model_unet.py @@ -3,7 +3,7 @@ # Licensed under the MIT License. # -------------------------------------------------------------------------- -from logging import getLogger +import logging from typing import Optional from fusion_attention_unet import FusionAttentionUnet @@ -14,11 +14,12 @@ from fusion_options import FusionOptions from fusion_skip_group_norm import FusionSkipGroupNorm from fusion_transpose import FusionInsertTranspose, FusionTranspose +from import_utils import is_installed from onnx import ModelProto from onnx_model import OnnxModel from onnx_model_bert import BertOnnxModel -logger = getLogger(__name__) +logger = logging.getLogger(__name__) class UnetOnnxModel(BertOnnxModel): @@ -94,14 +95,24 @@ def fuse_multi_head_attention(self, options: Optional[FusionOptions] = None): # Self Attention enable_packed_qkv = (options is None) or options.enable_packed_qkv self_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, False, enable_packed_qkv, False + self, + self.hidden_size, + self.num_heads, + is_cross_attention=False, + enable_packed_qkv=enable_packed_qkv, + enable_packed_kv=False, ) self_attention_fusion.apply() # Cross Attention enable_packed_kv = (options is None) or options.enable_packed_kv cross_attention_fusion = FusionAttentionUnet( - self, self.hidden_size, self.num_heads, True, False, enable_packed_kv + self, + self.hidden_size, + self.num_heads, + is_cross_attention=True, + enable_packed_qkv=False, + enable_packed_kv=enable_packed_kv, ) cross_attention_fusion.apply() @@ -110,23 +121,48 @@ def fuse_bias_add(self): fusion.apply() def optimize(self, options: Optional[FusionOptions] = None): + if is_installed("tqdm"): + import tqdm + from tqdm.contrib.logging import logging_redirect_tqdm + + with logging_redirect_tqdm(): + steps = 18 + progress_bar = tqdm.tqdm(range(0, steps), initial=0, desc="fusion") + self._optimize(options, progress_bar) + else: + logger.info("tqdm is not installed. Run optimization without progress bar") + self._optimize(options, None) + + def _optimize(self, options: Optional[FusionOptions] = None, progress_bar=None): if (options is not None) and not options.enable_shape_inference: self.disable_shape_inference() self.utils.remove_identity_nodes() + if progress_bar: + progress_bar.update(1) # Remove cast nodes that having same data type of input and output based on symbolic shape inference. self.utils.remove_useless_cast_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_layer_norm: self.fuse_layer_norm() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_gelu: self.fuse_gelu() + if progress_bar: + progress_bar.update(1) self.preprocess() + if progress_bar: + progress_bar.update(1) self.fuse_reshape() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_group_norm: channels_last = (options is None) or options.group_norm_channels_last @@ -135,42 +171,66 @@ def optimize(self, options: Optional[FusionOptions] = None): insert_transpose_fusion = FusionInsertTranspose(self) insert_transpose_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_splitgelu: bias_split_gelu_fusion = FusionBiasSplitGelu(self) bias_split_gelu_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_attention: + # self.save_model_to_file("before_mha.onnx") self.fuse_multi_head_attention(options) + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_layer_norm: self.fuse_skip_layer_norm() + if progress_bar: + progress_bar.update(1) self.fuse_shape() + if progress_bar: + progress_bar.update(1) # Remove reshape nodes that having same shape of input and output based on symbolic shape inference. self.utils.remove_useless_reshape_nodes() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_skip_group_norm: skip_group_norm_fusion = FusionSkipGroupNorm(self) skip_group_norm_fusion.apply() + if progress_bar: + progress_bar.update(1) if (options is None) or options.enable_bias_skip_layer_norm: # Fuse SkipLayerNormalization and Add Bias before it. self.fuse_add_bias_skip_layer_norm() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_gelu_approximation: self.gelu_approximation() + if progress_bar: + progress_bar.update(1) if options is None or options.enable_nhwc_conv: self.convert_conv_to_nhwc() - self.merge_adjacent_transpose() + if progress_bar: + progress_bar.update(1) if options is not None and options.enable_bias_add: self.fuse_bias_add() + if progress_bar: + progress_bar.update(1) self.postprocess() + if progress_bar: + progress_bar.update(1) logger.info(f"opset version: {self.get_opset_version()}") @@ -190,6 +250,7 @@ def get_fused_operator_statistics(self): "NhwcConv", "BiasAdd", ] + for op in ops: nodes = self.get_nodes_by_op_type(op) op_count[op] = len(nodes) diff --git a/onnxruntime/test/global_thread_pools/test_inference.cc b/onnxruntime/test/global_thread_pools/test_inference.cc index 4772e7de2bdd7..f553682975f11 100644 --- a/onnxruntime/test/global_thread_pools/test_inference.cc +++ b/onnxruntime/test/global_thread_pools/test_inference.cc @@ -55,9 +55,15 @@ static void RunSession(OrtAllocator& allocator, Ort::Session& session_object, // size_t total_len = type_info.GetElementCount(); ASSERT_EQ(values_y.size(), static_cast(5)); +// test inference is using onnxruntime_shared_lib_test_LIBS, so HasCudaEnvironment(800) isn't available +#ifdef USE_CUDA + const float tolerance = 1e-5f; +#else + const float tolerance = 1e-6f; +#endif OutT* f = output_tensor->GetTensorMutableData(); for (size_t i = 0; i != static_cast(5); ++i) { - ASSERT_NEAR(values_y[i], f[i], 1e-6f); + ASSERT_NEAR(values_y[i], f[i], tolerance); } } diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.cpp b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp new file mode 100644 index 0000000000000..941de8f05061f --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.cpp @@ -0,0 +1,141 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.cpp + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#include "test_sbgemm.h" + +// +// Short Execute() test helper to register each test seperately by all parameters. +// +template +class SBGemmShortExecuteTest : public MlasTestFixture> { + public: + explicit SBGemmShortExecuteTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) + : M_(M), N_(N), K_(K), Batch_(Batch), hasBias_(hasBias) {} + + void TestBody() override { + MlasTestFixture>::mlas_tester->Test(M_, N_, K_, Batch_, hasBias_); + } + + static size_t RegisterSingleTest(size_t M, size_t N, size_t K, size_t Batch, bool hasBias) { + std::stringstream ss; + ss << "Batch" << Batch << "/M" << M << "xN" << N << "xK" << K << "/" + << "hasBias" << hasBias; + auto test_name = ss.str(); + + testing::RegisterTest( + MlasSBGemmTest::GetTestSuiteName(), + test_name.c_str(), + nullptr, + test_name.c_str(), + __FILE__, + __LINE__, + // Important to use the fixture type as the return type here. + [=]() -> MlasTestFixture>* { + return new SBGemmShortExecuteTest( + M, N, K, Batch, hasBias); + }); + + return 1; + } + + static size_t RegisterShortExecuteTests() { + size_t test_registered = 0; + for (size_t b = 1; b < 16; b++) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 16; b <= 256; b <<= 1) { + test_registered += RegisterSingleTest(b, b, b, 1, false); + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 256; b < 320; b += 32) { + test_registered += RegisterSingleTest(b, b, b, 1, true); + } + for (size_t b = 1; b < 96; b++) { + test_registered += RegisterSingleTest(1, b, 32, 1, false); + test_registered += RegisterSingleTest(1, 32, b, 1, true); + test_registered += RegisterSingleTest(1, b, b, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(1, b, 32, 3, true); + test_registered += RegisterSingleTest(1, 32, b, 5, false); + } + } + // TODO: check why the cosine similary is < 0.99 for this shape alone + // test_registered += RegisterSingleTest(43, 500, 401, 1, true); + test_registered += RegisterSingleTest(1001, 1027, 1031, 1, false); + if (!Packed) { + test_registered += RegisterSingleTest(43, 500, 401, 5, true); + test_registered += RegisterSingleTest(1000, 1029, 1030, 3, false); + } + + return test_registered; + } + + private: + size_t M_, N_, K_, Batch_; + bool hasBias_; +}; + +static size_t SBGemmRegistLongExecute() { + size_t count = 0; + + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + + if (GetMlasThreadPool() != nullptr) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += MlasLongExecuteTests>::RegisterLongExecute(); + } + } + + return count; +} + +static size_t SBGemmRegistShortExecute() { + size_t count = 0; + + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + + if (GetMlasThreadPool() != nullptr) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + if (MlasSBGemmPackBSize(128, 128) > 0) { + count += SBGemmShortExecuteTest::RegisterShortExecuteTests(); + } + } + + return count; +} + +static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) { + if (!MlasBf16AccelerationSupported()) { + return false; + } + + if (is_short_execute) { + return SBGemmRegistShortExecute() > 0; + } + return SBGemmRegistLongExecute() > 0; +}); +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/mlas/unittest/test_sbgemm.h b/onnxruntime/test/mlas/unittest/test_sbgemm.h new file mode 100644 index 0000000000000..13701e2e3de46 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_sbgemm.h @@ -0,0 +1,281 @@ +/*++ + +Copyright (c) Microsoft Corporation. All rights reserved. +Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +Licensed under the MIT License. + +Module Name: + + test_sbgemm.h + +Abstract: + + Tests for MLAS bf16 precision GEMM. + +--*/ + +#if defined(__aarch64__) && defined(__linux__) + +#pragma once + +#include "test_util.h" + +template +void SmallFloatFill(T* start, size_t size) { + constexpr float MinimumFillValue = -11.0f; + auto FillAddress = start; + size_t offset = size % 23; + + for (size_t i = 0; i < size; i++) { + offset = (offset + 21) % 23; + *FillAddress++ = T((MinimumFillValue + offset) / 16.0f); + } +} + +float cosine_similarity(const float* A, const float* B, size_t Vector_Length) { + float dot = 0.0, denom_a = 0.0, denom_b = 0.0; + for (size_t i = 0u; i < Vector_Length; ++i) { + dot += A[i] * B[i]; + denom_a += A[i] * A[i]; + denom_b += B[i] * B[i]; + } + return dot / (sqrt(denom_a) * sqrt(denom_b)); +} + +/** + * @brief Test class for bf16 precision GEMM + * @tparam AType Data type of A matrix, need to be float + * @tparam BType Data type of b matrix, can be either float or prepacked bf16 + */ +template +class MlasSBGemmTest : public MlasTestBase { + private: + MatrixGuardBuffer BufferBPacked; + MatrixGuardBuffer BufferA; + MatrixGuardBuffer BufferB; + MatrixGuardBuffer BufferBias; + MatrixGuardBuffer BufferC; + MatrixGuardBuffer BufferCReference; + MatrixGuardBuffer BufferFloatC; + MLAS_THREADPOOL* threadpool_; + + void* PackB(size_t N, size_t K, const BType* B, size_t ldb) { + size_t PackedBSize = MlasSBGemmPackBSize(N, K); + if (PackedBSize == 0) { + return nullptr; + } + void* PackedB = BufferBPacked.GetBuffer(PackedBSize); + if (std::is_same::value) { + MlasSBGemmConvertPackB(N, K, (const float*)B, ldb, PackedB); + } else { + } + return PackedB; + } + + void CallSBGemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const float* A, + size_t lda, + const BType* B, + size_t ldb, + const float* Bias, + float* C, + size_t ldc) { + std::vector GemmParameters(BatchSize); + + for (size_t i = 0; i < GemmParameters.size(); i++) { + auto& params = GemmParameters[i]; + params.A = A + (M * lda * i); + params.lda = lda; + if (nullptr != Bias) { + params.Bias = reinterpret_cast(Bias + N * i); + } else { + params.Bias = nullptr; + } + params.C = reinterpret_cast(C + (M * ldc * i)); + params.ldc = ldc; + params.AIsfp32 = true; + params.BIsfp32 = true; + + if (Packed) { + ASSERT_EQ(BatchSize, size_t(1)) << "Packing B not supported in batching yet!"; + params.B = PackB(N, K, B, ldb); + params.ldb = 0; + params.BIsfp32 = false; + } else { + params.B = B + (K * N * i); + params.ldb = ldb; + } + } + + MlasSBGemmBatch(M, N, K, BatchSize, GemmParameters.data(), threadpool_); + } + + void ReferenceSgemm(size_t M, + size_t N, + size_t K, + size_t BatchSize, + const AType* A, + const BType* B, + const float* Bias, + float* C) { + constexpr size_t KStride = 256; + + for (size_t batch = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++) { + const AType* a = A + M * K * batch + m * K; + const BType* b = B + K * N * batch + n; + float* c = C + (M * N * batch) + (m * N) + n; + + for (size_t k = 0; k < K; k += KStride) { + float sum = 0.0f; + if (k == 0 && Bias != nullptr) { + sum = float(Bias[n]); + } + for (size_t kk = 0; kk < std::min(KStride, K - k); kk++) { + float down(float(*b) * float(*a) + sum); + sum = float(down); + b += N; + a += 1; + } + if (k == 0) { + *c = sum; + } else { + float d(sum + *c); + *c = float(d); + } + } + } + } + if (Bias) { + Bias += N; + } + } + } + + public: + MlasSBGemmTest() : threadpool_(Threaded ? GetMlasThreadPool() : nullptr) {} + + void Test(size_t M, size_t N, size_t K, size_t BatchSize, bool withBias) { + AType* A = BufferA.GetFilledBuffer(K * M * BatchSize + 16, SmallFloatFill); + AType Atail[16]; + std::memcpy(Atail, A + K * M * BatchSize, 16 * sizeof(AType)); + + BType* B = BufferB.GetFilledBuffer(N * K * BatchSize + 16, SmallFloatFill); + BType Btail[16]; + std::memcpy(Btail, B + N * K * BatchSize, 16 * sizeof(BType)); + + float BiasTail[16]; + const float* Bias = nullptr; + if (withBias) { + Bias = BufferBias.GetFilledBuffer(N * BatchSize + 16, SmallFloatFill); + std::memcpy(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)); + } + + float* C = BufferC.GetFilledBuffer(N * M * BatchSize, SmallFloatFill); + float* CReference = BufferCReference.GetFilledBuffer( + N * M * BatchSize, + [](float* start, size_t size) { + std::fill_n(start, size, -1.0f); + }); + this->CallSBGemm(M, N, K, BatchSize, A, K, B, N, Bias, C, N); + ReferenceSgemm(M, N, K, BatchSize, A, B, Bias, CReference); + const float cosine_similarity_threshold = 0.98; + + for (size_t batch = 0, f = 0; batch < BatchSize; batch++) { + for (size_t m = 0; m < M; m++) { + for (size_t n = 0; n < N; n++, f++) { + if (!(CloseEnough(float(C[f]), CReference[f]))) { + float cos_sim = cosine_similarity(C, CReference, (BatchSize * M * N)); + if (abs(cos_sim) < cosine_similarity_threshold) { + ASSERT_TRUE(false) << "cosine similarity check failed" << cos_sim; + } else { + break; + } + } + } + } + } + + ASSERT_EQ(std::memcmp(Atail, A + K * M * BatchSize, 16 * sizeof(AType)), 0) << "Matrix A buffer overwritten!"; + ASSERT_EQ(std::memcmp(Btail, B + N * K * BatchSize, 16 * sizeof(BType)), 0) << "Matrix B buffer overwritten!"; + if (withBias) { + ASSERT_EQ(std::memcmp(BiasTail, Bias + N * BatchSize, 16 * sizeof(float)), 0) << "Bias buffer overwritten!"; + } + } + + private: + public: + static const char* GetTestSuiteName() { + static std::string suite_name = std::string("SBGemmFP") + + (std::is_same::value ? "32" : "16") + + (std::is_same::value ? "32" : "16") + + (Packed ? "_Packed" : "_NoPack") + + (Threaded ? "_Threaded" : "_SingleThread"); + return suite_name.c_str(); + } + + void ExecuteLong(void) override { + for (size_t M = 16; M < 160; M += 32) { + for (size_t N = 16; N < 160; N += 32) { + static const size_t ks[] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 16, 20, 32, 48, 64, 118, 119, 120, 121, 122, 160, 240, 320}; + for (size_t k = 0; k < _countof(ks); k++) { + size_t K = ks[k]; + + Test(M, N, K, 1, false); + Test(M, N, K, 1, true); + Test(M + 1, N, K, 1, false); + Test(M, N + 1, K, 1, true); + Test(M + 1, N + 1, K, 1, false); + Test(M + 3, N + 2, K, 1, true); + Test(M + 4, N, K, 1, false); + Test(M, N + 4, K, 1, true); + Test(M + 4, N + 4, K, 1, false); + Test(M + 3, N + 7, K, 1, true); + Test(M + 8, N, K, 1, false); + Test(M, N + 8, K, 1, true); + Test(M + 12, N + 12, K, 1, false); + Test(M + 13, N, K, 1, true); + Test(M, N + 15, K, 1, false); + Test(M + 15, N + 15, K, 1, false); + if (!Packed) { + Test(M, N, K, 7, false); + Test(M + 3, N, K, 8, true); + Test(M, N + 1, K, 9, false); + Test(M + 12, N, K, 10, true); + Test(M, N + 15, K, 11, false); + Test(M + 15, N + 15, K, 12, true); + } + } + } + printf("M %zd\n", M); + } + + for (size_t M = 1; M < 160; M++) { + for (size_t N = 1; N < 160; N++) { + for (size_t K = 1; K < 160; K++) { + Test(M, N, K, 1, true); + } + } + printf("M %zd\n", M); + } + + for (size_t M = 160; M < 320; M += 24) { + for (size_t N = 112; N < 320; N += 24) { + for (size_t K = 1; K < 16; K++) { + Test(M, N, K, 1, true); + } + for (size_t K = 16; K < 160; K += 32) { + Test(M, N, K, 1, false); + } + } + printf("M %zd\n", M); + } + } +}; + +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index 7e0a811b7d07c..aca609cf94270 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -60,6 +60,10 @@ void usage() { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |' \n\n" "\t [Example] [For QNN EP] -e qnn -i \"profiling_level|detailed backend_path|/folderpath/libQnnCpu.so\" \n\n" "\t [SNPE only] [runtime]: SNPE runtime, options: 'CPU', 'GPU', 'GPU_FLOAT16', 'DSP', 'AIP_FIXED_TF'. \n" @@ -483,7 +487,7 @@ int real_main(int argc, char* argv[], Ort::Env& env) { if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -512,10 +516,20 @@ int real_main(int argc, char* argv[], Ort::Env& env) { std::string str = str_stream.str(); ORT_THROW("Wrong value for htp_graph_finalization_optimization_mode. select from: " + str); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', +'soc_model', 'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc new file mode 100644 index 0000000000000..ec9f78da14a75 --- /dev/null +++ b/onnxruntime/test/optimizer/qdq_transformer_fastmath_test.cc @@ -0,0 +1,730 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/framework/compute_capability.h" +#include "core/graph/model.h" +#include "core/graph/onnx_protobuf.h" +#include "core/mlas/inc/mlas.h" +#include "core/optimizer/qdq_transformer/qdq_final_cleanup.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selectors.h" +#include "core/optimizer/qdq_transformer/selectors_actions/qdq_selector_action_transformer.h" +#include "core/optimizer/qdq_transformer/selectors_actions/shared/utils.h" +#include "core/optimizer/utils.h" +#include "core/providers/partitioning_utils.h" +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "core/session/environment.h" +#include "core/session/inference_session.h" + +#include "test/compare_ortvalue.h" +#include "test/test_environment.h" +#include "test/framework/test_utils.h" +#include "test/util/include/asserts.h" +#include "test/util/include/inference_session_wrapper.h" + +#include "gtest/gtest.h" +#include "graph_transform_test_builder.h" + +#include "qdq_test_utils.h" + +#if defined(__aarch64__) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) + +struct QDQOpKeys { + const char* quantize_linear; + const char* dequantize_linear; +}; + +constexpr QDQOpKeys GetQDQOpKeys(bool use_contrib_qdq) { + if (use_contrib_qdq) { + return {"com.microsoft.QuantizeLinear", "com.microsoft.DequantizeLinear"}; + } + return {"QuantizeLinear", "DequantizeLinear"}; +} + +namespace onnxruntime { +namespace test { + +#if !defined(DISABLE_CONTRIB_OPS) + +TEST(QDQTransformerTests, DQ_S8_to_U8_FastMath) { + auto test_case = [](bool use_contrib_qdq) { + const std::vector& input_shape = {19, 37}; + const std::vector& weights_shape = {37, 23}; + + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input_shape, -1.f, 1.f); + + // Use full range weight values to expose u8s8 overflow problems + auto* weight = builder.MakeInitializer(weights_shape, -128, 127); + auto* output_arg = builder.MakeOutput(); + + // add QDQ activation + typedef std::numeric_limits Input1Limits; + auto* dq1_output = AddQDQNodePair(builder, input1_arg, .039f, + (int8_t)((Input1Limits::max() + Input1Limits::min()) / 2 + 1), + use_contrib_qdq); + + // add DQ weight + auto* dq_w_output = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(weight, .003f, -10, dq_w_output, use_contrib_qdq); + + builder.AddNode("MatMul", {dq1_output, dq_w_output}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 1); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, /*using NAN as a magic number to trigger cosine similarity*/ + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case(false); // Use ONNX QDQ ops + test_case(true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerMatMulTests(bool has_output_q, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + // add QDQ 1 + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + // add QDQ 2 + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + + if (has_output_q) { + // add binary operator + auto* matmul_op_output = builder.MakeIntermediate(); + builder.AddNode("MatMul", {dq1_output, dq2_output}, {matmul_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(matmul_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + builder.AddNode("MatMul", {dq1_output, dq2_output}, {output_arg}); + } + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if (has_output_q) { + if constexpr (std::is_same::value && + (std::is_same::value || + QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["QLinearMatMul"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + } else { + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 3); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 3); + } + } else { + if constexpr (std::is_same::value || + (QDQIsInt8Allowed() && std::is_same::value)) { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count["MatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + } else { + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 0); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + } + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({1, 2, 2}, {1, 2, 4}); + test_case({1, 23, 13, 13}, {13, 13}); + test_case({1, 22, 11, 13, 15}, {1, 22, 11, 15, 15}); + test_case({1, 2, 2}, {1, 2, 4}, true); // Use com.microsoft QDQ ops +} + +TEST(QDQTransformerTests, MatMul_U8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_U8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8U8S8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +TEST(QDQTransformerTests, MatMul_S8S8U8_FastMath) { + QDQTransformerMatMulTests(false); + QDQTransformerMatMulTests(true); +} + +// dummy test to disable the fastmath session op +TEST(QDQTransformerTests, MatMul_S8S8U8_DisableFastMath) { + QDQTransformerMatMulTests(false, true); + QDQTransformerMatMulTests(true, true); +} + +template +void QDQTransformerGemmTests(bool has_output_q, bool has_bias, bool beta_not_one = false, bool disable_fastmath = false) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq = false) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + typedef std::numeric_limits Input1Limits; + typedef std::numeric_limits Input2Limits; + typedef std::numeric_limits OutputTypeLimits; + + std::vector input_args; + + // add QDQ A + auto* q1_output = builder.MakeIntermediate(); + auto* dq1_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input1_arg, + .039f, + (Input1Limits::max() + Input1Limits::min()) / 2 + 1, + q1_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q1_output, + .039f, + (Input2Limits::max() + Input1Limits::min()) / 2 + 1, + dq1_output, use_contrib_qdq); + + input_args.push_back(dq1_output); + + // add QDQ B + auto* q2_output = builder.MakeIntermediate(); + auto* dq2_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(input2_arg, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + q2_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q2_output, + .04f, + (Input2Limits::max() + Input2Limits::min()) / 2 + 1, + dq2_output, use_contrib_qdq); + input_args.push_back(dq2_output); + + if (has_bias) { + auto* dq_bias_output = builder.MakeIntermediate(); + auto* bias = builder.MakeInitializer({input2_shape[1]}, static_cast(0), static_cast(127)); + builder.AddDequantizeLinearNode(bias, 0.00156f, + 0, + dq_bias_output, use_contrib_qdq); + input_args.push_back(dq_bias_output); + } + + Node* gemm_node = nullptr; + + if (has_output_q) { + auto* gemm_op_output = builder.MakeIntermediate(); + gemm_node = &builder.AddNode("Gemm", input_args, {gemm_op_output}); + + // add QDQ output + auto* q3_output = builder.MakeIntermediate(); + builder.AddQuantizeLinearNode(gemm_op_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + q3_output, use_contrib_qdq); + builder.AddDequantizeLinearNode(q3_output, + .039f, + (OutputTypeLimits::max() + OutputTypeLimits::min()) / 2 + 1, + output_arg, use_contrib_qdq); + } else { + gemm_node = &builder.AddNode("Gemm", input_args, {output_arg}); + } + + if (beta_not_one) { + gemm_node->AddAttribute("beta", 2.0f); + } + }; + + auto check_binary_op_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + if ((!has_output_q || std::is_same_v)&&(!has_bias || (std::is_same_v && !beta_not_one)) && + (std::is_same_v || std::is_same_v)) { + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 1); + EXPECT_EQ(op_to_count["Gemm"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], has_output_q ? 1 : 0); + } else { + int q_count = 2; // Q for A and B + int dq_count = 2; // DQ for A and B + if (has_bias) { + dq_count++; + } + if (has_output_q) { + q_count++; + dq_count++; + } + EXPECT_EQ(op_to_count["com.microsoft.QGemm"], 0); + EXPECT_EQ(op_to_count["Gemm"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], q_count); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], dq_count); + } + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 18 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + + if (disable_fastmath) { + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_binary_op_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + std::make_unique(QDQIsInt8Allowed()), + add_session_options); + } + }; + + test_case({2, 2}, {2, 4}); + test_case({13, 15}, {15, 15}); + test_case({2, 2}, {2, 4}, true); // Use com.microsoft QDQ ops +} + +template +void QDQTransformerGemmTests() { + QDQTransformerGemmTests(false, false); + QDQTransformerGemmTests(false, true); + QDQTransformerGemmTests(true, false); + QDQTransformerGemmTests(true, true); + QDQTransformerGemmTests(false, false, true); + QDQTransformerGemmTests(false, true, true); + QDQTransformerGemmTests(true, false, true); + QDQTransformerGemmTests(true, true, true); + // dummy test to disable the fastmath session + QDQTransformerGemmTests(true, true, true, true); +} + +TEST(QDQTransformerTests, Gemm_U8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_U8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8U8S8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, Gemm_S8S8U8_FastMath) { + QDQTransformerGemmTests(); + QDQTransformerGemmTests(); +} + +TEST(QDQTransformerTests, MatMul_No_Fusion_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -1.f, 1.f); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output1 = AddQDQNodePair(builder, input1_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_matmul_output1, input2_arg}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 1); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMul_1st_Input_Int8_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, -128, 127); + auto* input2_arg = builder.MakeInput(input2_shape, -1.f, 1.f); + auto* output_arg = builder.MakeOutput(); + + // add DQ with type int8 + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .004f, 1, dq_output_1, use_contrib_qdq); + + // add QDQ + MatMul + auto* matmul_output = builder.MakeIntermediate(); + auto* dq_matmul_output2 = AddQDQNodePair(builder, input2_arg, .004f, 129, use_contrib_qdq); + builder.AddNode("MatMul", {dq_output_1, dq_matmul_output2}, {matmul_output}); + + // add Q + builder.AddQuantizeLinearNode(matmul_output, .0039f, 135, output_arg, use_contrib_qdq); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["MatMul"], 1); + EXPECT_EQ(op_to_count["QLinearMatMul"], 0); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 2); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 2); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, check_graph, TransformerLevel::Level1, TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +TEST(QDQTransformerTests, MatMulIntegerToFloat_FastMath) { + auto test_case = [&](const std::vector& input1_shape, const std::vector& input2_shape, + bool use_contrib_qdq) { + auto build_test_case = [&](ModelTestBuilder& builder) { + auto* input1_arg = builder.MakeInput(input1_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* input2_arg = builder.MakeInput(input2_shape, + std::numeric_limits::min(), + std::numeric_limits::max()); + auto* output_arg = builder.MakeOutput(); + + // add DQ + auto* dq_output_1 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input1_arg, .0035f, 135, dq_output_1, use_contrib_qdq); + + auto* dq_output_2 = builder.MakeIntermediate(); + builder.AddDequantizeLinearNode(input2_arg, .0035f, 135, dq_output_2, use_contrib_qdq); + + builder.AddNode("MatMul", {dq_output_1, dq_output_2}, {output_arg}); + }; + + auto check_graph = [&](InferenceSessionWrapper& session) { + auto op_to_count = CountOpsInGraph(session.GetGraph()); + const QDQOpKeys qdq_keys = GetQDQOpKeys(use_contrib_qdq); + EXPECT_EQ(op_to_count["com.microsoft.MatMulIntegerToFloat"], 1); + EXPECT_EQ(op_to_count[qdq_keys.quantize_linear], 0); + EXPECT_EQ(op_to_count[qdq_keys.dequantize_linear], 0); + }; + + auto add_session_options = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 19 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options); + + auto add_session_options_disable_fm = [&](SessionOptions& so) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + }; + + TransformerTester(build_test_case, + check_graph, + TransformerLevel::Level1, + TransformerLevel::Level2, + 12 /*opset_version*/, + NAN /*per_sample_tolerance*/, + NAN /*relative_per_sample_tolerance*/, + nullptr, + add_session_options_disable_fm); + }; + + test_case({12, 37}, {37, 12}, false /*use_contrib_qdq*/); + test_case({12, 37}, {37, 12}, true /*use_contrib_qdq*/); + test_case({23, 13, 13}, {13, 13}, false /*use_contrib_qdq*/); + test_case({22, 11, 13, 15}, {15, 13}, false /*use_contrib_qdq*/); +} + +#endif // !defined(DISABLE_CONTRIB_OPS) && defined(__aarch64) + +} // namespace test +} // namespace onnxruntime + +#endif // defined(__aarch64) && defined(__linux__) && !defined(DISABLE_CONTRIB_OPS) diff --git a/onnxruntime/test/perftest/command_args_parser.cc b/onnxruntime/test/perftest/command_args_parser.cc index ef04e2be8fd29..6c1d447c7b3a3 100644 --- a/onnxruntime/test/perftest/command_args_parser.cc +++ b/onnxruntime/test/perftest/command_args_parser.cc @@ -78,6 +78,10 @@ namespace perftest { "\t [QNN only] [qnn_saver_path]: QNN Saver backend path. e.g '/folderpath/libQnnSaver.so'.\n" "\t [QNN only] [htp_graph_finalization_optimization_mode]: QNN graph finalization optimization mode, options: \n" "\t '0', '1', '2', '3', default is '0'.\n" + "\t [QNN only] [soc_model]: The SoC Model number. Refer to QNN SDK documentation for specific values. Defaults to '0' (unknown). \n" + "\t [QNN only] [htp_arch]: The minimum HTP architecture. The driver will use ops compatible with this architecture. \n" + "\t Options are '0', '68', '69', '73', '75'. Defaults to '0' (none). \n" + "\t [QNN only] [device_id]: The ID of the device to use when setting 'htp_arch'. Defaults to '0' (for single device). \n" "\t [Usage]: -e -i '| |'\n\n" "\t [Example] [For OpenVINO EP] -e openvino -i \"device_type|CPU_FP32 enable_npu_fast_compile|true num_of_threads|5 enable_opencl_throttling|true cache_dir|\"\"\"\n" "\t [Example] [For QNN EP] -e qnn -i \"backend_path|/folderpath/libQnnCpu.so\" \n\n" diff --git a/onnxruntime/test/perftest/ort_test_session.cc b/onnxruntime/test/perftest/ort_test_session.cc index f8a012af5bb13..6854a2649060a 100644 --- a/onnxruntime/test/perftest/ort_test_session.cc +++ b/onnxruntime/test/perftest/ort_test_session.cc @@ -343,7 +343,7 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_profiling_level.find(value) == supported_profiling_level.end()) { ORT_THROW("Supported profiling_level: off, basic, detailed"); } - } else if (key == "rpc_control_latency" || key == "vtcm_mb") { + } else if (key == "rpc_control_latency" || key == "vtcm_mb" || key == "soc_model" || key == "device_id") { // no validation } else if (key == "htp_performance_mode") { std::set supported_htp_perf_mode = {"burst", "balanced", "default", "high_performance", @@ -372,10 +372,20 @@ OnnxRuntimeTestSession::OnnxRuntimeTestSession(Ort::Env& env, std::random_device if (supported_qnn_context_priority.find(value) == supported_qnn_context_priority.end()) { ORT_THROW("Supported qnn_context_priority: low, normal, normal_high, high"); } + } else if (key == "htp_arch") { + std::unordered_set supported_htp_archs = {"0", "68", "69", "73", "75"}; + if (supported_htp_archs.find(value) == supported_htp_archs.end()) { + std::ostringstream str_stream; + std::copy(supported_htp_archs.begin(), supported_htp_archs.end(), + std::ostream_iterator(str_stream, ",")); + std::string str = str_stream.str(); + ORT_THROW("Wrong value for htp_arch. select from: " + str); + } } else { ORT_THROW(R"(Wrong key type entered. Choose from options: ['backend_path', 'profiling_level', 'rpc_control_latency', 'vtcm_mb', 'htp_performance_mode', -'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority'])"); +'qnn_saver_path', 'htp_graph_finalization_optimization_mode', 'qnn_context_priority', 'soc_model', +'htp_arch', 'device_id'])"); } qnn_options[key] = value; diff --git a/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc new file mode 100644 index 0000000000000..75e0c06b04f0d --- /dev/null +++ b/onnxruntime/test/providers/cpu/math/matmul_fastmath_test.cc @@ -0,0 +1,305 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +// Licensed under the MIT License. + +#include "core/session/onnxruntime_session_options_config_keys.h" +#include "gtest/gtest.h" +#include "test/providers/provider_test_utils.h" +#include "test/providers/run_options_config_keys.h" +#include "test/common/dnnl_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/common/tensor_op_test_utils.h" +#include "default_providers.h" + +#if defined(__aarch64__) && defined(__linux__) + +namespace onnxruntime { +namespace test { + +namespace { + +const onnxruntime::RunOptions run_options = []() { + onnxruntime::RunOptions options{}; + ORT_THROW_IF_ERROR(options.config_options.AddConfigEntry(kOpTesterRunOptionsConfigTestTunableOp, "true")); + return options; +}(); + +const constexpr auto run_with_tunable_op = &run_options; + +} // namespace + +template +struct MatMulTestData { + std::string name; + std::vector input0_dims; + std::vector input1_dims; + std::vector expected_dims; + std::vector expected_vals; +}; + +template +std::vector> GenerateTestCases() { + std::vector> test_cases; + test_cases.push_back( + {"test padding and broadcast A > B", + {3, 1, 1, 6}, + {2, 6, 7}, + {3, 2, 1, 7}, + {385, 400, 415, 430, 445, 460, 475, 1015, 1030, 1045, 1060, 1075, 1090, 1105, 1015, 1066, 1117, 1168, 1219, 1270, 1321, 3157, 3208, 3259, 3310, 3361, 3412, 3463, 1645, 1732, 1819, 1906, 1993, 2080, 2167, 5299, 5386, 5473, 5560, 5647, 5734, 5821}}); + + test_cases.push_back( + {"test padding and broadcast B > A", + {2, 3, 12}, + {3, 2, 12, 3}, + {3, 2, 3, 3}, + {1518, 1584, 1650, 3894, 4104, 4314, 6270, 6624, 6978, 26574, 27072, 27570, 34134, 34776, 35418, 41694, 42480, 43266, 6270, 6336, 6402, 19014, 19224, 19434, 31758, 32112, 32466, 62430, 62928, 63426, 80358, 81000, 81642, 98286, 99072, 99858, 11022, 11088, 11154, 34134, 34344, 34554, 57246, 57600, 57954, 98286, 98784, 99282, 126582, 127224, 127866, 154878, 155664, 156450}}); + + test_cases.push_back( + {"test 2D", + {8, 6}, + {6, 6}, + {8, 6}, + {330, 345, 360, 375, 390, 405, 870, 921, 972, 1023, 1074, 1125, 1410, 1497, 1584, 1671, 1758, 1845, 1950, 2073, 2196, 2319, 2442, 2565, 2490, 2649, 2808, 2967, 3126, 3285, 3030, 3225, 3420, 3615, 3810, 4005, 3570, 3801, 4032, 4263, 4494, 4725, 4110, 4377, 4644, 4911, 5178, 5445}}); + + test_cases.push_back( + {"test 2D special", + {2, 2, 16}, + {16, 4}, + {2, 2, 4}, + {4960, 5080, 5200, 5320, 12640, 13016, 13392, 13768, 20320, 20952, 21584, 22216, 28000, 28888, 29776, 30664}}); + + test_cases.push_back( + {"test 2D special 2", + {2, 2, 9}, + {1, 9, 4}, + {2, 2, 4}, + {816, 852, 888, 924, 2112, 2229, 2346, 2463, 3408, 3606, 3804, 4002, 4704, 4983, 5262, 5541}}); + + test_cases.push_back( + {"test 2D special 3", + {2, 12}, + {1, 1, 12, 3}, + {1, 1, 2, 3}, + {1518, 1584, 1650, 3894, 4104, 4314}}); + + test_cases.push_back( + {"test 3D batch", + {3, 1, 18}, + {3, 18, 2}, + {3, 1, 2}, + { + // clang-format off + 3570, 3723, + 26250, 26727, + 72258, 73059, + // clang-format on + }}); + + test_cases.push_back( + {"test 4D batch", + {2, 2, 1, 20}, + {2, 2, 20, 2}, + {2, 2, 1, 2}, + { + // clang-format off + 4940, 5130, + 36140, 36730, + 99340, 100330, + 194540, 195930, + // clang-format on + }}); + + return test_cases; +} + +template +void RunMatMulTest(int32_t opset_version, bool is_a_constant, bool is_b_constant, bool disable_fastmath) { + for (auto t : GenerateTestCases()) { + SCOPED_TRACE("test case: " + t.name); + + OpTester test("MatMul", opset_version); + + int64_t size0 = TensorShape::FromExistingBuffer(t.input0_dims).SizeHelper(0, t.input0_dims.size()); + std::vector input0_vals = ValueRange(size0); + + test.AddInput("A", t.input0_dims, input0_vals, is_a_constant); + + int64_t size1 = TensorShape::FromExistingBuffer(t.input1_dims).SizeHelper(0, t.input1_dims.size()); + std::vector input1_vals = ValueRange(size1); + test.AddInput("B", t.input1_dims, input1_vals, is_b_constant); + + test.AddOutput("Y", t.expected_dims, t.expected_vals); + + // OpenVINO EP: Disabled temporarily matmul broadcasting not fully supported + // Disable TensorRT because of unsupported data type + std::unordered_set excluded_providers{kTensorrtExecutionProvider, kOpenVINOExecutionProvider}; + if (t.name == "test 2D empty input") { + // NNAPI: currently fails for the "test 2D empty input" case + excluded_providers.insert(kNnapiExecutionProvider); + } + + if ("test padding and broadcast A > B" == t.name || "test 2D empty input" == t.name) { + // QNN can't handle 0 shap + excluded_providers.insert(kQnnExecutionProvider); + } + + SessionOptions so; + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + + if (disable_fastmath) { + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "0")); + + test.ConfigExcludeEps(excluded_providers) + .Config(run_with_tunable_op) + .Config(so) + .RunWithConfig(); + } + } +} + +template +void RunMatMulTest(int32_t opset_version) { + RunMatMulTest(opset_version, false, false, false); +} + +TEST(MathOpTest, MatMulFloatType_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, false, false); +} + +TEST(MathOpTest, MatMulFloatTypeInitializer_FastMath) { + // TODO: Unskip when fixed #41968513 + if (DefaultDmlExecutionProvider().get() != nullptr) { + GTEST_SKIP() << "Skipping because of the following error: Assertion failed: m_bufferTensorDesc.TotalTensorSizeInBytes >= ComputeByteSizeFromDimensions(nonBroadcastDimensions, dataType)"; + } + RunMatMulTest(7, false, true, false); +} + +TEST(MathOpTest, MatMulInt32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint32Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulInt64Type_FastMath) { + RunMatMulTest(9); +} + +TEST(MathOpTest, MatMulUint64Type_FastMath) { + RunMatMulTest(9); +} + +#ifndef ENABLE_TRAINING +// Prepacking is disabled in full training build so no need to test the feature in a training build. +TEST(MathOpTest, MatMulSharedPrepackedWeights_FastMath) { + OpTester test("MatMul"); + + std::vector b_init_values(32, 1.0f); + test.AddInput("A", {8, 4}, + {1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f, + 1.0f, 2.0f, 3.0f, 4.0f, + -1.0f, -2.0f, -3.0f, -4.0f}); + // B is to be an initializer for triggering pre-packing + test.AddInput("B", {4, 8}, b_init_values, true); + + test.AddOutput("Y", {8, 8}, + {10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, + 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, 10.0f, + -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f, -10.0f}); + + OrtValue b; + Tensor::InitOrtValue(DataTypeImpl::GetType(), TensorShape({4, 8}), + b_init_values.data(), OrtMemoryInfo(CPU, OrtAllocatorType::OrtDeviceAllocator), b); + + SessionOptions so; + // Set up B as a shared initializer to be shared between sessions + ASSERT_EQ(so.AddInitializer("B", &b), Status::OK()); + ASSERT_STATUS_OK(so.config_options.AddConfigEntry( + kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16, "1")); + + // We want all sessions running using this OpTester to be able to share pre-packed weights if applicable + test.EnableSharingOfPrePackedWeightsAcrossSessions(); + + // Pre-packing is limited just to the CPU EP for now and we will only test the CPU EP + // and we want to ensure that it is available in this build + auto cpu_ep = []() -> std::vector> { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + return execution_providers; + }; + + size_t number_of_pre_packed_weights_counter_session_1 = 0; + size_t number_of_shared_pre_packed_weights_counter = 0; + + // Session 1 + { + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_1, &number_of_shared_pre_packed_weights_counter); + // Assert that no pre-packed weights have been shared thus far + ASSERT_EQ(number_of_shared_pre_packed_weights_counter, static_cast(0)); + } + + auto number_of_elements_in_shared_prepacked_buffers_container = + test.GetNumPrePackedWeightsShared(); + // Assert that the number of elements in the shared container + // is the same as the number of weights that have been pre-packed + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_elements_in_shared_prepacked_buffers_container); + + // On some platforms/architectures MLAS may choose to not do any pre-packing and the number of elements + // that have been pre-packed will be zero in which case we do not continue with the testing + // of "sharing" of pre-packed weights as there are no pre-packed weights to be shared at all. + if (number_of_pre_packed_weights_counter_session_1 == 0) + return; + + // Session 2 + { + size_t number_of_pre_packed_weights_counter_session_2 = 0; + test.Config(so) + .Config(run_with_tunable_op) + .ConfigEps(cpu_ep()) + .RunWithConfig(&number_of_pre_packed_weights_counter_session_2, &number_of_shared_pre_packed_weights_counter); + + // Assert that the same number of weights were pre-packed in both sessions + ASSERT_EQ(number_of_pre_packed_weights_counter_session_1, number_of_pre_packed_weights_counter_session_2); + + // Assert that the number of pre-packed weights that were shared equals + // the number of pre-packed weights in the second session + ASSERT_EQ(number_of_pre_packed_weights_counter_session_2, + static_cast(number_of_shared_pre_packed_weights_counter)); + } +} + +#endif + +// Dummy run to disable the FastMath mode for the current session +TEST(MathOpTest, MatMulUint64Type_DisableFastMath) { + RunMatMulTest(9, false, false, true); +} + +} // namespace test +} // namespace onnxruntime +#endif // defined(__aarch64__) && defined(__linux__) diff --git a/onnxruntime/test/providers/cpu/model_tests.cc b/onnxruntime/test/providers/cpu/model_tests.cc index 859e082716760..8128c170c5211 100644 --- a/onnxruntime/test/providers/cpu/model_tests.cc +++ b/onnxruntime/test/providers/cpu/model_tests.cc @@ -39,6 +39,8 @@ #include "core/providers/armnn/armnn_provider_factory.h" #endif +#include "test/common/cuda_op_test_utils.h" + // test infrastructure #include "test/onnx/testenv.h" #include "test/onnx/TestCase.h" @@ -94,6 +96,21 @@ TEST_P(ModelTest, Run) { std::unique_ptr model_info = std::make_unique(model_path.c_str()); +#if defined(__linux__) + // ORT enables TF32 in GEMM for A100. TF32 will cause precsion loss and fail this test. + if (HasCudaEnvironment(800) && provider_name == "cuda") { + per_sample_tolerance = 1e-1; + if (model_path.find(ORT_TSTR("SSD")) > 0 || + model_path.find(ORT_TSTR("ssd")) > 0 || + model_path.find(ORT_TSTR("yolov3")) > 0 || + model_path.find(ORT_TSTR("mask_rcnn")) > 0 || + model_path.find(ORT_TSTR("FNS")) > 0) { + SkipTest("Skipping SSD test for big tolearance failure or other errors"); + return; + } + } +#endif + if (model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_TRAINING_DOMAIN) || model_info->HasDomain(ONNX_NAMESPACE::AI_ONNX_PREVIEW_TRAINING_DOMAIN)) { SkipTest("it has the training domain. No pipeline should need to run these tests."); diff --git a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc index 026bb07edf44c..0c8d6c46d4639 100644 --- a/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/quantize_linear_test.cc @@ -34,11 +34,6 @@ TEST(DequantizeLinearOpTest, Int8) { // scalar zero & scale with int8 TEST(DequantizeLinearOpTest, Int32) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); std::vector dims{4}; test.AddInput("x", dims, {-30, -3, 100, 127}); @@ -98,11 +93,6 @@ TEST(DequantizeLinearOpMLFloat16Test, Scalar) { // dequantize without zero point TEST(DequantizeLinearOpTest, Without_Zero_Point) { - // TODO: Unskip when fixed #41968513 - if (DefaultDmlExecutionProvider().get() != nullptr) { - GTEST_SKIP() << "Skipping because of the following error: AbiCustomRegistry.cpp(507): The parameter is incorrect"; - } - OpTester test("DequantizeLinear", 10); test.AddInput("x", {}, {100}); test.AddInput("x_scale", {}, {2.0f}); diff --git a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc index 06da2a5304716..6514feadf0ff7 100644 --- a/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc +++ b/onnxruntime/test/providers/cuda/nhwc/conv_transpose_test.cc @@ -70,7 +70,11 @@ TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcBias) { auto op = ConvTransposeOp{.input_dims = {1, 8, 80, 80}, .kernel_shape = {5, 5}, .channels = 16, .bias = true}; - MAKE_PROVIDERS_EPS_TYPE(TypeParam) + if (HasCudaEnvironment(800)) { + MAKE_PROVIDERS_EPS(1e-2) + } else { + MAKE_PROVIDERS_EPS_TYPE(TypeParam) + } } TYPED_TEST(CudaNhwcTypedTest, ConvTransposeNhwcPad) { diff --git a/onnxruntime/test/providers/qnn/qnn_basic_test.cc b/onnxruntime/test/providers/qnn/qnn_basic_test.cc index bc40682cf87b7..c50b1002fa8c8 100644 --- a/onnxruntime/test/providers/qnn/qnn_basic_test.cc +++ b/onnxruntime/test/providers/qnn/qnn_basic_test.cc @@ -176,7 +176,10 @@ TEST(QnnEP, TestDisableCPUFallback_ConflictingConfig) { // types and shapes. static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bool enable_qnn_saver = false, std::string htp_graph_finalization_opt_mode = "", - std::string qnn_context_priority = "") { + std::string qnn_context_priority = "", + std::string soc_model = "", + std::string htp_arch = "", + std::string device_id = "") { Ort::SessionOptions so; // Ensure all type/shape inference warnings result in errors! @@ -205,6 +208,18 @@ static void RunNHWCResizeModel(const ORTCHAR_T* ort_model_path, bool use_htp, bo options["qnn_context_priority"] = std::move(qnn_context_priority); } + if (!soc_model.empty()) { + options["soc_model"] = std::move(soc_model); + } + + if (!htp_arch.empty()) { + options["htp_arch"] = std::move(htp_arch); + } + + if (!device_id.empty()) { + options["device_id"] = std::move(device_id); + } + so.AppendExecutionProvider("QNN", options); Ort::Session session(*ort_env, ort_model_path, so); @@ -519,6 +534,45 @@ TEST_F(QnnHTPBackendTests, HTPGraphFinalizationOptimizationModes) { } } +// Test that models run with various SoC model values +TEST_F(QnnHTPBackendTests, HTPSocModels) { + constexpr std::array soc_models = { "", // No explicit SoC model specified + "0", // "Unknown" +#if defined(_M_ARM64) + "37" }; // SC8280X +#elif defined(__linux__) + "30" }; // SM8350 +#else + "" }; +#endif + + for (auto soc_model : soc_models) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + soc_model); + } +} + +// Test that models run with various HTP architecture values (and set device_id) +TEST_F(QnnHTPBackendTests, HTPArchValues) { + constexpr std::array htp_archs = {"", // No explicit arch specified + "0", // "None" + "68"}; // v68 + for (auto htp_arch : htp_archs) { + RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", + true, // use_htp + false, // enable_qnn_saver + "", // htp_graph_finalization_opt_mode + "", // qnn_context_priority + "", // soc_model + htp_arch, // htp_arch + "0"); // device_id + } +} + // Test that models run with high QNN context priority. TEST_F(QnnHTPBackendTests, QnnContextPriorityHigh) { RunNHWCResizeModel(ORT_MODEL_FOLDER "nhwc_resize_sizes_opset18.quant.onnx", diff --git a/onnxruntime/test/python/transformers/rotary_flash.py b/onnxruntime/test/python/transformers/rotary_flash.py new file mode 100644 index 0000000000000..42bff9c92b41b --- /dev/null +++ b/onnxruntime/test/python/transformers/rotary_flash.py @@ -0,0 +1,693 @@ +# Copyright (c) 2023, Tri Dao. + + +from typing import Optional, Tuple, Union + +import torch +import triton +import triton.language as tl +from einops import rearrange, repeat + +##### TRITON KERNEL FOR ROTARY ##### + + +# @triton.autotune( +# configs=[ +# triton.Config({"block_m": 2}), +# triton.Config({"block_m": 4}), +# triton.Config({"block_m": 8}), +# triton.Config({"block_m": 16}), +# ], +# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"], +# ) +@triton.jit +def rotary_kernel( + out_, # Pointers to matrices + x_, + cos_, + sin_, + CU_SEQLENS, + SEQLEN_OFFSETS, # this could be int or a pointer + # Matrix dimensions + seqlen, + nheads, + rotary_dim, + seqlen_ro, + CACHE_KEY_SEQLEN, + # strides + stride_out_batch, + stride_out_seqlen, + stride_out_nheads, + stride_out_headdim, + stride_x_batch, + stride_x_seqlen, + stride_x_nheads, + stride_x_headdim, + # Meta-parameters + block_k: tl.constexpr, + IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr, + IS_VARLEN: tl.constexpr, + INTERLEAVED: tl.constexpr, + CONJUGATE: tl.constexpr, + block_m: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + pid_batch = tl.program_id(axis=1) + pid_head = tl.program_id(axis=2) + rotary_dim_half = rotary_dim // 2 + + if not IS_VARLEN: + x_ = x_ + pid_batch * stride_x_batch + pid_head * stride_x_nheads + out_ = out_ + pid_batch * stride_out_batch + pid_head * stride_out_nheads + else: + start_idx = tl.load(CU_SEQLENS + pid_batch) + seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx + x_ = x_ + start_idx * stride_x_seqlen + pid_head * stride_x_nheads + out_ = out_ + start_idx * stride_out_seqlen + pid_head * stride_out_nheads + + if pid_m * block_m >= seqlen: + return + rm = pid_m * block_m + tl.arange(0, block_m) + if not IS_SEQLEN_OFFSETS_TENSOR: + rm_cs = rm + SEQLEN_OFFSETS + else: + rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch) + rk = tl.arange(0, block_k) + rk_half = tl.arange(0, block_k // 2) + + if not INTERLEAVED: + # Load the 1st and 2nd halves of x_, do calculation, then store to 1st and 2nd halves of out_ + x_ = x_ + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :]) + cos = tl.load(cos_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to( + tl.float32 + ) + sin = tl.load(sin_, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to( + tl.float32 + ) + x0 = tl.load(x_, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32) + x1 = tl.load( + x_ + rotary_dim_half * stride_x_headdim, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + if CONJUGATE: + sin = -sin + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + # write back result + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim) + tl.store(out_, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half)) + tl.store( + out_ + rotary_dim_half * stride_out_headdim, + o1, + mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), + ) + else: + # We don't want to load x_[0, 2, 4, ...] and x_[1, 3, 5, ...] separately since both are slow. + # Instead, we load x0 = x_[0, 1, 2, 3, ...] and x1 = x_[1, 0, 3, 2, ...]. + # Loading x0 will be fast but x1 will be slow. + # Then we load cos = cos_[0, 0, 1, 1, ...] and sin = sin_[0, 0, 1, 1, ...]. + # Then we do the calculation and use tl.where to pick put the right outputs for the even + # and for the odd indices. + rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ... + rk_repeat = tl.arange(0, block_k) // 2 + x0_ = x_ + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim) + x1_ = x_ + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim) + cos_ = cos_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + sin_ = sin_ + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :]) + cos = tl.load( + cos_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=1.0, + ).to(tl.float32) + sin = tl.load( + sin_, + mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), + other=0.0, + ).to(tl.float32) + x0 = tl.load(x0_, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32) + x1 = tl.load(x1_, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32) + if CONJUGATE: + sin = -sin + x0_cos = x0 * cos + x1_sin = x1 * sin + out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin) + out_ = out_ + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim) + tl.store(out_, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim)) + + +def apply_rotary( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + interleaved=False, + inplace=False, + conjugate=False, +) -> torch.Tensor: + """ + Arguments: + x: (batch, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim). + cos: (seqlen_ro, rotary_dim / 2) + sin: (seqlen_ro, rotary_dim / 2) + seqlen_offsets: integer or integer tensor of size (batch,) + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Returns: + y: (batch, seqlen, nheads, headdim) + """ + is_varlen = cu_seqlens is not None + if not is_varlen: + batch, seqlen, nheads, headdim = x.shape + else: + assert max_seqlen is not None, "If cu_seqlens is passed in, then max_seqlen must be passed" + total_seqlen, nheads, headdim = x.shape + batch_p_1 = cu_seqlens.shape[0] + batch = batch_p_1 - 1 + seqlen = max_seqlen + seqlen_ro, rotary_dim = cos.shape + assert sin.shape == cos.shape + rotary_dim *= 2 + assert rotary_dim <= headdim, "rotary_dim must be <= headdim" + assert headdim <= 256, "Only support headdim <= 256" + assert seqlen_ro >= seqlen, "seqlen_ro must be >= seqlen" + + assert cos.dtype == sin.dtype, f"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}" + assert x.dtype == cos.dtype, f"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}" + + cos, sin = cos.contiguous(), sin.contiguous() + if isinstance(seqlen_offsets, torch.Tensor): + assert seqlen_offsets.shape == (batch,) + assert seqlen_offsets.dtype in [torch.int32, torch.int64] + seqlen_offsets = seqlen_offsets.contiguous() + else: + assert seqlen_offsets + seqlen <= seqlen_ro + + output = torch.empty_like(x) if not inplace else x + if rotary_dim < headdim and not inplace: + output[..., rotary_dim:].copy_(x[..., rotary_dim:]) + + block_k = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256)) + grid = lambda META: (triton.cdiv(seqlen, META["block_m"]), batch, nheads) # noqa + block_m = 4 if interleaved else (8 if rotary_dim <= 64 else 4) + + # Need this, otherwise Triton tries to launch from cuda:0 and we get + # ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) + with torch.cuda.device(x.device.index): + rotary_kernel[grid]( + output, # data ptrs + x, + cos, + sin, + cu_seqlens, + seqlen_offsets, + seqlen, # shapes + nheads, + rotary_dim, + seqlen_ro, + seqlen // 128, # key for triton cache (limit number of compilations) + output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + output.stride(-3), # seqlen_stride or total_seqlen_stride + output.stride(-2), # nheads_stride + output.stride(-1), # headdim_stride + x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0 + x.stride(-3), # seqlen stride or total_seqlen_stride + x.stride(-2), # nheads stride + x.stride(-1), # headdim stride + block_k, + isinstance(seqlen_offsets, torch.Tensor), + is_varlen, + interleaved, + conjugate, + block_m, + ) + return output + + +##### ROTARY API ##### + + +def rotate_half(x, interleaved=False): + if not interleaved: + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + else: + x1, x2 = x[..., ::2], x[..., 1::2] + return rearrange(torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2) + + +def apply_rotary_emb_torch(x, cos, sin, interleaved=False): + """ + x: (batch_size, seqlen, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2) + """ + ro_dim = cos.shape[-1] * 2 + assert ro_dim <= x.shape[-1] + cos = repeat(cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + sin = repeat(sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)") + return torch.cat( + [x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin, x[..., ro_dim:]], + dim=-1, + ) + + +class ApplyRotaryEmb(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, + ): + out = apply_rotary( + x, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=max_seqlen, + interleaved=interleaved, + inplace=inplace, + ) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cu_seqlens) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cu_seqlens, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + ctx.inplace = inplace + ctx.max_seqlen = max_seqlen + return out if not inplace else x + + @staticmethod + def backward(ctx, do): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cu_seqlens, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cu_seqlens = ctx.saved_tensors + # TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with + # "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works. + if not ctx.interleaved and not ctx.inplace: + do = do.clone() + dx = apply_rotary( + do, + cos, + sin, + seqlen_offsets=seqlen_offsets, + cu_seqlens=cu_seqlens, + max_seqlen=ctx.max_seqlen, + interleaved=ctx.interleaved, + inplace=ctx.inplace, + conjugate=True, + ) + return dx, None, None, None, None, None, None, None + + +def apply_rotary_emb( + x, + cos, + sin, + interleaved=False, + inplace=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + cu_seqlens: Optional[torch.Tensor] = None, + max_seqlen: Optional[int] = None, +): + """ + Arguments: + x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + cos, sin: (seqlen_rotary, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + inplace: if True, apply rotary embedding in-place. + seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + cu_seqlens: (batch + 1,) or None + max_seqlen: int + Return: + out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None + else (total_seqlen, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding to the first rotary_dim of x. + """ + return ApplyRotaryEmb.apply(x, cos, sin, interleaved, inplace, seqlen_offsets, cu_seqlens, max_seqlen) + + +# For backward compatibility +apply_rotary_emb_func = apply_rotary_emb + + +class ApplyRotaryEmbQKV(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, + ): + batch, seqlen, three, nheads, headdim = qkv.shape + assert three == 3 + if cos_k is None and sin_k is None and qkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need qkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + # qk = rearrange(qkv[:, :, :2], "b s t h d -> b s (t h) d") + qk = qkv[:, :, :2].reshape(batch, seqlen, -1, headdim) + apply_rotary(qk, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + q, k = qkv[:, :, 0], qkv[:, :, 1] + apply_rotary(q, cos, sin, seqlen_offsets, interleaved=interleaved, inplace=True) + apply_rotary(k, cos_k, sin_k, seqlen_offsets, interleaved=interleaved, inplace=True) + ctx.save_for_backward(cos, sin, cos_k, sin_k) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin, cos_k, sin_k) + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, cos_k, sin_k, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return qkv + + @staticmethod + def backward(ctx, dqkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, cos_k, sin_k, seqlen_offsets = ctx.saved_tensors + else: + cos, sin, cos_k, sin_k = ctx.saved_tensors + if cos_k is None and sin_k is None and dqkv.is_contiguous(): + # Call 1 kernel instead of 2 kernels + # We need dqkv to be contiguous so that when we reshape to combine (3, nheads) + # dimensions, we get the same tensor + dqk = rearrange(dqkv[:, :, :2], "b s t h d -> b s (t h) d") + apply_rotary( + dqk, + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + else: + cos_k = cos if cos_k is None else cos_k + sin_k = sin if sin_k is None else sin_k + dq, dk = dqkv[:, :, 0], dqkv[:, :, 1] + apply_rotary(dq, cos, sin, seqlen_offsets, interleaved=ctx.interleaved, inplace=True, conjugate=True) + apply_rotary( + dk, + cos_k, + sin_k, + seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dqkv, None, None, None, None, None, None + + +def apply_rotary_emb_qkv_( + qkv, + cos, + sin, + cos_k=None, + sin_k=None, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + cos_k, sin_k: (seqlen, rotary_dim / 2), optional + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + qkv: (batch_size, seqlen, 3, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of Q and K. + """ + return ApplyRotaryEmbQKV.apply(qkv, cos, sin, cos_k, sin_k, interleaved, seqlen_offsets) + + +class ApplyRotaryEmbKV(torch.autograd.Function): + @staticmethod + def forward(ctx, kv, cos, sin, interleaved=False, seqlen_offsets: Union[int, torch.Tensor] = 0): + batch, seqlen, two, nheads, headdim = kv.shape + assert two == 2 + k = kv[:, :, 0] + apply_rotary(k, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved, inplace=True) + if isinstance(seqlen_offsets, int): + ctx.save_for_backward(cos, sin) # Can't save int with save_for_backward + ctx.seqlen_offsets = seqlen_offsets + else: + ctx.save_for_backward(cos, sin, seqlen_offsets) + ctx.seqlen_offsets = None + ctx.interleaved = interleaved + return kv + + @staticmethod + def backward(ctx, dkv): + seqlen_offsets = ctx.seqlen_offsets + if seqlen_offsets is None: + cos, sin, seqlen_offsets = ctx.saved_tensors + else: + cos, sin = ctx.saved_tensors + apply_rotary( + dkv[:, :, 0], + cos, + sin, + seqlen_offsets=seqlen_offsets, + interleaved=ctx.interleaved, + inplace=True, + conjugate=True, + ) + return dkv, None, None, None, None + + +apply_rotary_emb_kv_ = ApplyRotaryEmbKV.apply + + +def apply_rotary_emb_kv_( + kv, + cos, + sin, + interleaved=False, + seqlen_offsets: Union[int, torch.Tensor] = 0, +): + """ + Arguments: + kv: (batch_size, seqlen, 2, nheads, headdim) + cos, sin: (seqlen, rotary_dim / 2) + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of + 1st half and 2nd half (GPT-NeoX style). + seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount. + Most commonly used in inference when we have KV cache. + Return: + kv: (batch_size, seqlen, 2, nheads, headdim) + rotary_dim must be <= headdim + Apply rotary embedding *inplace* to the first rotary_dim of K. + """ + return ApplyRotaryEmbKV.apply(kv, cos, sin, interleaved, seqlen_offsets) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + If scale_base is not None, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). + A recommended value for scale_base is 512: https://github.com/HazyResearch/flash-attention/issues/96 + Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py + """ + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + ): + """ + interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead + of 1st half and 2nd half (GPT-NeoX style). + pos_idx_in_fp32: if True, the position indices [0.0, ..., seqlen - 1] are in fp32, + otherwise they might be in lower precision. + This option was added because previously (before 2023-07-02), when we construct + the position indices, we use the dtype of self.inv_freq. In most cases this would + be fp32, but if the model is trained in pure bf16 (not mixed precision), then + self.inv_freq would be bf16, and the position indices are also in bf16. + Because of the limited precision of bf16 (e.g. 1995.0 is rounded to 2000.0), the + embeddings for some positions will coincide. + To maintain compatibility with models previously trained in pure bf16, + we add this option. + """ + super().__init__() + self.dim = dim + self.base = float(base) + self.pos_idx_in_fp32 = pos_idx_in_fp32 + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = self._compute_inv_freq(device) + self.register_buffer("inv_freq", inv_freq, persistent=False) + self.interleaved = interleaved + self.scale_base = scale_base + scale = ( + (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim) + if scale_base is not None + else None + ) + self.register_buffer("scale", scale, persistent=False) + + self._seq_len_cached = 0 + self._cos_cached = None + self._sin_cached = None + self._cos_k_cached = None + self._sin_k_cached = None + + def _compute_inv_freq(self, device=None): + return 1.0 / (self.base ** (torch.arange(0, self.dim, 2, device=device, dtype=torch.float32) / self.dim)) + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) + + def forward( + self, + qkv: torch.Tensor, + kv: Optional[torch.Tensor] = None, + seqlen_offset: Union[int, torch.Tensor] = 0, + max_seqlen: Optional[int] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """ + qkv: (batch, seqlen, 3, nheads, headdim) if kv is none, + else it's just q of shape (batch, seqlen, nheads, headdim) + kv: (batch, seqlen, 2, nheads, headdim) + seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount. + Most commonly used in inference when we have KV cache. + If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one + should pass in max_seqlen, which will update the cos / sin cache up to that length. + Apply rotary embedding *inplace* to qkv and / or kv. + """ + seqlen = qkv.shape[1] + if max_seqlen is not None: + self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) + elif isinstance(seqlen_offset, int): + self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype) + if kv is None: + if self.scale is None: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + return apply_rotary_emb_qkv_( + qkv, + self._cos_cached, + self._sin_cached, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + q = qkv + q = apply_rotary_emb_func( + q, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + inplace=True, + seqlen_offsets=seqlen_offset, + ) + if self.scale is None: + kv = apply_rotary_emb_kv_( + kv, + self._cos_cached, + self._sin_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + else: + kv = apply_rotary_emb_kv_( + kv, + self._cos_k_cached, + self._sin_k_cached, + interleaved=self.interleaved, + seqlen_offsets=seqlen_offset, + ) + return q, kv diff --git a/onnxruntime/test/python/transformers/test_flash_attn.py b/onnxruntime/test/python/transformers/test_flash_attn.py index 8a839875de2a2..90d28872d3cc8 100644 --- a/onnxruntime/test/python/transformers/test_flash_attn.py +++ b/onnxruntime/test/python/transformers/test_flash_attn.py @@ -20,6 +20,7 @@ from bert_padding import pad_input, unpad_input from einops import rearrange, repeat from onnx import TensorProto, helper +from rotary_flash import apply_rotary_emb from onnxruntime import InferenceSession, OrtValue, SessionOptions @@ -184,7 +185,13 @@ def create_multihead_attention_graph(config): def create_group_query_attention_graph_prompt( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.buffer_sequence_length if share_buffer else 0 present_kv_seqlen = config.buffer_sequence_length if share_buffer else config.kv_sequence_length @@ -193,18 +200,22 @@ def create_group_query_attention_graph_prompt( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key" if share_buffer else "", "past_value" if share_buffer else "", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -218,25 +229,9 @@ def create_group_query_attention_graph_prompt( [ config.batch_size, config.q_sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.kv_sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -250,6 +245,27 @@ def create_group_query_attention_graph_prompt( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.kv_sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] if share_buffer: graph_input += [ helper.make_tensor_value_info( @@ -273,6 +289,25 @@ def create_group_query_attention_graph_prompt( ], ), ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.buffer_sequence_length if share_buffer else config.kv_sequence_length, + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -334,7 +369,13 @@ def create_group_query_attention_graph_prompt( def create_group_query_attention_graph_past( - config, past_kv_format=Formats.BSNH, share_buffer=True, local_window_size=-1 + config, + past_kv_format=Formats.BSNH, + share_buffer=True, + local_window_size=-1, + rotary=False, + rotary_interleaved=False, + packed=False, ): past_kv_seqlen = config.kv_sequence_length present_kv_seqlen = ( @@ -345,18 +386,22 @@ def create_group_query_attention_graph_past( "GroupQueryAttention", [ "query", - "key", - "value", + "key" if not packed else "", + "value" if not packed else "", "past_key", "past_value", "seqlens_k", "total_sequence_length", + "cos_cache" if rotary else "", + "sin_cache" if rotary else "", ], ["output", "present_key", "present_value"], "GroupQueryAttention_0", num_heads=config.num_heads, kv_num_heads=config.kv_num_heads, local_window_size=local_window_size, + do_rotary=rotary, + rotary_interleaved=rotary_interleaved, # is_past_bsnh=1 if past_kv_format == Formats.BSNH else 0, # kv_share_buffer=1 if share_buffer else 0, domain="com.microsoft", @@ -370,25 +415,9 @@ def create_group_query_attention_graph_past( [ config.batch_size, config.sequence_length, - config.num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "key", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, - ], - ), - helper.make_tensor_value_info( - "value", - TensorProto.FLOAT16, - [ - config.batch_size, - config.sequence_length, - config.kv_num_heads * config.head_size, + (config.num_heads * config.head_size) + if not packed + else (config.num_heads * config.head_size + 2 * config.kv_num_heads * config.head_size), ], ), helper.make_tensor_value_info( @@ -411,8 +440,6 @@ def create_group_query_attention_graph_past( config.head_size, ], ), - ] - graph_input += [ helper.make_tensor_value_info( "seqlens_k", TensorProto.INT32, @@ -424,6 +451,46 @@ def create_group_query_attention_graph_past( [1], ), ] + if not packed: + graph_input += [ + helper.make_tensor_value_info( + "key", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + helper.make_tensor_value_info( + "value", + TensorProto.FLOAT16, + [ + config.batch_size, + config.sequence_length, + config.kv_num_heads * config.head_size, + ], + ), + ] + if rotary: + graph_input += [ + helper.make_tensor_value_info( + "cos_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + helper.make_tensor_value_info( + "sin_cache", + TensorProto.FLOAT16, + [ + config.kv_sequence_length + (0 if share_buffer else config.sequence_length), + (math.floor(config.head_size / 16) * 16) // 2, + ], + ), + ] graph_output = [ helper.make_tensor_value_info( @@ -663,21 +730,38 @@ def mha_func(q, k, v, config): def gqa_prompt_func( - q, k, v, config, new_k, new_v, seqlens_k=None, window_size=-1, past_kv_format=Formats.BSNH, share_buffer=True + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + window_size=-1, + past_kv_format=Formats.BSNH, + share_buffer=True, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_prompt( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.q_sequence_length, -1)) past_k = k.clone() if share_buffer else None past_v = v.clone() if share_buffer else None - new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.kv_sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.kv_sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -686,9 +770,17 @@ def gqa_prompt_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -713,17 +805,23 @@ def gqa_prompt_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), "total_sequence_length": torch.tensor([config.q_sequence_length], dtype=torch.int32).detach().cpu().numpy(), } sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) io_binding.bind_cpu_input("total_sequence_length", ort_inputs["total_sequence_length"]) io_binding.bind_output("output") @@ -737,21 +835,38 @@ def gqa_prompt_func( def gqa_past_func( - q, k, v, config, new_k, new_v, seqlens_k=None, past_kv_format=Formats.BSNH, share_buffer=True, window_size=-1 + q, + k, + v, + config, + new_k, + new_v, + cos=None, + sin=None, + seqlens_k=None, + past_kv_format=Formats.BSNH, + share_buffer=True, + window_size=-1, + rotary_interleaved=False, ): onnx_model_str = create_group_query_attention_graph_past( - config, past_kv_format, share_buffer, local_window_size=window_size + config, + past_kv_format, + share_buffer, + local_window_size=window_size, + rotary=cos is not None, + rotary_interleaved=rotary_interleaved, + packed=new_k is None, ) q = torch.reshape(q, (config.batch_size, config.sequence_length, -1)) past_k = k.clone() past_v = v.clone() - new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) - new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) + if new_k is not None: + new_k = torch.reshape(new_k, (config.batch_size, config.sequence_length, -1)) + new_v = torch.reshape(new_v, (config.batch_size, config.sequence_length, -1)) if share_buffer: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": OrtValue.ortvalue_from_numpy(past_k.detach().cpu().numpy(), "cuda", 0), "past_value": OrtValue.ortvalue_from_numpy(past_v.detach().cpu().numpy(), "cuda", 0), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -763,9 +878,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_input( "past_key", "cuda", 0, numpy.float16, ort_inputs["past_key"].shape(), ort_inputs["past_key"].data_ptr() ) @@ -790,8 +913,6 @@ def gqa_past_func( else: ort_inputs = { "query": q.detach().cpu().numpy(), - "key": new_k.detach().cpu().numpy(), - "value": new_v.detach().cpu().numpy(), "past_key": past_k.detach().cpu().numpy(), "past_value": past_v.detach().cpu().numpy(), "seqlens_k": seqlens_k.detach().cpu().numpy().astype(numpy.int32), @@ -805,9 +926,17 @@ def gqa_past_func( sess_options = SessionOptions() ort_session = InferenceSession(onnx_model_str, sess_options, providers=["CUDAExecutionProvider"]) io_binding = ort_session.io_binding() + if new_k is not None: + ort_inputs["key"] = new_k.detach().cpu().numpy() + ort_inputs["value"] = new_v.detach().cpu().numpy() + io_binding.bind_cpu_input("key", ort_inputs["key"]) + io_binding.bind_cpu_input("value", ort_inputs["value"]) + if cos is not None: + ort_inputs["cos_cache"] = cos.detach().cpu().numpy() + ort_inputs["sin_cache"] = sin.detach().cpu().numpy() + io_binding.bind_cpu_input("cos_cache", ort_inputs["cos_cache"]) + io_binding.bind_cpu_input("sin_cache", ort_inputs["sin_cache"]) io_binding.bind_cpu_input("query", ort_inputs["query"]) - io_binding.bind_cpu_input("key", ort_inputs["key"]) - io_binding.bind_cpu_input("value", ort_inputs["value"]) io_binding.bind_cpu_input("past_key", ort_inputs["past_key"]) io_binding.bind_cpu_input("past_value", ort_inputs["past_value"]) io_binding.bind_cpu_input("seqlens_k", ort_inputs["seqlens_k"]) @@ -1029,9 +1158,12 @@ def parity_check_mha( def parity_check_gqa_prompt( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1080,6 +1212,8 @@ def parity_check_gqa_prompt( dtype=torch.float16, requires_grad=False, ) + # print(k.shape) + # print(new_k.shape) window_size = (-1, -1) left_window_size = -1 @@ -1105,19 +1239,47 @@ def parity_check_gqa_prompt( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.buffer_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") arange = rearrange(torch.arange(config.buffer_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") kv_seqlens = torch.tensor([config.kv_sequence_length], device="cuda").repeat(config.batch_size) kv_seqlens_expanded = rearrange(kv_seqlens, "b -> b 1") update_mask = arange < kv_seqlens_expanded - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1125,13 +1287,47 @@ def parity_check_gqa_prompt( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, k, v, config, new_k, new_v, cache_seqlens, left_window_size, past_format, True - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + True, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # print((out - out_ref)[0, :, 0, 0]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1139,10 +1335,16 @@ def parity_check_gqa_prompt( # Compare results print( "KV-buffer", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1171,9 +1373,12 @@ def parity_check_gqa_prompt( def parity_check_gqa_prompt_no_buff( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1229,13 +1434,42 @@ def parity_check_gqa_prompt_no_buff( # device="cuda", # ) # cache_seqlens[random.randint(0, cache_seqlens.size(dim=0) - 1)] = config.kv_sequence_length + rotary_seqlens = torch.tensor([0], device="cuda").repeat(config.batch_size) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=rotary_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.q_sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(k_cache_ref, cos, sin, seqlen_offsets=rotary_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, k_cache_ref + k_cache_ref = k_ro + brange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") new_mask = brange < cache_seqlens_expanded k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, new_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1243,9 +1477,39 @@ def parity_check_gqa_prompt_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_prompt_func( - q, None, None, config, new_k, new_v, cache_seqlens, left_window_size, past_format, False - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_prompt_func( + packed_qkv, + None, + None, + config, + None, + None, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_prompt_func( + q, + None, + None, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + left_window_size, + past_format, + False, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.q_sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() @@ -1256,7 +1520,17 @@ def parity_check_gqa_prompt_no_buff( # Compare results print( - "KV-buffer", + "No buff", + " packed:", + packed, + " causal:", + causal, + " local:", + local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1285,9 +1559,12 @@ def parity_check_gqa_prompt_no_buff( def parity_check_gqa_past( config, - causal=False, + causal=True, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1336,6 +1613,7 @@ def parity_check_gqa_past( dtype=torch.float16, requires_grad=False, ) + window_size = (-1, -1) left_window_size = -1 if local: @@ -1359,18 +1637,45 @@ def parity_check_gqa_past( dtype=torch.int32, device="cuda", ) + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = torch.rand(config.kv_sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1378,13 +1683,46 @@ def parity_check_gqa_past( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, True, left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + True, + left_window_size, + rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, cache_seqlens[0], :]) + # Make sure past-present buffer updating correctly assert numpy.allclose(present_k, k_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) assert numpy.allclose(present_v, v_cache_ref.detach().cpu().numpy(), rtol=rtol, atol=atol, equal_nan=True) @@ -1394,10 +1732,16 @@ def parity_check_gqa_past( "KV-buffer", "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, " B:", config.batch_size, " S:", @@ -1427,6 +1771,9 @@ def parity_check_gqa_past_no_buff( causal=False, local=False, past_format=Formats.BSNH, + rotary=False, + rotary_interleaved=False, + packed=False, rtol=1e-3, atol=1e-3, ): @@ -1503,18 +1850,47 @@ def parity_check_gqa_past_no_buff( device="cuda", ) cache_seqlens[random.randint(0, config.batch_size - 1)] = config.kv_sequence_length + + if rotary: + rotary_fraction = 1.0 + rotary_dim = math.floor(int(rotary_fraction * config.head_size) / 16) * 16 + angle = ( + torch.rand(config.kv_sequence_length + config.sequence_length, rotary_dim // 2, device="cuda") * 2 * math.pi + ) + cos = torch.cos(angle).to(dtype=torch.float16) + sin = torch.sin(angle).to(dtype=torch.float16) + if causal or local: + q_ro = apply_rotary_emb(q, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + q_ro = rearrange( + apply_rotary_emb( + rearrange(q, "b s h d -> b 1 (s h) d"), + cos, + sin, + seqlen_offsets=cache_seqlens, + interleaved=rotary_interleaved, + ), + "b 1 (s h) d -> b s h d", + s=config.sequence_length, + ) + # q_ro = q + k_ro = apply_rotary_emb(new_k, cos, sin, seqlen_offsets=cache_seqlens, interleaved=rotary_interleaved) + else: + cos, sin = None, None + q_ro, k_ro = q, new_k + arange = rearrange(torch.arange(config.kv_sequence_length + config.sequence_length, device="cuda"), "s -> 1 s") cache_seqlens_expanded = rearrange(cache_seqlens, "b -> b 1") update_mask = torch.logical_and( cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + config.sequence_length ) - k_cache_ref[update_mask] = rearrange(new_k, "b s ... -> (b s) ...") + k_cache_ref[update_mask] = rearrange(k_ro, "b s ... -> (b s) ...") v_cache_ref[update_mask] = rearrange(new_v, "b s ... -> (b s) ...") k_cache_rep = repeat(k_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) v_cache_rep = repeat(v_cache_ref, "b s h d -> b s (h g) d", g=config.num_heads // config.kv_num_heads) key_padding_mask = arange < cache_seqlens_expanded + config.sequence_length out_ref, _ = attention_ref( - q, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size + q_ro, k_cache_rep, v_cache_rep, None, key_padding_mask, 0.0, None, causal=True, window_size=window_size ) out_ref = out_ref.detach().cpu().numpy() if past_format == Formats.BNSH: @@ -1522,13 +1898,47 @@ def parity_check_gqa_past_no_buff( v_cache_ref = v_cache_ref.transpose(1, 2) # Flash function - out, present_k, present_v = gqa_past_func( - q, k, v, config, new_k, new_v, cache_seqlens, past_format, False, window_size=left_window_size - ) + if packed: + packed_qkv = torch.concatenate([q, new_k, new_v], dim=2) + out, present_k, present_v = gqa_past_func( + packed_qkv, + k, + v, + config, + None, + None, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) + else: + out, present_k, present_v = gqa_past_func( + q, + k, + v, + config, + new_k, + new_v, + cos, + sin, + cache_seqlens, + past_format, + False, + window_size=left_window_size, + rotary_interleaved=rotary_interleaved, + ) out = torch.squeeze(out, 0) out = torch.reshape(out, (config.batch_size, config.sequence_length, config.num_heads, config.head_size)) out = out.detach().cpu().numpy() + # print(cache_seqlens[0]) + # print((out - out_ref)[0]) + # print((present_k - k_cache_ref.detach().cpu().numpy())[0, 0, :, 0]) + # Make sure past-present buffer updating correctly # assert numpy.allclose( # present_k[:, :, :-1, :], k_cache_ref.detach().cpu().numpy()[:, :, :-1, :], rtol=rtol, atol=atol, equal_nan=True @@ -1540,10 +1950,16 @@ def parity_check_gqa_past_no_buff( # Compare results print( "NO buff", + " packed:", + packed, " causal:", causal, " local:", local, + " rotary:", + rotary, + " rotary_interleaved:", + rotary_interleaved, "past kv format:", "BSNH" if past_format == Formats.BSNH else "BNSH", " B:", @@ -1671,10 +2087,25 @@ def test_gqa_no_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) - parity_check_gqa_prompt(config, local=local, past_format=past_kv_format) - parity_check_gqa_prompt_no_buff(config, local=local, past_format=past_kv_format) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + config = PromptConfig(b, sq, skv, sq + skv + 8, n, n2, h) + parity_check_gqa_prompt( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_prompt_no_buff( + config, + local=local, + past_format=past_kv_format, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) def test_gqa_past(self): if not torch.cuda.is_available(): @@ -1684,7 +2115,6 @@ def test_gqa_past(self): return os.environ["ORT_DISABLE_FLASH_ATTENTION"] = "1" print("-------- TEST GQA PAST (TOKEN GEN) ---------") - print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") batches = [5] if pipeline_mode else [1, 3, 5] seqs = ( [(1, 128), (1, 1024), (1, 2048)] @@ -1706,6 +2136,7 @@ def test_gqa_past(self): num_h = [(32, 32), (9, 3), (4, 4)] if pipeline_mode else [(6, 6), (6, 3), (9, 9), (9, 3)] h_sizes = [16, 128, 256] if pipeline_mode else [32, 40, 64, 80, 96, 128, 160, 192, 224, 256] random.seed(69) + print("-------- MEMORY EFFICIENT (TOKEN GEN) --------") for b in batches: for s, s2 in seqs: for n, n2 in num_h: @@ -1734,23 +2165,30 @@ def test_gqa_past(self): for n, n2 in num_h: for h in h_sizes: for local in [False, True]: - for past_kv_format in [Formats.BNSH]: - sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 - config = Config(b, s, s2, sp, n, n2, h) - parity_check_gqa_past( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) - parity_check_gqa_past_no_buff( - config, - local=local, - past_format=past_kv_format, - rtol=1e-3, - atol=1e-3, - ) + for rotary, rotary_interleaved in [(True, False), (True, True), (False, False)]: + for past_kv_format, packed in [(Formats.BNSH, False), (Formats.BNSH, True)]: + sp = random.randint(1, s2 - s) if s2 - s > 0 else 0 + config = Config(b, s, s2, sp, n, n2, h) + parity_check_gqa_past( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) + parity_check_gqa_past_no_buff( + config, + local=local, + past_format=past_kv_format, + rtol=1e-3, + atol=1e-3, + rotary=rotary, + rotary_interleaved=rotary_interleaved, + packed=packed, + ) if __name__ == "__main__": diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 6ffe72f81bd24..8dad2c8e2d10d 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -43,6 +43,10 @@ #include #endif +#ifdef USE_ROCM +#include +#endif + // Once we use C++17 this could be replaced with std::size template constexpr size_t countof(T (&)[N]) { return N; } @@ -1762,6 +1766,27 @@ TEST(CApiTest, get_allocator_cuda) { } #endif +#ifdef USE_ROCM +TEST(CApiTest, get_allocator_rocm) { + Ort::SessionOptions session_options; + Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_ROCM(session_options, 0)); + Ort::Session session(*ort_env, NAMED_AND_ANON_DIM_PARAM_URI, session_options); + + Ort::MemoryInfo info_rocm("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + Ort::Allocator rocm_allocator(session, info_rocm); + + auto allocator_info = rocm_allocator.GetInfo(); + ASSERT_TRUE(info_rocm == allocator_info); + void* p = rocm_allocator.Alloc(1024); + ASSERT_NE(p, nullptr); + rocm_allocator.Free(p); + + auto mem_allocation = rocm_allocator.GetAllocation(1024); + ASSERT_NE(nullptr, mem_allocation.get()); + ASSERT_EQ(1024U, mem_allocation.size()); +} +#endif + TEST(CApiTest, io_binding) { Ort::SessionOptions session_options; Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CPU(session_options, 1)); @@ -1937,7 +1962,7 @@ TEST(CApiTest, io_binding_cuda) { } #endif -#if defined(USE_CUDA) || defined(USE_TENSORRT) +#if defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, basic_cuda_graph) { const auto& api = Ort::GetApi(); Ort::SessionOptions session_options; @@ -1955,7 +1980,7 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_TensorRT_V2( static_cast(session_options), rel_trt_options.get()) == nullptr); -#else +#elif defined(USE_CUDA) // Enable cuda graph in cuda provider option. OrtCUDAProviderOptionsV2* cuda_options = nullptr; ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); @@ -1968,34 +1993,55 @@ TEST(CApiTest, basic_cuda_graph) { ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( static_cast(session_options), rel_cuda_options.get()) == nullptr); +#elif defined(USE_ROCM) + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); #endif Ort::Session session(*ort_env, MODEL_URI, session_options); - Ort::MemoryInfo info_cuda("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#if defined(USE_ROCM) +// local hipify +#define cudaMemcpy hipMemcpy +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost + Ort::MemoryInfo info_mem("Hip", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#else + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); +#endif - Ort::Allocator cuda_allocator(session, info_cuda); - auto allocator_info = cuda_allocator.GetInfo(); - ASSERT_TRUE(info_cuda == allocator_info); + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); const std::array x_shape = {3, 2}; std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; - auto input_data = cuda_allocator.GetAllocation(x_values.size() * sizeof(float)); + auto input_data = allocator.GetAllocation(x_values.size() * sizeof(float)); ASSERT_NE(input_data.get(), nullptr); - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_x = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(input_data.get()), x_values.size(), + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, reinterpret_cast(input_data.get()), x_values.size(), x_shape.data(), x_shape.size()); const std::array expected_y_shape = {3, 2}; std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; - auto output_data = cuda_allocator.GetAllocation(expected_y.size() * sizeof(float)); + auto output_data = allocator.GetAllocation(expected_y.size() * sizeof(float)); ASSERT_NE(output_data.get(), nullptr); // Create an OrtValue tensor backed by data on CUDA memory - Ort::Value bound_y = Ort::Value::CreateTensor(info_cuda, reinterpret_cast(output_data.get()), + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, reinterpret_cast(output_data.get()), expected_y.size(), expected_y_shape.data(), expected_y_shape.size()); // Create IoBinding for inputs and outputs. @@ -2008,31 +2054,37 @@ TEST(CApiTest, basic_cuda_graph) { // Check the values against the bound raw memory (needs copying from device to host first) std::array y_values; - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Replay the captured CUDA graph session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Change the input and replay the CUDA graph again. x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; - cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); + (void)cudaMemcpy(input_data.get(), x_values.data(), sizeof(float) * x_values.size(), cudaMemcpyHostToDevice); binding.SynchronizeInputs(); session.Run(Ort::RunOptions(), binding); - cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); + (void)cudaMemcpy(y_values.data(), output_data.get(), sizeof(float) * y_values.size(), cudaMemcpyDeviceToHost); expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; ASSERT_THAT(y_values, ::testing::ContainerEq(expected_y)); // Clean up binding.ClearBoundInputs(); binding.ClearBoundOutputs(); +#if defined(USE_ROCM) +#undef cudaMemcpy +#undef cudaMemcpyHostToDevice +#undef cudaMemcpyDeviceToHost +#endif } -#ifndef REDUCED_OPS_BUILD // The following test uses some ops not supported in the reduced ops build +#ifndef REDUCED_OPS_BUILD +#if defined(USE_CUDA) || defined(USE_TENSORRT) TEST(CApiTest, cuda_graph_with_shape_nodes) { const auto& api = Ort::GetApi(); @@ -2053,10 +2105,34 @@ TEST(CApiTest, cuda_graph_with_shape_nodes) { // Successful loading of the ONNX model with shape nodes with cuda graph feature enabled Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); } +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) -#endif +#if defined(USE_ROCM) +TEST(CApiTest, hip_graph_with_shape_nodes) { + const auto& api = Ort::GetApi(); -#endif + // Enable hip graph in rocm provider option. + OrtROCMProviderOptions* rocm_options = nullptr; + ASSERT_TRUE(api.CreateROCMProviderOptions(&rocm_options) == nullptr); + std::unique_ptr + rel_rocm_options(rocm_options, api.ReleaseROCMProviderOptions); + std::vector keys{"enable_hip_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateROCMProviderOptions(rel_rocm_options.get(), keys.data(), values.data(), 1) == nullptr); + + Ort::SessionOptions session_options; + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_ROCM( + static_cast(session_options), + rel_rocm_options.get()) == nullptr); + + // Successful loading of the ONNX model with shape nodes with hip graph feature enabled + Ort::Session session(*ort_env, TSTR("testdata/cuda_graph_with_shape_nodes.onnx"), session_options); +} +#endif // defined(USE_ROCM) + +#endif // REDUCED_OPS_BUILD + +#endif // defined(USE_CUDA) || defined(USE_TENSORRT) || defined(USE_ROCM) TEST(CApiTest, create_tensor) { const char* s[] = {"abc", "kmp"}; diff --git a/onnxruntime/test/util/compare_ortvalue.cc b/onnxruntime/test/util/compare_ortvalue.cc index 3d53d4a3a0193..64ebe24188762 100644 --- a/onnxruntime/test/util/compare_ortvalue.cc +++ b/onnxruntime/test/util/compare_ortvalue.cc @@ -1,4 +1,5 @@ // Copyright (c) Microsoft Corporation. All rights reserved. +// Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. // Licensed under the MIT License. #include "test/compare_ortvalue.h" @@ -65,6 +66,54 @@ const char* ElementTypeToString(MLDataType type) { return DataTypeImpl::ToString(type); } +#if defined(__aarch64__) && defined(__linux__) +template +std::pair CheckCosineSimilarity(const Tensor& outvalue, const Tensor& expected_value) { + const size_t tensor_size = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + std::pair res = std::make_pair(COMPARE_RESULT::SUCCESS, ""); + const T cosine_similarity_threshold = 0.99f; + + T dot = 0.0f, denom_a = 0.0f, denom_b = 0.0f; + for (size_t i = 0u; i < tensor_size; ++i) { + if (isnan(expected_output[i]) && isnan(real_output[i])) + continue; + if (isinf(expected_output[i]) && isinf(real_output[i])) + continue; + dot += expected_output[i] * real_output[i]; + denom_a += expected_output[i] * expected_output[i]; + denom_b += real_output[i] * real_output[i]; + } + + T cos_factor = abs(dot / (sqrt(denom_a) * sqrt(denom_b))); + if (cos_factor < cosine_similarity_threshold) { + res.first = COMPARE_RESULT::RESULT_DIFFERS; + std::ostringstream oss; + oss << std::hex << "results differed, cosine similarity factor is " << cos_factor << "."; + res.second = oss.str(); + } + return res; +} + +template +std::pair CheckCloseMatch(const Tensor& outvalue, const Tensor& expected_value) { + const size_t size1 = static_cast(expected_value.Shape().Size()); + const T* expected_output = expected_value.Data(); + const T* real_output = outvalue.Data(); + const T close_match_threshold = 1.0; + + for (size_t di = 0; di != size1; ++di) { + const T diff = expected_output[di] - real_output[di]; + if (std::fabs(diff) > close_match_threshold) { + std::ostringstream oss; + oss << "expected " << expected_output[di] << ", got " << real_output[di]; + return std::make_pair(COMPARE_RESULT::RESULT_DIFFERS, oss.str()); + } + } + return std::make_pair(COMPARE_RESULT::SUCCESS, ""); +} +#endif /** * @brief Check if two values are closely matched with given tolerance. @@ -207,6 +256,37 @@ std::pair CompareTwoTensors(const Tensor& outvalue, oss << "shape mismatch, expect " << expected_tensor.Shape().ToString() << " got " << outvalue.Shape().ToString(); return std::make_pair(COMPARE_RESULT::SHAPE_MISMATCH, oss.str()); } + +#if defined(__aarch64__) && defined(__linux__) + if (isnan(per_sample_tolerance) || isnan(per_sample_tolerance)) { + if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCosineSimilarity(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else if (outvalue.IsDataType()) { + return CheckCloseMatch(outvalue, expected_tensor); + } else { + return std::make_pair(COMPARE_RESULT::NOT_SUPPORT, ""); + } + } +#endif + if (outvalue.IsDataType()) { return CompareFloatResult(outvalue, expected_tensor, per_sample_tolerance, relative_per_sample_tolerance, post_processing); diff --git a/onnxruntime/wasm/js_internal_api.js b/onnxruntime/wasm/js_internal_api.js index 25ece9c700d5d..7c70515e73eab 100644 --- a/onnxruntime/wasm/js_internal_api.js +++ b/onnxruntime/wasm/js_internal_api.js @@ -186,4 +186,7 @@ Module['jsepInit'] = (backend, alloc, free, copy, copyAsync, createKernel, relea Module['jsepCreateDownloader'] = (gpuBuffer, size, type) => { return backend['createDownloader'](gpuBuffer, size, type); }; + Module['jsepOnRunStart'] = () => { + return backend['onRunStart'](); + }; }; diff --git a/orttraining/orttraining/test/python/orttraining_test_dort.py b/orttraining/orttraining/test/python/orttraining_test_dort.py index f0b6b9c5fba28..573ec85d76013 100644 --- a/orttraining/orttraining/test/python/orttraining_test_dort.py +++ b/orttraining/orttraining/test/python/orttraining_test_dort.py @@ -216,7 +216,12 @@ def elementwise_model(tensor_x: torch.Tensor): tensor_q = tensor_p.relu() return tensor_q - local_backend = make_local_backend(dynamic=True, use_aot_autograd=False) + # TODO: Set use_aot_autograd=False. In order to decompose torch + # function calls to aten ops, we need to set + # user_aot_autograd=True because there is no decomposition in DORT + # anymore. A long-term fix will be brining # decomposition pass back + # into DORT. + local_backend = make_local_backend(dynamic=True, use_aot_autograd=True) optimized_elementwise_model = torch.compile(elementwise_model, backend=local_backend, dynamic=True) def run(fun, list_x): diff --git a/tools/ci_build/build.py b/tools/ci_build/build.py index 1034a82cb2854..6e5cd7b57e403 100644 --- a/tools/ci_build/build.py +++ b/tools/ci_build/build.py @@ -2046,7 +2046,8 @@ def run_onnxruntime_tests(args, source_dir, ctest_path, build_dir, configs): numpy_init_version = numpy.__version__ pb_init_version = google.protobuf.__version__ run_subprocess( - [sys.executable, "-m", "pip", "install", "-r", "requirements.txt"], cwd=SCRIPT_DIR + [sys.executable, "-m", "pip", "install", "-r", "requirements-transformers-test.txt"], + cwd=SCRIPT_DIR, ) run_subprocess([sys.executable, "-m", "pytest", "transformers"], cwd=cwd) # Restore initial numpy/protobuf version in case other tests use it diff --git a/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml new file mode 100644 index 0000000000000..ff2e7c0468a21 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/bigmodels-ci-pipeline.yml @@ -0,0 +1,259 @@ +# reference: https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/transformers/models/stable_diffusion/README.md +parameters: +- name: specificArtifact + displayName: Use Specific Artifact + type: boolean + default: false +- name: BuildId + displayName: Specific Artifact's RunId + type: number + default: 0 + +resources: + repositories: + - repository: manylinux + type: Github + endpoint: Microsoft + name: pypa/manylinux + ref: 5eda9aded5462201e6310105728d33016e637ea7 + + - repository: LLaMa2Onnx + type: Github + endpoint: Microsoft + name: Microsoft/Llama-2-Onnx + ref: main + +variables: + - template: templates/common-variables.yml + - name: docker_base_image + value: nvidia/cuda:11.8.0-cudnn8-devel-ubi8 + - name: linux_trt_version + value: 8.6.1.6-1.cuda11.8 + +stages: +- stage: Build_Onnxruntime_Cuda + jobs: + - job: Linux_Build + timeoutInMinutes: 120 + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Ubuntu2204-AMD-CPU + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: tools/ci_build/github/linux/docker/Dockerfile.manylinux2_28_cuda + Context: tools/ci_build/github/linux/docker + DockerBuildArgs: " + --network=host + --build-arg BASEIMAGE=$(docker_base_image) + --build-arg TRT_VERSION=$(linux_trt_version) + --build-arg BUILD_UID=$( id -u ) + " + Repository: onnxruntimecuda11build + + - task: Cache@2 + inputs: + key: '"ccache" | "$(Build.SourceBranch)" | "$(Build.SourceVersion)"' + path: $(CCACHE_DIR) + restoreKeys: | + "ccache" | "$(Build.SourceBranch)" + "ccache" + cacheHitVar: CACHE_RESTORED + displayName: Cach Task + + - script: | + sudo mkdir -p $(Pipeline.Workspace)/ccache + condition: ne(variables.CACHE_RESTORED, 'true') + displayName: Create Cache Dir + + - task: CmdLine@2 + inputs: + script: | + mkdir -p $HOME/.onnx + docker run -e CFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" -e CXXFLAGS="-Wp,-D_FORTIFY_SOURCE=2 -Wp,-D_GLIBCXX_ASSERTIONS -fstack-protector-strong -fstack-clash-protection -fcf-protection -O3 -Wl,--strip-all" --rm \ + --volume /data/onnx:/data/onnx:ro \ + --volume $(Build.SourcesDirectory):/onnxruntime_src \ + --volume $(Build.BinariesDirectory):/build \ + --volume /data/models:/build/models:ro \ + --volume $HOME/.onnx:/home/onnxruntimedev/.onnx \ + --volume $(Pipeline.Workspace)/ccache:/cache \ + -e ALLOW_RELEASED_ONNX_OPSET_ONLY=0 \ + -e NIGHTLY_BUILD \ + -e BUILD_BUILDNUMBER \ + -e CCACHE_DIR=/cache \ + onnxruntimecuda11build \ + /bin/bash -c " + set -ex; \ + env; \ + ccache -s; \ + /opt/python/cp38-cp38/bin/python3 /onnxruntime_src/tools/ci_build/build.py \ + --build_dir /build --cmake_generator Ninja \ + --config Release --update --build \ + --skip_submodule_sync \ + --build_shared_lib \ + --parallel \ + --build_wheel \ + --enable_onnx_tests --use_cuda --cuda_version=${{variables.common_cuda_version}} --cuda_home=/usr/local/cuda-${{variables.common_cuda_version}} --cudnn_home=/usr/local/cuda-${{variables.common_cuda_version}} \ + --enable_cuda_profiling --enable_cuda_nhwc_ops \ + --enable_pybind --build_java \ + --use_cache \ + --cmake_extra_defines 'CMAKE_CUDA_ARCHITECTURES=75;86' ; \ + ccache -sv; \ + ccache -z" + workingDirectory: $(Build.SourcesDirectory) + + - task: CmdLine@2 + inputs: + script: | + rm -rf $(Build.BinariesDirectory)/Release/onnxruntime $(Build.BinariesDirectory)/Release/pybind11 + rm -f $(Build.BinariesDirectory)/Release/models + find $(Build.BinariesDirectory)/Release/_deps -mindepth 1 ! -regex '^$(Build.BinariesDirectory)/Release/_deps/onnx-src\(/.*\)?' -delete + cd $(Build.BinariesDirectory)/Release + find -executable -type f > $(Build.BinariesDirectory)/Release/perms.txt + + - script: | + set -ex + mkdir -p $(Agent.TempDirectory)/ort + cp $(Build.BinariesDirectory)/Release/dist/*.whl $(Agent.TempDirectory)/ort/ + displayName: 'Copy Wheels' + + - task: PublishPipelineArtifact@0 + displayName: 'Publish Pipeline Artifact' + inputs: + artifactName: 'drop-ort-linux-gpu' + targetPath: '$(Agent.TempDirectory)/ort' + + - template: templates/explicitly-defined-final-tasks.yml + +- stage: Stale_Diffusion + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Stale_Diffusion + variables: + skipComponentGovernanceDetection: true + CCACHE_DIR: $(Pipeline.Workspace)/ccache + workspace: + clean: all + pool: onnxruntime-Linux-GPU-A10-12G + steps: + - checkout: self + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/Release' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - script: | + docker run --rm --gpus all -v $PWD:/workspace -v $(Build.BinariesDirectory)/Release:/Release nvcr.io/nvidia/pytorch:22.11-py3 \ + bash -c " + set -ex; \ + python3 --version; \ + python3 -m pip install --upgrade pip; \ + python3 -m pip install /Release/*.whl; \ + pushd /workspace/onnxruntime/python/tools/transformers/models/stable_diffusion; \ + python3 -m pip install -r requirements-cuda11.txt; \ + python3 -m pip install --upgrade polygraphy onnx-graphsurgeon --extra-index-url https://pypi.ngc.nvidia.com; \ + echo Generate an image guided by a text prompt; \ + python3 demo_txt2img.py "astronaut riding a horse on mars"; \ + echo Generate an image with Stable Diffusion XL guided by a text prompt; \ + python3 demo_txt2img_xl.py 'starry night over Golden Gate Bridge by van gogh'; \ + python3 demo_txt2img_xl.py --enable-refiner 'starry night over Golden Gate Bridge by van gogh'; \ + echo Generate an image guided by a text prompt using LCM LoRA; \ + python3 demo_txt2img_xl.py --scheduler LCM --lora-weights latent-consistency/lcm-lora-sdxl --denoising-steps 4 "Self-portrait oil painting, a beautiful cyborg with golden hair, 8k"; \ + popd; \ + " + displayName: 'Run stable diffusion demo' + workingDirectory: $(Build.SourcesDirectory) + +- stage: Llama2_ONNX_FP16 + dependsOn: + - Build_Onnxruntime_Cuda + jobs: + - job: Llama2_ONNX_FP16 + variables: + skipComponentGovernanceDetection: true + workspace: + clean: all + pool: onnxruntime-Linux-GPU-T4 + steps: + - task: mspremier.PostBuildCleanup.PostBuildCleanup-task.PostBuildCleanup@3 + displayName: 'Clean Agent Directories' + condition: always() + + - checkout: self + clean: true + submodules: none + + - checkout: LLaMa2Onnx + clean: true + submodules: none + + - template: templates/flex-downloadPipelineArtifact.yml + parameters: + StepName: 'Download Onnxruntime Artifact' + ArtifactName: 'drop-ort-linux-gpu' + TargetPath: '$(Build.BinariesDirectory)/ort-artifact/' + SpecificArtifact: ${{ parameters.specificArtifact }} + BuildId: ${{ parameters.BuildId }} + + - task: DownloadPackage@1 + displayName: 'Download Llama2 model' + inputs: + packageType: upack + feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' + version: 1.0.0 + definition: '772ebce3-7e06-46d5-b3cc-82040ec4b2ce' + downloadPath: $(Agent.TempDirectory)/llama2_onnx_ft16 + + - template: templates/get-docker-image-steps.yml + parameters: + Dockerfile: onnxruntime/tools/ci_build/github/linux/docker/Dockerfile.package_ubi8_cuda11_8_tensorrt8_6 + Context: onnxruntime/tools/ci_build/github/linux/docker/ + ScriptName: onnxruntime/tools/ci_build/get_docker_image.py + DockerBuildArgs: "--build-arg BUILD_UID=$( id -u )" + Repository: onnxruntimeubi8packagestest + UpdateDepsTxt: false + + - script: | + docker run --rm --gpus all -v $(Build.SourcesDirectory)/Llama-2-Onnx:/workspace \ + -v $(Build.BinariesDirectory)/ort-artifact/:/ort-artifact \ + -v $(Agent.TempDirectory)/llama2_onnx_ft16:/models \ + onnxruntimeubi8packagestest \ + bash -c " + set -ex; \ + python3 -m pip install --upgrade pip ; \ + python3 -m pip install /ort-artifact/*.whl ; \ + python3 -m pip install torch --index-url https://download.pytorch.org/whl/cu118 ; \ + python3 -m pip install sentencepiece ; \ + pushd /workspace ; \ + python3 MinimumExample/Example_ONNX_LlamaV2.py --onnx_file /models/ONNX/LlamaV2_7B_FT_float16.onnx \ + --embedding_file /models/embeddings.pth --tokenizer_path tokenizer.model --prompt 'What is the lightest element?' > /workspace/answer.txt ; \ + popd ; \ + " + displayName: 'Run Llama2 demo' + workingDirectory: $(Build.SourcesDirectory) + + - script: | + set -ex + real=$(cat $(Build.SourcesDirectory)/Llama-2-Onnx/answer.txt) + trim_actual=$(tr -dc '[[:print:]]' <<< "$real") + expected="The lightest element is hydrogen. Hydrogen is the lightest element on the periodic table, with an atomic mass of 1.00794 u (unified atomic mass units)." + [ "$expected" == "$trim_actual" ] && exit 0 || exit 1 + displayName: 'Check result' diff --git a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml index aa1a75bfcda45..5a50a9964bead 100644 --- a/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml +++ b/tools/ci_build/github/azure-pipelines/c-api-noopenmp-packaging-pipelines.yml @@ -1023,7 +1023,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1034,7 +1034,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -1046,7 +1046,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -1055,7 +1055,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' diff --git a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml index 1d2ba88652f48..0c24d4897ddf1 100644 --- a/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/cuda-packaging-pipeline.yml @@ -151,7 +151,7 @@ stages: # Testing - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -162,7 +162,7 @@ stages: - template: nuget/templates/test_win.yml parameters: - AgentPool : 'onnxruntime-Win2022-GPU-T4' + AgentPool : 'onnxruntime-Win2022-GPU-A10' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu.Windows' ArtifactSuffix: 'GPU' StageSuffix: 'GPU' @@ -174,7 +174,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' NugetPackageName : 'Microsoft.ML.OnnxRuntime.Gpu' @@ -184,7 +184,7 @@ stages: - template: nuget/templates/test_linux.yml parameters: - AgentPool : Onnxruntime-Linux-GPU + AgentPool : Onnxruntime-Linux-GPU-A10 ArtifactSuffix: 'GPU' StageSuffix: 'GPU' MoreSuffix: '_Linux' diff --git a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml index 1060a0138e0b7..5779b1da3fd43 100644 --- a/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/linux-gpu-ci-pipeline.yml @@ -137,7 +137,7 @@ jobs: --enable_cuda_profiling --enable_cuda_nhwc_ops \ --enable_pybind --build_java \ --use_cache \ - --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=75; \ + --cmake_extra_defines CMAKE_CUDA_ARCHITECTURES=86; \ ccache -sv; \ ccache -z" workingDirectory: $(Build.SourcesDirectory) @@ -166,7 +166,7 @@ jobs: skipComponentGovernanceDetection: true workspace: clean: all - pool: Onnxruntime-Linux-GPU-T4 + pool: onnxruntime-Linux-GPU-A10 dependsOn: - Linux_Build steps: diff --git a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml index 06cca0068523d..5349b1ca67ab1 100644 --- a/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml +++ b/tools/ci_build/github/azure-pipelines/py-packaging-pipeline.yml @@ -29,6 +29,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + - name: build_py_parameters displayName: 'Specify extra build parameters' type: string @@ -64,5 +69,6 @@ stages: enable_windows_gpu: ${{ parameters.enable_windows_gpu }} enable_mac_cpu: ${{ parameters.enable_mac_cpu }} enable_linux_arm: ${{ parameters.enable_linux_arm }} + enable_windows_arm64_qnn: ${{ parameters.enable_windows_arm64_qnn }} build_py_parameters: ${{ parameters.build_py_parameters }} - cmake_build_type: ${{ parameters.cmake_build_type }} \ No newline at end of file + cmake_build_type: ${{ parameters.cmake_build_type }} diff --git a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml index 537175f6bec73..55f6561b7a44a 100644 --- a/tools/ci_build/github/azure-pipelines/templates/download-deps.yml +++ b/tools/ci_build/github/azure-pipelines/templates/download-deps.yml @@ -11,7 +11,7 @@ steps: packageType: upack feed: '/7424c8e4-5c62-490e-95c4-79446f31017c' definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # The private ADO project @@ -22,7 +22,7 @@ steps: packageType: upack feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325' definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a' - version: 1.0.129 + version: 1.0.132 downloadPath: $(Build.BinariesDirectory)/deps # You can add more ADO accounts at here. diff --git a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml index 8669a883c31f1..297498843c38d 100644 --- a/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml +++ b/tools/ci_build/github/azure-pipelines/templates/py-packaging-stage.yml @@ -35,6 +35,11 @@ parameters: type: boolean default: true +- name: enable_windows_arm64_qnn + displayName: 'Whether Windows ARM64 package with QNN EP is built.' + type: boolean + default: true + # TODO: Now the Windows jobs use a different cmake build type. Consider to merge it. - name: cmake_build_type type: string @@ -446,3 +451,11 @@ stages: machine_pool: 'onnxruntime-Ubuntu2204-AMD-CPU' extra_build_arg: ${{ parameters.build_py_parameters }} cmake_build_type: ${{ parameters.cmake_build_type }} + + - ${{ if eq(parameters.enable_windows_arm64_qnn, true) }}: + - template: py-win-arm64-qnn.yml + parameters: + MACHINE_POOL: 'onnxruntime-qnn-windows-vs-2022-arm64' + QNN_SDK: 'qnn-v2.18.0.240101_win' + PYTHON_VERSION: '3.11' + NUMPY_VERSION: '1.25.2' diff --git a/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml new file mode 100644 index 0000000000000..adf7aa9c43205 --- /dev/null +++ b/tools/ci_build/github/azure-pipelines/templates/py-win-arm64-qnn.yml @@ -0,0 +1,165 @@ +parameters: + +- name: MACHINE_POOL + type: string + default: 'onnxruntime-qnn-windows-vs-2022-arm64' + +- name: QNN_SDK + displayName: QNN Windows SDK path + type: string + default: qnn-v2.18.0.240101_win + +- name: PYTHON_VERSION + type: string + default: '3.11' + +- name: NUMPY_VERSION + type: string + default: '1.25.2' + +- name: ENV_SETUP_SCRIPT + type: string + default: '' + +- name: BUILD_PY_PARAMETERS + displayName: > + Extra parameters to pass to build.py. Don't put newlines in here. + type: string + default: '' + +jobs: +- job: Win_py_arm64_qnn_Wheels_${{ replace(parameters.PYTHON_VERSION,'.','_') }} + timeoutInMinutes: 210 + workspace: + clean: all + pool: + name: ${{ parameters.MACHINE_POOL }} + variables: + GRADLE_OPTS: '-Dorg.gradle.daemon=false' + VSGenerator: 'Visual Studio 17 2022' + QNN_SDK_ROOTDIR: 'C:\data\qnnsdk\${{parameters.QNN_SDK}}' + steps: + - checkout: self + clean: true + submodules: recursive + + - template: telemetry-steps.yml + + - script: | + DIR C:\data\qnnsdk + displayName: Check available QNN SDKs + + - script: | + MKDIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + XCOPY /s /y /h /e /c /q "C:\Python\Python311\*.*" $(Agent.ToolsDirectory)\Python\3.11.0\arm64\ + COPY NUL $(Agent.ToolsDirectory)\Python\3.11.0\arm64.complete + DIR $(Agent.ToolsDirectory)\Python + DIR $(Agent.ToolsDirectory)\Python\3.11.0 + DIR $(Agent.ToolsDirectory)\Python\3.11.0\arm64 + displayName: Copy python 3.11.0 version to agent tools directory + + - task: UsePythonVersion@0 + inputs: + versionSpec: ${{ parameters.PYTHON_VERSION }} + addToPath: true + architecture: 'arm64' + + - task: onebranch.pipeline.tsaoptions@1 + displayName: 'OneBranch TSAOptions' + inputs: + tsaConfigFilePath: '$(Build.SourcesDirectory)\.config\tsaoptions.json' + appendSourceBranchName: false + + - task: PythonScript@0 + inputs: + scriptSource: inline + script: | + import subprocess + subprocess.call(['pip', 'install', '-q', 'setuptools', 'wheel', 'numpy==${{parameters.NUMPY_VERSION}}']) + workingDirectory: '$(Build.BinariesDirectory)' + displayName: 'Install python modules' + + - template: set-nightly-build-option-variable-step.yml + + - task: PythonScript@0 + displayName: 'Generate cmake config' + inputs: + scriptPath: '$(Build.SourcesDirectory)\tools\ci_build\build.py' + arguments: > + --config RelWithDebInfo + --build_dir $(Build.BinariesDirectory) + --skip_submodule_sync + --cmake_generator "$(VSGenerator)" + --use_qnn + --qnn_home $(QNN_SDK_ROOTDIR) + --enable_pybind + --parallel --update + --numpy_version ${{ parameters.NUMPY_VERSION }} + $(TelemetryOption) ${{ parameters.BUILD_PY_PARAMETERS }} + workingDirectory: '$(Build.BinariesDirectory)' + + - task: VSBuild@1 + displayName: 'Build' + inputs: + solution: '$(Build.BinariesDirectory)\RelWithDebInfo\onnxruntime.sln' + platform: 'arm64' + configuration: RelWithDebInfo + msbuildArchitecture: 'arm64' + maximumCpuCount: true + logProjectEvents: true + workingFolder: '$(Build.BinariesDirectory)\RelWithDebInfo' + createLogFile: true + + # Esrp signing + - template: win-esrp-dll.yml + parameters: + FolderPath: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\onnxruntime\capi' + DisplayName: 'ESRP - Sign Native dlls' + DoEsrp: true + Pattern: '*.pyd,*.dll' + + - task: PythonScript@0 + displayName: 'Build wheel' + inputs: + scriptPath: '$(Build.SourcesDirectory)\setup.py' + arguments: 'bdist_wheel ${{ parameters.BUILD_PY_PARAMETERS }} $(NightlyBuildOption) --wheel_name_suffix=qnn' + workingDirectory: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo' + + - task: CopyFiles@2 + displayName: 'Copy Python Wheel to: $(Build.ArtifactStagingDirectory)' + inputs: + SourceFolder: '$(Build.BinariesDirectory)\RelWithDebInfo\RelWithDebInfo\dist' + Contents: '*.whl' + TargetFolder: '$(Build.ArtifactStagingDirectory)' + + - task: PublishBuildArtifacts@1 + displayName: 'Publish Artifact: ONNXRuntime python wheel' + inputs: + ArtifactName: onnxruntime_qnn + + - script: | + 7z x *.whl + workingDirectory: '$(Build.ArtifactStagingDirectory)' + displayName: 'unzip the package' + + - task: CredScan@3 + displayName: 'Run CredScan' + inputs: + debugMode: false + continueOnError: true + + - task: BinSkim@4 + displayName: 'Run BinSkim' + inputs: + AnalyzeTargetGlob: '+:file|$(Build.ArtifactStagingDirectory)\**\*.dll' + + - task: TSAUpload@2 + displayName: 'TSA upload' + condition: and (succeeded(), eq(variables['Build.SourceBranch'], 'refs/heads/main')) + inputs: + GdnPublishTsaOnboard: false + GdnPublishTsaConfigFile: '$(Build.sourcesDirectory)\.gdn\.gdntsa' + + - template: component-governance-component-detection-steps.yml + parameters: + condition: 'succeeded' diff --git a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml index b8f9566274acc..db39c2cd2087f 100644 --- a/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml +++ b/tools/ci_build/github/azure-pipelines/win-ci-fuzz-testing.yml @@ -28,7 +28,7 @@ jobs: parameters: EnvSetupScript: $(EnvSetupScript) DownloadCUDA: false - BuildArch: $(buildArch) + BuildArch: x64 BuildConfig: $(BuildConfig) MachinePool: 'onnxruntime-Win-CPU-2022' WithCache: true diff --git a/tools/ci_build/requirements.txt b/tools/ci_build/requirements-transformers-test.txt similarity index 94% rename from tools/ci_build/requirements.txt rename to tools/ci_build/requirements-transformers-test.txt index 57fc8f08336d2..a5279781462a7 100644 --- a/tools/ci_build/requirements.txt +++ b/tools/ci_build/requirements-transformers-test.txt @@ -3,7 +3,8 @@ packaging protobuf==3.20.2 numpy==1.24.0 ; python_version < '3.12' numpy==1.26.0 ; python_version >= '3.12' +torch coloredlogs==15.0 transformers==4.36.0 psutil -einops \ No newline at end of file +einops