From 1c2dca95d813e3bf7a2b59a70fcedae9c84bed7d Mon Sep 17 00:00:00 2001
From: Ye Wang <52801275+wangyems@users.noreply.github.com>
Date: Wed, 3 Jan 2024 04:38:33 +0000
Subject: [PATCH] pass rotary embedding to attention op (#18846)
### Description
### Motivation and Context
---
docs/ContribOperators.md | 2 ++
.../contrib_ops/cpu/bert/attention_base.cc | 1 +
.../contrib_ops/cpu/bert/attention_base.h | 2 ++
.../contrib_ops/cpu/bert/attention_common.h | 1 +
.../cuda/bert/add_bias_transpose.cu | 19 +++++-----
.../cuda/bert/add_bias_transpose.h | 2 +-
.../cuda/bert/attention_prepare_qkv.cu | 3 +-
.../core/graph/contrib_ops/bert_defs.cc | 4 +++
.../test_parity_neox_attention.py | 36 +++++++++++--------
9 files changed, 45 insertions(+), 25 deletions(-)
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 131db5d8d9b37..38fceef67de25 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+rotary_embedding_dim : int
+Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
index 5d224bdc2235f..515a967aa2386 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc
@@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
+ output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
index 5ee40c4b98664..a6782daa58f1a 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h
@@ -38,6 +38,7 @@ class AttentionBase {
is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1;
+ rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0));
mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault("scale", 0.0f);
@@ -72,6 +73,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
bool do_rotary_; // whether or not to use rotary embeddings
+ int rotary_embedding_; // rotary embedding dimension
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
index a7f83469a768d..c9ed23895b60c 100644
--- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h
+++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h
@@ -56,6 +56,7 @@ struct AttentionParameters {
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
+ int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
index 626e4c0b87a3c..1ea2540db486f 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
@@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
- bool do_rotary = false, int past_sequence_length = 0) {
+ bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
assert(num_heads <= max_threads_per_block);
if (do_rotary) {
@@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
- if (qk_head_size != 64 && qk_head_size != 128) {
- ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
+ if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
+ ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
}
const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
- size_t smem_size = 2 * qk_head_size * sizeof(T);
+ size_t smem_size = 2 * rotary_embedding * sizeof(T);
const dim3 grid(sequence_length, num_heads, batch_size);
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
AddBiasTransposeQKV<<>>(total_matrix_count, input, biases, output,
- qkv_add_bias, qk_head_size, qk_head_size,
+ qkv_add_bias, rotary_embedding, qk_head_size,
step, format);
#else
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
@@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
- half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
+ half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
}
}
@@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output, bool /*enable_half4*/,
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
- int past_sequence_length) {
+ int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
@@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
- qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
+ qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
+ past_sequence_length);
}
}
diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
index d903267c99a01..efc31db43bcdb 100644
--- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
+++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
@@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
- int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
+ int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);
// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
index 5c65a30918ece..a513d9e8d2211 100644
--- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
+++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
@@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
- 3, parameters.do_rotary, parameters.past_sequence_length);
+ 3, parameters.do_rotary, parameters.rotary_embedding,
+ parameters.past_sequence_length);
}
return Status::OK();
}
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index ea67218b5c927..f8f63650615fd 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
+ .Attr("rotary_embedding_dim",
+ "Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size",
+ AttributeProto::INT,
+ OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
index 8c8e871a854b0..a98bb623beaea 100644
--- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py
+++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py
@@ -29,6 +29,7 @@ def create_neox_attention_graph(
qkv_weight,
qkv_bias,
num_heads,
+ rotary_embedding,
):
nodes = [
helper.make_node(
@@ -43,6 +44,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
+ rotary_embedding=rotary_embedding,
domain="com.microsoft",
),
]
@@ -174,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
class GPTNeoXAttention(nn.Module):
- def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
+ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
super().__init__()
self.do_rotary = True
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
- self.rotary_ndims = int(self.head_size)
+ self.rotary_ndims = rotary_ndims
max_positions = 2048
self.register_buffer(
"bias",
@@ -197,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))
if past_seq_len > 0:
+ assert self.rotary_ndims == self.head_size
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
batch_size,
seq_len,
@@ -220,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
.transpose(0, 1),
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
self.num_attention_heads,
+ self.rotary_ndims,
)
@classmethod
@@ -422,17 +426,21 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
- for hidden_size in [768]:
- attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)
-
- hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
- torch.float32
- )
-
- torch_output = attn.torch_forward(hidden_states)
- ort_output = attn.onnx_forward(hidden_states)
- if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ for rotary_ndims in [32, 64]:
+ for hidden_size in [768, 960]:
+ attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)
+
+ hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
+ torch.float32
+ )
+
+ torch_output = attn.torch_forward(hidden_states)
+ ort_output = attn.onnx_forward(hidden_states)
+ if ort_output is not None:
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
+ print(
+ f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
+ )
def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
@@ -466,7 +474,7 @@ def test_gpt_neox_decoder_masked_self_attention(self):
hidden_states, attention_mask=attention_mask, layer_past=layer_past
)
if ort_output is not None:
- assert torch.allclose(torch_output, ort_output, atol=1e-4)
+ assert torch.allclose(torch_output, ort_output, atol=1e-3)
if __name__ == "__main__":