diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 22e82443167f6..45c0e6f822ce9 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -3031,8 +3031,6 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-
unidirectional : int
-
Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8) @@ -5023,10 +5021,6 @@ This version of the operator has been available since version 1 of the 'com.micr
interleaved : int
Rotate using interleaved pattern. Default value is 0 (False).
-
num_heads : int
-
Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
-
rotary_embedding_dim : int
-
Rotary embedding dimension. Default value is 0.
scale : float
Custom scale will be used if specified. Default value is 1.0
@@ -5039,9 +5033,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
+
2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache : T
-
2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
+
2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs @@ -5054,7 +5048,7 @@ This version of the operator has been available since version 1 of the 'com.micr #### Type Constraints
-
T : tensor(float), tensor(float16), tensor(bfloat16)
+
T : tensor(float), tensor(float16)
Constrain input and output types to float tensors.
M : tensor(int64)
Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 9a2a7ac89bbb3..1be895dba3841 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -868,7 +868,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| -|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)| +|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)| |Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 768676259aa14..4711ccf487cc8 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -211,12 +211,6 @@ Status Attention::Compute(OpKernelContext* context) const { relative_position_bias, ¶meters)); - if (parameters.do_rotary) { - ORT_NOT_IMPLEMENTED( - "Rotary embedding is not supported in Attention CPU kernel. \ - Please fuse the model with MHA + RotaryEmbedding."); - } - const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; const int input_hidden_size = parameters.input_hidden_size; diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc index eb25d0fd7cc1e..694c40bf3eda6 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc @@ -40,7 +40,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i num_heads_ = static_cast(num_heads); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } // Reshape Q/K/V from BxSxD to BxSxNxH @@ -284,9 +283,8 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const { nullptr, ¶meters, num_heads_, - mask_filter_value_, scale, - is_unidirectional_, + mask_filter_value_, past_present_share_buffer, false)); diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h index fb7da78a5c0a5..4c86b777e9842 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h @@ -18,7 +18,6 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase { protected: int num_heads_; // number of attention heads float mask_filter_value_; - bool is_unidirectional_; }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h index c91f5b601b4e9..00e82c9844b3d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h @@ -25,7 +25,6 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, - bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing) { // key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None @@ -316,7 +315,7 @@ Status CheckInputs(const T* query, output_parameters->head_size = hidden_size / num_heads; output_parameters->v_head_size = v_hidden_size / num_heads; output_parameters->num_heads = num_heads; - output_parameters->is_unidirectional = is_unidirectional; + output_parameters->is_unidirectional = false; output_parameters->past_present_share_buffer = past_present_share_buffer; output_parameters->mask_filter_value = mask_filter_value; output_parameters->mask_type = mask_type; @@ -343,7 +342,6 @@ Status CheckInputs(const T* query, int num_heads, float mask_filter_value, float scale, - bool is_unidirectional, bool past_present_share_buffer, bool dmmha_packing, int max_threads_per_block) { @@ -352,8 +350,8 @@ Status CheckInputs(const T* query, } return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, - past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional, - past_present_share_buffer, dmmha_packing); + past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer, + dmmha_packing); } } // namespace multihead_attention_helper diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc index aa8b5b5f608fa..47f462d75fcc4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc @@ -27,13 +27,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); - rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); - num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); - - if (rotary_embedding_dim > 0) { - ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified"); - } } template @@ -48,8 +42,6 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, - num_heads, - rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -67,66 +59,61 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const { const int batch_size = parameters.batch_size; const int sequence_length = parameters.sequence_length; - const int n_heads = parameters.num_heads; + const int num_heads = parameters.num_heads; const int head_size = parameters.head_size; const int position_ids_format = parameters.position_ids_format; - const int rotary_emb_dim = parameters.rotary_embedding_dim; - const int half_rotary_emb_dim = rotary_emb_dim / 2; - + const int half_head_size = head_size / 2; // Default input tensor shape is [batch, seq_len, hidden_size] int head_stride = head_size; - int seq_stride = n_heads * head_stride; + int seq_stride = num_heads * head_stride; int batch_stride = sequence_length * seq_stride; if (parameters.transposed) { - // Transposed input tensor shape is [batch, n_heads, seq_len, head_size] + // Transposed input tensor shape is [batch, num_heads, seq_len, head_size] seq_stride = head_size; head_stride = sequence_length * seq_stride; - batch_stride = n_heads * head_stride; + batch_stride = num_heads * head_stride; } AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); auto* tp = context->GetOperatorThreadPool(); - const int loop_len = batch_size * sequence_length * n_heads; - const double cost = static_cast(rotary_emb_dim); + const int loop_len = batch_size * sequence_length * num_heads; + const double cost = static_cast(head_size); ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) { for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) { - const int b = static_cast((ptr / n_heads) / sequence_length); - const int s = static_cast((ptr / n_heads) % sequence_length); - const int n = static_cast(ptr % n_heads); + const int b = static_cast((ptr / num_heads) / sequence_length); + const int s = static_cast((ptr / num_heads) % sequence_length); + const int n = static_cast(ptr % num_heads); const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input_src + block_offset; T* output_data = output_dest + block_offset; - // Cache is (M, H/2) or (M, rotary_embedding_dim/2) + // Cache is (M, H/2) const int position_id = (position_ids_format == 0) ? static_cast(pos_ids_data[0]) + s : static_cast(pos_ids_data[b * sequence_length + s]); - const int cache_offset = position_id * half_rotary_emb_dim; + const int cache_offset = position_id * half_head_size; const T* cos_data = cos_cache_data + cache_offset; const T* sin_data = sin_cache_data + cache_offset; int cache_idx = 0; T sign = 0; int j = 0; - for (int i = 0; i < rotary_emb_dim; i++) { + for (int i = 0; i < head_size; i++) { if (interleaved) { - cache_idx = (i / 2) % half_rotary_emb_dim; + cache_idx = (i / 2) % half_head_size; sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1); j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign } else { - cache_idx = i % half_rotary_emb_dim; - sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1); - j = (i + half_rotary_emb_dim) % rotary_emb_dim; + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? static_cast(-1) : static_cast(1); + j = (i + half_head_size) % head_size; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } - for (int i = rotary_emb_dim; i < head_size; i++) { - output_data[i] = input_data[i]; - } } }); diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h index 4e32424a22b6c..be834a66cdc69 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h @@ -16,8 +16,6 @@ class RotaryEmbedding final : public OpKernel { protected: float scale; - int num_heads; - int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index dcbb36d1c4a3c..7b2e8289f7b06 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -11,15 +11,14 @@ namespace rotary_embedding_helper { // Parameters deduced from node attributes and inputs/outputs. struct RotaryParameters { - int batch_size; // Batch size used by input - int sequence_length; // Sequence length used by input - int hidden_size; // Hidden size used by input - int head_size; // Head size - int rotary_embedding_dim; // Rotary embedding dimension. - int num_heads; // num_heads = hidden_size / head_size - int max_sequence_length; // Sequence length used by cos/sin cache - int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) - bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) + int batch_size; // Batch size used by input + int sequence_length; // Sequence length used by input + int hidden_size; // Hidden size used by input + int head_size; // Head size used by cos/sin cache * 2 + int num_heads; // num_heads = hidden_size / head_size + int max_sequence_length; // Sequence length used by cos/sin cache + int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length) + bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden) }; template @@ -27,13 +26,11 @@ Status CheckInputs(const T* input, const T* position_ids, const T* cos_cache, const T* sin_cache, - int num_heads, - int rotary_embedding_dim, void* parameters) { // input : (batch_size, sequence_length, hidden_size) // position ids : (1) or (batch_size, sequence_length) - // cos cache : (max_sequence_length, rotary_embedding_dim / 2) - // sin cache : (max_sequence_length, rotary_embedding_dim / 2) + // cos cache : (max_sequence_length, head_size / 2) + // sin cache : (max_sequence_length, head_size / 2) // Check input const auto& input_dims = input->Shape().GetDims(); @@ -63,12 +60,6 @@ Status CheckInputs(const T* input, "the same shape"); } - // Check num_heads and rotary_embedding_dim - if (rotary_embedding_dim > 0 && num_heads == 0) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ", - "specified"); - } - // Get attributes from inputs int batch_size = static_cast(input_dims[0]); int sequence_length = static_cast(input_dims[1]); @@ -82,13 +73,8 @@ Status CheckInputs(const T* input, transposed = true; } int max_sequence_length = static_cast(cos_cache_dims[0]); - int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2 - : static_cast(hidden_size / num_heads); - if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ", - "head_size"); - } - + int head_size = static_cast(cos_cache_dims[1]) * 2; + int num_heads = hidden_size / head_size; int position_ids_format = -1; // Check position_ids input shapes @@ -105,15 +91,23 @@ Status CheckInputs(const T* input, } else { position_ids_format = 0; } - // Check cos_cache input shapes if (max_sequence_length != static_cast(cos_cache_dims[0])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ", "max_sequence_length, got ", cos_cache_dims[0]); } - if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) { + if ((head_size / 2) != static_cast(cos_cache_dims[1])) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ", - "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]); + "head_size / 2, got ", cos_cache_dims[1]); + } + // Check sin_cache input shapes + if (max_sequence_length != static_cast(sin_cache_dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ", + "max_sequence_length, got ", sin_cache_dims[0]); + } + if ((head_size / 2) != static_cast(sin_cache_dims[1])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ", + "head_size / 2, got ", sin_cache_dims[1]); } // Set rotary parameters @@ -123,11 +117,10 @@ Status CheckInputs(const T* input, output_parameters->sequence_length = sequence_length; output_parameters->hidden_size = hidden_size; output_parameters->head_size = head_size; - output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size); + output_parameters->num_heads = num_heads; output_parameters->max_sequence_length = max_sequence_length; output_parameters->position_ids_format = position_ids_format; output_parameters->transposed = transposed; - output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size; } return Status::OK(); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc index f978f50c6851f..ebd66d8c6528e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc @@ -44,8 +44,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; - ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead."); disable_fused_self_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false); @@ -107,7 +105,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { num_heads_, mask_filter_value_, scale_, - is_unidirectional_, false, // past_present_share_buffer false, // dmmha_packing device_prop.maxThreadsPerBlock)); diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h index 86a32c92ce003..c162f7133cc1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h @@ -25,7 +25,6 @@ class MultiHeadAttention final : public CudaKernel { int num_heads_; // number of attention heads float mask_filter_value_; float scale_; - bool is_unidirectional_; bool disable_fused_self_attention_; bool enable_trt_flash_attention_; bool disable_fused_cross_attention_; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc index 9de7ba3885c3c..2d12e975d88d7 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc @@ -29,13 +29,10 @@ namespace cuda { REGISTER_KERNEL_TYPED(float) REGISTER_KERNEL_TYPED(MLFloat16) -REGISTER_KERNEL_TYPED(BFloat16) template RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) { scale = info.GetAttrOrDefault("scale", 1.0); - rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); - num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0)); interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1); } @@ -51,8 +48,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { position_ids, cos_cache, sin_cache, - num_heads, - rotary_embedding_dim, ¶meters)); Tensor* output = context->Output(0, input->Shape()); @@ -76,7 +71,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const { parameters.sequence_length, parameters.num_heads, parameters.head_size, - parameters.rotary_embedding_dim, parameters.max_sequence_length, parameters.position_ids_format, interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h index d52f61d670444..6dab2ad56749e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h @@ -19,8 +19,6 @@ class RotaryEmbedding final : public CudaKernel { protected: float scale; - int num_heads; - int rotary_embedding_dim; bool interleaved; }; diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu index c6637041f05bd..e1b83bd8caf54 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu @@ -26,7 +26,6 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int position_ids_format, const bool interleaved, const int batch_stride, @@ -34,33 +33,24 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH const int head_stride) { // B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length // Use .x in innermost loop to access global memory efficiently - + const int b = blockIdx.z; const int s = blockIdx.y; const int n = blockIdx.x; const int i = threadIdx.x; - if (i >= head_size) { - return; - } - const int block_offset = b * batch_stride + s * seq_stride + n * head_stride; const T* input_data = input + block_offset; T* output_data = output + block_offset; - if (i >= rotary_embedding_dim) { - output_data[i] = input_data[i]; - return; - } - // Cache is (M, H/2) - const int half_rotary_embedding_dim = rotary_embedding_dim / 2; + const int half_head_size = head_size / 2; const int position_id = (position_ids_format == 0) ? \ static_cast(position_ids[0]) + s \ : static_cast(position_ids[b * sequence_length + s]); - const int cache_offset = position_id * half_rotary_embedding_dim; + const int cache_offset = position_id * half_head_size; const T* cos_data = cos_cache + cache_offset; const T* sin_data = sin_cache + cache_offset; @@ -68,13 +58,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH T sign = 0; int j = 0; if (interleaved) { - cache_idx = (i / 2) % half_rotary_embedding_dim; + cache_idx = (i / 2) % half_head_size; sign = (i % 2 == 0) ? -1 : 1; j = (i % 2 == 0) ? i+1 : i-1; // i - sign } else { - cache_idx = i % half_rotary_embedding_dim; - sign = (i < half_rotary_embedding_dim) ? -1 : 1; - j = (i + half_rotary_embedding_dim) % rotary_embedding_dim; + cache_idx = i % half_head_size; + sign = (i < half_head_size) ? -1 : 1; + j = (i + half_head_size) % head_size; } output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx]; } @@ -92,23 +82,20 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, const int max_threads_per_block, const bool transposed) { + + constexpr int smem_size = 0; + const dim3 grid(num_heads, sequence_length, batch_size); + const dim3 block(head_size, 1, 1); + // Note: Current implementation assumes head_size <= max_threads_per_block // because head_size is currently large for LLaMA-2. For smaller head_size // and num_heads values, we can create a block as `block(num_heads, head_size, 1)` // instead. This will require kernel changes to support. - ORT_ENFORCE(head_size <= max_threads_per_block, - "Rotary embedding dim must be <= max_threads_per_block"); - - int tpb = (head_size + 31)/32*32; - - const dim3 block(tpb); - const dim3 grid(num_heads, sequence_length, batch_size); // Default input tensor shape is [batch, seq, hidden_size] int head_stride = head_size; @@ -122,9 +109,10 @@ Status LaunchRotaryEmbeddingKernel( } assert(head_size <= max_threads_per_block); - RotaryEmbeddingBSNH<<>>( - output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size, - rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride + RotaryEmbeddingBSNH<<>>( + output, input, cos_cache, sin_cache, position_ids, + sequence_length, num_heads, head_size, position_ids_format, interleaved, + batch_stride, seq_stride, head_stride ); return CUDA_CALL(cudaGetLastError()); @@ -141,7 +129,6 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, @@ -159,25 +146,6 @@ template Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, - const int max_sequence_length, - const int position_ids_format, - const bool interleaved, - const int max_threads_per_block, - const bool transposed); - -template Status LaunchRotaryEmbeddingKernel( - cudaStream_t stream, - BFloat16* output, - const BFloat16* input, - const int64_t* position_ids, - const BFloat16* cos_cache, - const BFloat16* sin_cache, - const int batch_size, - const int sequence_length, - const int num_heads, - const int head_size, - const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h index 36300fe7a660f..ee1ccc43dcbff 100644 --- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h @@ -21,7 +21,6 @@ Status LaunchRotaryEmbeddingKernel( const int sequence_length, const int num_heads, const int head_size, - const int rotary_embedding_dim, const int max_sequence_length, const int position_ids_format, const bool interleaved, diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index fa73950c9c6f5..34b44694a5fcc 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -98,7 +98,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding); -class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); @@ -300,7 +299,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 7f34647f1faef..0317ffcfb0e31 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -927,10 +927,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Custom scale will be used if specified. Default value is 1/sqrt(head_size)", AttributeProto::FLOAT, OPTIONAL_VALUE) - .Attr("unidirectional", - "Whether every token can only attend to previous tokens. Default value is 0.", - AttributeProto::INT, - static_cast(0)) .Input(0, "query", "Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)", @@ -1149,14 +1145,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Rotate using interleaved pattern. Default value is 0 (False).", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("rotary_embedding_dim", - "Rotary embedding dimension. Default value is 0.", - AttributeProto::INT, - OPTIONAL_VALUE) - .Attr("num_heads", - "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim", - AttributeProto::INT, - OPTIONAL_VALUE) .Input(0, "input", "3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)", @@ -1167,17 +1155,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "M") .Input(2, "cos_cache", - "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", + "2D tensor with shape (max_sequence_length, head_size / 2).", "T") .Input(3, "sin_cache", - "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)", + "2D tensor with shape (max_sequence_length, head_size / 2).", "T") .Output(0, "output", "tensor with same shape as input.", "T") - .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") .TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors") .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { propagateElemTypeFromInputToOutput(ctx, 0, 0); diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index e64de0e6da16a..55f01bf0d3f1d 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -11,14 +11,6 @@ namespace onnxruntime { namespace test { -namespace { -enum class TensorType { - kFloat, - kFloat16, - kBFloat16 -}; -} // anonymous namespace - static void RunTest( const std::vector& input_data, const std::vector& position_ids, @@ -28,11 +20,10 @@ static void RunTest( int batch_size, int sequence_length, int head_size, - int rotary_embedding_dim, int num_heads, int max_sequence_length, int64_t interleaved, - TensorType tensor_type, + bool use_float16, bool disable_cpu, bool disable_cuda, bool disable_dml) { @@ -45,9 +36,7 @@ static void RunTest( int hidden_size = num_heads * head_size; std::vector input_dims = {batch_size, sequence_length, hidden_size}; std::vector pos_dims; - std::vector cache_dims = {max_sequence_length, rotary_embedding_dim > 0 - ? rotary_embedding_dim / 2 - : head_size / 2}; + std::vector cache_dims = {max_sequence_length, head_size / 2}; assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0); assert(max_sequence_length >= sequence_length); @@ -60,10 +49,7 @@ static void RunTest( std::string op_type = "RotaryEmbedding"; std::vector> execution_providers; - int min_cuda_architecture = (tensor_type == TensorType::kBFloat16) - ? 800 - : (tensor_type == TensorType::kFloat16) ? 530 - : 0; + int min_cuda_architecture = use_float16 ? 530 : 0; bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml; @@ -73,7 +59,7 @@ static void RunTest( if (enable_dml && !disable_dml) { execution_providers.push_back(DefaultDmlExecutionProvider()); } - if (tensor_type == TensorType::kFloat && !disable_cpu) { + if (!use_float16 && !disable_cpu) { execution_providers.push_back(DefaultCpuExecutionProvider()); } if (execution_providers.size() == 0) { @@ -84,36 +70,20 @@ static void RunTest( OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain); test.AddAttribute("interleaved", interleaved); - if (rotary_embedding_dim > 0) { - test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim); - test.AddAttribute("num_heads", num_heads); - } - - if (tensor_type == TensorType::kFloat) { + if (!use_float16) { test.AddInput("input", input_dims, input_data); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, cos_cache); test.AddInput("sin_cache", cache_dims, sin_cache); test.AddOutput("output", input_dims, output_data); - } else if (tensor_type == TensorType::kFloat16) { + } else { test.AddInput("input", input_dims, ToFloat16(input_data)); test.AddInput("position_ids", pos_dims, position_ids); test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache)); test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache)); test.AddOutput("output", input_dims, ToFloat16(output_data)); - } else { - test.AddInput("input", input_dims, FloatsToBFloat16s(input_data)); - test.AddInput("position_ids", pos_dims, position_ids); - test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache)); - test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache)); - test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data)); - } - if (tensor_type == TensorType::kBFloat16) { - test.SetOutputAbsErr("output", 0.03f); - } else { - test.SetOutputAbsErr("output", 0.002f); } - + test.SetOutputAbsErr("output", 0.002f); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } @@ -125,12 +95,10 @@ static void RunTests(const std::vector& input_data, int batch_size, int sequence_length, int head_size = 0, - int rotary_embedding_dim = 0, int num_heads = 0, int max_sequence_length = 0, int64_t interleaved = 0, - bool use_float16 = true, - bool disable_dml = false) { + bool use_float16 = true) { // FP32 test for CPU RunTest(input_data, position_ids, @@ -140,11 +108,10 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - TensorType::kFloat, + false, /* use_fp16 */ false, /* disable_cpu */ true, /* disable_cuda */ true /* disable_dml */); @@ -158,14 +125,13 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - TensorType::kFloat, + false, /* use_fp16 */ false, /* disable_cpu */ false, /* disable_cuda */ - disable_dml || false /* disable_dml */); + false /* disable_dml */); // FP16 test for CUDA and DML if (use_float16) { @@ -177,31 +143,13 @@ static void RunTests(const std::vector& input_data, batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved, - TensorType::kFloat16, + true, /* use_fp16 */ true, /* disable_cpu */ false, /* disable_cuda*/ - disable_dml || false /* disable_dml */); - - // RunTest(input_data, - // position_ids, - // cos_cache, - // sin_cache, - // output_data, - // batch_size, - // sequence_length, - // head_size, - // rotary_embedding_dim, - // num_heads, - // max_sequence_length, - // interleaved, - // TensorType::kBFloat16, - // true, /* disable_cpu */ - // false, /* disable_cuda*/ - // false /* disable_dml */); + false /* disable_dml */); } } @@ -211,7 +159,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { int sequence_length = 3; int num_heads = 2; int head_size = 4; - int rotary_embedding_dim = 0; int max_sequence_length = 8; int64_t interleaved = 1; // true @@ -243,7 +190,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -255,7 +201,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; - int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 1; // true @@ -443,7 +388,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -455,7 +399,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { int sequence_length = 8; int num_heads = 4; int head_size = 6; - int rotary_embedding_dim = 0; int max_sequence_length = 16; int64_t interleaved = 0; // false @@ -643,7 +586,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) { batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved); @@ -655,7 +597,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { int sequence_length = 2; int num_heads = 3; int head_size = 6; - int rotary_embedding_dim = 0; int max_sequence_length = 4; int64_t interleaved = 0; // false @@ -691,52 +632,10 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) { batch_size, sequence_length, head_size, - rotary_embedding_dim, num_heads, max_sequence_length, interleaved); } -TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) { - int batch_size = 1; - int sequence_length = 2; - int num_heads = 1; - int head_size = 6; - int rotary_embedding_dim = 4; - int max_sequence_length = 2; - int64_t interleaved = 0; // false - - std::vector input_data = { - -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f, - -0.2250f, -0.4327f, -1.5071f, -0.4586f}; - - std::vector position_ids = {0, 1}; - - std::vector cos_cache = { - 1.0000f, 1.0000f, 1.0000f, 0.5403f}; - - std::vector sin_cache = { - 0.0000f, 0.0000f, 0.0000f, 0.8415f}; - - std::vector output_data = { - -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f, - -0.2250f, -0.8673f, -1.5071f, -0.4586f}; - - RunTests(input_data, - position_ids, - cos_cache, - sin_cache, - output_data, - batch_size, - sequence_length, - head_size, - rotary_embedding_dim, - num_heads, - max_sequence_length, - interleaved, - true, /*use_fp16*/ - true /*disable_dml*/); -} - } // namespace test } // namespace onnxruntime