Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass rotary embedding to attention op #18846

Merged
merged 11 commits into from
Jan 3, 2024
2 changes: 2 additions & 0 deletions docs/ContribOperators.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@
<dd>Corresponding past and present are same tensor, its size is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>qkv_hidden_sizes</tt> : list of ints</dt>
<dd>Hidden dimension of Q, K, V: hidden_size, hidden_size and v_hidden_size</dd>
<dt><tt>rotary_embedding</tt> : int</dt>
<dd>Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size</dd>

Check notice on line 159 in docs/ContribOperators.md

View workflow job for this annotation

GitHub Actions / misspell

[misspell] docs/ContribOperators.md#L159

"Dimention" is a misspelling of "Dimension"
Raw output
./docs/ContribOperators.md:159:4: "Dimention" is a misspelling of "Dimension"
<dt><tt>scale</tt> : float</dt>
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
<dt><tt>unidirectional</tt> : int</dt>
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;

Check warning on line 256 in onnxruntime/contrib_ops/cpu/bert/attention_base.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/bert/attention_base.cc#L256

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/cpu/bert/attention_base.cc:256:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

Check warning on line 256 in onnxruntime/contrib_ops/cpu/bert/attention_base.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/cpu/bert/attention_base.cc#L256

Using C-style cast. Use static_cast<int>(...) instead [readability/casting] [4]
Raw output
onnxruntime/contrib_ops/cpu/bert/attention_base.cc:256:  Using C-style cast.  Use static_cast<int>(...) instead  [readability/casting] [4]
wangyems marked this conversation as resolved.
Show resolved Hide resolved
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class AttentionBase {

is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_embedding_ = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding", 0));
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

Expand Down Expand Up @@ -72,6 +73,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
bool do_rotary_; // whether or not to use rotary embeddings
int rotary_embedding_; // rotary embedding dimension
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct AttentionParameters {
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
bool do_rotary = false, int past_sequence_length = 0) {
bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
assert(num_heads <= max_threads_per_block);

if (do_rotary) {
Expand All @@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
if (qk_head_size != 64 && qk_head_size != 128) {
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
if (rotary_embedding != 32 && rotary_embedding != 64 && rotary_embedding != 128) {
ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention");
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
}

const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
size_t smem_size = 2 * qk_head_size * sizeof(T);
size_t smem_size = 2 * rotary_embedding * sizeof(T);

const dim3 grid(sequence_length, num_heads, batch_size);
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
AddBiasTransposeQKV<T><<<grid, block, smem_size, stream>>>(total_matrix_count, input, biases, output,
qkv_add_bias, qk_head_size, qk_head_size,
qkv_add_bias, rotary_embedding, qk_head_size,
step, format);
#else
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
Expand Down Expand Up @@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
Expand All @@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<half>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
}
}

Expand All @@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output, bool /*enable_half4*/,
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
int past_sequence_length) {
int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
Expand All @@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<float>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
past_sequence_length);
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);

// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.past_sequence_length);
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
}
return Status::OK();
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_embedding",
wangyems marked this conversation as resolved.
Show resolved Hide resolved
"Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size",

Check notice on line 337 in onnxruntime/core/graph/contrib_ops/bert_defs.cc

View workflow job for this annotation

GitHub Actions / misspell

[misspell] onnxruntime/core/graph/contrib_ops/bert_defs.cc#L337

"Dimention" is a misspelling of "Dimension"
Raw output
./onnxruntime/core/graph/contrib_ops/bert_defs.cc:337:15: "Dimention" is a misspelling of "Dimension"
wangyems marked this conversation as resolved.
Show resolved Hide resolved
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
Expand Down
36 changes: 22 additions & 14 deletions onnxruntime/test/python/transformers/test_parity_neox_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def create_neox_attention_graph(
qkv_weight,
qkv_bias,
num_heads,
rotary_embedding,
):
nodes = [
helper.make_node(
Expand All @@ -43,6 +44,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding=rotary_embedding,
domain="com.microsoft",
),
]
Expand Down Expand Up @@ -174,13 +176,13 @@ def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):


class GPTNeoXAttention(nn.Module):
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0, rotary_ndims=64):
super().__init__()
self.do_rotary = True
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size)
self.rotary_ndims = rotary_ndims
max_positions = 2048
self.register_buffer(
"bias",
Expand All @@ -197,6 +199,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
# self.query_key_value.bias.data.copy_(torch.tensor(np.zeros((3 * hidden_size))))

if past_seq_len > 0:
assert self.rotary_ndims == self.head_size
self.onnx_graph = create_neox_decoder_masked_self_attention_graph(
batch_size,
seq_len,
Expand All @@ -220,6 +223,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
.transpose(0, 1),
self.query_key_value.bias.reshape(self.num_attention_heads, 3, -1).transpose(0, 1).reshape(-1),
self.num_attention_heads,
self.rotary_ndims,
)

@classmethod
Expand Down Expand Up @@ -422,17 +426,21 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
for hidden_size in [768]:
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)

hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
torch.float32
)

torch_output = attn.torch_forward(hidden_states)
ort_output = attn.onnx_forward(hidden_states)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-4)
for rotary_ndims in [32, 64]:
for hidden_size in [768, 960]:
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size, 0, rotary_ndims)

hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
torch.float32
)

torch_output = attn.torch_forward(hidden_states)
ort_output = attn.onnx_forward(hidden_states)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-3)
print(
f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}, {rotary_ndims}"
)

def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
Expand Down Expand Up @@ -466,7 +474,7 @@ def test_gpt_neox_decoder_masked_self_attention(self):
hidden_states, attention_mask=attention_mask, layer_past=layer_past
)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-4)
assert torch.allclose(torch_output, ort_output, atol=1e-3)


if __name__ == "__main__":
Expand Down
Loading