Skip to content

Commit

Permalink
GQA MLFloat16 cpu (#22102)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <[email protected]>
  • Loading branch information
wangyems and Your Name authored Sep 24, 2024
1 parent 5fa4505 commit 6cc06ad
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 135 deletions.
4 changes: 2 additions & 2 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ Do not modify directly.*
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|
|GridSample|*in* X:**T1**<br> *in* Grid:**T1**<br> *out* Y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float)|
|GroupQueryAttention|*in* query:**T**<br> *in* key:**T**<br> *in* value:**T**<br> *in* past_key:**T**<br> *in* past_value:**T**<br> *in* seqlens_k:**M**<br> *in* total_sequence_length:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**<br> *out* present_key:**T**<br> *out* present_value:**T**|1+|**M** = tensor(int32)<br/> **T** = tensor(float), tensor(float16)|
|Inverse|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|MatMulBnb4|*in* A:**T1**<br> *in* B:**T2**<br> *in* absmax:**T1**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)|
|MatMulFpQ4|*in* A:**T1**<br> *in* B:**T2**<br> *in* B_shape:**T3**<br> *out* Y:**T1**|1+|**T1** = tensor(float)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int64)|
Expand All @@ -508,7 +508,7 @@ Do not modify directly.*
|QuantizeLinear|*in* x:**T1**<br> *in* y_scale:**T1**<br> *in* y_zero_point:**T2**<br> *out* y:**T2**|1+|**T1** = tensor(float)<br/> **T2** = tensor(int16), tensor(int4), tensor(int8), tensor(uint16), tensor(uint4), tensor(uint8)|
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float)|
|RotaryEmbedding|*in* input:**T**<br> *in* position_ids:**M**<br> *in* cos_cache:**T**<br> *in* sin_cache:**T**<br> *out* output:**T**|1+|**M** = tensor(int64)<br/> **T** = tensor(float), tensor(float16)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *in* seed:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_bias_sum:**T**|1+|**T** = tensor(double), tensor(float)|
Expand Down
20 changes: 14 additions & 6 deletions onnxruntime/contrib_ops/cpu/bert/attention_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ Status AddBiasTranspose(const Tensor* qkv, // Input: Q/K/V dat
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
}}; // For element-wise add

// Allocate space for output of Q(BS, D) + bias(D)
Expand Down Expand Up @@ -132,13 +132,13 @@ Status AddBiasReshape(const Tensor* qkv, // Input: Q/K/V data - query is
constexpr size_t element_size = sizeof(T);
ProcessBroadcastSpanFuncs add_funcs{
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.ScalarInput0<T>() + per_iter_bh.EigenInput1<T>().array();
per_iter_bh.OutputEigen<float>() = per_iter_bh.ScalarInput0<float>() + per_iter_bh.EigenInput1<float>().array();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>().array() + per_iter_bh.ScalarInput1<T>();
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>().array() + per_iter_bh.ScalarInput1<float>();
},
[](BroadcastHelper& per_iter_bh) {
per_iter_bh.OutputEigen<T>() = per_iter_bh.EigenInput0<T>() + per_iter_bh.EigenInput1<T>();
per_iter_bh.OutputEigen<float>() = per_iter_bh.EigenInput0<float>() + per_iter_bh.EigenInput1<float>();
}}; // For element-wise add

// Get Q's bias from combined bias
Expand Down Expand Up @@ -219,6 +219,10 @@ template Status MaybeTransposeToBNSHAndAddBias<float>(OpKernelContext* context,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

template Status MaybeTransposeToBNSHAndAddBias<MLFloat16>(OpKernelContext* context, AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, const Tensor* bias, int bias_offset, OrtValue& out);

template <typename T>
Status MaybeTransposeToBNSH(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
Expand All @@ -242,5 +246,9 @@ template Status MaybeTransposeToBNSH<float>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

template Status MaybeTransposeToBNSH<MLFloat16>(AllocatorPtr allocator,
int batch_size, int num_heads, int sequence_length, int head_size,
const Tensor* in, OrtValue& out);

} // namespace contrib
} // namespace onnxruntime
96 changes: 75 additions & 21 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class GQAAttentionBase {
int seqlen_present_kv_cache = static_cast<int>(present_key->Shape().GetDims()[2]);

// Compute the attention score.
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(T);
size_t bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * seqlen_present_kv_cache * sizeof(float);
auto attention_probs = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(attention_probs, BufferDeleter(allocator));

Expand All @@ -87,16 +87,17 @@ class GQAAttentionBase {
bool past_present_share_buffer = past_key_data == present_key_data && past_value_data == present_value_data;

const T* k = packed_qkv ? Q + num_heads_ * sequence_length * head_size : K;
ComputeAttentionProbs<T>(static_cast<T*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
ComputeAttentionProbs<T>(static_cast<float*>(attention_probs), Q, k, seqlens_k->Data<int32_t>(), batch_size,
sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size, past_key_data,
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp);
present_key_data, past_present_share_buffer, packed_qkv, is_prompt, tp, allocator);

// Compute the attentionScore * Value: out(B, N, S, H_v) = attention_probs(B, N, S, T) x V(B, N, T, H_v)
const T* v = packed_qkv ? Q + (num_heads_ + kv_num_heads_) * sequence_length * head_size : V;
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<T*>(attention_probs), v, seqlens_k->Data<int32_t>(),
ComputeVxAttentionScore(output->MutableData<T>(), static_cast<float*>(attention_probs), v,
seqlens_k->Data<int32_t>(),
batch_size, sequence_length, seqlen_past_kv_cache, seqlen_present_kv_cache, head_size,
hidden_size, past_value_data, present_value_data, past_present_share_buffer, packed_qkv,
is_prompt, tp);
is_prompt, tp, allocator);

return Status::OK();
}
Expand All @@ -106,7 +107,7 @@ class GQAAttentionBase {
// attention_probs(B, N, S, T) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, T, H -> B, N, H, T)
// attention_probs(B, N, S, T) = Softmax(attention_probs)
template <typename T>
void ComputeAttentionProbs(T* attention_probs, // output buffer with size BxNxSxT
void ComputeAttentionProbs(float* attention_probs, // output buffer with size BxNxSxT
const T* Q, // Q data. Its size is BxNxSxH
const T* K, // k data. Its size is BxNxLxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
Expand All @@ -120,7 +121,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
ThreadPool* tp) const { // thread pool
ThreadPool* tp, // thread pool
AllocatorPtr allocator) const { // allocator for temporary buffer
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
Expand All @@ -131,7 +133,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset(present_key, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_key,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}

const size_t loop_len = batch_size * num_heads_;
Expand Down Expand Up @@ -164,7 +168,7 @@ class GQAAttentionBase {
const size_t past_chunk_length = past_seqlen * head_size;

const ptrdiff_t output_offset = SafeInt<ptrdiff_t>(i) * sequence_length * present_buffer_sequence_length;
T* output = attention_probs + output_offset;
float* output = attention_probs + output_offset;

const T* k;
if (packed_qkv) {
Expand All @@ -190,12 +194,28 @@ class GQAAttentionBase {
q = Q + q_input_chunk_length * i;
}

math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/, output,
static_cast<int>(present_buffer_sequence_length), nullptr);
if constexpr (std::is_same<T, float>::value) {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(q_k_fp32, BufferDeleter(allocator));

float* q_fp32 = static_cast<float*>(q_k_fp32);
MlasConvertHalfToFloatBuffer(q, q_fp32, head_size * sequence_length);

float* k_fp32 = q_fp32 + head_size * sequence_length;
MlasConvertHalfToFloatBuffer(k, k_fp32, head_size * total_seqlen);

math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q_fp32,
static_cast<int>(head_size), k_fp32, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
}

// compute Softmax
T* output_softmax = output;
float* output_softmax = output;
for (size_t seq = 0; seq < sequence_length; seq++) {
size_t seq_causal_length = past_seqlen + seq + 1;
if (local_window_size_ > 0 && seq_causal_length > static_cast<size_t>(local_window_size_) + 1) {
Expand Down Expand Up @@ -237,7 +257,7 @@ class GQAAttentionBase {

template <typename T>
void ComputeVxAttentionScore(T* output, // buffer for the result with size BxSxNxH
const T* attention_probs, // Attention probs with size BxNxSxT
const float* attention_probs, // Attention probs with size BxNxSxT
const T* V, // V value with size BxN_kvxSxH
const int32_t* seqlens_k, // total - 1 sequence lengths tensor
const size_t batch_size, // batch size
Expand All @@ -251,7 +271,8 @@ class GQAAttentionBase {
const bool past_present_share_buffer, // whether present key and value share the same buffer
const bool packed_qkv, // whether Q, K, V are packed
const bool is_prompt, // whether it is prompt
ThreadPool* tp) const {
ThreadPool* tp,
AllocatorPtr allocator) const {
const ptrdiff_t packed_batch_stride =
packed_qkv ? SafeInt<ptrdiff_t>(num_heads_ + 2 * kv_num_heads_) * sequence_length * head_size
: SafeInt<ptrdiff_t>(0);
Expand All @@ -261,7 +282,9 @@ class GQAAttentionBase {
const size_t present_buff_chunk_length = present_buffer_sequence_length * head_size; // T x H

if (!past_present_share_buffer) {
memset(present_value, 0, batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
memset((void*)present_value,
0,
batch_size * kv_num_heads_ * present_buffer_sequence_length * head_size * sizeof(T));
}

const size_t loop_len = batch_size * num_heads_;
Expand All @@ -285,6 +308,13 @@ class GQAAttentionBase {
unit_cost.bytes_loaded += bytes_to_copy_trans_all;
unit_cost.bytes_stored += bytes_to_copy_trans_all;

size_t output_fp32_bytes = 0;
if constexpr (std::is_same<T, MLFloat16>::value) {
output_fp32_bytes = SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size * sizeof(float);
}
auto output_fp32 = allocator->Alloc(output_fp32_bytes);
BufferUniquePtr scratch_buffer(output_fp32, BufferDeleter(allocator));

ThreadPool::TryParallelFor(tp, loop_len, unit_cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
for (std::ptrdiff_t i = begin; i != end; ++i) {
const size_t batch_index = i / num_heads_;
Expand All @@ -305,15 +335,39 @@ class GQAAttentionBase {
i / kv_num_heads_factor);
}

T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
ptrdiff_t attention_probs_offset = SafeInt<ptrdiff_t>(sequence_length) * present_buffer_sequence_length * i;

math::GemmEx<T, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen, 1.f, /*alpha*/
attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v, static_cast<int>(head_size),
0.0f /*beta*/, output_current, static_cast<int>(hidden_size), nullptr);
if constexpr (std::is_same<T, float>::value) {
T* output_current = output + (batch_index * sequence_length * num_heads_ + head_index) * head_size;
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
1.f, /*alpha*/ attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v,
static_cast<int>(head_size), 0.0f /*beta*/, output_current,
static_cast<int>(hidden_size), nullptr);
} else {
size_t bytes = head_size * total_seqlen * sizeof(float);
auto v_fp32 = allocator->Alloc(bytes);
BufferUniquePtr scratch_buffer(v_fp32, BufferDeleter(allocator));

float* v_fp32_ptr = static_cast<float*>(v_fp32);
MlasConvertHalfToFloatBuffer(v, v_fp32_ptr, head_size * total_seqlen);

float* output_fp32_current = static_cast<float*>(output_fp32) +
(batch_index * sequence_length * num_heads_ + head_index) * head_size;
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasNoTrans, sequence_length, head_size, total_seqlen,
1.f, /*alpha*/ attention_probs + attention_probs_offset,
static_cast<int>(present_buffer_sequence_length), v_fp32_ptr,
static_cast<int>(head_size), 0.0f /*beta*/, output_fp32_current,
static_cast<int>(hidden_size), nullptr);
}
}
});

if constexpr (std::is_same<T, MLFloat16>::value) {
MlasConvertFloatToHalfBuffer(static_cast<float*>(output_fp32),
output,
SafeInt<size_t>(sequence_length) * batch_size * num_heads_ * head_size);
}
}
};

Expand Down
Loading

0 comments on commit 6cc06ad

Please sign in to comment.