diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 22e82443167f6..45c0e6f822ce9 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3031,8 +3031,6 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-unidirectional : int
-Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8)
@@ -5023,10 +5021,6 @@ This version of the operator has been available since version 1 of the 'com.micr
- interleaved : int
- Rotate using interleaved pattern. Default value is 0 (False).
-- num_heads : int
-- Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
-- rotary_embedding_dim : int
-- Rotary embedding dimension. Default value is 0.
- scale : float
- Custom scale will be used if specified. Default value is 1.0
@@ -5039,9 +5033,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
+2D tensor with shape (max_sequence_length, head_size / 2).
sin_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
+2D tensor with shape (max_sequence_length, head_size / 2).
#### Outputs
@@ -5054,7 +5048,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16), tensor(bfloat16)
+- T : tensor(float), tensor(float16)
- Constrain input and output types to float tensors.
- M : tensor(int64)
- Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 9a2a7ac89bbb3..1be895dba3841 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -868,7 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 768676259aa14..4711ccf487cc8 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -211,12 +211,6 @@ Status Attention::Compute(OpKernelContext* context) const {
relative_position_bias,
¶meters));
- if (parameters.do_rotary) {
- ORT_NOT_IMPLEMENTED(
- "Rotary embedding is not supported in Attention CPU kernel. \
- Please fuse the model with MHA + RotaryEmbedding.");
- }
-
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int input_hidden_size = parameters.input_hidden_size;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index eb25d0fd7cc1e..694c40bf3eda6 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -40,7 +40,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
num_heads_ = static_cast(num_heads);
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
- is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
}
// Reshape Q/K/V from BxSxD to BxSxNxH
@@ -284,9 +283,8 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
nullptr,
¶meters,
num_heads_,
- mask_filter_value_,
scale,
- is_unidirectional_,
+ mask_filter_value_,
past_present_share_buffer,
false));
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
index fb7da78a5c0a5..4c86b777e9842 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
@@ -18,7 +18,6 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
- bool is_unidirectional_;
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index c91f5b601b4e9..00e82c9844b3d 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -25,7 +25,6 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
- bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
@@ -316,7 +315,7 @@ Status CheckInputs(const T* query,
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
- output_parameters->is_unidirectional = is_unidirectional;
+ output_parameters->is_unidirectional = false;
output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
@@ -343,7 +342,6 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
- bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing,
int max_threads_per_block) {
@@ -352,8 +350,8 @@ Status CheckInputs(const T* query,
}
return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
- past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
- past_present_share_buffer, dmmha_packing);
+ past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer,
+ dmmha_packing);
}
} // namespace multihead_attention_helper
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index aa8b5b5f608fa..47f462d75fcc4 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -27,13 +27,7 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
- rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
- num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
-
- if (rotary_embedding_dim > 0) {
- ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
- }
}
template
@@ -48,8 +42,6 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
- num_heads,
- rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -67,66 +59,61 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int n_heads = parameters.num_heads;
+ const int num_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
- const int rotary_emb_dim = parameters.rotary_embedding_dim;
- const int half_rotary_emb_dim = rotary_emb_dim / 2;
-
+ const int half_head_size = head_size / 2;
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
- int seq_stride = n_heads * head_stride;
+ int seq_stride = num_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
- // Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
+ // Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
- batch_stride = n_heads * head_stride;
+ batch_stride = num_heads * head_stride;
}
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
- const int loop_len = batch_size * sequence_length * n_heads;
- const double cost = static_cast(rotary_emb_dim);
+ const int loop_len = batch_size * sequence_length * num_heads;
+ const double cost = static_cast(head_size);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
- const int b = static_cast((ptr / n_heads) / sequence_length);
- const int s = static_cast((ptr / n_heads) % sequence_length);
- const int n = static_cast(ptr % n_heads);
+ const int b = static_cast((ptr / num_heads) / sequence_length);
+ const int s = static_cast((ptr / num_heads) % sequence_length);
+ const int n = static_cast(ptr % num_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;
- // Cache is (M, H/2) or (M, rotary_embedding_dim/2)
+ // Cache is (M, H/2)
const int position_id = (position_ids_format == 0)
? static_cast(pos_ids_data[0]) + s
: static_cast(pos_ids_data[b * sequence_length + s]);
- const int cache_offset = position_id * half_rotary_emb_dim;
+ const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
- for (int i = 0; i < rotary_emb_dim; i++) {
+ for (int i = 0; i < head_size; i++) {
if (interleaved) {
- cache_idx = (i / 2) % half_rotary_emb_dim;
+ cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
- cache_idx = i % half_rotary_emb_dim;
- sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
- j = (i + half_rotary_emb_dim) % rotary_emb_dim;
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
+ j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
- for (int i = rotary_emb_dim; i < head_size; i++) {
- output_data[i] = input_data[i];
- }
}
});
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
index 4e32424a22b6c..be834a66cdc69 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -16,8 +16,6 @@ class RotaryEmbedding final : public OpKernel {
protected:
float scale;
- int num_heads;
- int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
index dcbb36d1c4a3c..7b2e8289f7b06 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -11,15 +11,14 @@ namespace rotary_embedding_helper {
// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
- int batch_size; // Batch size used by input
- int sequence_length; // Sequence length used by input
- int hidden_size; // Hidden size used by input
- int head_size; // Head size
- int rotary_embedding_dim; // Rotary embedding dimension.
- int num_heads; // num_heads = hidden_size / head_size
- int max_sequence_length; // Sequence length used by cos/sin cache
- int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
- bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size used by cos/sin cache * 2
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+ bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
template
@@ -27,13 +26,11 @@ Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
- int num_heads,
- int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
- // cos cache : (max_sequence_length, rotary_embedding_dim / 2)
- // sin cache : (max_sequence_length, rotary_embedding_dim / 2)
+ // cos cache : (max_sequence_length, head_size / 2)
+ // sin cache : (max_sequence_length, head_size / 2)
// Check input
const auto& input_dims = input->Shape().GetDims();
@@ -63,12 +60,6 @@ Status CheckInputs(const T* input,
"the same shape");
}
- // Check num_heads and rotary_embedding_dim
- if (rotary_embedding_dim > 0 && num_heads == 0) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
- "specified");
- }
-
// Get attributes from inputs
int batch_size = static_cast(input_dims[0]);
int sequence_length = static_cast(input_dims[1]);
@@ -82,13 +73,8 @@ Status CheckInputs(const T* input,
transposed = true;
}
int max_sequence_length = static_cast(cos_cache_dims[0]);
- int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2
- : static_cast(hidden_size / num_heads);
- if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
- "head_size");
- }
-
+ int head_size = static_cast(cos_cache_dims[1]) * 2;
+ int num_heads = hidden_size / head_size;
int position_ids_format = -1;
// Check position_ids input shapes
@@ -105,15 +91,23 @@ Status CheckInputs(const T* input,
} else {
position_ids_format = 0;
}
-
// Check cos_cache input shapes
if (max_sequence_length != static_cast(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
- if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) {
+ if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
- "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
+ "head_size / 2, got ", cos_cache_dims[1]);
+ }
+ // Check sin_cache input shapes
+ if (max_sequence_length != static_cast(sin_cache_dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
+ "max_sequence_length, got ", sin_cache_dims[0]);
+ }
+ if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
+ "head_size / 2, got ", sin_cache_dims[1]);
}
// Set rotary parameters
@@ -123,11 +117,10 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
- output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size);
+ output_parameters->num_heads = num_heads;
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
- output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index f978f50c6851f..ebd66d8c6528e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -44,8 +44,6 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
- is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
- ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
@@ -107,7 +105,6 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
num_heads_,
mask_filter_value_,
scale_,
- is_unidirectional_,
false, // past_present_share_buffer
false, // dmmha_packing
device_prop.maxThreadsPerBlock));
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
index 86a32c92ce003..c162f7133cc1c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
@@ -25,7 +25,6 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
float scale_;
- bool is_unidirectional_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index 9de7ba3885c3c..2d12e975d88d7 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -29,13 +29,10 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
-REGISTER_KERNEL_TYPED(BFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
- rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
- num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
}
@@ -51,8 +48,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
- num_heads,
- rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -76,7 +71,6 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
- parameters.rotary_embedding_dim,
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
index d52f61d670444..6dab2ad56749e 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
@@ -19,8 +19,6 @@ class RotaryEmbedding final : public CudaKernel {
protected:
float scale;
- int num_heads;
- int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
index c6637041f05bd..e1b83bd8caf54 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -26,7 +26,6 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int sequence_length,
const int num_heads,
const int head_size,
- const int rotary_embedding_dim,
const int position_ids_format,
const bool interleaved,
const int batch_stride,
@@ -34,33 +33,24 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
-
+
const int b = blockIdx.z;
const int s = blockIdx.y;
const int n = blockIdx.x;
const int i = threadIdx.x;
- if (i >= head_size) {
- return;
- }
-
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
- if (i >= rotary_embedding_dim) {
- output_data[i] = input_data[i];
- return;
- }
-
// Cache is (M, H/2)
- const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
+ const int half_head_size = head_size / 2;
const int position_id = (position_ids_format == 0) ? \
static_cast(position_ids[0]) + s \
: static_cast(position_ids[b * sequence_length + s]);
- const int cache_offset = position_id * half_rotary_embedding_dim;
+ const int cache_offset = position_id * half_head_size;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
@@ -68,13 +58,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
T sign = 0;
int j = 0;
if (interleaved) {
- cache_idx = (i / 2) % half_rotary_embedding_dim;
+ cache_idx = (i / 2) % half_head_size;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i+1 : i-1; // i - sign
} else {
- cache_idx = i % half_rotary_embedding_dim;
- sign = (i < half_rotary_embedding_dim) ? -1 : 1;
- j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
+ cache_idx = i % half_head_size;
+ sign = (i < half_head_size) ? -1 : 1;
+ j = (i + half_head_size) % head_size;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
@@ -92,23 +82,20 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
- const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
const bool transposed) {
+
+ constexpr int smem_size = 0;
+ const dim3 grid(num_heads, sequence_length, batch_size);
+ const dim3 block(head_size, 1, 1);
+
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
- ORT_ENFORCE(head_size <= max_threads_per_block,
- "Rotary embedding dim must be <= max_threads_per_block");
-
- int tpb = (head_size + 31)/32*32;
-
- const dim3 block(tpb);
- const dim3 grid(num_heads, sequence_length, batch_size);
// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
@@ -122,9 +109,10 @@ Status LaunchRotaryEmbeddingKernel(
}
assert(head_size <= max_threads_per_block);
- RotaryEmbeddingBSNH<<>>(
- output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size,
- rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride
+ RotaryEmbeddingBSNH<<>>(
+ output, input, cos_cache, sin_cache, position_ids,
+ sequence_length, num_heads, head_size, position_ids_format, interleaved,
+ batch_stride, seq_stride, head_stride
);
return CUDA_CALL(cudaGetLastError());
@@ -141,7 +129,6 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
- const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
@@ -159,25 +146,6 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
- const int rotary_embedding_dim,
- const int max_sequence_length,
- const int position_ids_format,
- const bool interleaved,
- const int max_threads_per_block,
- const bool transposed);
-
-template Status LaunchRotaryEmbeddingKernel(
- cudaStream_t stream,
- BFloat16* output,
- const BFloat16* input,
- const int64_t* position_ids,
- const BFloat16* cos_cache,
- const BFloat16* sin_cache,
- const int batch_size,
- const int sequence_length,
- const int num_heads,
- const int head_size,
- const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
index 36300fe7a660f..ee1ccc43dcbff 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -21,7 +21,6 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
- const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index fa73950c9c6f5..34b44694a5fcc 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -98,7 +98,6 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
-class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -300,7 +299,6 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
- BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 7f34647f1faef..0317ffcfb0e31 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -927,10 +927,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
- .Attr("unidirectional",
- "Whether every token can only attend to previous tokens. Default value is 0.",
- AttributeProto::INT,
- static_cast(0))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
@@ -1149,14 +1145,6 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
- .Attr("rotary_embedding_dim",
- "Rotary embedding dimension. Default value is 0.",
- AttributeProto::INT,
- OPTIONAL_VALUE)
- .Attr("num_heads",
- "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim",
- AttributeProto::INT,
- OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
@@ -1167,17 +1155,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"M")
.Input(2,
"cos_cache",
- "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
"T")
.Input(3,
"sin_cache",
- "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
+ "2D tensor with shape (max_sequence_length, head_size / 2).",
"T")
.Output(0,
"output",
"tensor with same shape as input.",
"T")
- .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
index e64de0e6da16a..55f01bf0d3f1d 100644
--- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
+++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
@@ -11,14 +11,6 @@
namespace onnxruntime {
namespace test {
-namespace {
-enum class TensorType {
- kFloat,
- kFloat16,
- kBFloat16
-};
-} // anonymous namespace
-
static void RunTest(
const std::vector& input_data,
const std::vector& position_ids,
@@ -28,11 +20,10 @@ static void RunTest(
int batch_size,
int sequence_length,
int head_size,
- int rotary_embedding_dim,
int num_heads,
int max_sequence_length,
int64_t interleaved,
- TensorType tensor_type,
+ bool use_float16,
bool disable_cpu,
bool disable_cuda,
bool disable_dml) {
@@ -45,9 +36,7 @@ static void RunTest(
int hidden_size = num_heads * head_size;
std::vector input_dims = {batch_size, sequence_length, hidden_size};
std::vector pos_dims;
- std::vector cache_dims = {max_sequence_length, rotary_embedding_dim > 0
- ? rotary_embedding_dim / 2
- : head_size / 2};
+ std::vector cache_dims = {max_sequence_length, head_size / 2};
assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0);
assert(max_sequence_length >= sequence_length);
@@ -60,10 +49,7 @@ static void RunTest(
std::string op_type = "RotaryEmbedding";
std::vector> execution_providers;
- int min_cuda_architecture = (tensor_type == TensorType::kBFloat16)
- ? 800
- : (tensor_type == TensorType::kFloat16) ? 530
- : 0;
+ int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
@@ -73,7 +59,7 @@ static void RunTest(
if (enable_dml && !disable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
- if (tensor_type == TensorType::kFloat && !disable_cpu) {
+ if (!use_float16 && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
if (execution_providers.size() == 0) {
@@ -84,36 +70,20 @@ static void RunTest(
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
test.AddAttribute("interleaved", interleaved);
- if (rotary_embedding_dim > 0) {
- test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);
- test.AddAttribute("num_heads", num_heads);
- }
-
- if (tensor_type == TensorType::kFloat) {
+ if (!use_float16) {
test.AddInput("input", input_dims, input_data);
test.AddInput("position_ids", pos_dims, position_ids);
test.AddInput("cos_cache", cache_dims, cos_cache);
test.AddInput("sin_cache", cache_dims, sin_cache);
test.AddOutput("output", input_dims, output_data);
- } else if (tensor_type == TensorType::kFloat16) {
+ } else {
test.AddInput("input", input_dims, ToFloat16(input_data));
test.AddInput("position_ids", pos_dims, position_ids);
test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache));
test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache));
test.AddOutput("output", input_dims, ToFloat16(output_data));
- } else {
- test.AddInput("input", input_dims, FloatsToBFloat16s(input_data));
- test.AddInput("position_ids", pos_dims, position_ids);
- test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache));
- test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache));
- test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data));
- }
- if (tensor_type == TensorType::kBFloat16) {
- test.SetOutputAbsErr("output", 0.03f);
- } else {
- test.SetOutputAbsErr("output", 0.002f);
}
-
+ test.SetOutputAbsErr("output", 0.002f);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
@@ -125,12 +95,10 @@ static void RunTests(const std::vector& input_data,
int batch_size,
int sequence_length,
int head_size = 0,
- int rotary_embedding_dim = 0,
int num_heads = 0,
int max_sequence_length = 0,
int64_t interleaved = 0,
- bool use_float16 = true,
- bool disable_dml = false) {
+ bool use_float16 = true) {
// FP32 test for CPU
RunTest(input_data,
position_ids,
@@ -140,11 +108,10 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- TensorType::kFloat,
+ false, /* use_fp16 */
false, /* disable_cpu */
true, /* disable_cuda */
true /* disable_dml */);
@@ -158,14 +125,13 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- TensorType::kFloat,
+ false, /* use_fp16 */
false, /* disable_cpu */
false, /* disable_cuda */
- disable_dml || false /* disable_dml */);
+ false /* disable_dml */);
// FP16 test for CUDA and DML
if (use_float16) {
@@ -177,31 +143,13 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- TensorType::kFloat16,
+ true, /* use_fp16 */
true, /* disable_cpu */
false, /* disable_cuda*/
- disable_dml || false /* disable_dml */);
-
- // RunTest(input_data,
- // position_ids,
- // cos_cache,
- // sin_cache,
- // output_data,
- // batch_size,
- // sequence_length,
- // head_size,
- // rotary_embedding_dim,
- // num_heads,
- // max_sequence_length,
- // interleaved,
- // TensorType::kBFloat16,
- // true, /* disable_cpu */
- // false, /* disable_cuda*/
- // false /* disable_dml */);
+ false /* disable_dml */);
}
}
@@ -211,7 +159,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) {
int sequence_length = 3;
int num_heads = 2;
int head_size = 4;
- int rotary_embedding_dim = 0;
int max_sequence_length = 8;
int64_t interleaved = 1; // true
@@ -243,7 +190,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -255,7 +201,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) {
int sequence_length = 8;
int num_heads = 4;
int head_size = 6;
- int rotary_embedding_dim = 0;
int max_sequence_length = 16;
int64_t interleaved = 1; // true
@@ -443,7 +388,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -455,7 +399,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) {
int sequence_length = 8;
int num_heads = 4;
int head_size = 6;
- int rotary_embedding_dim = 0;
int max_sequence_length = 16;
int64_t interleaved = 0; // false
@@ -643,7 +586,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -655,7 +597,6 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) {
int sequence_length = 2;
int num_heads = 3;
int head_size = 6;
- int rotary_embedding_dim = 0;
int max_sequence_length = 4;
int64_t interleaved = 0; // false
@@ -691,52 +632,10 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
- rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
}
-TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) {
- int batch_size = 1;
- int sequence_length = 2;
- int num_heads = 1;
- int head_size = 6;
- int rotary_embedding_dim = 4;
- int max_sequence_length = 2;
- int64_t interleaved = 0; // false
-
- std::vector input_data = {
- -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
- -0.2250f, -0.4327f, -1.5071f, -0.4586f};
-
- std::vector position_ids = {0, 1};
-
- std::vector cos_cache = {
- 1.0000f, 1.0000f, 1.0000f, 0.5403f};
-
- std::vector sin_cache = {
- 0.0000f, 0.0000f, 0.0000f, 0.8415f};
-
- std::vector output_data = {
- -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f,
- -0.2250f, -0.8673f, -1.5071f, -0.4586f};
-
- RunTests(input_data,
- position_ids,
- cos_cache,
- sin_cache,
- output_data,
- batch_size,
- sequence_length,
- head_size,
- rotary_embedding_dim,
- num_heads,
- max_sequence_length,
- interleaved,
- true, /*use_fp16*/
- true /*disable_dml*/);
-}
-
} // namespace test
} // namespace onnxruntime