diff --git a/cmake/external/onnxruntime_external_deps.cmake b/cmake/external/onnxruntime_external_deps.cmake
index 78f63227c8392..403b4b2c4107a 100644
--- a/cmake/external/onnxruntime_external_deps.cmake
+++ b/cmake/external/onnxruntime_external_deps.cmake
@@ -108,41 +108,14 @@ FetchContent_Declare(
)
# Download a protoc binary from Internet if needed
-if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
+if(NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
# This part of code is only for users' convenience. The code couldn't handle all cases. Users always can manually
# download protoc from Protobuf's Github release page and pass the local path to the ONNX_CUSTOM_PROTOC_EXECUTABLE
# variable.
- message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
- if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
- if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64")
- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64})
- FetchContent_Populate(protoc_binary)
- elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86")
- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32})
- FetchContent_Populate(protoc_binary)
- endif()
- if(protoc_binary_SOURCE_DIR)
- message("Use prebuilt protoc")
- set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe)
- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
- endif()
- elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux")
- if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64})
- FetchContent_Populate(protoc_binary)
- elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86})
- FetchContent_Populate(protoc_binary)
- elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*")
- FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64})
- FetchContent_Populate(protoc_binary)
- endif()
- if(protoc_binary_SOURCE_DIR)
- message("Use prebuilt protoc")
- set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
- set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
- endif()
- elseif ((CMAKE_SYSTEM_NAME STREQUAL "Emscripten" OR CMAKE_SYSTEM_NAME STREQUAL "Android" OR CMAKE_SYSTEM_NAME STREQUAL "iOS") AND CMAKE_HOST_SYSTEM_NAME STREQUAL "Darwin")
+ if (CMAKE_HOST_APPLE)
+ # Using CMAKE_CROSSCOMPILING is not recommended for Apple target devices.
+ # https://cmake.org/cmake/help/v3.26/variable/CMAKE_CROSSCOMPILING.html
+ # To keep it simple, just download and use the universal protoc binary for all Apple host builds.
FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_mac_universal} URL_HASH SHA1=${DEP_SHA1_protoc_mac_universal})
FetchContent_Populate(protoc_binary)
if(protoc_binary_SOURCE_DIR)
@@ -150,6 +123,38 @@ if(CMAKE_CROSSCOMPILING AND NOT ONNX_CUSTOM_PROTOC_EXECUTABLE)
set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
endif()
+ elseif (CMAKE_CROSSCOMPILING)
+ message("CMAKE_HOST_SYSTEM_NAME: ${CMAKE_HOST_SYSTEM_NAME}")
+ if(CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows")
+ if(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "AMD64")
+ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win64} URL_HASH SHA1=${DEP_SHA1_protoc_win64})
+ FetchContent_Populate(protoc_binary)
+ elseif(CMAKE_HOST_SYSTEM_PROCESSOR STREQUAL "x86")
+ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_win32} URL_HASH SHA1=${DEP_SHA1_protoc_win32})
+ FetchContent_Populate(protoc_binary)
+ endif()
+ if(protoc_binary_SOURCE_DIR)
+ message("Use prebuilt protoc")
+ set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc.exe)
+ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
+ endif()
+ elseif(CMAKE_HOST_SYSTEM_NAME STREQUAL "Linux")
+ if(CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(x86_64|amd64)$")
+ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x64})
+ FetchContent_Populate(protoc_binary)
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^(i.86|x86?)$")
+ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_x86} URL_HASH SHA1=${DEP_SHA1_protoc_linux_x86})
+ FetchContent_Populate(protoc_binary)
+ elseif(CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64.*")
+ FetchContent_Declare(protoc_binary URL ${DEP_URL_protoc_linux_aarch64} URL_HASH SHA1=${DEP_SHA1_protoc_linux_aarch64})
+ FetchContent_Populate(protoc_binary)
+ endif()
+ if(protoc_binary_SOURCE_DIR)
+ message("Use prebuilt protoc")
+ set(ONNX_CUSTOM_PROTOC_EXECUTABLE ${protoc_binary_SOURCE_DIR}/bin/protoc)
+ set(PROTOC_EXECUTABLE ${ONNX_CUSTOM_PROTOC_EXECUTABLE})
+ endif()
+ endif()
endif()
endif()
@@ -184,9 +189,9 @@ FetchContent_Declare(
)
set(protobuf_BUILD_TESTS OFF CACHE BOOL "Build protobuf tests" FORCE)
-#TODO: we'd better to turn the following option off. However, it will cause
+#TODO: we'd better to turn the following option off. However, it will cause
# ".\build.bat --config Debug --parallel --skip_submodule_sync --update" fail with an error message:
-# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is
+# install(EXPORT "ONNXTargets" ...) includes target "onnx_proto" which requires target "libprotobuf-lite" that is
# not in any export set.
#set(protobuf_INSTALL OFF CACHE BOOL "Install protobuf binaries and files" FORCE)
set(protobuf_USE_EXTERNAL_GTEST ON CACHE BOOL "" FORCE)
@@ -562,4 +567,3 @@ endif()
FILE(TO_NATIVE_PATH ${CMAKE_BINARY_DIR} ORT_BINARY_DIR)
FILE(TO_NATIVE_PATH ${PROJECT_SOURCE_DIR} ORT_SOURCE_DIR)
-
diff --git a/cmake/onnxruntime_mlas.cmake b/cmake/onnxruntime_mlas.cmake
index f89d2150a6830..17de2aa4aaea6 100644
--- a/cmake/onnxruntime_mlas.cmake
+++ b/cmake/onnxruntime_mlas.cmake
@@ -355,19 +355,23 @@ else()
${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S
${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S
${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S
+ ${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S
${MLAS_SRC_DIR}/activate_fp16.cpp
${MLAS_SRC_DIR}/dwconv.cpp
${MLAS_SRC_DIR}/halfgemm_kernel_neon.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_smmla.cpp
${MLAS_SRC_DIR}/qgemm_kernel_ummla.cpp
+ ${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/HalfGemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmS8S8KernelSmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
set_source_files_properties(${MLAS_SRC_DIR}/aarch64/QgemmU8X8KernelUmmla.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+i8mm ")
+ set_source_files_properties(${MLAS_SRC_DIR}/aarch64/SbgemmKernelNeon.S PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
set_source_files_properties(${MLAS_SRC_DIR}/activate_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/dwconv.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
set_source_files_properties(${MLAS_SRC_DIR}/pooling_fp16.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+fp16 ")
+ set_source_files_properties(${MLAS_SRC_DIR}/sbgemm_kernel_neon.cpp PROPERTIES COMPILE_FLAGS " -march=armv8.2-a+bf16 ")
endif()
if(ONNXRUNTIME_MLAS_MULTI_ARCH)
diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake
index fa395802d95ff..0987d6d164dbd 100644
--- a/cmake/onnxruntime_unittests.cmake
+++ b/cmake/onnxruntime_unittests.cmake
@@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
endif()
+ if (onnxruntime_USE_ROCM)
+ list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
+ endif()
if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
@@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()
+ if (onnxruntime_USE_ROCM)
+ target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
+ target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
+ endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 45c0e6f822ce9..e7b537d6894c8 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -2398,24 +2398,28 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Attributes
+- do_rotary : int
+- Whether to use rotary position embedding. Default value is 0.
- kv_num_heads : int (required)
- Number of attention heads for k and v
- local_window_size : int
- left_window_size for local attention (like Mistral). Default value is -1 meaning unused.
- num_heads : int (required)
- Number of attention heads for q
+- rotary_interleaved : int
+- Rotate using interleaved pattern. Default value is 0 (False).
- scale : float
- Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs
+#### Inputs (7 - 9)
- query : T
-- Query with shape (batch_size, sequence_length, hidden_size)
-- key : T
+- Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape(batch_size, sequence_length, d) where d is (num_heads * head_size + 2 * kv_num_heads * head_size).
+- key (optional) : T
- Key with shape (batch_size, kv_sequence_length, kv_hidden_size)
-- value : T
+- value (optional) : T
- Value with shape (batch_size, kv_sequence_length, kv_hidden_size)
- past_key (optional) : T
- past state key with support for format BNSH. When past_key uses same tensor as present_key(k-v cache), it is of length max_sequence_length... otherwise of length past_sequence_length.
@@ -2425,6 +2429,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- 1d Tensor of shape (batch_size). Indicates past sequence lengths for token generation case.
- total_sequence_length : M
- Scalar tensor of total sequence length (past + new).
+- cos_cache (optional) : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
+- sin_cache (optional) : T
+- 2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs
@@ -3031,6 +3039,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+unidirectional : int
+Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8)
@@ -5021,6 +5031,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- interleaved : int
- Rotate using interleaved pattern. Default value is 0 (False).
+- num_heads : int
+- Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+- rotary_embedding_dim : int
+- Rotary embedding dimension. Default value is 0.
- scale : float
- Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5047,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2).
+2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2).
+2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs
@@ -5048,7 +5062,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16)
+- T : tensor(float), tensor(float16), tensor(bfloat16)
- Constrain input and output types to float tensors.
- M : tensor(int64)
- Constrain input and output types to integer tensors
@@ -5755,7 +5769,7 @@ This version of the operator has been available since version 1 of the 'com.micr
- Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
-#### Inputs (5 - 14)
+#### Inputs (5 - 15)
- input_ids : F
@@ -5786,6 +5800,8 @@ This version of the operator has been available since version 1 of the 'com.micr
- Only keep this list of (layer, head) of QK in the final cross_qk output when use_cross_qk is set. Default collect allits shape is (number of (layer, head) to keep, 2), i.e., [[layer_id1, head_id1], [layer_id2, head_id2]......]
- extra_decoding_ids (optional) : I
- Part of the decoder_input_ids that we need cross qk for it. it is of shape (batch_size, extra_decoding_ids_len).In such case, we should remove this from the tail of the decoder_input_ids, and put it here. ids < 0 in it (for multiple batch) are treated as stop of the extra_decoding_ids for corresponding batch.
+- temperature (optional) : T
+- Temperature value to apply to logits processing during this execution's decoding. Shape is (1)
#### Outputs (1 - 5)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 394bd7ad2abae..31cca232fde34 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -499,7 +499,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int64)|
|Unique|*in* x:**T**
*out* y:**T**
*out* idx:**tensor(int64)**
*out* counts:**tensor(int64)**|1+|**T** = tensor(float)|
-|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)|
+|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float)|
|WordConvEmbedding|*in* Sequence:**T**
*in* W:**T1**
*in* B:**T1**
*in* C:**T1**
*out* Y:**T1**|1+|**T** = tensor(int32)
**T1** = tensor(float)|
| |
| |
@@ -843,7 +843,7 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
+|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(bfloat16), tensor(float16)|
|Inverse|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask:**T**
*in* global_weight:**T**
*in* global_bias:**T**
*in* global:**G**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -868,7 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -876,7 +876,7 @@ Do not modify directly.*
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|UnfoldTensor|*in* input:**T**
*out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
+|WhisperBeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*in* cross_qk_layer_head:**I**
*in* extra_decoding_ids:**I**
*in* temperature:**T**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**
*out* cross_qk:**V**
*out* non_speech_probs:**T**|1+|**T** = tensor(float), tensor(float16)|
| |
| |
@@ -922,10 +922,12 @@ Do not modify directly.*
|BitwiseNot|*in* X:**T**
*out* Y:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseOr|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|BitwiseXor|*in* A:**T**
*in* B:**T**
*out* C:**T**|18+|**T** = tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|Cast|*in* input:**T1**
*out* output:**T2**|13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Cast|*in* input:**T1**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||13+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||9+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||6+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|CastLike|*in* input:**T1**
*in* target_type:**T2**
*out* output:**T2**|19+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||15+|**T1** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Ceil|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
|Celu|*in* X:**T**
*out* Y:**T**|12+|**T** = tensor(float), tensor(float16)|
@@ -952,7 +954,8 @@ Do not modify directly.*
|DepthToSpace|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
+|DequantizeLinear|*in* x:**T**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T**
*out* y:**tensor(float)**
or
*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|19+|**T1** = tensor(int32), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
+|||13+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|||10+|**T** = tensor(int32), tensor(int8), tensor(uint8)|
|Div|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -961,7 +964,8 @@ Do not modify directly.*
|DynamicQuantizeLinear|*in* x:**T1**
*out* y:**T2**
*out* y_scale:**tensor(float)**
*out* y_zero_point:**T2**|11+|**T1** = tensor(float)
**T2** = tensor(int8), tensor(uint8)|
|Einsum|*in* Inputs:**T**
*out* Output:**T**|12+|**T** = tensor(float), tensor(float16)|
|Elu|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(float), tensor(float16)|
-|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|19+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
+|||13+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||11+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(bool)|
|||7+|**T** = tensor(float), tensor(float16)
**T1** = tensor(bool)|
|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
@@ -1004,7 +1008,8 @@ Do not modify directly.*
|Hardmax|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(float), tensor(float16)|
|||11+|**T** = tensor(float), tensor(float16)|
|||1+|**T** = tensor(float), tensor(float16)|
-|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|19+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||16+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||14+|**V** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||1+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
@@ -1099,7 +1104,8 @@ Do not modify directly.*
|||7+|**T** = tensor(float), tensor(float16)|
|QLinearConv|*in* x:**T1**
*in* x_scale:**tensor(float)**
*in* x_zero_point:**T1**
*in* w:**T2**
*in* w_scale:**tensor(float)**
*in* w_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*in* B:**T4**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)
**T4** = tensor(int32)|
|QLinearMatMul|*in* a:**T1**
*in* a_scale:**tensor(float)**
*in* a_zero_point:**T1**
*in* b:**T2**
*in* b_scale:**tensor(float)**
*in* b_zero_point:**T2**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T3**
*out* y:**T3**|10+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(int8), tensor(uint8)|
-|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
+|QuantizeLinear|*in* x:**T1**
*in* y_scale:**T1**
*in* y_zero_point:**T2**
*out* y:**T2**
or
*in* x:**T1**
*in* y_scale:**tensor(float)**
*in* y_zero_point:**T2**
*out* y:**T2**|19+|**T1** = tensor(float), tensor(float16), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
+|||13+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|||10+|**T1** = tensor(float), tensor(int32)
**T2** = tensor(int8), tensor(uint8)|
|RNN|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*out* Y:**T**
*out* Y_h:**T**|14+|**T** = tensor(float), tensor(float16)|
|||7+|**T** = tensor(float), tensor(float16)|
@@ -1150,7 +1156,8 @@ Do not modify directly.*
|Relu|*in* X:**T**
*out* Y:**T**|14+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8)|
|||13+|**T** = tensor(float), tensor(float16)|
|||6+|**T** = tensor(float), tensor(float16)|
-|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**
or
*in* data:**T**
*out* reshaped:**T**|14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|Reshape|*in* data:**T**
*in* shape:**tensor(int64)**
*out* reshaped:**T**
or
*in* data:**T**
*out* reshaped:**T**|19+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||14+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||5+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Resize|*in* X:**T**
*in* scales:**tensor(float)**
*out* Y:**T**
or
*in* X:**T1**
*in* roi:**T2**
*in* scales:**tensor(float)**
*in* sizes:**tensor(int64)**
*out* Y:**T1**|13+|**T1** = tensor(float), tensor(float16), tensor(int8), tensor(uint8)
**T2** = tensor(float), tensor(float16)|
@@ -1178,7 +1185,8 @@ Do not modify directly.*
|SequenceErase|*in* input_sequence:**S**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceInsert|*in* input_sequence:**S**
*in* tensor:**T**
*in* position:**I**
*out* output_sequence:**S**|11+|**I** = tensor(int32), tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|SequenceLength|*in* input_sequence:**S**
*out* length:**I**|11+|**I** = tensor(int64)
**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
-|Shape|*in* data:**T**
*out* shape:**T1**|15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Shape|*in* data:**T**
*out* shape:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||15+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Shrink|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint8)|
@@ -1188,7 +1196,8 @@ Do not modify directly.*
|||9+|**T** = tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Sin|*in* input:**T**
*out* output:**T**|7+|**T** = tensor(float), tensor(float16)|
|Sinh|*in* input:**T**
*out* output:**T**|9+|**T** = tensor(float), tensor(float16)|
-|Size|*in* data:**T**
*out* size:**T1**|13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|Size|*in* data:**T**
*out* size:**T1**|19+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
+|||13+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|||1+|**T** = seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**T1** = tensor(int64)|
|Slice|*in* data:**T**
*in* starts:**Tind**
*in* ends:**Tind**
*in* axes:**Tind**
*in* steps:**Tind**
*out* output:**T**
or
*in* data:**T**
*out* output:**T**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
|||11+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)
**Tind** = tensor(int32), tensor(int64)|
diff --git a/docs/python/conf.py b/docs/python/conf.py
index 065149441b72c..7ab2d42aa15e1 100644
--- a/docs/python/conf.py
+++ b/docs/python/conf.py
@@ -17,7 +17,7 @@
# -- Project information -----------------------------------------------------
project = "Python API"
-copyright = "2018-2023, Microsoft"
+copyright = "2018-2024, Microsoft"
author = "Microsoft"
# -- General configuration ---------------------------------------------------
diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h
index ea4f52f99649d..1de0217c7e1fa 100644
--- a/include/onnxruntime/core/framework/execution_provider.h
+++ b/include/onnxruntime/core/framework/execution_provider.h
@@ -326,6 +326,15 @@ class IExecutionProvider {
*/
virtual std::vector CreatePreferredAllocators() { return std::vector(); };
+ /**
+ * Get the array of pointers for EPContext nodes
+ * EP needs to implement this if has the requirement to generate the context cache model. Otherwise leave it.
+ * Default return an empty vector if not provided by the Execution Provider
+ */
+ virtual const InlinedVector GetEpContextNodes() const {
+ return InlinedVector();
+ }
+
private:
const std::string type_;
diff --git a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
index 60196d0c80cbb..32a9f06464ace 100644
--- a/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
+++ b/include/onnxruntime/core/providers/tensorrt/tensorrt_provider_options.h
@@ -11,6 +11,8 @@
/// User can only get the instance of OrtTensorRTProviderOptionsV2 via CreateTensorRTProviderOptions.
///
struct OrtTensorRTProviderOptionsV2 {
+ OrtTensorRTProviderOptionsV2& operator=(const OrtTensorRTProviderOptionsV2& other); // copy assignment operator
+
int device_id{0}; // cuda device id.
int has_user_compute_stream{0}; // indicator of user specified CUDA compute stream.
void* user_compute_stream{nullptr}; // user specified CUDA compute stream.
@@ -46,8 +48,26 @@ struct OrtTensorRTProviderOptionsV2 {
const char* trt_profile_max_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
const char* trt_profile_opt_shapes{nullptr}; // Specify the range of the input shapes to build the engine with
int trt_cuda_graph_enable{0}; // Enable CUDA graph in ORT TRT
- int trt_dump_ep_context_model{0}; // Dump EP context node model
- int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
- int trt_ep_context_compute_capability_enable{1}; // Add GPU compute capability as an EP context node's attribute
- const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
+
+ /*
+ * Please note that there are rules for using following context model related provider options:
+ *
+ * 1. In the case of dumping the context model and loading the context model,
+ * for security reason, TRT EP doesn't allow the "ep_cache_context" node attribute of EP context node to be
+ * the absolute path or relative path that is outside of context model directory.
+ * It means engine cache needs to be in the same directory or sub-directory of context model.
+ *
+ * 2. In the case of dumping the context model, the engine cache path will be changed to the relative path of context model directory.
+ * For example:
+ * If "trt_dump_ep_context_model" is enabled and "trt_engine_cache_enable" is enabled,
+ * if "trt_ep_context_file_path" is "./context_model_dir",
+ * - if "trt_engine_cache_path" is "" -> the engine cache will be saved to "./context_model_dir"
+ * - if "trt_engine_cache_path" is "engine_dir" -> the engine cache will be saved to "./context_model_dir/engine_dir"
+ *
+ */
+ int trt_dump_ep_context_model{0}; // Dump EP context node model
+ const char* trt_ep_context_file_path{nullptr}; // Specify file name to dump EP context node model. Can be a path or a file name or a file name with path.
+ int trt_ep_context_embed_mode{0}; // Specify EP context embed mode. Default 0 = context is engine cache path, 1 = context is engine binary data
+
+ const char* trt_engine_cache_prefix{nullptr}; // specify engine cache prefix
};
diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h
index aca9f4896fbdb..2ce9d361e8e56 100644
--- a/include/onnxruntime/core/session/onnxruntime_c_api.h
+++ b/include/onnxruntime/core/session/onnxruntime_c_api.h
@@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
+ enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
@@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;
+ int enable_hip_graph;
+
/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.
@@ -3608,6 +3611,14 @@ struct OrtApi {
* - "1": Faster preparation time, less optimal graph.
* - "2": Longer preparation time, more optimal graph.
* - "3": Longest preparation time, most likely even more optimal graph. See QNN SDK documentation for specific details.
+ * "soc_model": The SoC model number. Refer to the QNN SDK documentation for valid values. Defaults to "0" (unknown).
+ * "htp_arch": The minimum HTP architecture the driver will use to select compatible QNN operators. Available options:
+ * - "0": Default (none).
+ * - "68"
+ * - "69"
+ * - "73"
+ * - "75"
+ * "device_id": The ID of the device to use when setting 'htp_arch'. Defaults to "0" (for single device).
*
* SNPE supported keys:
* "runtime": SNPE runtime engine, options: "CPU", "CPU_FLOAT32", "GPU", "GPU_FLOAT32_16_HYBRID", "GPU_FLOAT16",
diff --git a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
index df79cb6e5b21b..b282438795eb5 100644
--- a/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
+++ b/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h
@@ -236,7 +236,7 @@ static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersFil
static const char* const kOrtSessionOptionsOptimizedModelExternalInitializersMinSizeInBytes =
"session.optimized_model_external_initializers_min_size_in_bytes";
-// Enable EP context feature to dump the partitioned graph which include the EP context into Onnx file.
+// Enable EP context feature to dump the partitioned graph which includes the EP context into Onnx file.
// The dumped Onnx model with EP context can be used for future inference to avoid the EP graph partitioning/compile overhead.
// "0": disable. (default)
// "1": enable.
@@ -249,4 +249,10 @@ static const char* const kOrtSessionOptionEpContextFilePath = "ep.context_file_p
// Flag to specify whether to dump the EP context into the Onnx model.
// "0": dump the EP context into separate file, keep the file name in the Onnx model.
// "1": dump the EP context into the Onnx model. (default).
-static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
\ No newline at end of file
+static const char* const kOrtSessionOptionEpContextEmbedMode = "ep.context_embed_mode";
+
+// Gemm fastmath mode provides fp32 gemm acceleration with bfloat16 based matmul.
+// Option values:
+// - "0": Gemm FastMath mode is not enabled. [DEFAULT]
+// - "1": Gemm FastMath mode is enabled.
+static const char* const kOrtSessionOptionsMlasGemmFastMathArm64Bfloat16 = "mlas.enable_gemm_fastmath_arm64_bfloat16";
diff --git a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
index 9f7b8d3a3dcfc..464234c34798a 100644
--- a/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
+++ b/java/src/main/native/ai_onnxruntime_OrtTrainingSession.c
@@ -66,7 +66,7 @@ JNIEXPORT jlong JNICALL Java_ai_onnxruntime_OrtTrainingSession_createTrainingSes
}
}
wchar_t* optimizerStr = NULL;
- if (optimizerPath == NULL) {
+ if (optimizerPath != NULL) {
optimizerStr = copyAndPad(jniEnv, optimizerPath);
if (optimizerStr == NULL) {
// exception has been thrown in Java, go to cleanup and return null.
diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md
index 2f510308d9306..2557971eb4ded 100644
--- a/js/web/docs/webgpu-operators.md
+++ b/js/web/docs/webgpu-operators.md
@@ -52,6 +52,7 @@ Do not modify directly.*
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| Greater | ai.onnx(7-8,9-12,13+) | |
| GreaterOrEqual | ai.onnx(12-15,16+) | |
+| HardSigmoid | ai.onnx(6+) | |
| If | ai.onnx(1-10,11-12,13-18,19+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
| LayerNormalization | ai.onnx(17+) | |
diff --git a/js/web/lib/wasm/binding/ort-wasm.d.ts b/js/web/lib/wasm/binding/ort-wasm.d.ts
index 9d4d5875310b7..68054210e79a7 100644
--- a/js/web/lib/wasm/binding/ort-wasm.d.ts
+++ b/js/web/lib/wasm/binding/ort-wasm.d.ts
@@ -182,6 +182,10 @@ export interface OrtWasmModule extends EmscriptenModule {
jsepCreateDownloader:
(gpuBuffer: GPUBuffer, size: number,
type: Tensor.GpuBufferDataTypes) => () => Promise;
+ /**
+ * [exported from js_internal_api.js] Called when InferenceSession.run started.
+ */
+ jsepOnRunStart: () => void;
// #endregion
}
diff --git a/js/web/lib/wasm/jsep/backend-webgpu.ts b/js/web/lib/wasm/jsep/backend-webgpu.ts
index 2956ec1cad4da..afef7042a4280 100644
--- a/js/web/lib/wasm/jsep/backend-webgpu.ts
+++ b/js/web/lib/wasm/jsep/backend-webgpu.ts
@@ -208,7 +208,7 @@ export class WebGpuBackend {
Object.defineProperty(this.env.webgpu, 'device', {value: this.device});
- // init queryType, which is necessary for createKernel
+ // init queryType, which is necessary for InferenceSession.create
this.setQueryType();
}
@@ -223,8 +223,6 @@ export class WebGpuBackend {
if (!this.commandEncoder) {
this.commandEncoder = this.device.createCommandEncoder();
- // refresh queryType, as sometimes we only need to enable query for a specific run
- this.setQueryType();
if (this.queryType !== 'none' && typeof this.querySet === 'undefined') {
this.querySet = this.device.createQuerySet({
type: 'timestamp',
@@ -639,6 +637,7 @@ export class WebGpuBackend {
return createView(data.buffer, type);
};
}
+ // #endregion
writeTimestamp(index: number): void {
if (this.queryType !== 'inside-passes') {
return;
@@ -657,5 +656,7 @@ export class WebGpuBackend {
}
}
}
- // #endregion
+ onRunStart(): void {
+ this.setQueryType();
+ }
}
diff --git a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
index 90e02da986b8f..cc504093ca0d7 100644
--- a/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
+++ b/js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
@@ -82,6 +82,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map = new
['GlobalMaxPool', [pool.globalMaxPool, pool.parseGlobalMaxPoolAttributes]],
['Greater', [binaryOps.greater]],
['GreaterOrEqual', [binaryOps.greaterOrEqual]],
+ ['HardSigmoid', [unaryOps.hardSigmoid, unaryOps.parseHardSigmoidAttributes]],
['InstanceNormalization', [instanceNorm]],
['LayerNormalization', [layerNorm]],
['LeakyRelu', [unaryOps.leakyRelu, unaryOps.parseAlphaAttributes]],
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
index 2b390e4e633c5..76929efb32537 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/unary-op.ts
@@ -242,6 +242,26 @@ export const sigmoid = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sigmoid', a => `(1.0 / (1.0 + exp(-${a})))`));
};
+export interface HardSigmoidAttributes extends AttributeWithCacheKey {
+ readonly alpha: number;
+ readonly beta: number;
+}
+
+export const parseHardSigmoidAttributes = (attributes: Record): HardSigmoidAttributes =>
+ createAttributeWithCacheKey(attributes as {
+ alpha: number;
+ beta: number;
+ });
+
+export const hardSigmoid = (context: ComputeContext, attributes: HardSigmoidAttributes): void => {
+ const dataType = tensorTypeToWsglValueType(context.inputs[0].dataType);
+ context.compute(createElementwiseProgramInfo(
+ context.inputs[0], 'HardSigmoid',
+ a => `max(vec4<${dataType}>(0.0), min(vec4<${dataType}>(1.0), ${attributes.alpha} * ${a} + vec4<${dataType}>(${
+ attributes.beta})))`,
+ undefined, attributes.cacheKey));
+};
+
export const sin = (context: ComputeContext): void => {
context.compute(createElementwiseProgramInfo(context.inputs[0], 'Sin', 'sin'));
};
diff --git a/js/web/lib/wasm/wasm-core-impl.ts b/js/web/lib/wasm/wasm-core-impl.ts
index 5821fac3c468f..8768643fa7257 100644
--- a/js/web/lib/wasm/wasm-core-impl.ts
+++ b/js/web/lib/wasm/wasm-core-impl.ts
@@ -488,8 +488,8 @@ export const run = async(
}
}
+ wasm.jsepOnRunStart?.();
let errorCode: number;
-
if (!BUILD_DEFS.DISABLE_WEBGPU && ioBindingState) {
errorCode = await wasm._OrtRunWithBinding(
sessionHandle, ioBindingState.handle, outputCount, outputValuesOffset, runOptionsHandle);
diff --git a/js/web/test/data/ops/fused-conv.jsonc b/js/web/test/data/ops/fused-conv.jsonc
index 812e9d7c2def0..ad1c0a72c11d3 100644
--- a/js/web/test/data/ops/fused-conv.jsonc
+++ b/js/web/test/data/ops/fused-conv.jsonc
@@ -108,5 +108,39 @@
]
}
]
+ },
+ {
+ "name": "fused conv with clip",
+ "operator": "FusedConv",
+ "attributes": [
+ { "name": "activation", "data": "Clip", "type": "string" },
+ { "name": "kernel_shape", "data": [2, 2], "type": "ints" },
+ { "name": "activation_params", "data": [400.0, 600.0], "type": "floats" }
+ ],
+ "opset": { "domain": "com.microsoft", "version": 1 },
+ "cases": [
+ {
+ "name": "T[0]",
+ "inputs": [
+ {
+ "data": [10, 20, 30, 40, 50, 60, 70, 80, 90],
+ "dims": [1, 1, 3, 3],
+ "type": "float32"
+ },
+ {
+ "data": [1, 2, 3, 4],
+ "dims": [1, 1, 2, 2],
+ "type": "float32"
+ }
+ ],
+ "outputs": [
+ {
+ "data": [400, 470, 600, 600],
+ "dims": [1, 1, 2, 2],
+ "type": "float32"
+ }
+ ]
+ }
+ ]
}
]
diff --git a/js/web/test/suite-test-list.jsonc b/js/web/test/suite-test-list.jsonc
index 83fe299080ed6..56db28b0a379c 100644
--- a/js/web/test/suite-test-list.jsonc
+++ b/js/web/test/suite-test-list.jsonc
@@ -597,9 +597,9 @@
// // "test_hardmax_example",
// // "test_hardmax_negative_axis",
// // "test_hardmax_one_hot",
- // // "test_hardsigmoid_default",
- // // "test_hardsigmoid_example",
- // // "test_hardsigmoid",
+ "test_hardsigmoid_default",
+ "test_hardsigmoid_example",
+ "test_hardsigmoid",
// // "test_hardswish_expanded",
// // "test_hardswish",
"test_if",
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 4711ccf487cc8..768676259aa14 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const {
relative_position_bias,
¶meters));
+ if (parameters.do_rotary) {
+ ORT_NOT_IMPLEMENTED(
+ "Rotary embedding is not supported in Attention CPU kernel. \
+ Please fuse the model with MHA + RotaryEmbedding.");
+ }
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int input_hidden_size = parameters.input_hidden_size;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index da489a6901512..8afeb874750b4 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -99,10 +99,15 @@ struct GroupQueryAttentionParameters {
bool is_unidirectional; // causal
int local_window_size;
bool kv_share_buffer;
+ bool is_packed_qkv;
bool is_prompt; // determines if seqlens_k is past or kv sequence length tensor
+ bool do_rotary;
+ bool rotary_interleaved;
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
+ int zeros_count;
+ int* zero_ptr;
};
namespace attention {
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 694c40bf3eda6..eb25d0fd7cc1e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
num_heads_ = static_cast(num_heads);
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
}
// Reshape Q/K/V from BxSxD to BxSxNxH
@@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
nullptr,
¶meters,
num_heads_,
- scale,
mask_filter_value_,
+ scale,
+ is_unidirectional_,
past_present_share_buffer,
false));
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
index 4c86b777e9842..fb7da78a5c0a5 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
@@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
+ bool is_unidirectional_;
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 00e82c9844b3d..c91f5b601b4e9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -25,6 +25,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
@@ -315,7 +316,7 @@ Status CheckInputs(const T* query,
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
- output_parameters->is_unidirectional = false;
+ output_parameters->is_unidirectional = is_unidirectional;
output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
@@ -342,6 +343,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing,
int max_threads_per_block) {
@@ -350,8 +352,8 @@ Status CheckInputs(const T* query,
}
return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
- past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer,
- dmmha_packing);
+ past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
+ past_present_share_buffer, dmmha_packing);
}
} // namespace multihead_attention_helper
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 47f462d75fcc4..aa8b5b5f608fa 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+
+ if (rotary_embedding_dim > 0) {
+ ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
+ }
}
template
@@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
+ num_heads,
+ rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int num_heads = parameters.num_heads;
+ const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
- const int half_head_size = head_size / 2;
+ const int rotary_emb_dim = parameters.rotary_embedding_dim;
+ const int half_rotary_emb_dim = rotary_emb_dim / 2;
+
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
- int seq_stride = num_heads * head_stride;
+ int seq_stride = n_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
- // Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
+ // Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
- batch_stride = num_heads * head_stride;
+ batch_stride = n_heads * head_stride;
}
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
- const int loop_len = batch_size * sequence_length * num_heads;
- const double cost = static_cast(head_size);
+ const int loop_len = batch_size * sequence_length * n_heads;
+ const double cost = static_cast(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
- const int b = static_cast((ptr / num_heads) / sequence_length);
- const int s = static_cast((ptr / num_heads) % sequence_length);
- const int n = static_cast(ptr % num_heads);
+ const int b = static_cast((ptr / n_heads) / sequence_length);
+ const int s = static_cast((ptr / n_heads) % sequence_length);
+ const int n = static_cast(ptr % n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;
- // Cache is (M, H/2)
+ // Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast(pos_ids_data[0]) + s
: static_cast(pos_ids_data[b * sequence_length + s]);
- const int cache_offset = position_id * half_head_size;
+ const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
- for (int i = 0; i < head_size; i++) {
+ for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
- cache_idx = (i / 2) % half_head_size;
+ cache_idx = (i / 2) % half_rotary_emb_dim;
sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
- cache_idx = i % half_head_size;
- sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
- j = (i + half_head_size) % head_size;
+ cache_idx = i % half_rotary_emb_dim;
+ sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
+ j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
+ for (int i = rotary_emb_dim; i < head_size; i++) {
+ output_data[i] = input_data[i];
+ }
}
});
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
index be834a66cdc69..4e32424a22b6c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel {
protected:
float scale;
+ int num_heads;
+ int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
index 7b2e8289f7b06..dcbb36d1c4a3c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -11,14 +11,15 @@ namespace rotary_embedding_helper {
// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
- int batch_size; // Batch size used by input
- int sequence_length; // Sequence length used by input
- int hidden_size; // Hidden size used by input
- int head_size; // Head size used by cos/sin cache * 2
- int num_heads; // num_heads = hidden_size / head_size
- int max_sequence_length; // Sequence length used by cos/sin cache
- int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
- bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size
+ int rotary_embedding_dim; // Rotary embedding dimension.
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+ bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
template
@@ -26,11 +27,13 @@ Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
+ int num_heads,
+ int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
- // cos cache : (max_sequence_length, head_size / 2)
- // sin cache : (max_sequence_length, head_size / 2)
+ // cos cache : (max_sequence_length, rotary_embedding_dim / 2)
+ // sin cache : (max_sequence_length, rotary_embedding_dim / 2)
// Check input
const auto& input_dims = input->Shape().GetDims();
@@ -60,6 +63,12 @@ Status CheckInputs(const T* input,
"the same shape");
}
+ // Check num_heads and rotary_embedding_dim
+ if (rotary_embedding_dim > 0 && num_heads == 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
+ "specified");
+ }
+
// Get attributes from inputs
int batch_size = static_cast(input_dims[0]);
int sequence_length = static_cast(input_dims[1]);
@@ -73,8 +82,13 @@ Status CheckInputs(const T* input,
transposed = true;
}
int max_sequence_length = static_cast(cos_cache_dims[0]);
- int head_size = static_cast(cos_cache_dims[1]) * 2;
- int num_heads = hidden_size / head_size;
+ int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2
+ : static_cast(hidden_size / num_heads);
+ if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
+ "head_size");
+ }
+
int position_ids_format = -1;
// Check position_ids input shapes
@@ -91,23 +105,15 @@ Status CheckInputs(const T* input,
} else {
position_ids_format = 0;
}
+
// Check cos_cache input shapes
if (max_sequence_length != static_cast(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
- if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
+ if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
- "head_size / 2, got ", cos_cache_dims[1]);
- }
- // Check sin_cache input shapes
- if (max_sequence_length != static_cast(sin_cache_dims[0])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
- "max_sequence_length, got ", sin_cache_dims[0]);
- }
- if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
- "head_size / 2, got ", sin_cache_dims[1]);
+ "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
// Set rotary parameters
@@ -117,10 +123,11 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
- output_parameters->num_heads = num_heads;
+ output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size);
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
+ output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
index 56d950ca2f41e..dc72a038c3d58 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
@@ -397,12 +397,8 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
index 94547887d3a90..cd891a9508019 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
@@ -404,12 +404,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
index 91b93a125ad7a..4d6643c68a98b 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_whisper.h
@@ -500,12 +500,8 @@ Status BeamSearchWhisper::Execute(const FeedsFetchesManager& encoder_feeds_fe
output_sequences_scores);
// Output per token scores
- if (output_scores) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = beam_state.scores;
- assert(target.size() == source.size());
- ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
+ gsl::span per_token_scores = beam_state.scores;
+ this->beam_scorer_->OutputScores(per_token_scores, output_scores);
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index 3962486d5b5eb..bb6885c3216bc 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -123,8 +123,20 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
logits_processor = logits_processor_tensor ? static_cast(*logits_processor_tensor->Data()) : 0;
ORT_ENFORCE(logits_processor >= 0,
"logits_processor shall be a non-negative integer, got ", logits_processor);
-}
+ if (this->model_type == IGenerationParameters::kModelTypeWhisper) {
+ auto* temperature_tensor = context->Input(14);
+ if (temperature_tensor) {
+ if (temperature_tensor->IsDataType()) {
+ temperature = *temperature_tensor->Data();
+ } else {
+ temperature = static_cast(*temperature_tensor->Data());
+ }
+ } else {
+ temperature = 1.0f;
+ }
+ }
+}
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
index 7e2e5b2129221..0eccbe26605f5 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
@@ -50,11 +50,12 @@ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) con
return beams_.back().score < current_score;
}
+template
void BeamHypotheses::Output(
int top_k,
int max_length,
- gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
- gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty
+ gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
+ gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
ORT_ENFORCE(top_k <= beams_used_);
@@ -67,7 +68,7 @@ void BeamHypotheses::Output(
gsl::copy(item.hypothesis, target);
if (!sequences_scores.empty())
- sequences_scores[index] = item.score;
+ sequences_scores[index] = (T)item.score;
}
}
@@ -181,21 +182,21 @@ void BeamSearchScorer::Process(ISequences& sequences,
}
}
-void BeamSearchScorer::Finalize(ISequences& sequences,
- gsl::span& final_beam_scores,
- Tensor* output_sequences,
- Tensor* output_sequence_scores) {
- ORT_ENFORCE(output_sequences != nullptr);
-
+template
+void OutputSequenceScores(BeamSearchScorer* scorer,
+ ISequences& sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) {
// Finalize all open beam hypotheses and add to generated hypotheses.
- for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
- BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
+ for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];
if (beam_hyp.done_) {
continue;
}
- for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) {
- size_t batch_beam_index = batch_index * num_beams_ + beam_index;
+ for (size_t beam_index = 0; beam_index < scorer->num_beams_; beam_index++) {
+ size_t batch_beam_index = batch_index * scorer->num_beams_ + beam_index;
float final_score = final_beam_scores[batch_beam_index];
auto final_tokens = sequences.GetSequence(narrow(batch_beam_index));
beam_hyp.Add(final_tokens, final_score);
@@ -206,26 +207,59 @@ void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span output = output_sequences->MutableDataAsSpan();
// Fill output sequences with pad token ID so that we do not need append it later.
- std::fill_n(output.data(), output.size(), pad_token_id_);
+ std::fill_n(output.data(), output.size(), scorer->pad_token_id_);
// Score of each sequence, with shape (batch_size * num_return_sequences).
- gsl::span sequence_scores;
+ gsl::span sequence_scores;
if (output_sequence_scores) {
- sequence_scores = output_sequence_scores->MutableDataAsSpan();
+ sequence_scores = output_sequence_scores->MutableDataAsSpan();
}
// Select the best hypotheses according to number of sequences to return.
- for (size_t batch_index = 0; batch_index < batch_size_; batch_index++) {
- BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
+ for (size_t batch_index = 0; batch_index < scorer->batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = scorer->beam_hyps_[batch_index];
- auto batch_output = output.subspan(batch_index * num_return_sequences_ * max_length_,
- num_return_sequences_ * max_length_);
- gsl::span sequence_scores_buffer;
+ auto batch_output = output.subspan(batch_index * scorer->num_return_sequences_ * scorer->max_length_,
+ scorer->num_return_sequences_ * scorer->max_length_);
+ gsl::span sequence_scores_buffer;
if (!sequence_scores.empty())
- sequence_scores_buffer = sequence_scores.subspan(batch_index * num_return_sequences_, num_return_sequences_);
+ sequence_scores_buffer = sequence_scores.subspan(batch_index * scorer->num_return_sequences_, scorer->num_return_sequences_);
+
+ beam_hyp.template Output(narrow(scorer->num_return_sequences_), narrow(scorer->max_length_), batch_output,
+ sequence_scores_buffer);
+ }
+}
+
+void BeamSearchScorer::Finalize(ISequences& sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) {
+ ORT_ENFORCE(output_sequences != nullptr);
- beam_hyp.Output(narrow(num_return_sequences_), narrow(max_length_), batch_output,
- sequence_scores_buffer);
+ if (output_sequence_scores == nullptr || output_sequence_scores->IsDataType()) {
+ OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
+ } else {
+ ORT_ENFORCE(output_sequence_scores->IsDataType());
+ OutputSequenceScores(this, sequences, final_beam_scores, output_sequences, output_sequence_scores);
+ }
+}
+
+void BeamSearchScorer::OutputScores(gsl::span& final_scores, Tensor* output_scores) {
+ if (output_scores) {
+ if (output_scores->IsDataType()) {
+ gsl::span target = output_scores->MutableDataAsSpan();
+ ORT_ENFORCE(target.size() == final_scores.size());
+ std::copy_n(final_scores.data(), final_scores.size(), target.data());
+ } else {
+ ORT_ENFORCE(output_scores->IsDataType());
+ gsl::span target = output_scores->MutableDataAsSpan();
+ ORT_ENFORCE(target.size() == final_scores.size());
+ const float* src = final_scores.data();
+ MLFloat16* dst = target.data();
+ for (size_t i = 0; i < target.size(); i++) {
+ dst[i] = MLFloat16(src[i]);
+ }
+ }
}
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
index 94b6d340d9f4a..dc92e8038a68e 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
@@ -35,10 +35,11 @@ struct BeamHypotheses {
bool CanImprove(float best_sum_logprobs, int current_length) const;
// Output results
- void Output(int top_k, // number of sequences to return
- int max_length, // max sequence length
- gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
- gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
+ template
+ void Output(int top_k, // number of sequences to return
+ int max_length, // max sequence length
+ gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
+ gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring
int beams_used_; // Number of elements used in beams_
@@ -60,13 +61,14 @@ struct BeamSearchScorer : IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) override;
+ void OutputScores(gsl::span& final_scores, Tensor* output_scores) override;
+
bool IsDone() const override { return not_done_count_ == 0; }
gsl::span GetNextScores() override { return next_beam_scores_; }
gsl::span GetNextTokens() override { return next_beam_tokens_; }
gsl::span GetNextIndicesCPU() override { return next_beam_indices_; }
- private:
size_t batch_size_;
size_t num_beams_;
size_t max_length_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index f6faf2e325f8f..cb62e2f7bf4da 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -120,6 +120,9 @@ struct IBeamScorer {
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;
+ virtual void OutputScores(gsl::span& final_scores,
+ Tensor* output_scores) = 0;
+
virtual bool IsDone() const = 0; // GPU version will return false here, as it asynchronously queues up the event
virtual bool IsDoneLater() const { return false; } // GPU version waits for the asynchous result to complete here
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
index d6eb87228bb4a..2c296bf4f8483 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.cc
@@ -355,13 +355,15 @@ bool is_supported(const cudaDeviceProp& dprops, int head_size, int num_heads, in
Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
cudaStream_t stream,
void* q, // batch_size x seqlen_q x num_heads x head_size
- void* kcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
- void* vcache, // batch_size x seqlen_k x num_heads_k x head_size or batch_size x num_heads_k seqlen_k x head_size
- void* k, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
- void* v, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* kcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+ void* vcache, // batch_size x seqlen_k_max x num_heads_k x head_size or batch_size x num_heads_k seqlen_k_max x head_size
+ void* k_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
+ void* v_new, // (optional) batch_size x seqlen_k_new x num_heads_k x head_size
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
+ void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
+ void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@@ -376,16 +378,15 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits,
void* softmax_lse_accum, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- int local_window_size) {
- // if (seqlen_q == 1) {
- // is_causal = false;
- // } // causal=true is the same as causal=false in this case
-
+ int local_window_size,
+ bool is_rotary_interleaved,
+ bool is_packed_qkv) {
auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; };
const int head_size_rounded = round_multiple(head_size, 32);
const int seqlen_q_rounded = round_multiple(seqlen_q, 128);
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);
+ // In kv-cache case, seqlen_k_max as kv sequence length
Flash_fwd_params params;
set_params_fprop(params,
batch_size,
@@ -406,15 +407,24 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
is_causal ? 0 : -1);
params.dprops = &dprops;
- if (k != nullptr && v != nullptr) {
+ if (k_new != nullptr && v_new != nullptr) {
params.seqlen_knew = seqlen_k_new;
- params.knew_ptr = k;
- params.vnew_ptr = v;
+ params.knew_ptr = k_new;
+ params.vnew_ptr = v_new;
// All stride are in elements, not bytes.
- params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
- params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
- params.knew_row_stride = num_heads_k * head_size;
- params.vnew_row_stride = num_heads_k * head_size;
+ if (is_packed_qkv) {
+ params.q_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.q_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ params.knew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.vnew_batch_stride = (seqlen_q * num_heads * head_size) + (2 * seqlen_k_new * num_heads_k * head_size);
+ params.knew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ params.vnew_row_stride = (num_heads * head_size) + (2 * num_heads_k * head_size);
+ } else {
+ params.knew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.vnew_batch_stride = seqlen_k_new * num_heads_k * head_size;
+ params.knew_row_stride = num_heads_k * head_size;
+ params.vnew_row_stride = num_heads_k * head_size;
+ }
params.knew_head_stride = head_size;
params.vnew_head_stride = head_size;
} else {
@@ -434,6 +444,13 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
params.cu_seqlens_k = static_cast(seqlens_k_);
}
+ if (rotary_cos != nullptr) {
+ params.rotary_cos_ptr = rotary_cos;
+ params.rotary_sin_ptr = rotary_sin;
+ params.is_rotary_interleaved = is_rotary_interleaved;
+ params.rotary_dim = (head_size / 16) * 16;
+ }
+
params.num_splits = num_splits;
if (params.num_splits > 1 && softmax_lse_accum != nullptr && out_accum != nullptr) {
params.softmax_lseaccum_ptr = softmax_lse_accum;
@@ -444,7 +461,7 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
}
// Only split kernel supports appending to KV cache
- run_mha_fwd(params, stream, /*force_split_kernel=*/k != nullptr);
+ run_mha_fwd(params, stream, /*force_split_kernel=*/k_new != nullptr);
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
index 3d75d6834b8e0..387d1cf9d84fe 100644
--- a/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
+++ b/onnxruntime/contrib_ops/cuda/bert/flash_attention/flash_api.h
@@ -87,6 +87,8 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
void* out, // batch_size x seqlen_q x num_heads x head_size
void* softmax_lse, // batch_size x num_heads x seqlen_q
void* seqlens_k_, // batch_size
+ void* rotary_sin, // seqlen_ro x (rotary_dim / 2)
+ void* rotary_cos, // seqlen_ro x (rotary_dim / 2)
int batch_size,
int num_heads,
int num_heads_k,
@@ -101,7 +103,9 @@ Status mha_fwd_kvcache(const cudaDeviceProp& dprops,
int num_splits = 0,
void* softmax_lse_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads
void* out_accum = nullptr, // num_splits x batch_size x seqlen_q x num_heads x head_size_rounded
- int local_window_size = -1);
+ int local_window_size = -1,
+ bool is_rotary_interleaved = false,
+ bool is_packed_qkv = false);
size_t get_softmax_lse_size(int max_seqlen_q, int batch_size, int num_heads);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
index fd6fb79742cac..fe56f84f0a886 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.cc
@@ -47,6 +47,8 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
kv_num_heads_ = static_cast(kv_num_heads);
is_past_bsnh_ = false; // info.GetAttrOrDefault("is_past_bsnh", 1) == 1;
local_window_size_ = static_cast(info.GetAttrOrDefault("local_window_size", -1));
+ do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_interleaved_ = info.GetAttrOrDefault("rotary_interleaved", 0) == 1;
scale_ = info.GetAttrOrDefault("scale", 0.0f);
#if USE_FLASH_ATTENTION
@@ -62,6 +64,9 @@ GroupQueryAttention::GroupQueryAttention(const OpKernelInfo& info)
#else
disable_memory_efficient_attention_ = true;
#endif
+ if (!disable_flash_attention_) {
+ zeros_ = this->GetScratchBuffer(kZerosCount, nullptr);
+ }
}
template
@@ -73,6 +78,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
const Tensor* past_value = context->Input(4);
const Tensor* seqlens_k = context->Input(5);
const Tensor* total_seqlen = context->Input(6);
+ const Tensor* cos_cache = context->Input(7);
+ const Tensor* sin_cache = context->Input(8);
auto& device_prop = GetDeviceProp();
GroupQueryAttentionParameters parameters;
@@ -84,6 +91,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
value,
past_key,
past_value,
+ cos_cache,
+ sin_cache,
¶meters,
num_heads_,
kv_num_heads_,
@@ -93,7 +102,13 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
scale_,
device_prop.maxThreadsPerBlock));
parameters.local_window_size = local_window_size_;
+ parameters.is_unidirectional = is_unidirectional_;
+ parameters.zeros_count = kZerosCount;
+ parameters.zero_ptr = zeros_.get();
+ // parameters.left_padding = left_padding_;
int sequence_length = parameters.sequence_length;
+ parameters.do_rotary = do_rotary_;
+ parameters.rotary_interleaved = rotary_interleaved_;
TensorShapeVector output_shape(3);
output_shape[0] = static_cast(parameters.batch_size);
@@ -139,6 +154,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&
+ do_rotary_ == false &&
+ key != nullptr &&
(parameters.head_size & 7) == 0 &&
parameters.sequence_length <= parameters.seqlen_past_kv_cache + parameters.sequence_length &&
(sizeof(T) == 2 || parameters.sequence_length >= attention::kMinSeqLenForMemoryEfficientAttentionFp32) &&
@@ -182,8 +199,8 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
Tensor* present_value = context->Output(2, present_shape);
data.query = reinterpret_cast(query->Data());
- data.key = reinterpret_cast(key->Data());
- data.value = reinterpret_cast(value->Data());
+ data.key = key == nullptr ? nullptr : reinterpret_cast(key->Data());
+ data.value = value == nullptr ? nullptr : reinterpret_cast(value->Data());
data.past_key = (nullptr == past_key) ? nullptr : reinterpret_cast(past_key->Data());
data.past_value = (nullptr == past_value) ? nullptr : reinterpret_cast(past_value->Data());
data.output = reinterpret_cast(output->MutableData());
@@ -229,6 +246,11 @@ Status GroupQueryAttention::ComputeInternal(OpKernelContext* context) const {
if (fmha_buffer != nullptr) {
data.fmha_buffer = reinterpret_cast(fmha_buffer.get());
}
+ // Rotary
+ if (parameters.do_rotary) {
+ data.cos_cache = reinterpret_cast(cos_cache->Data());
+ data.sin_cache = reinterpret_cast(sin_cache->Data());
+ }
cublasHandle_t cublas = GetCublasHandle(context);
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
index 54a8127e29e7b..15573ece166fc 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention.h
@@ -23,10 +23,15 @@ class GroupQueryAttention final : public CudaKernel {
int num_heads_; // number of attention heads
int kv_num_heads_; // different for k and v for group query attention
int local_window_size_;
+ bool is_unidirectional_;
bool is_past_bsnh_;
+ bool do_rotary_;
+ bool rotary_interleaved_;
float scale_;
bool disable_flash_attention_;
bool disable_memory_efficient_attention_;
+ static constexpr int kZerosCount = 256; // In prompt case we create a zero buffer of size 256 for seqlen (assume batch_size <= 256)
+ IAllocatorUniquePtr zeros_;
};
} // namespace cuda
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
index 2cb9955807f26..853e1a710cb24 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_helper.h
@@ -16,6 +16,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@@ -24,19 +26,18 @@ Status CheckInputs(const Tensor* query,
bool is_past_bsnh,
float scale) {
// Note: Here S* is past_cache_sequence_length, S- is past_sequence_length, S+ is sequence_length
- // past_key : (B, N_k, S*, H) or (B, N_k, S-, H)
- // past_value : (B, N_k, S*, H) or (B, N_k, S-, H)
+ // past_key : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
+ // past_value : (B, N_k, S*, H) or (B, N_k, S-, H) or nullptr
// no packing for q/k/v:
- // query (Q) : (B, S, D)
- // key (K) : (B, S, D_kv)
- // value (V) : (B, S, D_kv)
+ // query (Q) : (B, S, D) or (B, S, (D_q + 2 D_kv))
+ // key (K) : (B, S, D_kv) or nullptr
+ // value (V) : (B, S, D_kv) or nullptr
ORT_UNUSED_PARAMETER(value);
AttentionQkvFormat qkv_format = Q_K_V_BSNH;
AttentionQkvFormat past_kv_format = is_past_bsnh ? Q_K_V_BSNH : Q_K_V_BNSH;
-
+ const bool is_packed_qkv = key == nullptr;
const auto& query_dims = query->Shape().GetDims();
- const auto& key_dims = key->Shape().GetDims();
if (query_dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'query' is expected to have 3 dimensions, got ",
@@ -46,10 +47,69 @@ Status CheckInputs(const Tensor* query,
int batch_size = static_cast(query_dims[0]);
int sequence_length = static_cast(query_dims[1]);
int q_hidden_size = static_cast(query_dims[2]);
- int head_size = static_cast(q_hidden_size) / num_heads;
+ int head_size = 0;
+
+ if (num_heads % kv_num_heads != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
+ num_heads % kv_num_heads);
+ }
- int kv_hidden_size = static_cast(key_dims[2]);
+ int kv_hidden_size = 0;
+ // Check key and value when not packed
+ if (!is_packed_qkv) {
+ head_size = static_cast(q_hidden_size) / num_heads;
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ const auto& key_dims = key->Shape().GetDims();
+ if (key_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
+ key_dims.size());
+ } else if (query_dims[0] != key_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != key_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'key' shall have same dim 1 (sequence length)");
+ }
+ kv_hidden_size = static_cast(key_dims[2]);
+ const auto& value_dims = value->Shape().GetDims();
+ if (value_dims.size() != 3) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
+ value_dims.size());
+ } else if (query_dims[0] != value_dims[0]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 0 (batch size)");
+ } else if (query_dims[1] != value_dims[1]) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'query' and 'value' shall have same dim 1 (sequence length)");
+ } else if (value_dims[2] != kv_hidden_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
+ }
+ } else {
+ // Check packed qkv
+ head_size = static_cast(q_hidden_size) / (num_heads + 2 * kv_num_heads);
+ if (head_size % 8 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size must be a multiple of 8. Got head_size % 8 == ",
+ head_size % 8);
+ }
+ if (value != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'key' and 'value' shall be both present, or both absent in the case of packed qkv.");
+ }
+ q_hidden_size = head_size * num_heads;
+ kv_hidden_size = head_size * kv_num_heads;
+ }
+ // Check past-present KV
int32_t past_sequence_length = 0;
if (past_key != nullptr && past_value != nullptr) {
const auto& past_key_dims = past_key->Shape().GetDims();
@@ -130,41 +190,6 @@ Status CheckInputs(const Tensor* query,
"Input 'past_key' and 'past_value' shall be both present or both absent.");
}
- if (key_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'key' is expected to have 3 dimensions, got ",
- key_dims.size());
- }
- if (query_dims[0] != key_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'key' shall have same dim 0 (batch size)");
- }
-
- if (num_heads % kv_num_heads != 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "num_heads must be a multiple of kv_num_heads. Got num_heads % kv_num_heads == ",
- num_heads % kv_num_heads);
- }
-
- const auto& value_dims = value->Shape().GetDims();
- if (value_dims.size() != 3) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have 3 dimensions, got ",
- value_dims.size());
- }
-
- if (query_dims[0] != value_dims[0]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query' and 'value' shall have same dim 0 (batch_size)");
- }
-
- if (static_cast(sequence_length) != value_dims[1]) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'query,' 'key,' and 'value' shall have the same dim 1 (sequence_length)");
- }
-
- if (value_dims[2] != kv_hidden_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'value' is expected to have same hidden size as key.");
- }
-
// Check seqlens_k tensor (holding past seqlen for token gen)
const auto& seqlens_dim = seqlens_k->Shape().GetDims();
if (seqlens_dim.size() != 1 && seqlens_dim[0] != batch_size) {
@@ -180,6 +205,36 @@ Status CheckInputs(const Tensor* query,
int total_sequence_length = *((*total_seqlen).template Data());
int present_sequence_length = std::max(total_sequence_length, past_sequence_length);
+ if (cos_cache != nullptr && sin_cache != nullptr) {
+ const auto& cos_dims = cos_cache->Shape().GetDims();
+ const auto& sin_dims = sin_cache->Shape().GetDims();
+
+ if (head_size % 16 != 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "head_size shall be a multiple of 16. Got head_size % 16 == ",
+ head_size % 16);
+ }
+ if (cos_dims[0] != present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 0 must be of present_sequence_length.");
+ }
+ if (sin_dims[0] != present_sequence_length) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 0 must be of present_sequence_length.");
+ }
+ if (cos_dims[1] != (head_size / 16) * 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "cos_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ if (sin_dims[1] != (head_size / 16) * 8) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "sin_cache dimension 1 must be <= head_size / 2 and a multiple of 8.");
+ }
+ } else if (cos_cache != nullptr || sin_cache != nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'cos_cache' and 'sin_cache' shall be both present or both absent.");
+ }
+
bool is_prompt = sequence_length != 1;
if (parameters != nullptr) {
@@ -190,9 +245,10 @@ Status CheckInputs(const Tensor* query,
output_parameters->seqlen_present_kv_cache = present_sequence_length; // max sequence length of present kv tensors
output_parameters->hidden_size = q_hidden_size;
output_parameters->num_heads = num_heads;
- output_parameters->head_size = q_hidden_size / num_heads;
+ output_parameters->head_size = head_size;
output_parameters->kv_hidden_size = kv_hidden_size;
output_parameters->kv_num_heads = kv_num_heads;
+ output_parameters->is_packed_qkv = is_packed_qkv;
output_parameters->is_unidirectional = true;
output_parameters->is_prompt = is_prompt;
output_parameters->scale = scale;
@@ -208,6 +264,8 @@ Status CheckInputs(const Tensor* query,
const Tensor* value,
const Tensor* past_key,
const Tensor* past_value,
+ const Tensor* cos_cache,
+ const Tensor* sin_cache,
void* parameters,
int num_heads,
int kv_num_heads,
@@ -220,7 +278,7 @@ Status CheckInputs(const Tensor* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(query, key, value, past_key, past_value, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
+ return CheckInputs(query, key, value, past_key, past_value, cos_cache, sin_cache, parameters, num_heads, kv_num_heads, seqlens_k, total_seqlen, is_past_bsnh, scale);
}
} // namespace group_query_attention_helper
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
index 5b0f5d0cfe601..d88e9a49fb5ee 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.cu
@@ -151,9 +151,10 @@ template
Status LaunchConcatNewToPastKV(contrib::GroupQueryAttentionParameters& parameters,
GroupQueryAttentionData& data,
cudaStream_t stream,
- const int max_threads_per_block) {
+ const int max_threads_per_block,
+ const bool past_only = false) {
const int batch_size = parameters.batch_size;
- const int kv_sequence_length = parameters.sequence_length;
+ const int kv_sequence_length = past_only ? 0 : parameters.sequence_length;
const int past_sequence_length = parameters.seqlen_past_kv_cache;
const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int kv_num_heads = parameters.kv_num_heads;
@@ -441,7 +442,6 @@ Status LaunchUngroup(contrib::GroupQueryAttentionParameters& parameters,
return CUDA_CALL(cudaGetLastError());
}
-
__global__ void PastToTotalSeqlen(int32_t* seqlens_k,
int32_t* seqlens_k_buff,
const int add_seqlen) {
@@ -451,7 +451,7 @@ __global__ void PastToTotalSeqlen(int32_t* seqlens_k,
// Convert Past to Total sequence length tensor
Status LaunchGetSeqlenBuff(contrib::GroupQueryAttentionParameters& parameters, int32_t* seqlens_k,
int32_t* seqlens_k_buff, bool is_total, cudaStream_t stream,
- const int threads_per_block) {
+ const int threads_per_block) {
if (parameters.is_prompt) {
return Status::OK();
}
@@ -482,91 +482,63 @@ Status FlashAttention(
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int kv_sequence_length = parameters.sequence_length;
- const int present_sequence_length = parameters.seqlen_present_kv_cache;
const int num_heads = parameters.num_heads;
const int kv_num_heads = parameters.kv_num_heads;
const int head_size = parameters.head_size;
AttentionQkvFormat past_kv_format = parameters.past_kv_format;
-
- void* query = reinterpret_cast(const_cast(data.query));
- void* key = reinterpret_cast(const_cast(data.key));
- void* value = reinterpret_cast(const_cast(data.value));
-
bool is_causal = true;
-
bool is_bf16 = std::is_same::value;
- // Note: seqlens_k is past sequence length for flash
- if (parameters.is_prompt) {
- // Launch kernel to copy seqlen
- constexpr int thr_per_blk = 256;
- int blk_in_grid = (batch_size + thr_per_blk -1) / thr_per_blk;
- repeat_seqlen<<>>(data.seqlens_k_total, parameters.sequence_length, batch_size);
- }
-
- void* seqlens_k = reinterpret_cast(data.seqlens_k);
-
- if (parameters.kv_share_buffer) {
- // Share buffer case
- if (data.past_key == nullptr || data.past_key != data.present_key) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Past and present kv shall share the same tensor when kv_share_buffer is on.");
- }
-
- if (parameters.is_prompt) {
- ORT_RETURN_IF_ERROR(LaunchConcatKVInPlace(parameters, data, stream, max_threads_per_block));
- key = nullptr;
- value = nullptr;
- seqlens_k = reinterpret_cast(data.seqlens_k_total);
- }
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1);
+ void* query = reinterpret_cast(const_cast(data.query));
+ void* key;
+ void* value;
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
- device_prop, stream, query, present_key, present_value, key, value, data.output, reinterpret_cast(data.softmax_lse),
- seqlens_k, batch_size, num_heads, kv_num_heads,
- head_size, sequence_length, present_sequence_length, kv_sequence_length,
- scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
- reinterpret_cast(data.out_accum), parameters.local_window_size));
+ if (!parameters.is_packed_qkv) {
+ key = reinterpret_cast(const_cast(data.key));
+ value = reinterpret_cast(const_cast(data.value));
} else {
- // Not share buffer case
- // Note that Flash Attention kv-caching operates in place on a buffer... therefore this path is inneficient
- if (data.past_key != nullptr && data.past_key == data.present_key) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Past and present kv share the same tensor but kv_share_buffer is not on.");
- }
-
- ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block));
+ const size_t key_offset = static_cast(num_heads * head_size);
+ const size_t value_offset = static_cast(kv_num_heads * head_size);
+ key = reinterpret_cast(query) + key_offset;
+ value = reinterpret_cast(key) + value_offset;
+ }
- if (!parameters.is_prompt) {
- ORT_RETURN_IF_ERROR(LaunchGetSeqlenBuff(parameters, data.seqlens_k, data.seqlens_k_total, true, stream, 256));
+ void* seqlens_k = reinterpret_cast(data.seqlens_k);
+ if (parameters.is_prompt) {
+ // set seqlens_k to zeros... flash api uses seqlens_k to indicate where to append key and value
+ // user should use seqlens_k to index into output to get new tokens
+ if (batch_size <= parameters.zeros_count) {
+ seqlens_k = parameters.zero_ptr;
+ } else {
+ // Launch kernel to create larger seqlen tensor when batch_size > 256
+ constexpr int thr_per_blk = 256;
+ int blk_in_grid = (batch_size + thr_per_blk - 1) / thr_per_blk;
+ repeat_seqlen<<>>(data.seqlens_k_total, 0, batch_size);
+ seqlens_k = data.seqlens_k_total;
}
-
- seqlens_k = reinterpret_cast(data.seqlens_k_total);
-
- void* present_key = reinterpret_cast(const_cast(data.present_key));
- void* present_value = reinterpret_cast(const_cast(data.present_value));
-
- DUMP_TENSOR_INIT();
- DUMP_TENSOR("seqlens_k", reinterpret_cast(seqlens_k), batch_size, 1);
- DUMP_TENSOR("Q", data.query, batch_size, sequence_length, num_heads, head_size);
- DUMP_TENSOR("K", data.present_key, batch_size, kv_num_heads, present_sequence_length, head_size);
- DUMP_TENSOR("V", data.present_value, batch_size, kv_num_heads, present_sequence_length, head_size);
-
- bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
- ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
- device_prop, stream, query, present_key, present_value, nullptr, nullptr, data.output, reinterpret_cast(data.softmax_lse),
- seqlens_k, batch_size, num_heads, kv_num_heads,
- head_size, sequence_length, present_sequence_length, 0,
- scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
- reinterpret_cast(data.out_accum), parameters.local_window_size));
+ } else if (!parameters.kv_share_buffer) { // copy past kv to present kv
+ ORT_RETURN_IF_ERROR(LaunchConcatNewToPastKV(parameters, data, stream, max_threads_per_block, true));
}
+ void* present_key = reinterpret_cast(const_cast(data.present_key));
+ void* present_value = reinterpret_cast(const_cast(data.present_value));
+ void* cos_cache = reinterpret_cast(const_cast(data.cos_cache));
+ void* sin_cache = reinterpret_cast(const_cast(data.sin_cache));
+
+ bool past_bsnh = past_kv_format == AttentionQkvFormat::Q_K_V_BSNH;
+ ORT_RETURN_IF_ERROR(onnxruntime::flash::mha_fwd_kvcache(
+ device_prop, stream, query, present_key, present_value, key, value, data.output,
+ reinterpret_cast(data.softmax_lse), seqlens_k, cos_cache, sin_cache,
+ batch_size, num_heads, kv_num_heads, head_size, sequence_length,
+ parameters.seqlen_present_kv_cache, kv_sequence_length,
+ scale, is_causal, is_bf16, past_bsnh, parameters.num_splits, reinterpret_cast(data.softmax_lse_accum),
+ reinterpret_cast(data.out_accum), parameters.local_window_size, parameters.rotary_interleaved,
+ parameters.is_packed_qkv));
+
+ // if (parameters.left_padding && parameters.is_prompt) {
+ // ORT_RETURN_IF_ERROR(LaunchLeftPadLast(parameters, data, stream, device_prop.maxThreadsPerBlock));
+ // }
+
DUMP_TENSOR_INIT();
DUMP_TENSOR("flash attention output", data.output, batch_size, sequence_length, num_heads, head_size);
@@ -672,7 +644,6 @@ Status EfficientAttention(
p.has_custom_right_padding = true;
run_memory_efficient_attention(p);
- DUMP_TENSOR_INIT();
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
index de32d7ea93163..1bf91f9c875eb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/group_query_attention_impl.h
@@ -21,6 +21,8 @@ struct GroupQueryAttentionData {
const T* past_key = nullptr;
const T* past_value = nullptr;
int* seqlens_k = nullptr;
+ const T* cos_cache = nullptr;
+ const T* sin_cache = nullptr;
// Flash buffers
T* softmax_lse = nullptr;
T* softmax_lse_accum = nullptr;
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index ebd66d8c6528e..f978f50c6851f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
+ ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
@@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
num_heads_,
mask_filter_value_,
scale_,
+ is_unidirectional_,
false, // past_present_share_buffer
false, // dmmha_packing
device_prop.maxThreadsPerBlock));
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
index c162f7133cc1c..86a32c92ce003 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
@@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
float scale_;
+ bool is_unidirectional_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index 2d12e975d88d7..9de7ba3885c3c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -29,10 +29,13 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
}
@@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
+ num_heads,
+ rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
+ parameters.rotary_embedding_dim,
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
index 6dab2ad56749e..d52f61d670444 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
@@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel {
protected:
float scale;
+ int num_heads;
+ int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
index e1b83bd8caf54..c6637041f05bd 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int position_ids_format,
const bool interleaved,
const int batch_stride,
@@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
-
+
const int b = blockIdx.z;
const int s = blockIdx.y;
const int n = blockIdx.x;
const int i = threadIdx.x;
+ if (i >= head_size) {
+ return;
+ }
+
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
+ if (i >= rotary_embedding_dim) {
+ output_data[i] = input_data[i];
+ return;
+ }
+
// Cache is (M, H/2)
- const int half_head_size = head_size / 2;
+ const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
const int position_id = (position_ids_format == 0) ? \
static_cast(position_ids[0]) + s \
: static_cast(position_ids[b * sequence_length + s]);
- const int cache_offset = position_id * half_head_size;
+ const int cache_offset = position_id * half_rotary_embedding_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
@@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
T sign = 0;
int j = 0;
if (interleaved) {
- cache_idx = (i / 2) % half_head_size;
+ cache_idx = (i / 2) % half_rotary_embedding_dim;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i+1 : i-1; // i - sign
} else {
- cache_idx = i % half_head_size;
- sign = (i < half_head_size) ? -1 : 1;
- j = (i + half_head_size) % head_size;
+ cache_idx = i % half_rotary_embedding_dim;
+ sign = (i < half_rotary_embedding_dim) ? -1 : 1;
+ j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
@@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
const bool transposed) {
-
- constexpr int smem_size = 0;
- const dim3 grid(num_heads, sequence_length, batch_size);
- const dim3 block(head_size, 1, 1);
-
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
+ ORT_ENFORCE(head_size <= max_threads_per_block,
+ "Rotary embedding dim must be <= max_threads_per_block");
+
+ int tpb = (head_size + 31)/32*32;
+
+ const dim3 block(tpb);
+ const dim3 grid(num_heads, sequence_length, batch_size);
// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
@@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel(
}
assert(head_size <= max_threads_per_block);
- RotaryEmbeddingBSNH<<>>(
- output, input, cos_cache, sin_cache, position_ids,
- sequence_length, num_heads, head_size, position_ids_format, interleaved,
- batch_stride, seq_stride, head_stride
+ RotaryEmbeddingBSNH<<>>(
+ output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size,
+ rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride
);
return CUDA_CALL(cudaGetLastError());
@@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
@@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block,
+ const bool transposed);
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ BFloat16* output,
+ const BFloat16* input,
+ const int64_t* position_ids,
+ const BFloat16* cos_cache,
+ const BFloat16* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
index ee1ccc43dcbff..36300fe7a660f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 34b44694a5fcc..fa73950c9c6f5 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -299,6 +300,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo