Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Dec 15, 2023
1 parent 8f7b89b commit e3d6d03
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 15 deletions.
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
output_parameters->is_unidirectional = is_unidirectional_;
output_parameters->past_present_share_buffer = (past_present_share_buffer_ != 0 && past != nullptr);
output_parameters->do_rotary = do_rotary_;
output_parameters->rotary_embedding = rotary_embedding_ == 0 ? (int)(output_parameters->head_size) : rotary_embedding_;
output_parameters->mask_filter_value = mask_filter_value_;
output_parameters->scale = scale_;
output_parameters->mask_type = mask_type;
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class AttentionBase {

is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
do_rotary_ = info.GetAttrOrDefault<int64_t>("do_rotary", 0) == 1;
rotary_embedding_ = info.GetAttrOrDefault<int64_t>("rotary_embedding", 0);
mask_filter_value_ = info.GetAttrOrDefault<float>("mask_filter_value", -10000.0f);
scale_ = info.GetAttrOrDefault<float>("scale", 0.0f);

Expand Down Expand Up @@ -72,6 +73,7 @@ class AttentionBase {
bool require_same_hidden_size_; // whether the implementation supports different hidden sizes of Q/K/V.
bool past_present_share_buffer_; // whether or not the past (if used) and present tensor share the same buffer
bool do_rotary_; // whether or not to use rotary embeddings
int rotary_embedding_; // rotary embedding dimension
float mask_filter_value_; // the value to be used for filtered out positions
float scale_; // the scale to be used for softmax
};
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/contrib_ops/cpu/bert/attention_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ struct AttentionParameters {
int v_head_size; // hidden size per head of V
int num_heads;
int num_splits;
int rotary_embedding;
bool is_unidirectional;
bool past_present_share_buffer;
bool do_rotary;
Expand Down
19 changes: 10 additions & 9 deletions onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,7 @@ void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, T* qkv_add_bias, const int v_head_size, int total_matrix_count,
bool do_rotary = false, int past_sequence_length = 0) {
bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0) {
assert(num_heads <= max_threads_per_block);

if (do_rotary) {
Expand All @@ -650,20 +650,20 @@ void InvokeAddBiasTranspose(
if (format != 1 && format != 2 && format != 3) {
ORT_THROW("format must be 1, 2 or 3 for rotary attention");
}
if (qk_head_size != 64 && qk_head_size != 128) {
ORT_THROW("qk_head_size must be 64 or 128 for rotary attention");
if (rotary_embedding != 64 && rotary_embedding != 128) {
ORT_THROW("rotary_embedding must be 64 or 128 for rotary attention");
}
if (v_head_size != -1 && qk_head_size != v_head_size) {
ORT_THROW("qk_head_size must be equal to v_head_size for rotary attention");
}

const int step = past_sequence_length == 0 ? sequence_length : past_sequence_length;
size_t smem_size = 2 * qk_head_size * sizeof(T);
size_t smem_size = 2 * rotary_embedding * sizeof(T);

const dim3 grid(sequence_length, num_heads, batch_size);
const dim3 block((qk_head_size / 2 + 31) / 32 * 32, 1, 1);
AddBiasTransposeQKV<T><<<grid, block, smem_size, stream>>>(total_matrix_count, input, biases, output,
qkv_add_bias, qk_head_size, qk_head_size,
qkv_add_bias, rotary_embedding, qk_head_size,
step, format);
#else
ORT_THROW("Rotary Attention is supported on sm >= 530. Current sm is", __CUDA_ARCH__);
Expand Down Expand Up @@ -727,7 +727,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output, bool enable_half4, const int v_head_size,
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int past_sequence_length) {
half* qkv_add_bias, int total_matrix_count, bool do_rotary, int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (enable_half4 && 0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
Expand All @@ -753,7 +753,7 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<half>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding, past_sequence_length);
}
}

Expand All @@ -763,7 +763,7 @@ void LaunchAddBiasTranspose(
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output, bool /*enable_half4*/,
const int v_head_size, float* qkv_add_bias, int total_matrix_count, bool do_rotary,
int past_sequence_length) {
int rotary_embedding, int past_sequence_length) {
total_matrix_count = std::max(num_matrices, total_matrix_count);
if (0 == (qk_head_size % 4) && (v_head_size == -1 || 0 == (v_head_size % 4)) && !do_rotary) {
const int H = qk_head_size / 4;
Expand All @@ -789,7 +789,8 @@ void LaunchAddBiasTranspose(
InvokeAddBiasTranspose<float>(
stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output,
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, past_sequence_length);
qkv_add_bias, v_head_size, total_matrix_count, do_rotary, rotary_embedding,
past_sequence_length);
}
}

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/add_bias_transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size, T* qkv_add_bias = nullptr,
int total_matrix_count = -1, bool do_rotary = false, int past_sequence_length = 0);
int total_matrix_count = -1, bool do_rotary = false, int rotary_embedding = 0, int past_sequence_length = 0);

// Add (bias) and Transpose for separated inputs of Q, K and V, and output Trt format.
// For self attention:
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/contrib_ops/cuda/bert/attention_prepare_qkv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
LaunchAddBiasTranspose(stream, matrix_to_transpose, format, max_threads_per_block,
batch_size, sequence_length, num_heads, qk_head_size,
data.gemm_buffer, data.bias, qkv, true, v_head_size, qkv_add_bias,
3, parameters.do_rotary, parameters.past_sequence_length);
3, parameters.do_rotary, parameters.rotary_embedding,
parameters.past_sequence_length);
}
return Status::OK();
}
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/graph/contrib_ops/bert_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,10 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"Whether to use rotary position embedding. Default value is 0.",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("rotary_embedding",
"Dimention of rotary embedding. Limited to 32 or 64. Default value is head_size / 2",
AttributeProto::INT,
OPTIONAL_VALUE)
.Attr("mask_filter_value",
"The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def create_neox_attention_graph(
num_heads=num_heads,
unidirectional=1,
do_rotary=1,
rotary_embedding = 64,
domain="com.microsoft",
),
]
Expand Down Expand Up @@ -180,7 +181,7 @@ def __init__(self, batch_size, seq_len, num_head, hidden_size, past_seq_len=0):
self.num_attention_heads = num_head
self.hidden_size = hidden_size
self.head_size = self.hidden_size // self.num_attention_heads
self.rotary_ndims = int(self.head_size)
self.rotary_ndims = 64
max_positions = 2048
self.register_buffer(
"bias",
Expand Down Expand Up @@ -422,7 +423,7 @@ def test_gpt_neox_attention(self):
for batch_size in [1, 2, 4, 8]:
for seq_len in [32, 128, 512, 1024, 2048]:
for num_head in [12]:
for hidden_size in [768]:
for hidden_size in [768, 960]:
attn = GPTNeoXAttention(batch_size, seq_len, num_head, hidden_size)

hidden_states = torch.normal(mean=0.5, std=0.1, size=(batch_size, seq_len, hidden_size)).to(
Expand All @@ -432,7 +433,10 @@ def test_gpt_neox_attention(self):
torch_output = attn.torch_forward(hidden_states)
ort_output = attn.onnx_forward(hidden_states)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-4)
assert torch.allclose(torch_output, ort_output, atol=1e-3)
print(
f"Passed: test_gpt_neox_attention: {batch_size}, {seq_len}, {num_head}, {hidden_size}"
)

def test_gpt_neox_decoder_masked_self_attention(self):
for batch_size in [1, 2, 4, 8]:
Expand Down Expand Up @@ -466,7 +470,7 @@ def test_gpt_neox_decoder_masked_self_attention(self):
hidden_states, attention_mask=attention_mask, layer_past=layer_past
)
if ort_output is not None:
assert torch.allclose(torch_output, ort_output, atol=1e-4)
assert torch.allclose(torch_output, ort_output, atol=1e-3)


if __name__ == "__main__":
Expand Down

0 comments on commit e3d6d03

Please sign in to comment.