Skip to content

Commit

Permalink
broadcast attn bias in decoder masked mha
Browse files Browse the repository at this point in the history
  • Loading branch information
tianleiwu committed Aug 12, 2024
1 parent 91284f0 commit c76f294
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
const Tensor* key = context->Input<Tensor>(1);
const Tensor* value = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* relative_position_bias = context->Input<Tensor>(4);
const Tensor* attention_bias = context->Input<Tensor>(4);
const Tensor* past_key = context->Input<Tensor>(kPastInputIndex);
const Tensor* past_value = context->Input<Tensor>(kPastInputIndex + 1);
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
Expand All @@ -80,7 +80,7 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
value,
bias,
mask_index,
relative_position_bias,
attention_bias,
past_key,
past_value,
past_seq_len,
Expand Down Expand Up @@ -141,16 +141,16 @@ Status DecoderMaskedMultiHeadAttention<T1, T2>::ComputeInternal(OpKernelContext*
// Update the q buffers
parameters.q = const_cast<T1*>(query->Data<T1>());

// Update the relative position bias for self attention
if (relative_position_bias != nullptr) {
parameters.relative_attention_bias = const_cast<T1*>(relative_position_bias->Data<T1>());
// Update the attention bias for self attention
if (attention_bias != nullptr) {
parameters.attention_bias = const_cast<T1*>(attention_bias->Data<T1>());
}

// Decoder cross-attention
if (past_key == nullptr && present_key == nullptr) {
if (relative_position_bias != nullptr) {
if (attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"DecoderMaskedMultiHeadAttention does not support relative position bias for cross-attention");
"DecoderMaskedMultiHeadAttention does not support attention bias for cross-attention");
}

parameters.is_cross_attention = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
const Tensor* past = context->Input<Tensor>(kPastInputIndex);
const Tensor* relative_position_bias = context->Input<Tensor>(5);
const Tensor* attention_bias = context->Input<Tensor>(5);
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
const Tensor* beam_width = context->Input<Tensor>(kBeamWidthInputIndex);
const Tensor* cache_indir = context->Input<Tensor>(kCacheIndirectionInputIndex);
Expand All @@ -61,7 +61,7 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
bias->Shape(),
mask_index,
past,
relative_position_bias,
attention_bias,
&parameters,
device_prop.maxThreadsPerBlock,
past_seq_len));
Expand All @@ -85,8 +85,8 @@ Status DecoderMaskedSelfAttention<T1, T2>::ComputeInternal(OpKernelContext* cont
}

// TODO(hasesh): If there is a need, we will support this later
if (relative_position_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support relative position bias currently");
if (attention_bias != nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "DecoderMaskedSelfAttention does not support attention bias currently");
}

// TODO(hasesh): Support more mask types. Currently, it only supports the HuggingFace GreedySearch/BeamSearch pattern.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,18 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;

// The offset of attention bias for current head.
int64_t attn_bias_offset = 0;
if (params.attention_bias != nullptr && params.attention_bias_dims.size() == 4) {
// Support broadcasting the first and second dimensions of attention bias.
if (params.attention_bias_dims[0] > 1) {
attn_bias_offset = static_cast<int64_t>(bbi) * params.num_heads * params.sequence_length * params.total_sequence_length;
}
if (params.attention_bias_dims[1] > 1) {
attn_bias_offset += static_cast<int64_t>(hi) * params.sequence_length * params.total_sequence_length;
}
}

// Trigger the loads from the Q and K buffers.
Qk_vec_k q;
zero(q);
Expand Down Expand Up @@ -286,9 +298,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
if (tidx == 0) {
// Normalize qk.
qk *= inv_sqrt_dh;
if (params.relative_attention_bias != nullptr) {
qk = add_vec(qk,
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + tlength]);
if (params.attention_bias != nullptr) {
qk = add_vec(qk, reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + tlength]);
}
qk_max = qk;
qk_smem[tlength] = qk;
Expand Down Expand Up @@ -386,9 +397,8 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio

// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
if (params.relative_attention_bias != nullptr) {
qk = add_vec(qk,
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + ti]);
if (params.attention_bias != nullptr) {
qk = add_vec(qk, reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + ti]);
}
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
Expand Down Expand Up @@ -479,9 +489,9 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiHeadAttentio
#pragma unroll
for (int k_unroll = 0; k_unroll < K_CACHE_DATA_LOAD_UNROLL; ++k_unroll) {
if (time_bounds_cond[k_unroll] && (tidx % THREADS_PER_KEY == 0)) {
if (params.relative_attention_bias != nullptr) {
if (params.attention_bias != nullptr) {
qk[k_unroll] = add_vec(qk[k_unroll],
reinterpret_cast<T*>(params.relative_attention_bias)[hi * params.sequence_length * params.total_sequence_length + time_step[k_unroll]]);
reinterpret_cast<T*>(params.attention_bias)[attn_bias_offset + time_step[k_unroll]]);
}
qk_max = fmaxf(qk_max, qk[k_unroll]);
qk_smem[time_step[k_unroll]] = qk[k_unroll];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ struct DecoderMaskedMultiHeadAttentionParams : AttentionParameters {
void* v = nullptr;
void* v_bias = nullptr;

void* relative_attention_bias = nullptr;
void* attention_bias = nullptr;

void* k_cache = nullptr;
void* v_cache = nullptr;
Expand Down Expand Up @@ -68,4 +68,4 @@ void mmha_launch_kernel(const DecoderMaskedMultiHeadAttentionParams& params, cud
} // namespace cuda

} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/cuda/bert/multihead_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
// Check whether the relative position bias alignment is good for memory efficient attention.
// Check whether the attention bias alignment is good for memory efficient attention.
(attention_bias == nullptr || parameters.sequence_length % (4 * sizeof(T)) == 0) &&
(nullptr == key_padding_mask || parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) &&
has_memory_efficient_attention(sm, std::is_same<T, MLFloat16>::value,
Expand Down

0 comments on commit c76f294

Please sign in to comment.