diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 45c0e6f822ce9..22e82443167f6 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -3031,6 +3031,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Number of attention heads
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
+unidirectional : int
+Whether every token can only attend to previous tokens. Default value is 0.
#### Inputs (1 - 8)
@@ -5021,6 +5023,10 @@ This version of the operator has been available since version 1 of the 'com.micr
- interleaved : int
- Rotate using interleaved pattern. Default value is 0 (False).
+- num_heads : int
+- Number of attention heads. Default value is 0. Must use with rotary_embedding_dim
+- rotary_embedding_dim : int
+- Rotary embedding dimension. Default value is 0.
- scale : float
- Custom scale will be used if specified. Default value is 1.0
@@ -5033,9 +5039,9 @@ This version of the operator has been available since version 1 of the 'com.micr
position_ids : M
1D tensor with shape (1) or 2D tensor with shape (batch_size, sequence_length)
cos_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2).
+2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
sin_cache : T
-2D tensor with shape (max_sequence_length, head_size / 2).
+2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)
#### Outputs
@@ -5048,7 +5054,7 @@ This version of the operator has been available since version 1 of the 'com.micr
#### Type Constraints
-- T : tensor(float), tensor(float16)
+- T : tensor(float), tensor(float16), tensor(bfloat16)
- Constrain input and output types to float tensors.
- M : tensor(int64)
- Constrain input and output types to integer tensors
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 394bd7ad2abae..9ecc58bee0725 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -868,7 +868,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
-|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(float), tensor(float16)|
+|RotaryEmbedding|*in* input:**T**
*in* position_ids:**M**
*in* cos_cache:**T**
*in* sin_cache:**T**
*out* output:**T**|1+|**M** = tensor(int64)
**T** = tensor(bfloat16), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*in* seed:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipGroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*in* skip:**T**
*in* bias:**T**
*out* Y:**T**
*out* S:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_bias_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc
index 4711ccf487cc8..768676259aa14 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc
@@ -211,6 +211,12 @@ Status Attention::Compute(OpKernelContext* context) const {
relative_position_bias,
¶meters));
+ if (parameters.do_rotary) {
+ ORT_NOT_IMPLEMENTED(
+ "Rotary embedding is not supported in Attention CPU kernel. \
+ Please fuse the model with MHA + RotaryEmbedding.");
+ }
+
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
const int input_hidden_size = parameters.input_hidden_size;
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
index 694c40bf3eda6..eb25d0fd7cc1e 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.cc
@@ -40,6 +40,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) : OpKernel(i
num_heads_ = static_cast(num_heads);
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
}
// Reshape Q/K/V from BxSxD to BxSxNxH
@@ -283,8 +284,9 @@ Status MultiHeadAttention::Compute(OpKernelContext* context) const {
nullptr,
¶meters,
num_heads_,
- scale,
mask_filter_value_,
+ scale,
+ is_unidirectional_,
past_present_share_buffer,
false));
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
index 4c86b777e9842..fb7da78a5c0a5 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention.h
@@ -18,6 +18,7 @@ class MultiHeadAttention final : public OpKernel, public AttentionCPUBase {
protected:
int num_heads_; // number of attention heads
float mask_filter_value_;
+ bool is_unidirectional_;
};
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
index 00e82c9844b3d..c91f5b601b4e9 100644
--- a/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/multihead_attention_helper.h
@@ -25,6 +25,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing) {
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
@@ -315,7 +316,7 @@ Status CheckInputs(const T* query,
output_parameters->head_size = hidden_size / num_heads;
output_parameters->v_head_size = v_hidden_size / num_heads;
output_parameters->num_heads = num_heads;
- output_parameters->is_unidirectional = false;
+ output_parameters->is_unidirectional = is_unidirectional;
output_parameters->past_present_share_buffer = past_present_share_buffer;
output_parameters->mask_filter_value = mask_filter_value;
output_parameters->mask_type = mask_type;
@@ -342,6 +343,7 @@ Status CheckInputs(const T* query,
int num_heads,
float mask_filter_value,
float scale,
+ bool is_unidirectional,
bool past_present_share_buffer,
bool dmmha_packing,
int max_threads_per_block) {
@@ -350,8 +352,8 @@ Status CheckInputs(const T* query,
}
return CheckInputs(query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value,
- past_seq_len, parameters, num_heads, mask_filter_value, scale, past_present_share_buffer,
- dmmha_packing);
+ past_seq_len, parameters, num_heads, mask_filter_value, scale, is_unidirectional,
+ past_present_share_buffer, dmmha_packing);
}
} // namespace multihead_attention_helper
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
index 47f462d75fcc4..aa8b5b5f608fa 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.cc
@@ -27,7 +27,13 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
+
+ if (rotary_embedding_dim > 0) {
+ ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
+ }
}
template
@@ -42,6 +48,8 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
+ num_heads,
+ rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -59,61 +67,66 @@ Status RotaryEmbedding::Compute(OpKernelContext* context) const {
const int batch_size = parameters.batch_size;
const int sequence_length = parameters.sequence_length;
- const int num_heads = parameters.num_heads;
+ const int n_heads = parameters.num_heads;
const int head_size = parameters.head_size;
const int position_ids_format = parameters.position_ids_format;
- const int half_head_size = head_size / 2;
+ const int rotary_emb_dim = parameters.rotary_embedding_dim;
+ const int half_rotary_emb_dim = rotary_emb_dim / 2;
+
// Default input tensor shape is [batch, seq_len, hidden_size]
int head_stride = head_size;
- int seq_stride = num_heads * head_stride;
+ int seq_stride = n_heads * head_stride;
int batch_stride = sequence_length * seq_stride;
if (parameters.transposed) {
- // Transposed input tensor shape is [batch, num_heads, seq_len, head_size]
+ // Transposed input tensor shape is [batch, n_heads, seq_len, head_size]
seq_stride = head_size;
head_stride = sequence_length * seq_stride;
- batch_stride = num_heads * head_stride;
+ batch_stride = n_heads * head_stride;
}
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
- const int loop_len = batch_size * sequence_length * num_heads;
- const double cost = static_cast(head_size);
+ const int loop_len = batch_size * sequence_length * n_heads;
+ const double cost = static_cast(rotary_emb_dim);
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t ptr = begin; ptr != end; ++ptr) {
- const int b = static_cast((ptr / num_heads) / sequence_length);
- const int s = static_cast((ptr / num_heads) % sequence_length);
- const int n = static_cast(ptr % num_heads);
+ const int b = static_cast((ptr / n_heads) / sequence_length);
+ const int s = static_cast((ptr / n_heads) % sequence_length);
+ const int n = static_cast(ptr % n_heads);
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input_src + block_offset;
T* output_data = output_dest + block_offset;
- // Cache is (M, H/2)
+ // Cache is (M, H/2) or (M, rotary_embedding_dim/2)
const int position_id = (position_ids_format == 0)
? static_cast(pos_ids_data[0]) + s
: static_cast(pos_ids_data[b * sequence_length + s]);
- const int cache_offset = position_id * half_head_size;
+ const int cache_offset = position_id * half_rotary_emb_dim;
const T* cos_data = cos_cache_data + cache_offset;
const T* sin_data = sin_cache_data + cache_offset;
int cache_idx = 0;
T sign = 0;
int j = 0;
- for (int i = 0; i < head_size; i++) {
+ for (int i = 0; i < rotary_emb_dim; i++) {
if (interleaved) {
- cache_idx = (i / 2) % half_head_size;
+ cache_idx = (i / 2) % half_rotary_emb_dim;
sign = (i % 2 == 0) ? static_cast(-1) : static_cast(1);
j = (i % 2 == 0) ? i + 1 : i - 1; // i - sign
} else {
- cache_idx = i % half_head_size;
- sign = (i < half_head_size) ? static_cast(-1) : static_cast(1);
- j = (i + half_head_size) % head_size;
+ cache_idx = i % half_rotary_emb_dim;
+ sign = (i < half_rotary_emb_dim) ? static_cast(-1) : static_cast(1);
+ j = (i + half_rotary_emb_dim) % rotary_emb_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
+ for (int i = rotary_emb_dim; i < head_size; i++) {
+ output_data[i] = input_data[i];
+ }
}
});
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
index be834a66cdc69..4e32424a22b6c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding.h
@@ -16,6 +16,8 @@ class RotaryEmbedding final : public OpKernel {
protected:
float scale;
+ int num_heads;
+ int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
index 7b2e8289f7b06..dcbb36d1c4a3c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
+++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
@@ -11,14 +11,15 @@ namespace rotary_embedding_helper {
// Parameters deduced from node attributes and inputs/outputs.
struct RotaryParameters {
- int batch_size; // Batch size used by input
- int sequence_length; // Sequence length used by input
- int hidden_size; // Hidden size used by input
- int head_size; // Head size used by cos/sin cache * 2
- int num_heads; // num_heads = hidden_size / head_size
- int max_sequence_length; // Sequence length used by cos/sin cache
- int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
- bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
+ int batch_size; // Batch size used by input
+ int sequence_length; // Sequence length used by input
+ int hidden_size; // Hidden size used by input
+ int head_size; // Head size
+ int rotary_embedding_dim; // Rotary embedding dimension.
+ int num_heads; // num_heads = hidden_size / head_size
+ int max_sequence_length; // Sequence length used by cos/sin cache
+ int position_ids_format; // Format of position ids - 0 is (1), 1 is (batch_size, sequence_length)
+ bool transposed; // Whether the input tensor has been transposed into (batch, num_heads, seq_len, hidden)
};
template
@@ -26,11 +27,13 @@ Status CheckInputs(const T* input,
const T* position_ids,
const T* cos_cache,
const T* sin_cache,
+ int num_heads,
+ int rotary_embedding_dim,
void* parameters) {
// input : (batch_size, sequence_length, hidden_size)
// position ids : (1) or (batch_size, sequence_length)
- // cos cache : (max_sequence_length, head_size / 2)
- // sin cache : (max_sequence_length, head_size / 2)
+ // cos cache : (max_sequence_length, rotary_embedding_dim / 2)
+ // sin cache : (max_sequence_length, rotary_embedding_dim / 2)
// Check input
const auto& input_dims = input->Shape().GetDims();
@@ -60,6 +63,12 @@ Status CheckInputs(const T* input,
"the same shape");
}
+ // Check num_heads and rotary_embedding_dim
+ if (rotary_embedding_dim > 0 && num_heads == 0) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "num_heads must be provided if rotary_embedding_dim is ",
+ "specified");
+ }
+
// Get attributes from inputs
int batch_size = static_cast(input_dims[0]);
int sequence_length = static_cast(input_dims[1]);
@@ -73,8 +82,13 @@ Status CheckInputs(const T* input,
transposed = true;
}
int max_sequence_length = static_cast(cos_cache_dims[0]);
- int head_size = static_cast(cos_cache_dims[1]) * 2;
- int num_heads = hidden_size / head_size;
+ int head_size = rotary_embedding_dim == 0 ? static_cast(cos_cache_dims[1]) * 2
+ : static_cast(hidden_size / num_heads);
+ if (rotary_embedding_dim > 0 && rotary_embedding_dim > head_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "rotary_embedding_dim must be less than or equal to ",
+ "head_size");
+ }
+
int position_ids_format = -1;
// Check position_ids input shapes
@@ -91,23 +105,15 @@ Status CheckInputs(const T* input,
} else {
position_ids_format = 0;
}
+
// Check cos_cache input shapes
if (max_sequence_length != static_cast(cos_cache_dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 0 should be same as ",
"max_sequence_length, got ", cos_cache_dims[0]);
}
- if ((head_size / 2) != static_cast(cos_cache_dims[1])) {
+ if ((head_size / 2) != static_cast(cos_cache_dims[1]) && (rotary_embedding_dim > 0 && (rotary_embedding_dim / 2) != static_cast(cos_cache_dims[1]))) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'cos_cache' dimension 1 should be same as ",
- "head_size / 2, got ", cos_cache_dims[1]);
- }
- // Check sin_cache input shapes
- if (max_sequence_length != static_cast(sin_cache_dims[0])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 0 should be same as ",
- "max_sequence_length, got ", sin_cache_dims[0]);
- }
- if ((head_size / 2) != static_cast(sin_cache_dims[1])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'sin_cache' dimension 1 should be same as ",
- "head_size / 2, got ", sin_cache_dims[1]);
+ "head_size / 2 or rotary_embedding_dim / 2, got ", cos_cache_dims[1]);
}
// Set rotary parameters
@@ -117,10 +123,11 @@ Status CheckInputs(const T* input,
output_parameters->sequence_length = sequence_length;
output_parameters->hidden_size = hidden_size;
output_parameters->head_size = head_size;
- output_parameters->num_heads = num_heads;
+ output_parameters->num_heads = num_heads > 0 ? num_heads : static_cast(hidden_size / head_size);
output_parameters->max_sequence_length = max_sequence_length;
output_parameters->position_ids_format = position_ids_format;
output_parameters->transposed = transposed;
+ output_parameters->rotary_embedding_dim = rotary_embedding_dim > 0 ? rotary_embedding_dim : head_size;
}
return Status::OK();
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
index ebd66d8c6528e..f978f50c6851f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
@@ -44,6 +44,8 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info)
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
+ is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
+ ORT_ENFORCE(!is_unidirectional_, "Unidirectional MHA does not support CUDA kernel. Consider using Attention or GQA instead.");
disable_fused_self_attention_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault(attention::kDisableFusedSelfAttention, false);
@@ -105,6 +107,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const {
num_heads_,
mask_filter_value_,
scale_,
+ is_unidirectional_,
false, // past_present_share_buffer
false, // dmmha_packing
device_prop.maxThreadsPerBlock));
diff --git a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
index c162f7133cc1c..86a32c92ce003 100644
--- a/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
+++ b/onnxruntime/contrib_ops/cuda/bert/multihead_attention.h
@@ -25,6 +25,7 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
float scale_;
+ bool is_unidirectional_;
bool disable_fused_self_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
index 2d12e975d88d7..9de7ba3885c3c 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.cc
@@ -29,10 +29,13 @@ namespace cuda {
REGISTER_KERNEL_TYPED(float)
REGISTER_KERNEL_TYPED(MLFloat16)
+REGISTER_KERNEL_TYPED(BFloat16)
template
RotaryEmbedding::RotaryEmbedding(const OpKernelInfo& info) : CudaKernel(info) {
scale = info.GetAttrOrDefault("scale", 1.0);
+ rotary_embedding_dim = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
+ num_heads = static_cast(info.GetAttrOrDefault("num_heads", 0));
interleaved = (info.GetAttrOrDefault("interleaved", 0) == 1);
}
@@ -48,6 +51,8 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
position_ids,
cos_cache,
sin_cache,
+ num_heads,
+ rotary_embedding_dim,
¶meters));
Tensor* output = context->Output(0, input->Shape());
@@ -71,6 +76,7 @@ Status RotaryEmbedding::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.num_heads,
parameters.head_size,
+ parameters.rotary_embedding_dim,
parameters.max_sequence_length,
parameters.position_ids_format,
interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
index 6dab2ad56749e..d52f61d670444 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding.h
@@ -19,6 +19,8 @@ class RotaryEmbedding final : public CudaKernel {
protected:
float scale;
+ int num_heads;
+ int rotary_embedding_dim;
bool interleaved;
};
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
index e1b83bd8caf54..c6637041f05bd 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.cu
@@ -26,6 +26,7 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int position_ids_format,
const bool interleaved,
const int batch_stride,
@@ -33,24 +34,33 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
const int head_stride) {
// B = batch size, S = sequence length, N = num heads, H = head size, M = max sequence length
// Use .x in innermost loop to access global memory efficiently
-
+
const int b = blockIdx.z;
const int s = blockIdx.y;
const int n = blockIdx.x;
const int i = threadIdx.x;
+ if (i >= head_size) {
+ return;
+ }
+
const int block_offset = b * batch_stride + s * seq_stride + n * head_stride;
const T* input_data = input + block_offset;
T* output_data = output + block_offset;
+ if (i >= rotary_embedding_dim) {
+ output_data[i] = input_data[i];
+ return;
+ }
+
// Cache is (M, H/2)
- const int half_head_size = head_size / 2;
+ const int half_rotary_embedding_dim = rotary_embedding_dim / 2;
const int position_id = (position_ids_format == 0) ? \
static_cast(position_ids[0]) + s \
: static_cast(position_ids[b * sequence_length + s]);
- const int cache_offset = position_id * half_head_size;
+ const int cache_offset = position_id * half_rotary_embedding_dim;
const T* cos_data = cos_cache + cache_offset;
const T* sin_data = sin_cache + cache_offset;
@@ -58,13 +68,13 @@ __global__ void RotaryEmbeddingBSNH(T* output, // BxSxNxH
T sign = 0;
int j = 0;
if (interleaved) {
- cache_idx = (i / 2) % half_head_size;
+ cache_idx = (i / 2) % half_rotary_embedding_dim;
sign = (i % 2 == 0) ? -1 : 1;
j = (i % 2 == 0) ? i+1 : i-1; // i - sign
} else {
- cache_idx = i % half_head_size;
- sign = (i < half_head_size) ? -1 : 1;
- j = (i + half_head_size) % head_size;
+ cache_idx = i % half_rotary_embedding_dim;
+ sign = (i < half_rotary_embedding_dim) ? -1 : 1;
+ j = (i + half_rotary_embedding_dim) % rotary_embedding_dim;
}
output_data[i] = input_data[i] * cos_data[cache_idx] + sign * input_data[j] * sin_data[cache_idx];
}
@@ -82,20 +92,23 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
const int max_threads_per_block,
const bool transposed) {
-
- constexpr int smem_size = 0;
- const dim3 grid(num_heads, sequence_length, batch_size);
- const dim3 block(head_size, 1, 1);
-
// Note: Current implementation assumes head_size <= max_threads_per_block
// because head_size is currently large for LLaMA-2. For smaller head_size
// and num_heads values, we can create a block as `block(num_heads, head_size, 1)`
// instead. This will require kernel changes to support.
+ ORT_ENFORCE(head_size <= max_threads_per_block,
+ "Rotary embedding dim must be <= max_threads_per_block");
+
+ int tpb = (head_size + 31)/32*32;
+
+ const dim3 block(tpb);
+ const dim3 grid(num_heads, sequence_length, batch_size);
// Default input tensor shape is [batch, seq, hidden_size]
int head_stride = head_size;
@@ -109,10 +122,9 @@ Status LaunchRotaryEmbeddingKernel(
}
assert(head_size <= max_threads_per_block);
- RotaryEmbeddingBSNH<<>>(
- output, input, cos_cache, sin_cache, position_ids,
- sequence_length, num_heads, head_size, position_ids_format, interleaved,
- batch_stride, seq_stride, head_stride
+ RotaryEmbeddingBSNH<<>>(
+ output, input, cos_cache, sin_cache, position_ids, sequence_length, num_heads, head_size,
+ rotary_embedding_dim, position_ids_format, interleaved, batch_stride, seq_stride, head_stride
);
return CUDA_CALL(cudaGetLastError());
@@ -129,6 +141,7 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
@@ -146,6 +159,25 @@ template Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
+ const int max_sequence_length,
+ const int position_ids_format,
+ const bool interleaved,
+ const int max_threads_per_block,
+ const bool transposed);
+
+template Status LaunchRotaryEmbeddingKernel(
+ cudaStream_t stream,
+ BFloat16* output,
+ const BFloat16* input,
+ const int64_t* position_ids,
+ const BFloat16* cos_cache,
+ const BFloat16* sin_cache,
+ const int batch_size,
+ const int sequence_length,
+ const int num_heads,
+ const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
index ee1ccc43dcbff..36300fe7a660f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
+++ b/onnxruntime/contrib_ops/cuda/bert/rotary_embedding_impl.h
@@ -21,6 +21,7 @@ Status LaunchRotaryEmbeddingKernel(
const int sequence_length,
const int num_heads,
const int head_size,
+ const int rotary_embedding_dim,
const int max_sequence_length,
const int position_ids_format,
const bool interleaved,
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 34b44694a5fcc..fa73950c9c6f5 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -98,6 +98,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, RotaryEmbedding);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, RotaryEmbedding);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, BFloat16, RotaryEmbedding);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
@@ -299,6 +300,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 0317ffcfb0e31..7f34647f1faef 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -927,6 +927,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
+ .Attr("unidirectional",
+ "Whether every token can only attend to previous tokens. Default value is 0.",
+ AttributeProto::INT,
+ static_cast(0))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
@@ -1145,6 +1149,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Rotate using interleaved pattern. Default value is 0 (False).",
AttributeProto::INT,
OPTIONAL_VALUE)
+ .Attr("rotary_embedding_dim",
+ "Rotary embedding dimension. Default value is 0.",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
+ .Attr("num_heads",
+ "Number of attention heads. Default value is 0. Must use with rotary_embedding_dim",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Input(0,
"input",
"3D tensor with shape (batch_size, sequence_length, hidden_size) or 4D with shape (batch_size, num_heads, sequence_length, head_size)",
@@ -1155,17 +1167,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"M")
.Input(2,
"cos_cache",
- "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
"T")
.Input(3,
"sin_cache",
- "2D tensor with shape (max_sequence_length, head_size / 2).",
+ "2D tensor with shape (max_sequence_length, head_size / 2) or (max_sequence_length, rotary_embedding_dim / 2)",
"T")
.Output(0,
"output",
"tensor with same shape as input.",
"T")
- .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
+ .TypeConstraint("T", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("M", {"tensor(int64)"}, "Constrain input and output types to integer tensors")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
propagateElemTypeFromInputToOutput(ctx, 0, 0);
diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
index 55f01bf0d3f1d..e64de0e6da16a 100644
--- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
+++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
@@ -11,6 +11,14 @@
namespace onnxruntime {
namespace test {
+namespace {
+enum class TensorType {
+ kFloat,
+ kFloat16,
+ kBFloat16
+};
+} // anonymous namespace
+
static void RunTest(
const std::vector& input_data,
const std::vector& position_ids,
@@ -20,10 +28,11 @@ static void RunTest(
int batch_size,
int sequence_length,
int head_size,
+ int rotary_embedding_dim,
int num_heads,
int max_sequence_length,
int64_t interleaved,
- bool use_float16,
+ TensorType tensor_type,
bool disable_cpu,
bool disable_cuda,
bool disable_dml) {
@@ -36,7 +45,9 @@ static void RunTest(
int hidden_size = num_heads * head_size;
std::vector input_dims = {batch_size, sequence_length, hidden_size};
std::vector pos_dims;
- std::vector cache_dims = {max_sequence_length, head_size / 2};
+ std::vector cache_dims = {max_sequence_length, rotary_embedding_dim > 0
+ ? rotary_embedding_dim / 2
+ : head_size / 2};
assert(hidden_size != 0 && head_size != 0 && num_heads != 0 && max_sequence_length != 0);
assert(max_sequence_length >= sequence_length);
@@ -49,7 +60,10 @@ static void RunTest(
std::string op_type = "RotaryEmbedding";
std::vector> execution_providers;
- int min_cuda_architecture = use_float16 ? 530 : 0;
+ int min_cuda_architecture = (tensor_type == TensorType::kBFloat16)
+ ? 800
+ : (tensor_type == TensorType::kFloat16) ? 530
+ : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
bool enable_dml = (nullptr != DefaultDmlExecutionProvider().get()) && !disable_dml;
@@ -59,7 +73,7 @@ static void RunTest(
if (enable_dml && !disable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}
- if (!use_float16 && !disable_cpu) {
+ if (tensor_type == TensorType::kFloat && !disable_cpu) {
execution_providers.push_back(DefaultCpuExecutionProvider());
}
if (execution_providers.size() == 0) {
@@ -70,20 +84,36 @@ static void RunTest(
OpTester test(op_type.c_str(), 1, onnxruntime::kMSDomain);
test.AddAttribute("interleaved", interleaved);
- if (!use_float16) {
+ if (rotary_embedding_dim > 0) {
+ test.AddAttribute("rotary_embedding_dim", rotary_embedding_dim);
+ test.AddAttribute("num_heads", num_heads);
+ }
+
+ if (tensor_type == TensorType::kFloat) {
test.AddInput("input", input_dims, input_data);
test.AddInput("position_ids", pos_dims, position_ids);
test.AddInput("cos_cache", cache_dims, cos_cache);
test.AddInput("sin_cache", cache_dims, sin_cache);
test.AddOutput("output", input_dims, output_data);
- } else {
+ } else if (tensor_type == TensorType::kFloat16) {
test.AddInput("input", input_dims, ToFloat16(input_data));
test.AddInput("position_ids", pos_dims, position_ids);
test.AddInput("cos_cache", cache_dims, ToFloat16(cos_cache));
test.AddInput("sin_cache", cache_dims, ToFloat16(sin_cache));
test.AddOutput("output", input_dims, ToFloat16(output_data));
+ } else {
+ test.AddInput("input", input_dims, FloatsToBFloat16s(input_data));
+ test.AddInput("position_ids", pos_dims, position_ids);
+ test.AddInput("cos_cache", cache_dims, FloatsToBFloat16s(cos_cache));
+ test.AddInput("sin_cache", cache_dims, FloatsToBFloat16s(sin_cache));
+ test.AddOutput("output", input_dims, FloatsToBFloat16s(output_data));
+ }
+ if (tensor_type == TensorType::kBFloat16) {
+ test.SetOutputAbsErr("output", 0.03f);
+ } else {
+ test.SetOutputAbsErr("output", 0.002f);
}
- test.SetOutputAbsErr("output", 0.002f);
+
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
@@ -95,10 +125,12 @@ static void RunTests(const std::vector& input_data,
int batch_size,
int sequence_length,
int head_size = 0,
+ int rotary_embedding_dim = 0,
int num_heads = 0,
int max_sequence_length = 0,
int64_t interleaved = 0,
- bool use_float16 = true) {
+ bool use_float16 = true,
+ bool disable_dml = false) {
// FP32 test for CPU
RunTest(input_data,
position_ids,
@@ -108,10 +140,11 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- false, /* use_fp16 */
+ TensorType::kFloat,
false, /* disable_cpu */
true, /* disable_cuda */
true /* disable_dml */);
@@ -125,13 +158,14 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- false, /* use_fp16 */
+ TensorType::kFloat,
false, /* disable_cpu */
false, /* disable_cuda */
- false /* disable_dml */);
+ disable_dml || false /* disable_dml */);
// FP16 test for CUDA and DML
if (use_float16) {
@@ -143,13 +177,31 @@ static void RunTests(const std::vector& input_data,
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved,
- true, /* use_fp16 */
+ TensorType::kFloat16,
true, /* disable_cpu */
false, /* disable_cuda*/
- false /* disable_dml */);
+ disable_dml || false /* disable_dml */);
+
+ // RunTest(input_data,
+ // position_ids,
+ // cos_cache,
+ // sin_cache,
+ // output_data,
+ // batch_size,
+ // sequence_length,
+ // head_size,
+ // rotary_embedding_dim,
+ // num_heads,
+ // max_sequence_length,
+ // interleaved,
+ // TensorType::kBFloat16,
+ // true, /* disable_cpu */
+ // false, /* disable_cuda*/
+ // false /* disable_dml */);
}
}
@@ -159,6 +211,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) {
int sequence_length = 3;
int num_heads = 2;
int head_size = 4;
+ int rotary_embedding_dim = 0;
int max_sequence_length = 8;
int64_t interleaved = 1; // true
@@ -190,6 +243,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_SmallData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -201,6 +255,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) {
int sequence_length = 8;
int num_heads = 4;
int head_size = 6;
+ int rotary_embedding_dim = 0;
int max_sequence_length = 16;
int64_t interleaved = 1; // true
@@ -388,6 +443,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_Interleaved_LargeData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -399,6 +455,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) {
int sequence_length = 8;
int num_heads = 4;
int head_size = 6;
+ int rotary_embedding_dim = 0;
int max_sequence_length = 16;
int64_t interleaved = 0; // false
@@ -586,6 +643,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_LargeData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
@@ -597,6 +655,7 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) {
int sequence_length = 2;
int num_heads = 3;
int head_size = 6;
+ int rotary_embedding_dim = 0;
int max_sequence_length = 4;
int64_t interleaved = 0; // false
@@ -632,10 +691,52 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_NotInterleaved_SmallData_LlamaMSFT) {
batch_size,
sequence_length,
head_size,
+ rotary_embedding_dim,
num_heads,
max_sequence_length,
interleaved);
}
+TEST(RotaryEmbeddingTest, RotaryEmbedding_CustomRotaryDim_SmallData_Phi) {
+ int batch_size = 1;
+ int sequence_length = 2;
+ int num_heads = 1;
+ int head_size = 6;
+ int rotary_embedding_dim = 4;
+ int max_sequence_length = 2;
+ int64_t interleaved = 0; // false
+
+ std::vector input_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.7529f,
+ -0.2250f, -0.4327f, -1.5071f, -0.4586f};
+
+ std::vector position_ids = {0, 1};
+
+ std::vector cos_cache = {
+ 1.0000f, 1.0000f, 1.0000f, 0.5403f};
+
+ std::vector sin_cache = {
+ 0.0000f, 0.0000f, 0.0000f, 0.8415f};
+
+ std::vector output_data = {
+ -1.0408f, 0.9166f, -1.3042f, -1.1097f, -1.2188f, 1.1676f, 1.0076f, -0.0427f,
+ -0.2250f, -0.8673f, -1.5071f, -0.4586f};
+
+ RunTests(input_data,
+ position_ids,
+ cos_cache,
+ sin_cache,
+ output_data,
+ batch_size,
+ sequence_length,
+ head_size,
+ rotary_embedding_dim,
+ num_heads,
+ max_sequence_length,
+ interleaved,
+ true, /*use_fp16*/
+ true /*disable_dml*/);
+}
+
} // namespace test
} // namespace onnxruntime