From e3d6d03c74d006a20cede2583b3a5545a6514623 Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 15 Dec 2023 21:29:08 +0000 Subject: [PATCH 1/8] init --- .../contrib_ops/cpu/bert/attention_base.cc | 1 + .../contrib_ops/cpu/bert/attention_base.h | 2 ++ .../contrib_ops/cpu/bert/attention_common.h | 1 + .../cuda/bert/add_bias_transpose.cu | 19 ++++++++++--------- .../cuda/bert/add_bias_transpose.h | 2 +- .../cuda/bert/attention_prepare_qkv.cu | 3 ++- .../core/graph/contrib_ops/bert_defs.cc | 4 ++++ .../test_parity_neox_attention.py | 12 ++++++++---- 8 files changed, 29 insertions(+), 15 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc index 5d224bdc2235f..515a967aa2386 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.cc @@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, output_parameters->is_unidirectional = is_unidirectional_; output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr); output_parameters->do_rotary = do_rotary_; + output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_; output_parameters->mask_filter_value = mask_filter_value_; output_parameters->scale = scale_; output_parameters->mask_type = mask_type; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 5ee40c4b98664..254a5e161ce04 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -38,6 +38,7 @@ class AttentionBase { is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; + rotary_embedding_ = info.GetAttrOrDefault("rotary_embedding", 0); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); @@ -72,6 +73,7 @@ class AttentionBase { bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V. bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer bool do_rotary_; // whether or not to use rotary embeddings + int rotary_embedding_; // rotary embedding dimension float mask_filter_value_; // the value to be used for filtered out positions float scale_; // the scale to be used for softmax }; diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_common.h b/onnxruntime/contrib_ops/cpu/bert/attention_common.h index a7f83469a768d..c9ed23895b60c 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_common.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_common.h @@ -56,6 +56,7 @@ struct AttentionParameters { int v_head_size; // hidden size per head of V int num_heads; int num_splits; + int rotary_embedding; bool is_unidirectional; bool past_present_share_buffer; bool do_rotary; diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 626e4c0b87a3c..88732f2aeb2ed 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -640,7 +640,7 @@ void InvokeAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count, - bool do_rotary = false, int past_sequence_length = 0) { + bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) { assert(num_heads <= max_threads_per_block); if (do_rotary) { @@ -650,20 +650,20 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (qk_head_size != 64 && qk_head_size != 128) { - ORT_THROW("qk_head_size must be 64 or 128 for rotary attention"); + if (rotary_embedding != 64 && rotary_embedding != 128) { + ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention"); } const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length; - size_t smem_size = 2 * qk_head_size * sizeof(T); + size_t smem_size = 2 * rotary_embedding * sizeof(T); const dim3 grid(sequence_length, num_heads, batch_size); const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1); AddBiasTransposeQKV<<>>(total_matrix_count, input, biases, output, - qkv_add_bias, qk_head_size, qk_head_size, + qkv_add_bias, rotary_embedding, qk_head_size, step, format); #else ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__); @@ -727,7 +727,7 @@ void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size, - half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) { + half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) { total_matrix_count = std::max(num_matrices, total_matrix_count); if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) { const int H = qk_head_size / 4; @@ -753,7 +753,7 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, - qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length); + qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length); } } @@ -763,7 +763,7 @@ void LaunchAddBiasTranspose( const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const float* input, const float* biases, float* output, bool /*enable_half4*/, const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary, - int past_sequence_length) { + int rotary_embedding, int past_sequence_length) { total_matrix_count = std::max(num_matrices, total_matrix_count); if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) { const int H = qk_head_size / 4; @@ -789,7 +789,8 @@ void LaunchAddBiasTranspose( InvokeAddBiasTranspose( stream, num_matrices, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, - qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length); + qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, + past_sequence_length); } } diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h index d903267c99a01..efc31db43bcdb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h @@ -33,7 +33,7 @@ void LaunchAddBiasTranspose( cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block, const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size, const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr, - int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0); + int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0); // Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format. // For self attention: diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu index 5c65a30918ece..a513d9e8d2211 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu @@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters, LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block, batch_size, sequence_length, num_heads, qk_head_size, data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias, - 3, parameters.do_rotary, parameters.past_sequence_length); + 3, parameters.do_rotary, parameters.rotary_embedding, + parameters.past_sequence_length); } return Status::OK(); } diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index ea67218b5c927..553d8712430c0 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) + .Attr("rotary_embedding", + "Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size / 2", + AttributeProto::INT, + OPTIONAL_VALUE) .Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f", AttributeProto::FLOAT, diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py index 8c8e871a854b0..fbd3b2b3a0171 100644 --- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py @@ -43,6 +43,7 @@ def create_neox_attention_graph( num_heads=num_heads, unidirectional=1, do_rotary=1, + rotary_embedding = 64, domain="com.microsoft", ), ] @@ -180,7 +181,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): self.num_attention_heads = num_head self.hidden_size = hidden_size self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = int(self.head_size) + self.rotary_ndims = 64 max_positions = 2048 self.register_buffer( "bias", @@ -422,7 +423,7 @@ def test_gpt_neox_attention(self): for batch_size in [1, 2, 4, 8]: for seq_len in [32, 128, 512, 1024, 2048]: for num_head in [12]: - for hidden_size in [768]: + for hidden_size in [768, 960]: attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size) hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to( @@ -432,7 +433,10 @@ def test_gpt_neox_attention(self): torch_output = attn.torch_forward(hidden_states) ort_output = attn.onnx_forward(hidden_states) if ort_output is not None: - assert torch.allclose(torch_output, ort_output, atol=1e-4) + assert torch.allclose(torch_output, ort_output, atol=1e-3) + print( + f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}" + ) def test_gpt_neox_decoder_masked_self_attention(self): for batch_size in [1, 2, 4, 8]: @@ -466,7 +470,7 @@ def test_gpt_neox_decoder_masked_self_attention(self): hidden_states, attention_mask=attention_mask, layer_past=layer_past ) if ort_output is not None: - assert torch.allclose(torch_output, ort_output, atol=1e-4) + assert torch.allclose(torch_output, ort_output, atol=1e-3) if __name__ == "__main__": From d3bc5e6cee7d01316b5ec7c0515ed61b3e44ef5c Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 15 Dec 2023 21:30:42 +0000 Subject: [PATCH 2/8] fix annotation --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 553d8712430c0..678efa205cd3f 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -334,7 +334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT, OPTIONAL_VALUE) .Attr("rotary_embedding", - "Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size / 2", + "Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", From 25805837606e61e022bf9e5474cea69de074c182 Mon Sep 17 00:00:00 2001 From: Your Date: Fri, 15 Dec 2023 22:19:07 +0000 Subject: [PATCH 3/8] add dim=32 and refine test --- .../cuda/bert/add_bias_transpose.cu | 2 +- .../test_parity_neox_attention.py | 34 +++++++++++-------- 2 files changed, 20 insertions(+), 16 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 88732f2aeb2ed..603236b421630 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -650,7 +650,7 @@ void InvokeAddBiasTranspose( if (format != 1 && format != 2 && format != 3) { ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } - if (rotary_embedding != 64 && rotary_embedding != 128) { + if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) { ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { diff --git a/onnxruntime/test/python/transformers/test_parity_neox_attention.py b/onnxruntime/test/python/transformers/test_parity_neox_attention.py index fbd3b2b3a0171..a98bb623beaea 100644 --- a/onnxruntime/test/python/transformers/test_parity_neox_attention.py +++ b/onnxruntime/test/python/transformers/test_parity_neox_attention.py @@ -29,6 +29,7 @@ def create_neox_attention_graph( qkv_weight, qkv_bias, num_heads, + rotary_embedding, ): nodes = [ helper.make_node( @@ -43,7 +44,7 @@ def create_neox_attention_graph( num_heads=num_heads, unidirectional=1, do_rotary=1, - rotary_embedding = 64, + rotary_embedding=rotary_embedding, domain="com.microsoft", ), ] @@ -175,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0): class GPTNeoXAttention(nn.Module): - def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): + def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64): super().__init__() self.do_rotary = True self.num_attention_heads = num_head self.hidden_size = hidden_size self.head_size = self.hidden_size // self.num_attention_heads - self.rotary_ndims = 64 + self.rotary_ndims = rotary_ndims max_positions = 2048 self.register_buffer( "bias", @@ -198,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): # self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size)))) if past_seq_len > 0: + assert self.rotary_ndims == self.head_size self.onnx_graph = create_neox_decoder_masked_self_attention_graph( batch_size, seq_len, @@ -221,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0): .transpose(0, 1), self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1), self.num_attention_heads, + self.rotary_ndims, ) @classmethod @@ -423,21 +426,22 @@ def test_gpt_neox_attention(self): for batch_size in [1, 2, 4, 8]: for seq_len in [32, 128, 512, 1024, 2048]: for num_head in [12]: - for hidden_size in [768, 960]: - attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size) + for rotary_ndims in [32, 64]: + for hidden_size in [768, 960]: + attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims) - hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to( - torch.float32 - ) - - torch_output = attn.torch_forward(hidden_states) - ort_output = attn.onnx_forward(hidden_states) - if ort_output is not None: - assert torch.allclose(torch_output, ort_output, atol=1e-3) - print( - f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}" + hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to( + torch.float32 ) + torch_output = attn.torch_forward(hidden_states) + ort_output = attn.onnx_forward(hidden_states) + if ort_output is not None: + assert torch.allclose(torch_output, ort_output, atol=1e-3) + print( + f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}" + ) + def test_gpt_neox_decoder_masked_self_attention(self): for batch_size in [1, 2, 4, 8]: for past_seq_len in [1, 4, 32, 128, 512, 1024]: From c604e675f62a308fa618da6cd7dd7c72b185c86a Mon Sep 17 00:00:00 2001 From: Your Date: Mon, 18 Dec 2023 18:43:59 +0000 Subject: [PATCH 4/8] update --- onnxruntime/contrib_ops/cpu/bert/attention_base.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index 254a5e161ce04..e1a6829839995 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -38,7 +38,7 @@ class AttentionBase { is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_embedding_ = info.GetAttrOrDefault("rotary_embedding", 0); + rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding", 0)); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); From 3d2ac02c9690b964d2abc036f4d2f7eab16d5381 Mon Sep 17 00:00:00 2001 From: Your Date: Tue, 19 Dec 2023 21:36:07 +0000 Subject: [PATCH 5/8] update docs --- docs/ContribOperators.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index e5b43ddba8cc7..ed84af41381bf 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -155,6 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
+
rotary_embedding : int
+
Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int
From dfeec46ceca0103b0fc1b747b741015d5676c7a2 Mon Sep 17 00:00:00 2001 From: Your Date: Wed, 27 Dec 2023 23:56:36 +0000 Subject: [PATCH 6/8] update --- onnxruntime/contrib_ops/cpu/bert/attention_base.h | 2 +- onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu | 2 +- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index e1a6829839995..a6782daa58f1a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -38,7 +38,7 @@ class AttentionBase { is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; do_rotary_ = info.GetAttrOrDefault("do_rotary", 0) == 1; - rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding", 0)); + rotary_embedding_ = static_cast(info.GetAttrOrDefault("rotary_embedding_dim", 0)); mask_filter_value_ = info.GetAttrOrDefault("mask_filter_value", -10000.0f); scale_ = info.GetAttrOrDefault("scale", 0.0f); diff --git a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu index 603236b421630..1ea2540db486f 100644 --- a/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu +++ b/onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu @@ -651,7 +651,7 @@ void InvokeAddBiasTranspose( ORT_THROW("format must be 1, 2 or 3 for rotary attention"); } if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) { - ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention"); + ORT_THROW("rotary_embedding must be 32, 64 or 128 for rotary attention"); } if (v_head_size != -1 && qk_head_size != v_head_size) { ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention"); diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 678efa205cd3f..42f259d5b4746 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -333,8 +333,8 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "Whether to use rotary position embedding. Default value is 0.", AttributeProto::INT, OPTIONAL_VALUE) - .Attr("rotary_embedding", - "Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size", + .Attr("rotary_embedding_dim", + "Dimention of rotary embedding. Limited to 32, 64 or 128. Default value is head_size", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", From 1eeeb98bf4910f01ec58f904fd647e88122d6b6f Mon Sep 17 00:00:00 2001 From: Your Date: Wed, 27 Dec 2023 23:58:57 +0000 Subject: [PATCH 7/8] review comments --- onnxruntime/core/graph/contrib_ops/bert_defs.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 42f259d5b4746..f8f63650615fd 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -334,7 +334,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( AttributeProto::INT, OPTIONAL_VALUE) .Attr("rotary_embedding_dim", - "Dimention of rotary embedding. Limited to 32, 64 or 128. Default value is head_size", + "Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size", AttributeProto::INT, OPTIONAL_VALUE) .Attr("mask_filter_value", From 9fc215fb748836c4da9cfe4b3287962f457a6066 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Wed, 27 Dec 2023 15:59:55 -0800 Subject: [PATCH 8/8] update docs --- docs/ContribOperators.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 3c53ba5454b36..38fceef67de25 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -155,8 +155,8 @@ This version of the operator has been available since version 1 of the 'com.micr
Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)
qkv_hidden_sizes : list of ints
Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size
-
rotary_embedding : int
-
Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size
+
rotary_embedding_dim : int
+
Dimension of rotary embedding. Limited to 32, 64 or 128. Default value is head_size
scale : float
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
unidirectional : int