diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index c60b25f3418f6..0048190f9063b 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -180,8 +180,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)
past (optional) : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
-relative_position_bias (optional) : T
-additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+attention_bias (optional) : T
+additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length (optional) : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
@@ -1166,7 +1166,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Value with shape (batch_size, 1, v_hidden_size) for self attention or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size) for cross attention
mask_index (optional) : M
Mask values of shape (batch_size, total_sequence_length) or (batch_size, kv_sequence_length)
-relative_position_bias (optional) : T
+attention_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
past_key (optional) : T
past state for key with shape (batch_size, num_heads, past_sequence_length, head_size) for self attentionWhen past_present_share_buffer is set, its shape is (batch_size, num_heads, max_sequence_length, head_size). The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
@@ -1256,8 +1256,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Mask values of shape (batch_size, total_sequence_length)
past : T
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
-relative_position_bias (optional) : T
-additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
+attention_bias (optional) : T
+additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
beam_width (optional) : M
@@ -3202,8 +3202,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection
key_padding_mask (optional) : M
Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), or (batch_size, sequence_length, total_sequence_length)
-relative_position_bias (optional) : T
-relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)
+attention_bias (optional) : T
+bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
past_key (optional) : T
past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)
past_value (optional) : T
@@ -3516,8 +3516,8 @@ This version of the operator has been available since version 1 of the 'com.micr
In packing mode, it specifies the offset of each token(batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-relative_position_bias (optional) : T
-A tensor with shape (batch_size, num_heads, sequence_length, sequence_length)or (1, num_heads, sequence_length, sequence_length).It specifies the additional bias to QxK'
+attention_bias (optional) : T
+A tensor with shape (batch_size or 1, num_heads or 1, sequence_length, sequence_length).It specifies the additional bias to QxK'
#### Outputs
@@ -3591,8 +3591,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Offset of each token before packing, with shape (batch_size, sequence_length).
cumulative_sequence_length : M
A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.
-relative_position_bias (optional) : T
-It specifies the additional bias to QxK'. The shape is (batch_size, num_heads, sequence_length, sequence_length) or (1, num_heads, sequence_length, sequence_length)
+attention_bias (optional) : T
+It specifies the additional bias to QxK'. The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)
#### Outputs
@@ -4468,7 +4468,7 @@ This version of the operator has been available since version 1 of the 'com.micr
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
the inclusive start positions. When unidirectional is 1, and each token only attend to previous tokens. For GPT-2, both past
and present state are optional. Present state could appear in output even when past state is not in input.
- Current version does not support past/present, relative_position_bias and qkv_hidden_sizes.
+ Current version does not support past/present, attention_bias and qkv_hidden_sizes.
TODO: Support them if needed in the future.
#### Version
@@ -4533,8 +4533,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, past_sequence_length + sequence_length)or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).
past (optional) : Q
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).
-relative_position_bias (optional) : S
-additional add to QxK' with shape (batch_size, num_heads, sequence_length, sequence_length).
+attention_bias (optional) : S
+additional add to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length).
#### Outputs
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index f0aa332ff39eb..96173b5a4ea4a 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -460,7 +460,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
@@ -490,7 +490,7 @@ Do not modify directly.*
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float)
**T2** = tensor(uint8)
**T3** = tensor(float), tensor(uint8)
**T4** = tensor(int32)|
|MaxpoolWithMask|*in* X:**T**
*in* M:**tensor(int32)**
*out* Y:**T**|1+|**T** = tensor(float)|
-|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)|
+|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float)|
|MurmurHash3|*in* X:**T1**
*out* Y:**T2**|1+|**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(string), tensor(uint32), tensor(uint64)
**T2** = tensor(int32), tensor(uint32)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcMaxPool|*in* x:**T**
*out* y:**T**|1+|**T** = tensor(int8), tensor(uint8)|
@@ -848,7 +848,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|BeamSearch|*in* input_ids:**F**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*in* prefix_vocab_mask:**M**
*in* attention_mask:**I**
*in* decoder_input_ids:**I**
*in* logits_processor:**I**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasDropout|*in* data:**T**
*in* bias:**T**
*in* residual:**T**
*in* ratio:**T1**
*in* training_mode:**T2**
*out* output:**T**
*out* mask:**T2**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T1** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)
**T2** = tensor(bool)|
@@ -861,8 +861,8 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
-|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)|
-|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
+|DecoderMaskedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* mask_index:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*in* bias:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**
*out* qk:**V**|1+|**T** = tensor(float), tensor(float16)|
+|DecoderMaskedSelfAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|DynamicTimeWarping|*in* input:**F**
*out* output:**I**|1+|**F** = tensor(float)
**I** = tensor(int32)|
@@ -884,14 +884,14 @@ Do not modify directly.*
|MatMulBnb4|*in* A:**T1**
*in* B:**T2**
*in* absmax:**T1**
*out* Y:**T1**|1+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
|MoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
-|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
+|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**T** = tensor(float), tensor(float16)|
|NGramRepeatBlock|*in* input_ids:**Tid**
*in* scores:**T**
*out* scores_out:**T**|1+|**T** = tensor(float)
**Tid** = tensor(int64)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
-|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
-|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* relative_position_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|PackedAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
+|PackedMultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* token_offset:**M**
*in* cumulative_sequence_length:**M**
*in* attention_bias:**T**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8)
**T2** = tensor(int8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
|QMoE|*in* input:**T**
*in* router_probs:**T**
*in* fc1_experts_weights:**T1**
*in* fc1_scales:**T**
*in* fc1_experts_bias:**T**
*in* fc2_experts_weights:**T1**
*in* fc2_scales:**T**
*in* fc2_experts_bias:**T**
*in* fc3_experts_weights:**T1**
*in* fc3_scales:**T**
*in* fc3_experts_bias:**T**
*out* output:**T**|1+|**T** = tensor(float16)
**T1** = tensor(uint8)|
-|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* relative_position_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
+|QOrderedAttention|*in* input:**Q**
*in* scale_input:**S**
*in* scale_Q_gemm:**S**
*in* scale_K_gemm:**S**
*in* scale_V_gemm:**S**
*in* Q_weight:**Q**
*in* K_weight:**Q**
*in* V_weight:**Q**
*in* scale_Q_weight:**S**
*in* scale_K_weight:**S**
*in* scale_V_weight:**S**
*in* Q_bias:**S**
*in* K_bias:**S**
*in* V_bias:**S**
*in* scale_QKT_gemm:**S**
*in* scale_QKT_softmax:**S**
*in* scale_values_gemm:**S**
*in* mask_index:**G**
*in* past:**Q**
*in* attention_bias:**S**
*out* output:**Q**|1+|**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedGelu|*in* X:**Q**
*in* scale_X:**S**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLayerNormalization|*in* X:**Q**
*in* scale_X:**S**
*in* scale:**F**
*in* B:**F**
*in* scale_Y:**S**
*out* Y:**Q**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|QOrderedLongformerAttention|*in* input:**Q**
*in* scale_input:**S**
*in* weight:**Q**
*in* scale_weight:**S**
*in* bias:**S**
*in* scale_bias:**S**
*in* scale_qkv_gemm:**S**
*in* mask:**F**
*in* global_weight:**Q**
*in* scale_global_weight:**S**
*in* global_bias:**S**
*in* scale_global_gemm:**S**
*in* global:**G**
*in* scale_output:**S**
*out* output:**Q**|1+|**F** = tensor(float16)
**G** = tensor(int32)
**Q** = tensor(int8)
**S** = tensor(float)|
@@ -1296,7 +1296,7 @@ Do not modify directly.*
| |
| |
|**Operator Domain:** *com.microsoft*||||
-|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|Attention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* attention_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|BiasAdd|*in* X:**T**
*in* bias:**T**
*in* skip:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|BiasSplitGelu|*in* X:**T**
*in* bias:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
@@ -1312,7 +1312,7 @@ Do not modify directly.*
|GroupQueryAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* past_key:**T**
*in* past_value:**T**
*in* seqlens_k:**M**
*in* total_sequence_length:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|MatMulIntegerToFloat|*in* A:**T1**
*in* B:**T2**
*in* a_scale:**T3**
*in* b_scale:**T3**
*in* a_zero_point:**T1**
*in* b_zero_point:**T2**
*in* bias:**T3**
*out* Y:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)|
|MatMulNBits|*in* A:**T1**
*in* B:**T2**
*in* scales:**T1**
*in* zero_points:**T3**
*in* g_idx:**T4**
*in* bias:**T1**
*out* Y:**T1**|1+|**T1** = tensor(float), tensor(float16)
**T2** = tensor(uint8)|
-|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* relative_position_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
+|MultiHeadAttention|*in* query:**T**
*in* key:**T**
*in* value:**T**
*in* bias:**T**
*in* key_padding_mask:**M**
*in* attention_bias:**T**
*in* past_key:**T**
*in* past_value:**T**
*out* output:**T**
*out* present_key:**T**
*out* present_value:**T**|1+|**M** = tensor(int32)
**T** = tensor(float), tensor(float16)|
|NhwcConv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
|QAttention|*in* input:**T1**
*in* weight:**T2**
*in* bias:**T3**
*in* input_scale:**T3**
*in* weight_scale:**T3**
*in* mask_index:**T4**
*in* input_zero_point:**T1**
*in* weight_zero_point:**T2**
*in* past:**T3**
*out* output:**T3**
*out* present:**T3**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(int8), tensor(uint8)
**T3** = tensor(float), tensor(float16)
**T4** = tensor(int32)|
|QLinearAdd|*in* A:**T**
*in* A_scale:**tensor(float)**
*in* A_zero_point:**T**
*in* B:**T**
*in* B_scale:**tensor(float)**
*in* B_zero_point:**T**
*in* C_scale:**tensor(float)**
*in* C_zero_point:**T**
*out* C:**T**|1+|**T** = tensor(int8), tensor(uint8)|
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
index 0008fd1aff62e..8840ef97b4279 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts
@@ -101,7 +101,7 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
// bias (Q/K/V) : (D + D + D_v)
// mask_index : see below
// past (K/V) : (2, B, N, P, H) or NULL
- // relative_position_bias : (B, N, S, T) or NULL
+ // attention_bias : (B, N, S, T) or NULL
// For mask_index, the following shapes are supported:
// NULL, (B, 1), (1, 1)
@@ -118,10 +118,10 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
const bias = inputs[2];
const maskIndex = inputs[3];
const past = inputs[4];
- const relativePositionBias = inputs[5];
+ const attentionBias = inputs[5];
- if (past && relativePositionBias) {
- throw new Error('Attention cannot have both past and relative_position_bias');
+ if (past && attentionBias) {
+ throw new Error('Attention cannot have both past and attention_bias');
}
if (input.dims.length !== 3) {
@@ -217,6 +217,22 @@ const validateAttentionInputs = (inputs: readonly TensorView[], attributes: Atte
throw new Error('past is not supported');
}
+ if (attentionBias) {
+ if (attentionBias.dims.length !== 4) {
+ throw new Error('Input "attention_bias" must have 4 dimensions');
+ }
+
+ // TODO: support broadcasting the first and second dimensions of attention_bias
+ if (
+ attentionBias.dims[0] !== batchSize ||
+ attentionBias.dims[1] !== attributes.numHeads ||
+ attentionBias.dims[2] !== sequenceLength ||
+ attentionBias.dims[3] !== totalSequenceLength
+ ) {
+ throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)');
+ }
+ }
+
return {
batchSize,
sequenceLength,
@@ -348,7 +364,7 @@ const createAttentionProbsProgramInfo = (
q: TensorView,
key: TensorView,
pastKey: TensorView | undefined,
- relativePositionBias: TensorView | undefined,
+ attentionBias: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
pastSequenceLength: number,
@@ -385,7 +401,7 @@ const createAttentionProbsProgramInfo = (
if (pastKey) {
inputDependencies.push('type');
}
- if (relativePositionBias) {
+ if (attentionBias) {
inputDependencies.push('type');
}
const outputs = [{ dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default }];
@@ -400,8 +416,8 @@ const createAttentionProbsProgramInfo = (
const pastKeyInput = inputVariable('past_key', pastKey.dataType, pastKey.dims, components);
inputVars.push(pastKeyInput);
}
- if (relativePositionBias) {
- inputVars.push(inputVariable('relative_position_bias', relativePositionBias.dataType, relativePositionBias.dims));
+ if (attentionBias) {
+ inputVars.push(inputVariable('attention_bias', attentionBias.dataType, attentionBias.dims));
}
const output = outputVariable('output', q.dataType, probsShape);
const outputVars = [output];
@@ -491,7 +507,7 @@ const createAttentionProbsProgramInfo = (
}
})()};
output[outputIdx] = ${output.type.value} (sum * uniforms.alpha) + ${
- relativePositionBias ? 'relative_position_bias[outputIdx]' : '0.0'
+ attentionBias ? 'attention_bias[outputIdx]' : '0.0'
};
}
}`;
@@ -499,7 +515,7 @@ const createAttentionProbsProgramInfo = (
return {
name: 'AttentionProbs',
shaderCache: {
- hint: `${components};${relativePositionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
+ hint: `${components};${attentionBias !== undefined};${pastKey !== undefined};${context.outputCount}`,
inputDependencies,
},
getRunData: () => ({ outputs, dispatchGroup: dispatch, programUniforms }),
@@ -648,7 +664,7 @@ export const applyAttention = (
_past: TensorView | undefined,
pastKey: TensorView | undefined,
pastValue: TensorView | undefined,
- relativePositionBias: TensorView | undefined,
+ attentionBias: TensorView | undefined,
parameters: AttentionParameters,
attributes: AttentionAttrs,
) => {
@@ -657,8 +673,8 @@ export const applyAttention = (
const totalSequenceLength = pastSequenceLength + parameters.kvSequenceLength;
const inputsK = parameters.kvNumHeads === undefined && outputCount > 1 && pastKey ? [q, k, pastKey] : [q, k];
- if (relativePositionBias) {
- inputsK.push(relativePositionBias);
+ if (attentionBias) {
+ inputsK.push(attentionBias);
}
// Run AttentionProbs
@@ -668,7 +684,7 @@ export const applyAttention = (
q,
k,
outputCount > 1 ? pastKey : undefined,
- relativePositionBias,
+ attentionBias,
parameters,
attributes,
pastSequenceLength,
diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
index 1e0902eb0ff56..72e09303ba76f 100644
--- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
+++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
@@ -26,53 +26,60 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
const value = getInput(inputs, 2);
const bias = getInput(inputs, 3);
const keyPaddingMask = getInput(inputs, 4);
- const relativePositionBias = getInput(inputs, 5);
+ const attentionBias = getInput(inputs, 5);
const pastKey = getInput(inputs, 6);
const pastValue = getInput(inputs, 7);
- // Abbreviation and Meanings:
- // B: batch_size
- // S: sequence_length (input sequence length of query)
- // P: past_sequence_length (past sequence length of key or value)
- // L: kv_sequence_length (input sequence length of key or value)
- // M: max_sequence_length
- // T: total_sequence_length = past_sequence_length + kv_sequence_length
- // N: num_heads
- // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size
- // H_v: v_head_size
- // D_i: input hidden size
- // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size
- // D_v: v_hidden_size = num_heads * v_head_size
-
- // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
- // relative_position_bias : (B, 1, S, L)
- // past_key : (B, N, S*, H)
- // past_value : (B, N, S*, H)
- // When no packing for q/k/v:
+ // ---------------------------------------------------------------
+ // Notations:
+ // B: batch_size
+ // N: num_heads
+ // H: head_size of Q and K
+ // H_v: head_size of V
+ // D: hidden_size for Q and K, where D = N * H
+ // D_v: hidden_size of V, where D_v = N * H_v
+ // S: q_sequence_length
+ // P: past_sequence_length of kv cache
+ // L: kv_sequence_length
+ // T: total_sequence_length = P + L
+ // M: max_sequence_length of kv cache when past and present share buffer
+ // ---------------------------------------------------------------
+ // MultiHeadAttention inputs:
+ // ---------------------------------------------------------------
+ // Q_K_V_BSNH - no packing:
// query (Q) : (B, S, D)
- // key (K) : (B, L, D) or (B, N, S*, H)
- // value (V) : (B, L, D_v) or (B, N, S*, H)
- // bias (Q/K/V) : (D + D + D_v)
- // When packed kv is used:
+ // key (K) : (B, L, D)
+ // value (V) : (B, L, D_v)
+ // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v):
// query (Q) : (B, S, D)
- // key (K) : (B, L, N, 2, H)
- // value (V) : None
- // bias (Q/K/V) : None
- // When packed qkv is used:
- // query (Q) : (B, L, N, 3, H) or (B, S, 3*D)
- // key (K) : None
- // value (V) : None
+ // key (K) : (B, N, L, H)
+ // value (V) : (B, N, L, H_v)
+ // Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv):
+ // query (Q) : (B, S, D)
+ // key (K/V) : (B, L, N, 2, H)
+ // value : None
+ // QKV_BSN3H - packed qkv (kv cache is not used, S == L, D == D_v):
+ // query (Q/K/V) : (B, S, N, 3, H)
+ // key : None
+ // value : None
+ //
+ // Other inputs:
// bias (Q/K/V) : None or (D + D + D_v)
+ // key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T)
+ // attention_bias : None or (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T)
+ // past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
+ // past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
+ //
+ // Not Supported:
+ // key_padding_mask, packed kv, packed qkv, and broadcast for attention_bias.
if (query.dims.length !== 3 && query.dims.length !== 5) {
throw new Error('Input query is expected to have 3 or 5 dimensions');
}
- const dmmhaPacking = false;
const batchSize = query.dims[0];
const sequenceLength = query.dims[1];
- const hiddenSize =
- query.dims.length === 3 ? (dmmhaPacking ? query.dims[2] / 3 : query.dims[2]) : attributes.numHeads * query.dims[4];
+ const hiddenSize = query.dims.length === 3 ? query.dims[2] : attributes.numHeads * query.dims[4];
let kvSequenceLength = sequenceLength;
let pastSequenceLength = 0;
@@ -137,15 +144,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
throw new Error('Expect "key" shape (batch_size, num_heads, kv_sequence_length, head_size) for past_key');
}
- qkvFormat = AttentionQkvFormat.unknown;
+ qkvFormat = AttentionQkvFormat.unknown; // Q_K_V_BSNH_BNSH_BNSH
kvSequenceLength = key.dims[2];
}
} else {
// packed QKV
- if (query.dims.length !== 3 && query.dims.length !== 5) {
- throw new Error('Input "query" is expected to have 3 or 5 dimensions when key is empty');
+ if (query.dims.length !== 5) {
+ throw new Error('Input "query" is expected to have 5 dimensions when key is empty');
}
- if (query.dims.length === 5 && (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3)) {
+ if (query.dims[2] !== attributes.numHeads || query.dims[3] !== 3) {
throw new Error('Expect "query" shape (batch_size, kv_sequence_length, num_heads, 3, head_size) for packed kv');
}
@@ -157,13 +164,15 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
throw new Error('Input "bias" is expected to have 1 dimension');
}
- if (value) {
- if (query.dims.length === 5 && query.dims[3] === 2) {
+ if (key) {
+ if (key.dims.length === 5 && key.dims[3] === 2) {
throw new Error('bias is not allowed for packed kv.');
}
}
}
+ const totalSequenceLength = pastSequenceLength + kvSequenceLength;
+
let maskType: AttentionMaskType = AttentionMaskType.none;
if (keyPaddingMask) {
maskType = AttentionMaskType.maskUnknown;
@@ -174,11 +183,11 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
} else if (maskDims[0] === 3 * batchSize + 2) {
maskType = AttentionMaskType.mask1DKeySeqLenStart;
}
- } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === kvSequenceLength) {
+ } else if (maskDims.length === 2 && maskDims[0] === batchSize && maskDims[1] === totalSequenceLength) {
maskType = AttentionMaskType.mask2dKeyPadding;
}
if (maskType === AttentionMaskType.maskUnknown) {
- throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, kv_sequence_length)');
+ throw new Error('Input "key_padding_mask" shape shall be (batch_size) or (batch_size, total_sequence_length)');
}
throw new Error('Mask not supported');
}
@@ -200,32 +209,34 @@ const validateInputs = (inputs: readonly TensorView[], attributes: AttentionAttr
}
vHiddenSize = value.dims[2];
} else {
+ // Q_K_V_BSNH_BNSH_BNSH
if (kvSequenceLength !== value.dims[2]) {
- throw new Error('Input "past_key" and "past_value" shall have the same dim 2 (kv_sequence_length)');
+ throw new Error('Input "key" and "value" shall have the same dim 2 (kv_sequence_length)');
}
vHiddenSize = value.dims[1] * value.dims[3];
passPastInKv = true;
}
}
- const totalSequenceLength = pastSequenceLength + kvSequenceLength;
const broadcastResPosBias = false;
if (keyPaddingMask) {
throw new Error('Key padding mask is not supported');
}
- if (relativePositionBias) {
- if (relativePositionBias.dims.length !== 4) {
- throw new Error('Input "relative_position_bias" is expected to have 4 dimensions');
+ if (attentionBias) {
+ if (attentionBias.dims.length !== 4) {
+ throw new Error('Input "attention_bias" is expected to have 4 dimensions');
}
+
+ // TODO: support broadcasting the first and second dimensions of attention_bias.
if (
- (relativePositionBias.dims[0] !== batchSize && relativePositionBias.dims[0] !== 1) ||
- relativePositionBias.dims[1] !== attributes.numHeads ||
- relativePositionBias.dims[2] !== sequenceLength ||
- relativePositionBias.dims[3] !== totalSequenceLength
+ attentionBias.dims[0] !== batchSize ||
+ attentionBias.dims[1] !== attributes.numHeads ||
+ attentionBias.dims[2] !== sequenceLength ||
+ attentionBias.dims[3] !== totalSequenceLength
) {
- throw new Error('Input "relative_position_bias" shape (batch_size, 1, sequence_length, kv_sequence_length)');
+ throw new Error('Expect "attention_bias" shape (batch_size, num_heads, sequence_length, total_sequence_length)');
}
}
@@ -360,7 +371,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
const value = getInput(context.inputs, 2);
const bias = getInput(context.inputs, 3);
const keyPaddingMask = getInput(context.inputs, 4);
- const relativePositionBias = getInput(context.inputs, 5);
+ const attentionBias = getInput(context.inputs, 5);
const pastKey = getInput(context.inputs, 6);
const pastValue = getInput(context.inputs, 7);
if (query.dims.length === 5) {
@@ -395,7 +406,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
undefined,
pastKey,
pastValue,
- relativePositionBias,
+ attentionBias,
params,
attributes,
);
@@ -425,17 +436,5 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio
2 * params.hiddenSize,
);
- applyAttention(
- context,
- Q,
- K,
- V,
- keyPaddingMask,
- undefined,
- pastKey,
- pastValue,
- relativePositionBias,
- params,
- attributes,
- );
+ applyAttention(context, Q, K, V, keyPaddingMask, undefined, pastKey, pastValue, attentionBias, params, attributes);
};
diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc
index 6ce6a5e0a8ce6..ed937a22c0b84 100644
--- a/js/web/test/data/ops/multihead-attention.jsonc
+++ b/js/web/test/data/ops/multihead-attention.jsonc
@@ -228,7 +228,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -293,7 +293,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -322,7 +322,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size=1 with optional RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
+ "name": "MultiHeadAttention Basic, one head and head-size=1 with optional AttentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -358,7 +358,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -397,7 +397,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, pastKey, pastValue inputs and optional presentKey, presentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -433,7 +433,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -474,7 +474,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size=1 with relativePositionBias, pastKey and pastValue",
+ "name": "MultiHeadAttention Basic, one head and head-size=1 with attentionBias, pastKey and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -510,7 +510,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
@@ -540,7 +540,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size=4 with relativePositionBias, and pastValue",
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, and pastValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -576,7 +576,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [100, 200],
"dims": [1, 1, 1, 2],
@@ -642,7 +642,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -717,7 +717,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": null,
"type": "float32"
@@ -767,7 +767,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue",
+ "name": "MultiHeadAttention Basic, one head and head-size one with attentionBias, pastKey, pastValue, presentKey and presentValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -803,7 +803,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
@@ -843,7 +843,7 @@
]
},
{
- "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs",
+ "name": "MultiHeadAttention Basic, one head and head-size=4 with attentionBias, PastKey, PastValue inputs and PresentKey and PresentValue outputs",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
@@ -879,7 +879,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [100, 200],
"dims": [1, 1, 1, 2],
@@ -957,7 +957,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
@@ -1033,7 +1033,7 @@
"data": null,
"type": "int32"
},
- // RelativePositionBias
+ // AttentionBias
{
"data": [50, 100],
"dims": [1, 1, 1, 2],
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 768676259aa14..ad14fb8258656 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -198,7 +198,7 @@ Status Attention::Compute(OpKernelContext* context) const {
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(4);
- const Tensor* relative_position_bias = context->Input(5);
+ const Tensor* attention_bias = context->Input(5);
const TensorShape& weights_shape = (weights ? weights->Shape() : weight_shape_);
@@ -208,7 +208,7 @@ Status Attention::Compute(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past,
- relative_position_bias,
+ attention_bias,
¶meters));
if (parameters.do_rotary) {
@@ -338,7 +338,7 @@ Status Attention::Compute(OpKernelContext* context) const {
output, nullptr /* present_key */, nullptr /* present_value */,
batch_size, sequence_length, sequence_length,
parameters.head_size, parameters.v_head_size, parameters.v_hidden_size,
- relative_position_bias, context);
+ attention_bias, context);
}
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index f7d8fedc734e4..52dcb990ab67f 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "contrib_ops/cpu/bert/attention_base.h"
+#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "core/providers/common.h"
namespace onnxruntime {
@@ -12,7 +13,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* relative_position_bias,
+ const Tensor* attention_bias,
void* parameters,
const Tensor* past_seq_len) const {
// Abbreviation and Meanings:
@@ -37,7 +38,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// bias (Q/K/V) : (D + D + D_v)
// mask_index : see below
// past (K/V) : (2, B, N, P, H) or NULL
- // relative_position_bias : (B, N, S, T) or NULL
+ // attention_bias : (B or 1, N or 1, S, T) or NULL
// For mask_index, the following shapes are supported:
// NULL, (B, 1), (1, 1)
@@ -49,9 +50,9 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger
// than hidden dimension of Q, K and V.
- if (past != nullptr && relative_position_bias != nullptr) {
- // past is used on GPT-2 model with past state, we don't have a case for relative position bias yet
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and relative_position_bias");
+ if (past != nullptr && attention_bias != nullptr) {
+ // past is used on GPT-2 model with past state, we don't have a case for attention bias yet
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Attention cannot have both past and attention_bias");
}
const auto& dims = input_shape.GetDims();
@@ -191,39 +192,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
}
}
- bool broadcast_res_pos_bias = false;
- if (relative_position_bias != nullptr) {
- const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
+ gsl::span attention_bias_dims;
+ if (attention_bias != nullptr) {
+ attention_bias_dims = attention_bias->Shape().GetDims();
- if (relative_position_bias_dims.size() != 4) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
- relative_position_bias_dims.size());
- }
-
- if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 0 should be same as batch_size or 1, got ",
- relative_position_bias_dims[0]);
- }
- if (relative_position_bias_dims[0] == 1) {
- broadcast_res_pos_bias = true;
- }
- if (relative_position_bias_dims[1] != num_heads_) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
- relative_position_bias_dims[1]);
- }
- if (relative_position_bias_dims[2] != sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
- relative_position_bias_dims[2]);
- }
- if (relative_position_bias_dims[3] != total_sequence_length) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
- relative_position_bias_dims[3]);
- }
+ ORT_RETURN_IF_ERROR(multihead_attention_helper::CheckAttentionBias(
+ attention_bias_dims, batch_size, num_heads_, sequence_length, total_sequence_length));
}
if (past != nullptr && past_present_share_buffer_) {
@@ -257,7 +231,8 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
- output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
+ output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1;
+ output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1;
output_parameters->qkv_format = Q_K_V_BNSH;
}
@@ -329,7 +304,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
const TensorShape& bias_shape,
const Tensor*& mask_index,
const Tensor* past,
- const Tensor* relative_position_bias,
+ const Tensor* attention_bias,
void* parameters,
const int max_threads_per_block,
const Tensor* past_seq_len) const {
@@ -337,7 +312,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, relative_position_bias, parameters, past_seq_len);
+ return CheckInputs(input_shape, weights_shape, bias_shape, mask_index, past, attention_bias, parameters, past_seq_len);
}
Tensor* AttentionBase::GetPresent(OpKernelContext* context,
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index a6782daa58f1a..05756cd54d842 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -18,7 +18,7 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr.
const Tensor* past,
- const Tensor* relative_position_bias,
+ const Tensor* attention_bias,
void* parameters,
const int max_threads_per_block, // for CUDA
const Tensor* past_seq_len = nullptr) const;
@@ -63,7 +63,7 @@ class AttentionBase {
const TensorShape& bias_shape,
const Tensor*& mask_index, // Dummy mask of shape (1 or batch_size, 1) will be updated to nullptr.
const Tensor* past,
- const Tensor* relative_position_bias,
+ const Tensor* attention_bias,
void* parameters,
const Tensor* past_seq_len = nullptr) const;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index 88127387d08ea..5a5899166f5ba 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
+#include
namespace onnxruntime {
namespace contrib {
@@ -68,7 +69,8 @@ struct AttentionParameters {
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
- bool broadcast_res_pos_bias;
+ bool broadcast_attn_bias_dim_0;
+ bool broadcast_attn_bias_dim_1;
float mask_filter_value;
float scale;
bool use_tf32;
@@ -88,8 +90,8 @@ struct PackedAttentionParameters {
int num_heads;
float scale;
int token_count;
- bool has_relative_position_bias;
- bool broadcast_res_pos_bias;
+ bool broadcast_attn_bias_dim_0;
+ bool broadcast_attn_bias_dim_1;
bool use_tf32;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
index dd52001c2ac6b..ae2eaf0204026 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h
@@ -19,23 +19,23 @@ class AttentionCPUBase : public AttentionBase {
: AttentionBase(info, require_same_hidden_size) {}
template
- Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
- const T* K, // K data with shape BxNxLxH
- const T* V, // V value with size BxNxLxH_v
- const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
- const Tensor* past, // past state
- const Tensor* past_key, // past K input tensor (if not using past state)
- const Tensor* past_value, // past V input tensor (if not using past state)
- Tensor* output, // output tensor
- Tensor* present_key, // present K output tensor (if separating present KV)
- Tensor* present_value, // present V output tensor (if separating present KV)
- int batch_size, // batch size (B)
- int sequence_length, // sequence length of Q (S)
- int kv_sequence_length, // sequence length of K or V (L)
- int qk_head_size, // head size of Q or K (H)
- int v_head_size, // head size of V (H_v)
- int v_hidden_size, // hidden size of V (D_v)
- const Tensor* relative_position_bias, // bias addition in QK. Its size is BxNxSxT
+ Status ApplyAttention(const T* Q, // Q data with shape BxNxSxH
+ const T* K, // K data with shape BxNxLxH
+ const T* V, // V value with size BxNxLxH_v
+ const Tensor* mask_index, // mask index. nullptr if no mask or its size is B
+ const Tensor* past, // past state
+ const Tensor* past_key, // past K input tensor (if not using past state)
+ const Tensor* past_value, // past V input tensor (if not using past state)
+ Tensor* output, // output tensor
+ Tensor* present_key, // present K output tensor (if separating present KV)
+ Tensor* present_value, // present V output tensor (if separating present KV)
+ int batch_size, // batch size (B)
+ int sequence_length, // sequence length of Q (S)
+ int kv_sequence_length, // sequence length of K or V (L)
+ int qk_head_size, // head size of Q or K (H)
+ int v_head_size, // head size of V (H_v)
+ int v_hidden_size, // hidden size of V (D_v)
+ const Tensor* attn_bias, // additive bias applied on scaled QK.
OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
@@ -66,10 +66,14 @@ class AttentionCPUBase : public AttentionBase {
gsl::span mask_index_dims = mask_index != nullptr
? mask_index->Shape().GetDims()
: gsl::span{};
+ DUMP_CPU_TENSOR_INIT();
+ DUMP_CPU_TENSOR("Mask", mask_index_data, mask_index_dims);
+
if (mask_data != nullptr) {
+ // Convert mask from boolean (0/1) to float (mask_filter_value/0.0f).
+ // Merge padding mask with causual mask, and broadcast to 3D (BxSxT).
PrepareMask(mask_index_data, mask_index_dims, static_cast(mask_data),
causal, batch_size, sequence_length, past_sequence_length, mask_filter_value_);
- DUMP_CPU_TENSOR_INIT();
DUMP_CPU_TENSOR("Mask3D", static_cast(mask_data), batch_size, sequence_length, total_sequence_length);
}
@@ -82,10 +86,8 @@ class AttentionCPUBase : public AttentionBase {
const T* past_value_data = past_value != nullptr ? past_value->Data() : nullptr;
T* present_value_data = present_value != nullptr ? present_value->MutableData() : nullptr;
- const T* relative_position_bias_data = nullptr;
- if (relative_position_bias != nullptr) {
- relative_position_bias_data = relative_position_bias->Data();
- }
+ const T* attn_bias_data = (attn_bias != nullptr) ? attn_bias->Data() : nullptr;
+ auto attn_bias_dims = (attn_bias != nullptr) ? attn_bias->Shape().GetDims() : gsl::span{};
// Compute the attention score.
size_t bytes = SafeInt(batch_size) * num_heads_ * sequence_length * total_sequence_length * sizeof(T);
@@ -95,7 +97,7 @@ class AttentionCPUBase : public AttentionBase {
static_cast(mask_data),
batch_size, sequence_length, kv_sequence_length, past_sequence_length,
qk_head_size == 0 ? v_head_size : qk_head_size, past_data, past_key_data,
- present_data, present_key_data, tp, scale, relative_position_bias_data);
+ present_data, present_key_data, tp, scale, attn_bias_data, attn_bias_dims);
// Compute the attentionScore * Value: out_tmp(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
auto out_tmp_data =
@@ -115,22 +117,23 @@ class AttentionCPUBase : public AttentionBase {
// 1 x mask_data(B, N, S, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template
- void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
- const T* Q, // Q data. Its size is BxNxSxH
- const T* K, // k data. Its size is BxNxLxH
- T* mask_data, // buffer for mask data.
- int batch_size, // batch size of self-attention
- int sequence_length, // sequence length of self-attention (S)
- int kv_sequence_length, // sequence length of cross-attention (L)
- int past_sequence_length, // sequence length of past state
- int head_size, // head size of self-attention
- const T* past, // past state
- const T* past_key, // past key only (if not using past state)
- T* present, // present state
- T* present_key, // present key only (if not using present state)
- ThreadPool* tp, // thread pool
- float scale, // scale factor
- const T* relative_position_bias_data // bias addition matrix with shape BxNxSxT
+ void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
+ const T* Q, // Q data. Its size is BxNxSxH
+ const T* K, // k data. Its size is BxNxLxH
+ T* mask_data, // buffer for mask data.
+ int batch_size, // batch size of self-attention
+ int sequence_length, // sequence length of self-attention (S)
+ int kv_sequence_length, // sequence length of cross-attention (L)
+ int past_sequence_length, // sequence length of past state
+ int head_size, // head size of self-attention
+ const T* past, // past state
+ const T* past_key, // past key only (if not using past state)
+ T* present, // present state
+ T* present_key, // present key only (if not using present state)
+ ThreadPool* tp, // thread pool
+ float scale, // scale factor
+ const T* attn_bias_data, // attention bias
+ gsl::span attn_bias_dims // attention bias shape
) const {
const int total_sequence_length = past_sequence_length + kv_sequence_length; // T = P + L
const size_t past_chunk_length = static_cast(past_sequence_length) * head_size; // P x H
@@ -138,14 +141,20 @@ class AttentionCPUBase : public AttentionBase {
const size_t kv_input_chunk_length = static_cast(kv_sequence_length) * head_size; // L x H
const size_t present_chunk_length = past_chunk_length + kv_input_chunk_length; // T x H
+ DUMP_CPU_TENSOR_INIT();
+ DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size);
+ DUMP_CPU_TENSOR("K", K, batch_size, num_heads_, total_sequence_length, head_size);
+ DUMP_CPU_TENSOR("Attn_Bias", attn_bias_data, attn_bias_dims);
+
{
const int loop_len = batch_size * num_heads_;
const float alpha = scale;
TensorOpCost unit_cost;
- const ptrdiff_t probs_matrix_bytes = SafeInt(sequence_length) * total_sequence_length * sizeof(T);
+ const ptrdiff_t probs_matrix_size = SafeInt(sequence_length) * total_sequence_length;
+ const ptrdiff_t probs_matrix_bytes = probs_matrix_size * sizeof(T);
unit_cost.compute_cycles =
- static_cast(SafeInt(2) * sequence_length * head_size * total_sequence_length);
+ static_cast(SafeInt(2) * head_size * probs_matrix_size);
unit_cost.bytes_loaded = static_cast((sequence_length + total_sequence_length) * head_size * sizeof(T));
unit_cost.bytes_stored = static_cast(probs_matrix_bytes);
@@ -160,8 +169,8 @@ class AttentionCPUBase : public AttentionBase {
unit_cost.bytes_stored += bytes_to_copy_key;
}
- if (relative_position_bias_data != nullptr) {
- unit_cost.compute_cycles += static_cast(sequence_length * total_sequence_length);
+ if (attn_bias_data != nullptr) {
+ unit_cost.compute_cycles += static_cast(probs_matrix_size);
unit_cost.bytes_loaded += probs_matrix_bytes * 2;
unit_cost.bytes_stored += probs_matrix_bytes;
}
@@ -169,13 +178,34 @@ class AttentionCPUBase : public AttentionBase {
ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const int batch_index = static_cast(i) / num_heads_;
+ const std::ptrdiff_t head_index = i % static_cast(num_heads_);
+
+ const ptrdiff_t output_offset = SafeInt(i) * probs_matrix_size;
+ const ptrdiff_t mask_offset = SafeInt(batch_index) * probs_matrix_size;
- const ptrdiff_t output_offset = SafeInt(i) * sequence_length * total_sequence_length;
- const ptrdiff_t mask_offset = SafeInt(batch_index) * sequence_length * total_sequence_length;
T* output = attention_probs + output_offset;
- // Broadcast mask data: (Bx)SxT -> (BxNx)SxT
- if (mask_data != nullptr) {
+ if (attn_bias_data != nullptr) {
+ // Attention bias has shape (B or 1, N or 1, S, T)
+ // Here we handle the broadcast of batch_size and num_heads dimensions.
+ ptrdiff_t attn_bias_offset = 0;
+ if (attn_bias_dims[0] != 1) {
+ attn_bias_offset += SafeInt(batch_index) * num_heads_ * probs_matrix_size;
+ }
+ if (attn_bias_dims[1] != 1) {
+ attn_bias_offset += head_index * probs_matrix_size;
+ }
+
+ memcpy(output, attn_bias_data + attn_bias_offset, probs_matrix_bytes);
+
+ if (mask_data != nullptr) {
+ // This can be optimized with vectorized add using MlasAddFloat32x4.
+ for (ptrdiff_t j = 0; j < probs_matrix_size; j++) {
+ output[j] += mask_data[mask_offset + j];
+ }
+ }
+ } else if (mask_data != nullptr) {
+ // Broadcast mask data: (Bx)SxT -> (BxNx)SxT
memcpy(output, mask_data + mask_offset, probs_matrix_bytes);
}
@@ -193,20 +223,13 @@ class AttentionCPUBase : public AttentionBase {
// B: K' (B x N x) T x H (B x N x) H x T H x T
// C: attention_probs (B x N x) S x T (B x N x) S x T S x T
math::Gemm(CblasNoTrans, CblasTrans, sequence_length, total_sequence_length, head_size, alpha,
- Q + q_input_chunk_length * i, k, mask_data != nullptr ? 1.0f : 0.0f, output,
- nullptr);
-
- if (relative_position_bias_data != nullptr) {
- for (int j = 0; j < sequence_length * total_sequence_length; j++) {
- output[j] += relative_position_bias_data[output_offset + j];
- }
- }
+ Q + q_input_chunk_length * i, k,
+ (mask_data != nullptr || attn_bias_data != nullptr) ? 1.0f : 0.0f,
+ output, nullptr);
}
});
}
- DUMP_CPU_TENSOR_INIT();
- DUMP_CPU_TENSOR("Q", Q, batch_size, num_heads_, sequence_length, head_size);
DUMP_CPU_TENSOR("QK (scaled)", attention_probs, batch_size, num_heads_, sequence_length, total_sequence_length);
// attention_probs(B, N, S, T) = Softmax(attention_probs)
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 0d77376779230..ca818f09c4b1e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -57,7 +57,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
const Tensor* value = context->Input(2);
const Tensor* bias = context->Input(3);
const Tensor* key_padding_mask = context->Input(4);
- const Tensor* extra_add_qk = context->Input(5);
+ const Tensor* attn_bias = context->Input(5);
const Tensor* past_key = context->Input(6);
const Tensor* past_value = context->Input(7);
@@ -75,7 +75,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
value,
bias,
key_padding_mask,
- extra_add_qk,
+ attn_bias,
past_key,
past_value,
nullptr,
@@ -135,7 +135,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
value->Data(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
- qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context);
+ qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}
OrtValue K;
@@ -149,7 +149,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
!disable_flash_ &&
!is_unidirectional_ &&
key_padding_mask == nullptr &&
- extra_add_qk == nullptr &&
+ attn_bias == nullptr &&
past_key == nullptr &&
past_value == nullptr &&
present_k == nullptr &&
@@ -215,7 +215,7 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
V.GetMutable()->MutableData(),
key_padding_mask, nullptr /* past */, past_key, past_value, output, present_k, present_v,
batch_size, q_sequence_length, kv_sequence_length,
- qk_head_size, v_head_size, v_hidden_size, extra_add_qk, context);
+ qk_head_size, v_head_size, v_hidden_size, attn_bias, context);
}
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index cfb8d36843777..0cfe90963c334 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -179,39 +179,35 @@ Status CheckPast(const T* past_key, const T* past_value, const T* past_seq_len,
return Status::OK();
}
-template
-Status CheckRelativePositionBias(
- const T* relative_position_bias, int batch_size, int num_heads, int sequence_length, int total_sequence_length,
- bool& broadcast_res_pos_bias) {
- const auto& relative_position_bias_dims = relative_position_bias->Shape().GetDims();
-
- if (relative_position_bias_dims.size() != 4) {
+inline Status CheckAttentionBias(
+ const gsl::span& attention_bias_dims,
+ int64_t batch_size, int64_t num_heads, int64_t sequence_length, int64_t total_sequence_length) {
+ if (attention_bias_dims.size() != 4) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' is expected to have 4 dimensions, got ",
- relative_position_bias_dims.size());
+ "Input 'attention_bias' is expected to have 4 dimensions, got ",
+ attention_bias_dims.size());
}
- if (relative_position_bias_dims[0] != batch_size && relative_position_bias_dims[0] != 1) {
+
+ if (attention_bias_dims[0] != batch_size && attention_bias_dims[0] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 0 should be batch_size or 1, got ",
- relative_position_bias_dims[0]);
+ "Input 'attention_bias' dimension 0 should be batch_size or 1, got ",
+ attention_bias_dims[0]);
}
- if (relative_position_bias_dims[0] == 1) {
- broadcast_res_pos_bias = true;
- }
- if (relative_position_bias_dims[1] != num_heads) {
+
+ if (attention_bias_dims[1] != num_heads && attention_bias_dims[1] != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 1 should be same as number of heads, got ",
- relative_position_bias_dims[1]);
+ "Input 'attention_bias' dimension 1 should be same as number of heads or 1, got ",
+ attention_bias_dims[1]);
}
- if (relative_position_bias_dims[2] != sequence_length) {
+ if (attention_bias_dims[2] != sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 2 should be same as sequence_length, got ",
- relative_position_bias_dims[2]);
+ "Input 'attention_bias' dimension 2 should be same as sequence_length, got ",
+ attention_bias_dims[2]);
}
- if (relative_position_bias_dims[3] != total_sequence_length) {
+ if (attention_bias_dims[3] != total_sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
- "Input 'relative_position_bias' dimension 3 should be same as total_sequence_length, got ",
- relative_position_bias_dims[3]);
+ "Input 'attention_bias' dimension 3 should be same as total_sequence_length, got ",
+ attention_bias_dims[3]);
}
return Status::OK();
}
@@ -243,7 +239,7 @@ Status CheckInputs(const T* query,
const T* value,
const T* bias,
const T* key_padding_mask,
- const T* relative_position_bias,
+ const T* attention_bias,
const T* past_key,
const T* past_value,
const T* past_seq_len,
@@ -258,13 +254,15 @@ Status CheckInputs(const T* query,
// Notations:
// B: batch_size
// N: num_heads
- // H: head_size (V might have different head size than Q and K)
- // D: hidden_size = N * H
+ // H: head_size of Q and K.
+ // H_v: head_size of V.
+ // D: hidden_size of Q and K, where D = N * H
+ // D_v: hidden_size of V, where D_v = N * H_v
// S: q_sequence_length
- // P: past_sequence_length
+ // P: past_sequence_length of kv cache
// L: kv_sequence_length
// T: total_sequence_length = P + L
- // M: max_sequence_length
+ // M: max_sequence_length of kv cache when past and present share buffer
// ---------------------------------------------------------------
// MultiHeadAttention inputs:
// ---------------------------------------------------------------
@@ -275,7 +273,7 @@ Status CheckInputs(const T* query,
// Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache is not used, L == T, D == D_v):
// query (Q) : (B, S, D)
// key (K) : (B, N, L, H)
- // value (V) : (B, N, L, H)
+ // value (V) : (B, N, L, H_v)
// Q_KV_BSNH_BSN2H - packed kv (kv cache is not used, bias is not allowed for packed kv):
// query (Q) : (B, S, D)
// key (K/V) : (B, L, N, 2, H)
@@ -288,7 +286,7 @@ Status CheckInputs(const T* query,
// Other inputs:
// bias (Q/K/V) : None or (D + D + D_v)
// key_padding_mask (K/V) : (B) or (3 * B + 2) or (B, T) or (B, S, T)
- // relative_position_bias : (B, N, S, T) or (1, N, S, T)
+ // attention_bias : (B, N, S, T), (1, N, S, T), (B, 1, S, T) or (1, 1, S, T)
// past_key : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
// past_value : (B, N, P, H) or None. Past state is only allowed for Q_K_V_BSNH.
// ---------------------------------------------------------------
@@ -298,7 +296,7 @@ Status CheckInputs(const T* query,
// query (Q) : (B, S, D)
// key (K) : (B, L, D)
// value (V) : (B, L, D)
- // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and relative_position_bias are not used. L == T):
+ // Q_K_V_BSNH_BNSH_BNSH - cross attention (kv cache and attention_bias are not used. L == T):
// query (Q) : (B, S, D)
// key (K) : (B, N, L, H)
// value (V) : (B, N, L, H)
@@ -310,7 +308,7 @@ Status CheckInputs(const T* query,
// Other inputs:
// bias (Q/K/V) : None or (3 * D)
// key_padding_mask (K/V) : None or (B, T)
- // relative_position_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA.
+ // attention_bias : (1, N, S, T), or (B, N, S, T) where only 1 x N x S x T data is used in CUDA.
//
// The following inputs are not used in cross attention (so they are None for cross attention):
// past_key : (B, N, P, H), or (B, N, M, H) when past_present_share_buffer is True.
@@ -401,10 +399,11 @@ Status CheckInputs(const T* query,
}
}
- bool broadcast_res_pos_bias = false;
- if (relative_position_bias != nullptr) {
- ORT_RETURN_IF_ERROR(CheckRelativePositionBias(
- relative_position_bias, batch_size, num_heads, sequence_length, total_sequence_length, broadcast_res_pos_bias));
+ gsl::span attention_bias_dims;
+ if (attention_bias != nullptr) {
+ attention_bias_dims = attention_bias->Shape().GetDims();
+ ORT_RETURN_IF_ERROR(CheckAttentionBias(
+ attention_bias_dims, batch_size, num_heads, sequence_length, total_sequence_length));
}
assert(qkv_format != UNKNOWN);
@@ -428,7 +427,8 @@ Status CheckInputs(const T* query,
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
output_parameters->scale = scale;
- output_parameters->broadcast_res_pos_bias = broadcast_res_pos_bias;
+ output_parameters->broadcast_attn_bias_dim_0 = attention_bias_dims.size() > 0 && attention_bias_dims[0] == 1;
+ output_parameters->broadcast_attn_bias_dim_1 = attention_bias_dims.size() > 1 && attention_bias_dims[1] == 1;
output_parameters->qkv_format = qkv_format;
}
@@ -441,7 +441,7 @@ Status CheckInputs(const T* query,
const T* value,
const T* bias,
const T* key_padding_mask,
- const T* relative_position_bias,
+ const T* attention_bias,
const T* past_key,
const T* past_value,
const T* past_seq_len,
@@ -457,7 +457,7 @@ Status CheckInputs(const T* query,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads should be no larger than ", max_threads_per_block);
}
- return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
+ return CheckInputs(query, key, value, bias, key_padding_mask, attention_bias, past_key, past_value,
past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
past_present_share_buffer, operator_type);
}
diff --git a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
index 6201b892a89b0..2c897f183164f 100644
--- a/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
+++ b/onnxruntime/contrib_ops/cpu/quantization/attention_quant.cc
@@ -160,7 +160,7 @@ Status QAttention::Compute(OpKernelContext* context) const {
bias->Shape(),
mask_index,
past_tensor,
- nullptr, // relative_position_bias
+ nullptr, // attention_bias
nullptr // parameters
));
diff --git a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h
index 2782a59d4326d..12cbc5049a02a 100644
--- a/onnxruntime/contrib_ops/cpu/utils/console_dumper.h
+++ b/onnxruntime/contrib_ops/cpu/utils/console_dumper.h
@@ -32,6 +32,11 @@ class IConsoleDumper {
virtual void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
virtual void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const = 0;
+ virtual void Print(const char* name, const int32_t* tensor, gsl::span& dims) const = 0;
+ virtual void Print(const char* name, const int64_t* tensor, gsl::span& dims) const = 0;
+ virtual void Print(const char* name, const float* tensor, gsl::span& dims) const = 0;
+ virtual void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const = 0;
+
virtual void Print(const char* name, const Tensor& value) const = 0;
virtual void Print(const char* name, const OrtValue& value) const = 0;
virtual void Print(const char* name, int index, bool end_line) const = 0;
@@ -43,5 +48,38 @@ class IConsoleDumper {
bool is_enabled_;
};
+template
+void PrintTensorByDims(const TConsoleDumper* dumper,
+ const char* name,
+ const T* tensor,
+ gsl::span& dims) {
+ if (dumper->IsEnabled() && (tensor == nullptr || dims.size() == 0)) {
+ std::cout << std::string(name) << " is None" << std::endl;
+ return;
+ }
+
+ auto num_dims = dims.size();
+ if (num_dims == 1) {
+ dumper->Print(name, tensor, 1, static_cast(dims[0]));
+ } else if (num_dims == 2) {
+ dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1]));
+ } else if (num_dims == 3) {
+ dumper->Print(name, tensor, static_cast(dims[0]), static_cast(dims[1]), static_cast(dims[2]));
+ } else if (num_dims == 4) {
+ dumper->Print(name, tensor,
+ static_cast(dims[0]),
+ static_cast(dims[1]),
+ static_cast(dims[2]),
+ static_cast(dims[3]));
+ } else if (num_dims == 5) {
+ dumper->Print(name, tensor,
+ static_cast(dims[0]) * static_cast(dims[1]),
+ static_cast(dims[2]),
+ static_cast(dims[3]),
+ static_cast(dims[4]));
+ } else {
+ ORT_ENFORCE(false, "Unsupported tensor dims");
+ }
+}
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc
index 87a9cd3965763..7755f9505d99d 100644
--- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc
+++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.cc
@@ -246,7 +246,24 @@ void CpuTensorConsoleDumper::Print(const char* name, const std::string& value, b
}
}
+void CpuTensorConsoleDumper::Print(const char* name, const int32_t* tensor, gsl::span& dims) const {
+ PrintTensorByDims(this, name, tensor, dims);
+}
+
+void CpuTensorConsoleDumper::Print(const char* name, const int64_t* tensor, gsl::span& dims) const {
+ PrintTensorByDims(this, name, tensor, dims);
+}
+
+void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, gsl::span& dims) const {
+ PrintTensorByDims(this, name, tensor, dims);
+}
+
+void CpuTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const {
+ PrintTensorByDims(this, name, tensor, dims);
+}
+
#else
+
CpuTensorConsoleDumper::CpuTensorConsoleDumper() {
}
@@ -303,6 +320,18 @@ void CpuTensorConsoleDumper::Print(const char*, int, bool) const {
void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const {
}
+
+void CpuTensorConsoleDumper::Print(const char*, const int32_t*, gsl::span&) const {
+}
+
+void CpuTensorConsoleDumper::Print(const char*, const int64_t*, gsl::span&) const {
+}
+
+void CpuTensorConsoleDumper::Print(const char*, const float*, gsl::span&) const {
+}
+
+void CpuTensorConsoleDumper::Print(const char*, const MLFloat16*, gsl::span&) const {
+}
#endif
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h
index f102eae6ec709..6fc4dfd4a0671 100644
--- a/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h
+++ b/onnxruntime/contrib_ops/cpu/utils/dump_tensor.h
@@ -30,6 +30,11 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
void Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2, int dim3) const override;
void Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2, int dim3) const override;
+ void Print(const char* name, const int32_t* tensor, gsl::span& dims) const override;
+ void Print(const char* name, const int64_t* tensor, gsl::span& dims) const override;
+ void Print(const char* name, const float* tensor, gsl::span& dims) const override;
+ void Print(const char* name, const MLFloat16* tensor, gsl::span& dims) const override;
+
void Print(const char* name, const Tensor& value) const override;
void Print(const char* name, const OrtValue& value) const override;
void Print(const char* name, int index, bool end_line) const override;
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc
index 5c0989bced70c..1d1416995a673 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc
@@ -59,7 +59,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
const Tensor* bias = context->Input(2);
const Tensor* mask_index = context->Input(3);
const Tensor* past = context->Input(kPastInputIndex);
- const Tensor* relative_position_bias = context->Input(5);
+ const Tensor* attention_bias = context->Input(5);
const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex);
auto& device_prop = GetDeviceProp();
@@ -74,7 +74,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
bias != nullptr ? bias->Shape() : bias_shape,
mask_index,
past,
- relative_position_bias,
+ attention_bias,
¶meters,
device_prop.maxThreadsPerBlock,
past_seq_len));
@@ -104,7 +104,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
#if USE_FLASH_ATTENTION
bool use_flash_attention = !disable_flash_attention_ &&
- (nullptr == relative_position_bias) &&
+ (nullptr == attention_bias) &&
nullptr == past &&
nullptr == present &&
parameters.hidden_size == parameters.v_hidden_size &&
@@ -146,7 +146,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
- nullptr == relative_position_bias &&
+ nullptr == attention_bias &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
@@ -169,7 +169,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
(nullptr == mask_index || is_mask_1d_seq_len) &&
nullptr == past &&
nullptr == present &&
- nullptr == relative_position_bias &&
+ nullptr == attention_bias &&
parameters.hidden_size == parameters.v_hidden_size &&
FusedMHARunnerFP16v2::IsSupported(sm, parameters.head_size, sequence_length,
enable_trt_flash_attention_, false);
@@ -201,12 +201,9 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
(nullptr == mask_index || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
(sizeof(T) == 2 || parameters.sequence_length >= this->kernel_options_->MinSeqLenForEfficientAttentionFp32()) &&
+ (nullptr == attention_bias || parameters.sequence_length % (4 * sizeof(T)) == 0) &&
has_memory_efficient_attention(sm, sizeof(T) == 2, parameters.head_size, parameters.v_head_size);
- if (use_memory_efficient_attention) {
- bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
- use_memory_efficient_attention = (nullptr == relative_position_bias || is_good_for_rpb);
- }
#else
constexpr bool use_memory_efficient_attention = false;
#endif
@@ -277,8 +274,8 @@ Status Attention::ComputeInternal(OpKernelContext* context) const {
if (nullptr != past) {
data.past = reinterpret_cast(past->Data());
}
- if (nullptr != relative_position_bias) {
- data.relative_position_bias = reinterpret_cast(relative_position_bias->Data());
+ if (nullptr != attention_bias) {
+ data.attention_bias = reinterpret_cast(attention_bias->Data());
}
data.has_qkv_workspace = true;
data.workspace = reinterpret_cast(work_space.get());
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
index f9eabe27d97e4..28e2b7b28764b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu
@@ -290,7 +290,7 @@ Status FlashAttention(
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
assert(nullptr == data.mask_index);
- assert(nullptr == data.relative_position_bias);
+ assert(nullptr == data.attention_bias);
assert(parameters.head_size == parameters.v_head_size);
constexpr bool is_bf16 = false;
@@ -332,6 +332,8 @@ Status EfficientAttention(
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH ||
data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH_BNSH_BNSH);
+ assert(parameters.mask_type == AttentionMaskType::MASK_NONE ||
+ parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
@@ -345,22 +347,25 @@ Status EfficientAttention(
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
- p.seqlen_k_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(data.mask_index));
- p.seqstart_q_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(
- data.mask_index + parameters.batch_size));
- p.seqstart_k_ptr = nullptr == data.mask_index
- ? nullptr
- : const_cast(reinterpret_cast(
- data.mask_index + 2 * parameters.batch_size + 1));
+
+ if (nullptr == data.mask_index) {
+ p.seqlen_k_ptr = nullptr;
+ p.seqstart_q_ptr = nullptr;
+ p.seqstart_k_ptr = nullptr;
+ } else {
+ p.seqlen_k_ptr = const_cast(reinterpret_cast(data.mask_index));
+ p.seqstart_q_ptr = p.seqlen_k_ptr + parameters.batch_size;
+ p.seqstart_k_ptr = p.seqlen_k_ptr + 2 * parameters.batch_size + 1;
+ }
+
p.query = data.q;
p.key = data.k;
p.value = data.v;
- p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
- p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
+
+ p.attn_bias = (nullptr == data.attention_bias) ? nullptr : data.attention_bias;
+ p.broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0;
+ p.broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1;
+
p.output = data.output;
p.is_kv_bsnh = data.qkv_format == AttentionQkvFormat::Q_K_V_BSNH;
p.workspace = MemoryEfficientAttentionParams::need_workspace(parameters.v_head_size, sizeof(T) == sizeof(float))
@@ -415,6 +420,12 @@ Status UnfusedAttention(
const int present_size_per_batch_k = present_sequence_length * qk_head_size;
const int present_size_per_batch_v = present_sequence_length * v_head_size;
+ DUMP_TENSOR_INIT();
+ DUMP_TENSOR_D("q", data.q, batch_size, num_heads, sequence_length, qk_head_size);
+ DUMP_TENSOR_D("k", data.k, batch_size, num_heads, total_sequence_length, qk_head_size);
+ DUMP_TENSOR_D("v", data.v, batch_size, num_heads, total_sequence_length, v_head_size);
+ DUMP_TENSOR_D("mask_index", mask_index, mask_index_dims);
+
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
total_sequence_length, sequence_length, qk_head_size,
@@ -423,7 +434,6 @@ Status UnfusedAttention(
&zero, data.scratch, total_sequence_length, sequence_length * total_sequence_length, batches,
device_prop, parameters.use_tf32));
- DUMP_TENSOR_INIT();
DUMP_TENSOR_D("QK", data.scratch, batch_size, num_heads, sequence_length, total_sequence_length);
constexpr size_t element_size = sizeof(T);
@@ -431,6 +441,9 @@ Status UnfusedAttention(
sequence_length, total_sequence_length);
T* scratch2 = data.scratch + (bytes / element_size);
+ const bool broadcast_attn_bias_dim_0 = parameters.broadcast_attn_bias_dim_0;
+ const bool broadcast_attn_bias_dim_1 = parameters.broadcast_attn_bias_dim_1;
+
// Apply softmax and store result R to scratch2: BxNxSxT
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
const int mask_dimension = static_cast(mask_index_dims.size());
@@ -444,7 +457,7 @@ Status UnfusedAttention(
ORT_RETURN_IF_ERROR(
ComputeSoftmaxWithRawMask(
ort_stream, total_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, nullptr, data.relative_position_bias, parameters.broadcast_res_pos_bias,
+ mask_index, nullptr, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional, scale, mask_dimension,
parameters.max_sequence_length, use_persistent_softmax, persistent_softmax_workspace,
parameters.mask_filter_value));
@@ -454,17 +467,17 @@ Status UnfusedAttention(
const int* mask_start = (mask_index_dims[0] > batch_size) ? mask_index + batch_size : nullptr;
ORT_RETURN_IF_ERROR(ComputeSoftmaxWithMask1D(
stream, total_sequence_length, sequence_length, batch_size, num_heads,
- mask_index, mask_start, data.relative_position_bias, parameters.broadcast_res_pos_bias,
+ mask_index, mask_start, data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
data.scratch, scratch2, parameters.is_unidirectional));
} else { // no mask
ORT_RETURN_IF_ERROR(
ComputeSoftmax(
- stream, total_sequence_length, sequence_length, batch_size, num_heads, data.relative_position_bias,
- parameters.broadcast_res_pos_bias, data.scratch, scratch2, parameters.is_unidirectional));
+ stream, total_sequence_length, sequence_length, batch_size, num_heads,
+ data.attention_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ data.scratch, scratch2, parameters.is_unidirectional));
}
DUMP_TENSOR_D("Softmax", scratch2, batch_size, num_heads, sequence_length, total_sequence_length);
- DUMP_TENSOR_D("V", data.v, batch_size, num_heads, sequence_length, v_head_size);
// compute R*V (as V*R), and store in temp_output (space used by Q): BxNxSxH_v
T* temp_output = data.q;
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
index fad353dcfeb07..a6760f84e69f3 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h
@@ -69,7 +69,7 @@ struct AttentionData {
const T* past = nullptr;
const T* past_key = nullptr;
const T* past_value = nullptr;
- const T* relative_position_bias = nullptr;
+ const T* attention_bias = nullptr;
bool has_qkv_workspace = false;
T* workspace = nullptr;
@@ -115,7 +115,7 @@ struct AttentionData {
<< ", fused_runner=" << (fused_runner != nullptr)
<< ", fused_cross=" << (fused_cross_attention_kernel != nullptr)
<< ", bias=" << (bias != nullptr)
- << ", attn_bias=" << (relative_position_bias != nullptr)
+ << ", attn_bias=" << (attention_bias != nullptr)
<< ", mask_dims=" << mask_index_dims.size()
<< ", has_qkv_workspace=" << has_qkv_workspace
<< ", workspace=" << workspace_bytes
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 05c592ec61059..575e65ebef0e9 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -77,18 +77,22 @@ void DumpInputs(contrib::AttentionParameters& parameters, AttentionData& data
DUMP_TENSOR_D("V_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
}
- if (data.relative_position_bias != nullptr) {
- DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias,
- parameters.broadcast_res_pos_bias ? 1 : batch_size,
- num_heads, sequence_length, kv_sequence_length);
+ if (data.attention_bias != nullptr) {
+ DUMP_TENSOR_D("attention_bias", data.attention_bias,
+ parameters.broadcast_attn_bias_dim_0 ? 1 : batch_size,
+ parameters.broadcast_attn_bias_dim_1 ? 1 : num_heads,
+ sequence_length,
+ parameters.total_sequence_length);
}
if (data.mask_index != nullptr) {
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
- DUMP_TENSOR_D("mask", data.mask_index, batch_size, parameters.total_sequence_length);
+ DUMP_TENSOR_D("mask (2D)", data.mask_index, batch_size, parameters.total_sequence_length);
}
if (parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
- DUMP_TENSOR_D("mask", data.mask_index, 3 * batch_size + 2, 1);
+ DUMP_TENSOR_D("mask (seqlen_k)", data.mask_index, 1, batch_size);
+ DUMP_TENSOR_D("mask (cu_seqlen_q)", data.mask_index + batch_size, 1, batch_size + 1);
+ DUMP_TENSOR_D("mask (cu_seqlen_k)", data.mask_index + 2 * batch_size + 1, 1, batch_size + 1);
}
}
}
@@ -258,7 +262,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
- assert(data.relative_position_bias == nullptr);
+ assert(data.attention_bias == nullptr);
assert(data.mask_index == nullptr);
assert(parameters.hidden_size == parameters.v_hidden_size);
@@ -290,7 +294,7 @@ Status PrepareQkv_MHA_NoPast(contrib::AttentionParameters& parameters,
#endif
else if (data.fused_runner != nullptr) {
assert(qk_head_size == v_head_size);
- assert(data.relative_position_bias == nullptr);
+ assert(data.attention_bias == nullptr);
// Query (BxSxNxH), Key (BxSxNxH), Value (BxSxNxH) => Q: BxSxNx(H + H + H)
LaunchAddBiasTransposeTrt(
@@ -524,7 +528,7 @@ Status PrepareQkv_MHA_PackedQKV(contrib::AttentionParameters& parameters,
true, v_head_size, qkv_add_bias, 3);
data.qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
} else if (nullptr != data.fused_runner) {
- assert(nullptr == data.relative_position_bias);
+ assert(nullptr == data.attention_bias);
if (data.bias == nullptr) {
// When there is no bias, we can directly use the original packed QKV input.
// Need revisit this when we add support for causal.
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
index 01ea02f48d3ab..52f94247a8b2b 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.cu
@@ -29,12 +29,45 @@ namespace onnxruntime {
namespace contrib {
namespace attention_softmax_cuda {
-template
-__device__ inline void Softmax(const int all_sequence_length,
+#define DISPATCH_BIAS(attn_bias, HAS_BIAS, ...) \
+ [&] { \
+ const dim3 grid(num_heads* sequence_length, batch_size, 1); \
+ if (attn_bias != nullptr) { \
+ constexpr static bool HAS_BIAS = true; \
+ return __VA_ARGS__(); \
+ } else { \
+ constexpr static bool HAS_BIAS = false; \
+ return __VA_ARGS__(); \
+ } \
+ }()
+
+// Macro to declare variables:
+// offset: offset in input/output
+// bias_offset: offset in attn_bias
+// b: batch index
+// s: sequence index
+// grid size is (num_heads * sequence_length, batch_size, 1)
+// input and output shape is (batch_size, num_heads, sequence_length, total_sequence_length)
+// bias shape is (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)
+#define DECLARE_SOFTMAX_VARS() \
+ [[maybe_unused]] const int s = blockIdx.x % sequence_length; \
+ const int b = blockIdx.y; \
+ int64_t offset = static_cast(b * gridDim.x + blockIdx.x) * static_cast(total_sequence_length); \
+ [[maybe_unused]] int64_t bias_offset = 0; \
+ if constexpr (HAS_BIAS) { \
+ const int j = (broadcast_attn_bias_dim_0 ? 0 : (b * gridDim.x)) + (broadcast_attn_bias_dim_1 ? s : blockIdx.x); \
+ bias_offset = static_cast(j) * static_cast(total_sequence_length); \
+ }
+
+// This kernel is for non causal, attention mask 1D or None, and total_sequence_length > 1024.
+template
+__device__ inline void Softmax(const int total_sequence_length,
+ const int sequence_length,
const int valid_end,
const int valid_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output) {
using BlockReduce = cub::BlockReduce;
@@ -45,28 +78,22 @@ __device__ inline void Softmax(const int all_sequence_length,
float thread_data_max(-CUDART_INF_F);
- const bool no_rpb = (rel_pos_bias == nullptr);
+ DECLARE_SOFTMAX_VARS();
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
- const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
- const int size_per_batch = gridDim.x * all_sequence_length;
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
- const int index = offset + i;
- float input_at_idx = no_rpb
- ? float(input[index])
- : float(input[index] + (broadcast_rel_pos_bias
- ? rel_pos_bias[index % size_per_batch]
- : rel_pos_bias[index]));
- if (thread_data_max < input_at_idx) {
- thread_data_max = input_at_idx;
+ float input_data = HAS_BIAS
+ ? float(input[offset + i]) + float(attn_bias[bias_offset + i])
+ : float(input[offset + i]);
+ if (thread_data_max < input_data) {
+ thread_data_max = input_data;
}
}
}
-
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
// Store max value
@@ -78,9 +105,11 @@ __device__ inline void Softmax(const int all_sequence_length,
float thread_data_sum(0.f);
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
- const int index = offset + i;
- float val = no_rpb ? input[index] : input[index] + rel_pos_bias[index % size_per_batch];
- thread_data_sum += expf(val - max_block);
+ float input_data = HAS_BIAS
+ ? float(input[offset + i]) + float(attn_bias[bias_offset + i])
+ : float(input[offset + i]);
+
+ thread_data_sum += expf(input_data - max_block);
}
}
@@ -90,21 +119,25 @@ __device__ inline void Softmax(const int all_sequence_length,
}
__syncthreads();
- for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
const int index = offset + i;
- float input_at_idx = no_rpb ? float(input[index]) : float(input[index] + rel_pos_bias[index % size_per_batch]);
- const float val = (i >= valid_start && i < valid_end) ? expf(input_at_idx - max_block) * sum_reverse_block : 0.f;
+ float input_data = HAS_BIAS
+ ? float(input[index]) + float(attn_bias[bias_offset + i])
+ : float(input[index]);
+ const float val = (i >= valid_start && i < valid_end) ? expf(input_data - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
}
-template
-__device__ inline void SoftmaxSmall(const int all_sequence_length,
+// This kernel is for non causal, attention mask 1D or None, and total_sequence_length <= 1024.
+template
+__device__ inline void SoftmaxSmall(const int total_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
bool causal) {
@@ -114,34 +147,30 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
__shared__ float sum_reverse_block;
__shared__ float max_block;
- // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
- const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
+ DECLARE_SOFTMAX_VARS();
+
const int index = offset + threadIdx.x;
// Update end position for causal.
int end = valid_end;
if (causal) {
- const int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
+ const int end_causal = total_sequence_length - sequence_length + s + 1;
if (end_causal < end) {
end = end_causal;
}
}
const bool is_valid = (threadIdx.x >= valid_start && threadIdx.x < end);
+ float input_data = is_valid ? (HAS_BIAS
+ ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x])
+ : float(input[index]))
+ : float(-CUDART_INF_F);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
- const bool no_rpb = (rel_pos_bias == nullptr);
- const int size_per_batch = gridDim.x * all_sequence_length;
- float input_data = no_rpb
- ? float(input[index])
- : float(input[index] + (broadcast_rel_pos_bias
- ? rel_pos_bias[index % size_per_batch]
- : rel_pos_bias[index]));
- float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
- const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
+ const auto max = BlockReduce(tmp_storage).Reduce(input_data, cub::Max(), end);
// Store max value
if (threadIdx.x == 0) {
@@ -162,23 +191,25 @@ __device__ inline void SoftmaxSmall(const int all_sequence_length,
}
__syncthreads();
- // threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
- if (threadIdx.x < all_sequence_length) {
+ // threadIdx.x might be larger than total_sequence_length due to alignment to 32x.
+ if (threadIdx.x < total_sequence_length) {
output[index] = is_valid ? T(thread_data_exp * sum_reverse_block) : T(0.f);
}
}
-template
-__global__ void SoftmaxLargeKernel(const int all_sequence_length,
+// This kernel is for causal or not, attention mask 1D or None, and total_sequence_length <= 1024.
+template
+__global__ void SoftmaxLargeKernel(const int total_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
bool causal) {
- extern __shared__ float cached_data[]; // float[all_sequence_length]
+ extern __shared__ float cached_data[]; // float[total_sequence_length]
using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage tmp_storage;
@@ -186,36 +217,26 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length,
__shared__ float sum_reverse_block;
__shared__ float max_block;
+ DECLARE_SOFTMAX_VARS();
+
// Update end position for causal.
int end = valid_end;
if (causal) {
- int end_causal = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
+ int end_causal = total_sequence_length - sequence_length + s + 1;
if (end_causal < end) {
end = end_causal;
}
}
- // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
- const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
- const int size_per_batch = gridDim.x * all_sequence_length;
-
float thread_data_max = -CUDART_INF_F;
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- const int index = offset + seq_idx;
- const bool is_valid = (seq_idx >= valid_start && seq_idx < end);
-
- // e^x is represented as infinity if x is large enough, like 100.f.
- // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
- // a math transform as below is leveraged to get a stable softmax:
- // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
- float input_data = is_valid
- ? (rel_pos_bias
- ? float(input[index] + (broadcast_rel_pos_bias
- ? rel_pos_bias[index % size_per_batch]
- : rel_pos_bias[index]))
- : float(input[index]))
- : float(-CUDART_INF_F);
- cached_data[seq_idx] = input_data;
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ const int index = offset + i;
+ const bool is_valid = (i >= valid_start && i < end);
+ float input_data = is_valid ? (HAS_BIAS
+ ? float(input[index]) + float(attn_bias[bias_offset + i])
+ : float(input[index]))
+ : float(-CUDART_INF_F);
+ cached_data[i] = input_data;
thread_data_max = max(thread_data_max, input_data);
}
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
@@ -227,10 +248,10 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length,
__syncthreads();
float thread_data_exp(0.f);
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- const bool is_valid = (seq_idx >= valid_start && seq_idx < end);
- cached_data[seq_idx] = is_valid ? expf(cached_data[seq_idx] - max_block) : 0.0f;
- thread_data_exp += cached_data[seq_idx];
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ const bool is_valid = (i >= valid_start && i < end);
+ cached_data[i] = is_valid ? expf(cached_data[i] - max_block) : 0.0f;
+ thread_data_exp += cached_data[i];
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
@@ -240,20 +261,22 @@ __global__ void SoftmaxLargeKernel(const int all_sequence_length,
}
__syncthreads();
- // threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- const bool is_valid = (seq_idx >= valid_start && seq_idx < end);
- output[offset + seq_idx] = is_valid ? T(cached_data[seq_idx] * sum_reverse_block) : T(0.f);
+ // threadIdx.x might be larger than total_sequence_length due to alignment to 32x.
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ const bool is_valid = (i >= valid_start && i < end);
+ output[offset + i] = is_valid ? T(cached_data[i] * sum_reverse_block) : T(0.f);
}
}
-template
-__global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
+// This kernel is for causal or not, raw attention mask (2D, 3D or 4D) and total_sequence_length > 1024.
+template
+__global__ void SoftmaxWithRawMaskLargeKernel(const int total_sequence_length,
const int sequence_length,
const int* attention_mask, // 2D, 3D or 4D attention mask
const bool* key_padding_mask,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
const bool causal,
@@ -262,7 +285,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
const int max_sequence_length,
const bool skip_softmax,
const float mask_filter_value) {
- extern __shared__ float cached_data[]; // float[all_sequence_length]
+ extern __shared__ float cached_data[]; // float[total_sequence_length]
using BlockReduce = cub::BlockReduce;
__shared__ typename BlockReduce::TempStorage tmp_storage;
@@ -271,37 +294,30 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
__shared__ float max_block;
float max_thread_data = -CUDART_INF_F;
- const int size_per_batch = gridDim.x * all_sequence_length;
-
- // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
- int base_index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- float thread_data = -CUDART_INF_F;
- int index = base_index + seq_idx;
- if (rel_pos_bias == nullptr) {
- thread_data = float(input[index]) * rsqrt_head_size;
- } else {
- T rel_pos_bias_value = broadcast_rel_pos_bias ? rel_pos_bias[index % size_per_batch] : rel_pos_bias[index];
- thread_data = float(input[index] + rel_pos_bias_value) * rsqrt_head_size;
- }
- const int sequence_index = blockIdx.x % sequence_length;
+ DECLARE_SOFTMAX_VARS();
+
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ int index = offset + i;
+ float input_data = HAS_BIAS
+ ? float(input[index]) + float(attn_bias[bias_offset + i])
+ : float(input[index]);
+ float thread_data = input_data * rsqrt_head_size;
if (causal) {
- int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length.
- if (seq_idx > from_index) {
+ int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
+ if (i > from_index) {
thread_data = -CUDART_INF_F;
}
}
int mask_offset = 0;
- const int batch_index = blockIdx.y;
if (mask_dimension == 2) {
- mask_offset = batch_index * all_sequence_length + seq_idx;
+ mask_offset = b * total_sequence_length + i;
} else if (mask_dimension == 3) {
- mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + seq_idx;
+ mask_offset = (b * sequence_length + s) * total_sequence_length + i;
} else if (mask_dimension == 4) {
- int from_index = all_sequence_length - sequence_length + sequence_index;
- mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + seq_idx;
+ int from_index = total_sequence_length - sequence_length + s;
+ mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + i;
}
if (nullptr == key_padding_mask) {
@@ -318,7 +334,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
if (skip_softmax) {
output[index] = T(thread_data);
}
- cached_data[seq_idx] = thread_data;
+ cached_data[i] = thread_data;
max_thread_data = max(max_thread_data, thread_data);
}
@@ -326,7 +342,7 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
return;
}
- const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), all_sequence_length);
+ const float max = BlockReduce(tmp_storage).Reduce(max_thread_data, cub::Max(), total_sequence_length);
// Store max value
if (threadIdx.x == 0) {
@@ -335,9 +351,9 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
__syncthreads();
float sum_thread_data_exp = 0.0f;
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- auto ev = expf(cached_data[seq_idx] - max_block);
- cached_data[seq_idx] = ev;
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ auto ev = expf(cached_data[i] - max_block);
+ cached_data[i] = ev;
sum_thread_data_exp += ev;
}
const auto sum = BlockReduce(tmp_storage).Reduce(sum_thread_data_exp, cub::Sum(), TPB);
@@ -348,18 +364,20 @@ __global__ void SoftmaxWithRawMaskLargeKernel(const int all_sequence_length,
}
__syncthreads();
- for (int seq_idx = threadIdx.x; seq_idx < all_sequence_length; seq_idx += TPB) {
- output[base_index + seq_idx] = T(cached_data[seq_idx] * sum_reverse_block);
+ for (int i = threadIdx.x; i < total_sequence_length; i += TPB) {
+ output[offset + i] = T(cached_data[i] * sum_reverse_block);
}
}
-template
-__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
+// This kernel is for causal or not, raw attention mask (2D, 3D or 4D), and total_sequence_length <= 1024.
+template
+__device__ inline void SoftmaxWithRawMaskSmall(const int total_sequence_length,
const int sequence_length,
const int* attention_mask, // 2D, 3D or 4D attention mask
const bool* key_padding_mask,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
const bool causal,
@@ -374,31 +392,29 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
__shared__ float sum_reverse_block;
__shared__ float max_block;
- // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
- int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
- const int size_per_batch = gridDim.x * all_sequence_length;
+ DECLARE_SOFTMAX_VARS();
+
+ int64_t index = offset + threadIdx.x;
float thread_data = -CUDART_INF_F;
- if (threadIdx.x < all_sequence_length) {
+ if (threadIdx.x < total_sequence_length) {
thread_data = float(input[index]) * rsqrt_head_size;
- const int sequence_index = blockIdx.x % sequence_length;
if (causal) {
- int from_index = all_sequence_length - sequence_length + sequence_index; // offset in all sequence length.
+ int from_index = total_sequence_length - sequence_length + s; // offset in total sequence length.
if (threadIdx.x > from_index) {
thread_data = -CUDART_INF_F;
}
}
int mask_offset = 0;
- const int batch_index = blockIdx.y;
if (mask_dimension == 2) {
- mask_offset = batch_index * all_sequence_length + threadIdx.x;
+ mask_offset = b * total_sequence_length + threadIdx.x;
} else if (mask_dimension == 3) {
- mask_offset = (batch_index * sequence_length + sequence_index) * all_sequence_length + threadIdx.x;
+ mask_offset = (b * sequence_length + s) * total_sequence_length + threadIdx.x;
} else if (mask_dimension == 4) {
- int from_index = all_sequence_length - sequence_length + sequence_index;
- mask_offset = (batch_index * max_sequence_length + from_index) * max_sequence_length + threadIdx.x;
+ int from_index = total_sequence_length - sequence_length + s;
+ mask_offset = (b * max_sequence_length + from_index) * max_sequence_length + threadIdx.x;
}
if (nullptr == key_padding_mask) {
@@ -412,20 +428,19 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
}
}
- if (rel_pos_bias != nullptr) {
- float bias = broadcast_rel_pos_bias ? float(rel_pos_bias[index % size_per_batch]) : float(rel_pos_bias[index]);
- thread_data += bias;
+ if (HAS_BIAS) {
+ thread_data += float(attn_bias[bias_offset + threadIdx.x]);
}
}
if (skip_softmax) {
- if (threadIdx.x < all_sequence_length) {
+ if (threadIdx.x < total_sequence_length) {
output[index] = T(thread_data);
}
return;
}
- const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
+ const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), total_sequence_length);
// Store max value
if (threadIdx.x == 0) {
@@ -433,8 +448,8 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
}
__syncthreads();
- float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
- const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length);
+ float thread_data_exp = threadIdx.x < total_sequence_length ? expf(thread_data - max_block) : 0.0f;
+ const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), total_sequence_length);
// Store value of 1.0/sum
if (threadIdx.x == 0) {
@@ -442,84 +457,97 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
}
__syncthreads();
- if (threadIdx.x < all_sequence_length) {
+ if (threadIdx.x < total_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
-template
-__global__ void SoftmaxKernelSmall(const int all_sequence_length,
+template
+__global__ void SoftmaxKernelSmall(const int total_sequence_length,
const int sequence_length,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
bool causal) {
- SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
+ SoftmaxSmall(total_sequence_length, sequence_length, total_sequence_length, 0,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
}
-template
-__global__ void SoftmaxKernel(const int all_sequence_length,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+template
+__global__ void SoftmaxKernel(const int total_sequence_length,
+ const int sequence_length,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output) {
- Softmax(all_sequence_length, all_sequence_length, 0,
- rel_pos_bias, broadcast_rel_pos_bias, input, output);
+ Softmax(total_sequence_length, sequence_length, total_sequence_length, 0,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output);
}
template
-Status ComputeSoftmax(cudaStream_t stream, const int all_sequence_length, const int sequence_length,
- const int batch_size, const int num_heads, const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias, T* input, T* output, bool causal) {
- const dim3 grid(sequence_length * num_heads, batch_size, 1);
- if (all_sequence_length <= 32) {
- const int blockSize = 32;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 64) {
- const int blockSize = 64;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 128) {
- const int blockSize = 128;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 256) {
- const int blockSize = 256;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 512) {
- const int blockSize = 512;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 1024) {
- const int blockSize = 1024;
- SoftmaxKernelSmall<<>>(
- all_sequence_length, sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (!causal) {
- const int blockSize = 1024;
- SoftmaxKernel<<>>(
- all_sequence_length, rel_pos_bias, broadcast_rel_pos_bias, input, output);
- } else {
- const int blockSize = 256;
- const int sh_bytes = sizeof(float) * all_sequence_length;
- SoftmaxLargeKernel<<>>(
- all_sequence_length, sequence_length, all_sequence_length, 0, rel_pos_bias, broadcast_rel_pos_bias,
- input, output, true);
- }
-
+Status ComputeSoftmax(cudaStream_t stream, const int total_sequence_length, const int sequence_length,
+ const int batch_size, const int num_heads, const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0, const bool broadcast_attn_bias_dim_1,
+ T* input, T* output, bool causal) {
+ DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] {
+ if (total_sequence_length <= 32) {
+ const int blockSize = 32;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (total_sequence_length <= 64) {
+ const int blockSize = 64;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (total_sequence_length <= 128) {
+ const int blockSize = 128;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (total_sequence_length <= 256) {
+ const int blockSize = 256;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (total_sequence_length <= 512) {
+ const int blockSize = 512;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (total_sequence_length <= 1024) {
+ const int blockSize = 1024;
+ SoftmaxKernelSmall<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
+ } else if (!causal) {
+ const int blockSize = 1024;
+ SoftmaxKernel<<>>(
+ total_sequence_length, sequence_length,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output);
+ } else {
+ const int blockSize = 256;
+ const int sh_bytes = sizeof(float) * total_sequence_length;
+ SoftmaxLargeKernel<<>>(
+ total_sequence_length, sequence_length, total_sequence_length, 0, attn_bias,
+ broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ input, output, true);
+ }
+ });
return CUDA_CALL(cudaGetLastError());
}
-template
-__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length,
+template
+__global__ void MaskedSoftmaxKernelSmall(const int total_sequence_length,
const int sequence_length,
const int* mask_end,
const int* mask_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
bool causal) {
@@ -529,25 +557,27 @@ __global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length,
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
- end_position = min(all_sequence_length, mask_end[batch]);
+ end_position = min(total_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
- end_position = all_sequence_length;
+ end_position = total_sequence_length;
}
}
__syncthreads();
- SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
+ SoftmaxSmall(total_sequence_length, sequence_length, end_position, start_position,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output, causal);
}
-template
-__device__ inline void SoftmaxSmallPacked(const int sequence_length,
+template
+__device__ inline void SoftmaxSmallPacked(const int total_sequence_length,
+ const int sequence_length,
const int end,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output) {
using BlockReduce = cub::BlockReduce;
@@ -556,23 +586,13 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length,
__shared__ float sum_reverse_block;
__shared__ float max_block;
- // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
- const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * sequence_length;
- const int index = offset + threadIdx.x;
+ DECLARE_SOFTMAX_VARS();
+
+ int64_t index = offset + threadIdx.x;
bool is_valid = threadIdx.x < end;
- // e^x is represented as infinity if x is large enough, like 100.f.
- // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
- // a math transform as below is leveraged to get a stable softmax:
- // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
- const bool no_rpb = (rel_pos_bias == nullptr);
- const int size_per_batch = gridDim.x * sequence_length;
- float input_data = no_rpb
- ? float(input[index])
- : float(input[index] + (broadcast_rel_pos_bias
- ? rel_pos_bias[index % size_per_batch]
- : rel_pos_bias[index]));
+ float input_data = HAS_BIAS ? float(input[index]) + float(attn_bias[bias_offset + threadIdx.x]) : float(input[index]);
float thread_data_max = is_valid ? input_data : float(-CUDART_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
@@ -596,16 +616,20 @@ __device__ inline void SoftmaxSmallPacked(const int sequence_length,
}
__syncthreads();
- // threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
+ // threadIdx.x might be larger than total_sequence_length due to alignment to 32x.
if (threadIdx.x < sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
-template
+template
__global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input,
- const T* rel_pos_bias, const bool broadcast_rel_pos_bias,
- const int* cum_seq_length, const int sequence_length,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
+ const int* cum_seq_length,
+ const int total_sequence_length,
+ const int sequence_length,
T* output) {
__shared__ int end_position;
@@ -615,15 +639,18 @@ __global__ void SoftmaxKernelSmallWithCumSeqLen(const T* input,
}
__syncthreads();
- SoftmaxSmallPacked(sequence_length, end_position,
- rel_pos_bias, broadcast_rel_pos_bias,
- input, output);
+ SoftmaxSmallPacked(total_sequence_length, sequence_length, end_position,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output);
}
-template
+template
__global__ void SoftmaxKernelWithCumSeqLen(const T* input,
- const T* rel_pos_bias, const bool broadcast_rel_pos_bias,
- const int* cum_seq_length, const int sequence_length,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
+ const int* cum_seq_length,
+ const int total_sequence_length,
+ const int sequence_length,
T* output) {
__shared__ int end_position;
@@ -633,16 +660,19 @@ __global__ void SoftmaxKernelWithCumSeqLen(const T* input,
}
__syncthreads();
- Softmax(sequence_length, end_position, 0 /*start_position*/,
- rel_pos_bias, broadcast_rel_pos_bias, input, output);
+ constexpr int start_position = 0;
+ Softmax(total_sequence_length, sequence_length, end_position, start_position,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output);
}
-template
-__global__ void MaskedSoftmaxKernel(const int all_sequence_length,
+template
+__global__ void MaskedSoftmaxKernel(const int total_sequence_length,
+ const int sequence_length,
const int* mask_end,
const int* mask_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input, T* output) {
__shared__ int start_position;
__shared__ int end_position;
@@ -650,27 +680,28 @@ __global__ void MaskedSoftmaxKernel(const int all_sequence_length,
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
- end_position = min(all_sequence_length, mask_end[batch]);
+ end_position = min(total_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
- end_position = all_sequence_length;
+ end_position = total_sequence_length;
}
}
__syncthreads();
- Softmax(all_sequence_length, end_position, start_position,
- rel_pos_bias, broadcast_rel_pos_bias, input, output);
+ Softmax(total_sequence_length, sequence_length, end_position, start_position,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output);
}
-template
-__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length,
+template
+__global__ void SoftmaxWithRawMaskSmallKernel(const int total_sequence_length,
const int sequence_length,
const int* attention_mask,
const bool* key_padding_mask,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
const bool causal,
@@ -679,9 +710,9 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length,
const int max_sequence_length,
const bool skip_softmax,
const float mask_filter_value) {
- SoftmaxWithRawMaskSmall(
- all_sequence_length, sequence_length,
- attention_mask, key_padding_mask, rel_pos_bias, broadcast_rel_pos_bias, input, output,
+ SoftmaxWithRawMaskSmall(
+ total_sequence_length, sequence_length, attention_mask, key_padding_mask,
+ attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1, input, output,
causal, rsqrt_head_size, mask_dimension, max_sequence_length,
skip_softmax, mask_filter_value);
}
@@ -689,107 +720,120 @@ __global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length,
template
Status ComputeSoftmaxWithCumSeqLength(
const T* input,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const int32_t* cum_seq_length,
const int batch_size,
const int sequence_length,
+ const int total_sequence_length,
const int num_heads,
T* output, cudaStream_t stream) {
- const dim3 grid(sequence_length * num_heads, batch_size, 1);
-
- if (sequence_length <= 32) {
- const int blockSize = 32;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
-
- } else if (sequence_length <= 64) {
- const int blockSize = 64;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- } else if (sequence_length <= 128) {
- const int blockSize = 128;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- } else if (sequence_length <= 256) {
- const int blockSize = 256;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- } else if (sequence_length <= 512) {
- const int blockSize = 512;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- } else if (sequence_length <= 1024) {
- const int blockSize = 1024;
- SoftmaxKernelSmallWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- } else {
- SoftmaxKernelWithCumSeqLen
- <<>>(input, rel_pos_bias, broadcast_rel_pos_bias,
- cum_seq_length, sequence_length, output);
- }
+ DISPATCH_BIAS(attn_bias, HAS_BIAS, [&] {
+ if (sequence_length <= 32) {
+ const int blockSize = 32;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else if (sequence_length <= 64) {
+ const int blockSize = 64;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else if (sequence_length <= 128) {
+ const int blockSize = 128;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else if (sequence_length <= 256) {
+ const int blockSize = 256;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else if (sequence_length <= 512) {
+ const int blockSize = 512;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else if (sequence_length <= 1024) {
+ const int blockSize = 1024;
+ SoftmaxKernelSmallWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ } else {
+ const int blockSize = 1024;
+ SoftmaxKernelWithCumSeqLen
+ <<>>(input, attn_bias, broadcast_attn_bias_dim_0, broadcast_attn_bias_dim_1,
+ cum_seq_length, total_sequence_length, sequence_length, output);
+ }
+ });
return CUDA_CALL(cudaGetLastError());
}
template
Status ComputeSoftmaxWithMask1D(cudaStream_t stream,
- const int all_sequence_length,
+ const int total_sequence_length,
const int sequence_length,
const int batch_size,
const int num_heads,
const int* mask_index,
const int* mask_start,
- const T* rel_pos_bias,
- const bool broadcast_rel_pos_bias,
+ const T* attn_bias,
+ const bool broadcast_attn_bias_dim_0,
+ const bool broadcast_attn_bias_dim_1,
const T* input,
T* output,
const bool causal) {
- const dim3 grid(sequence_length * num_heads, batch_size, 1);
-
- if (all_sequence_length <= 32) {
- const int blockSize = 32;
- MaskedSoftmaxKernelSmall
- <<>>(all_sequence_length, sequence_length, mask_index, mask_start,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 64) {
- const int blockSize = 64;
- MaskedSoftmaxKernelSmall
- <<>>(all_sequence_length, sequence_length, mask_index, mask_start,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 128) {
- const int blockSize = 128;
- MaskedSoftmaxKernelSmall
- <<>>(all_sequence_length, sequence_length, mask_index, mask_start,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 256) {
- const int blockSize = 256;
- MaskedSoftmaxKernelSmall
- <<>>(all_sequence_length, sequence_length, mask_index, mask_start,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 512) {
- const int blockSize = 512;
- MaskedSoftmaxKernelSmall
- <<>>(all_sequence_length, sequence_length, mask_index, mask_start,
- rel_pos_bias, broadcast_rel_pos_bias, input, output, causal);
- } else if (all_sequence_length <= 1024) {
- const int blockSize = 1024;
- MaskedSoftmaxKernelSmall