Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cherry-pick LLaMA/SDXL to rel-1.16.2 #18202

Merged
merged 40 commits into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
6918ba3
Add LLaMA scripts (#17020)
kunal-vaishnavi Aug 23, 2023
366df09
Add Whisper scripts (#17043)
kunal-vaishnavi Aug 23, 2023
438f794
Fix attention fusion for UNet onnx model export when using LoRA weigh…
PatriceVignola Aug 29, 2023
3b2f819
Handle dtype attribute in float16 conversion script (#17321)
tianleiwu Aug 30, 2023
0bf7aa9
fix assert error in attention fusion script (#17375)
tianleiwu Sep 1, 2023
0acc737
Fix weight tensors in transformers optimizer not saved to external da…
tianleiwu Sep 6, 2023
e529526
fix embed layer norm fusion with embedding sum output (#17460)
tianleiwu Sep 8, 2023
c947639
Attention fusion for stable diffusion clip model (#17445)
tianleiwu Sep 8, 2023
a17bda1
Improve performance of prune_graph in onnx_model.py (#17502)
tianleiwu Sep 12, 2023
9b38999
Update optimize_pipeline for SDXL (#17536)
tianleiwu Sep 15, 2023
da37196
Refactoring of attention cuda kernel: move prepare qkv and concat_pas…
tianleiwu Sep 15, 2023
3cca12f
Reduce GPU memory for Whisper models converted to ONNX (#17378)
petermcaughan Sep 5, 2023
865350e
Provide kwargs to remove_shared_initializers (#17539)
jambayk Sep 18, 2023
adfc75c
Refactor Attention cuda kernel (#17578)
tianleiwu Sep 19, 2023
c4ba739
Bugfix: Add initializer to model in AttentionMask directly (#17719)
jambayk Sep 27, 2023
b181a43
[ROCm] Update whisper benchmark script (#17391)
PeixuanZuo Sep 18, 2023
86e334d
StableDiffusion XL with TensorRT EP (#17748)
tianleiwu Oct 4, 2023
ef73084
Add CUDA EP in StableDiffusion demo (#17788)
tianleiwu Oct 5, 2023
bcad2ed
[BeamSearch]optimize key cache reordering (#17771)
wangyems Oct 5, 2023
f4c4d3e
[CUDA] GroupQueryAttention operator using FlashAttention (#17674)
aciddelgado Oct 9, 2023
fb98485
[CUDA/ROCm] Update BiasSplitGelu for SD XL Refiner model (#17849)
tianleiwu Oct 10, 2023
52f3004
[CUDA/ROCm] Remove limitation of BiasAdd (#17848)
tianleiwu Oct 11, 2023
cc975aa
Fix GroupNorm fusion: skip if num of channels not supported (#17869)
tianleiwu Oct 12, 2023
d5be7fb
[CUDA] Fix SkipLayerNorm strict mode when skip has broadcast (#17896)
tianleiwu Oct 13, 2023
cadd86c
Fix build break with cuda 12.2 (#17922)
yufenglee Oct 13, 2023
68b7e0c
Provide kwargs to remove_shared_initializers (#17539)
yufenglee Oct 13, 2023
e1a5c5c
[CUDA] Fix SkipLayerNorm vectorized kernel out-of-bounds read (#17943)
tianleiwu Oct 16, 2023
961342e
Fix MHA shape inference (#18009)
PatriceVignola Oct 18, 2023
00f85d3
[CUDA] StableDiffusion XL demo with CUDA EP (#17997)
tianleiwu Oct 18, 2023
8ecd9a5
Fix Packed MultiHead Attention (#17996)
aciddelgado Oct 18, 2023
d07e08c
Support large model export using multi-gpu (#17990)
wejoncy Oct 22, 2023
f2d5e8b
LLaMA Model Optimization (#18021)
kunal-vaishnavi Oct 23, 2023
e5dce65
Add MatMul FP4 and NF4 Support (#18066)
jambayk Oct 25, 2023
2a48cd2
Add updates to LLaMA scripts (#18076)
kunal-vaishnavi Oct 27, 2023
f1a77e0
optimize SLN with large dimension (#18138)
yufenglee Oct 30, 2023
69948b9
Add one more MHA mask pattern (#18164)
PatriceVignola Oct 31, 2023
99aa0a2
[CUDA] Update GroupNorm and Add SkipGroupNorm (#18091)
tianleiwu Oct 31, 2023
6cdf3a1
fix windows build
tianleiwu Oct 31, 2023
3a42ede
Reduce LLaMA memory usage (#18181)
kunal-vaishnavi Nov 1, 2023
1bf715b
Split KV on MHA and Attention ops (#18007)
aciddelgado Nov 1, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/onnxruntime_python.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,9 @@ file(GLOB onnxruntime_python_transformers_models_bert_src CONFIGURE_DEPENDS
file(GLOB onnxruntime_python_transformers_models_gpt2_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/gpt2/*.py"
)
file(GLOB onnxruntime_python_transformers_models_llama_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/llama/*.py"
)
file(GLOB onnxruntime_python_transformers_models_longformer_src CONFIGURE_DEPENDS
"${ONNXRUNTIME_ROOT}/python/tools/transformers/models/longformer/*.py"
)
Expand Down Expand Up @@ -537,6 +540,7 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bart
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/bert
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/stable_diffusion
COMMAND ${CMAKE_COMMAND} -E make_directory $<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/t5
Expand Down Expand Up @@ -628,6 +632,9 @@ add_custom_command(
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_gpt2_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/gpt2/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_llama_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/llama/
COMMAND ${CMAKE_COMMAND} -E copy
${onnxruntime_python_transformers_models_longformer_src}
$<TARGET_FILE_DIR:${build_output_target}>/onnxruntime/transformers/models/longformer/
Expand Down
13 changes: 13 additions & 0 deletions cmake/onnxruntime_rocm_hipify.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ set(contrib_ops_excluded_files
"bert/attention_impl.cu"
"bert/attention_softmax.h"
"bert/attention_softmax.cu"
"bert/attention_prepare_qkv.cu"
"bert/decoder_attention_impl.h"
"bert/decoder_attention_impl.cu"
"bert/decoder_masked_multihead_attention.h"
"bert/decoder_masked_multihead_attention.cc"
"bert/decoder_masked_self_attention.h"
Expand Down Expand Up @@ -58,6 +61,16 @@ set(contrib_ops_excluded_files
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
"quantization/attention_quantization_impl.cuh"
"quantization/dequantize_blockwise.cuh"
"quantization/dequantize_blockwise.cu"
"quantization/dequantize_blockwise_bnb4.cuh"
"quantization/dequantize_blockwise_bnb4.cu"
"quantization/matmul_bnb4.cc"
"quantization/matmul_bnb4.cuh"
"quantization/matmul_bnb4.cu"
"quantization/matmul_nbits.cc"
"quantization/matmul_nbits.cuh"
"quantization/matmul_nbits.cu"
"quantization/quantize_dequantize_linear.cc"
"quantization/qordered_ops/qordered_attention_impl.cu"
"quantization/qordered_ops/qordered_attention_impl.h"
Expand Down
319 changes: 315 additions & 4 deletions docs/ContribOperators.md

Large diffs are not rendered by default.

9 changes: 9 additions & 0 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -453,9 +453,11 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
|MatMulInteger16|*in* A:**T1**<br> *in* B:**T2**<br> *out* Y:**T3**|1+|**T1** = tensor(int16)<br/> **T2** = tensor(int16)<br/> **T3** = tensor(int32)|
|MatMulIntegerToFloat|*in* A:**T1**<br> *in* B:**T2**<br> *in* a_scale:**T3**<br> *in* b_scale:**T3**<br> *in* a_zero_point:**T1**<br> *in* b_zero_point:**T2**<br> *in* bias:**T3**<br> *out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(int8), tensor(uint8)<br/> **T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MaxpoolWithMask|*in* X:**T**<br> *in* M:**tensor(int32)**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)<br/> **T2** = tensor(int32), tensor(uint32)|
Expand All @@ -475,9 +477,11 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int8), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
Expand Down Expand Up @@ -838,9 +842,12 @@ Do not modify directly.*
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32), tensor(int64)<br/> **T** = tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Irfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|LongformerAttention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask:**T**<br> *in* global_weight:**T**<br> *in* global_bias:**T**<br> *in* global:**G**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**<br> *in* B:**T2**<br> *in* scales:**T1**<br> *in* zero_points:**T2**<br> *out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)<br/> **T2** = tensor(uint8)|
|MultiHeadAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**M**<br> *in* relative_position_bias:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**<br> *in* scores:**T**<br> *out* scores_out:**T**|1+|**T** = tensor(float)<br/> **Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
Expand All @@ -859,7 +866,9 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**<br> *in* gamma:**M**<br> *in* beta:**M**<br> *in* skip:**T**<br> *in* bias:**T**<br> *out* Y:**T**<br> *out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipSimplifiedLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
Expand Down
21 changes: 21 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct AttentionParameters {
int v_hidden_size; // hidden size of V
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
Expand Down Expand Up @@ -82,6 +83,26 @@ struct PackedAttentionParameters {
bool broadcast_res_pos_bias;
};

// Parameters deduced from node attributes and inputs/outputs.
struct GroupQueryAttentionParameters {
int batch_size;
int sequence_length;
int past_sequence_length; // actual sequence length of past_key and past_value
int kv_sequence_length; // sequence length of key and value (or new_k and new_v when past is present)
int present_sequence_length; // past_sequence_length + kv_sequence_length
int max_sequence_length; // allocated length of past_key and past_value
int hidden_size;
int num_heads;
int head_size;
int kv_hidden_size;
int kv_num_heads;
int num_splits; // number of splits for splitkv
bool is_unidirectional; // causal
float scale;
AttentionQkvFormat qkv_format;
AttentionQkvFormat past_kv_format;
};

namespace attention {
// Environment variable to enable or disable TRT fused self attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedSelfAttention = "ORT_DISABLE_FUSED_ATTENTION";
Expand Down
1 change: 0 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

#include <unsupported/Eigen/SpecialFunctions>
#include <vector>
#include <iostream>

using onnxruntime::concurrency::ThreadPool;

Expand Down
14 changes: 11 additions & 3 deletions onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
AttentionMaskType mask_type = AttentionMaskType::MASK_NONE;
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
Expand All @@ -215,13 +216,21 @@ Status CheckInputs(const T* query,
} else if (mask_dims[0] == static_cast<int64_t>(3) * static_cast<int64_t>(batch_size) + static_cast<int64_t>(2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
} else if (mask_dims.size() == 3 && mask_dims[0] == static_cast<int64_t>(batch_size) &&
mask_dims[1] == static_cast<int64_t>(sequence_length) &&
mask_dims[2] == static_cast<int64_t>(total_sequence_length)) {
mask_type = AttentionMaskType::MASK_3D_ATTENTION;
}

if (mask_type == AttentionMaskType::MASK_UNKNOWN) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'key_padding_mask' shape shall be (batch_size) or (batch_size, kv_sequence_length)");
"Input 'key_padding_mask' shape shall be 1D, 2D, or 3D");
}
}

Expand Down Expand Up @@ -256,7 +265,6 @@ Status CheckInputs(const T* query,
}
}

int total_sequence_length = past_sequence_length + kv_sequence_length;
bool broadcast_res_pos_bias = false;
if (relative_position_bias != nullptr) {
const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
Expand Down
115 changes: 115 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "contrib_ops/cpu/bert/rotary_embedding.h"
#include "contrib_ops/cpu/bert/rotary_embedding_helper.h"

#include "core/platform/threadpool.h"

using onnxruntime::concurrency::ThreadPool;
using namespace onnxruntime::contrib::rotary_embedding_helper;

Check warning on line 10 in onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc#L10

Do not use namespace using-directives. Use using-declarations instead. [build/namespaces] [5]
Raw output
onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc:10:  Do not use namespace using-directives.  Use using-declarations instead.  [build/namespaces] [5]

namespace onnxruntime {
namespace contrib {

// These ops are internal-only, so register outside of onnx
ONNX_OPERATOR_TYPED_KERNEL_EX(
RotaryEmbedding,
kMSDomain,
1,
float,
kCpuExecutionProvider,
KernelDefBuilder()
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("M", DataTypeImpl::GetTensorType<int64_t>()),
RotaryEmbedding<float>);

template <typename T>
RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault<float>("scale", 1.0);
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1);
}

template <typename T>
Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
const Tensor* input = context->Input<Tensor>(0);
const Tensor* position_ids = context->Input<Tensor>(1);
const Tensor* cos_cache = context->Input<Tensor>(2);
const Tensor* sin_cache = context->Input<Tensor>(3);

RotaryParameters parameters = {};
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
position_ids,
cos_cache,
sin_cache,
&parameters));

Tensor* output = context->Output(0, input->Shape());

if (parameters.sequence_length > parameters.max_sequence_length) {
// Launch update_cos_sin_cache kernel with scale
ORT_NOT_IMPLEMENTED("Updating cos_cache and sin_cache in RotaryEmbedding is not currently supported");
}

const T* input_src = input->Data<T>();
const int64_t* pos_ids_data = position_ids->Data<int64_t>();
const T* cos_cache_data = cos_cache->Data<T>();
const T* sin_cache_data = sin_cache->Data<T>();
T* output_dest = output->MutableData<T>();

const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
const int half_head_size = head_size / 2;

AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();

const int loop_len = batch_size * sequence_length * num_heads;
const double cost = static_cast<double>(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
const int b = static_cast<int>((ptr / num_heads) / sequence_length);
const int s = static_cast<int>((ptr / num_heads) % sequence_length);
const int n = static_cast<int>(ptr % num_heads);

const int block_offset = b * sequence_length * num_heads + s * num_heads + n;
const int data_offset = block_offset * head_size;

const T* input_data = input_src + data_offset;
T* output_data = output_dest + data_offset;

// Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
? static_cast<int>(pos_ids_data[0]) + s
: static_cast<int>(pos_ids_data[b * sequence_length + s]);
const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;

int cache_idx = 0;
T sign = 0;
int j = 0;
for (int i = 0; i < head_size; i++) {
if (interleaved) {
cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
cache_idx = i % half_head_size;
sign = (i < half_head_size) ? static_cast<T>(-1) : static_cast<T>(1);
j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
}
});

return Status::OK();
}

} // namespace contrib
} // namespace onnxruntime
Loading
Loading