From 3096988dc091735d9ddbcf43055c62805723bc9e Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:18:17 -0800 Subject: [PATCH 1/5] Update multihead_attention.cu --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index 6f98312e4067d..cd08d3cf1163e 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -68,6 +68,7 @@ MultiHeadAttention::MultiHeadAttention(const OpKernelInfo& info) scale_ = info.GetAttrOrDefault("scale", 0.0f); past_present_share_buffer_ = info.GetAttrOrDefault("past_present_share_buffer", 0LL) != 0LL; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; using HipT = typename ToHipType::MappedType; using AttentionTunableOp = GemmSoftmaxGemmPermuteTunableOp; From 19ffe27ce6d489f5e42e13549a4e0aeb397bbdc3 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:19:35 -0800 Subject: [PATCH 2/5] Update multihead_attention.h --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.h | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h index 84d8b76bbfebe..1d676d7a7bcac 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.h @@ -25,6 +25,7 @@ class MultiHeadAttention final : public RocmKernel { float mask_filter_value_; float scale_; bool past_present_share_buffer_{false}; + bool is_unidirectional_{false}; // type-erased GemmSoftmaxGemmPermuteTunableOp, the reason for this is: // 1. We don't want to include the cuh file where GemmSoftmaxGemmPermuteTunableOp is defined. From 67be4068d9efffb3f1f4b01e17c1b2de6d3180d7 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 23 Jan 2024 14:21:34 -0800 Subject: [PATCH 3/5] Update multihead_attention.cu --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index cd08d3cf1163e..a8d7bb2672b4b 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -123,7 +123,7 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, &attn, - num_heads_, mask_filter_value_, scale_, + num_heads_, is_unidirectional_, mask_filter_value_, scale_, past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { From 63efcf4e1bc3cfd2d5cb89cc0353c5c8a53f28bd Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 23 Jan 2024 15:06:01 -0800 Subject: [PATCH 4/5] Update multihead_attention.cu --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index a8d7bb2672b4b..ea5cac9e30387 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -123,7 +123,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, &attn, - num_heads_, is_unidirectional_, mask_filter_value_, scale_, + num_heads_, false, /*is_unidirectional_*/ + mask_filter_value_, scale_, past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) { From a8e0dce83952b955a88d3077618b070db77daf21 Mon Sep 17 00:00:00 2001 From: Ye Wang <52801275+wangyems@users.noreply.github.com> Date: Tue, 23 Jan 2024 16:28:37 -0800 Subject: [PATCH 5/5] Update multihead_attention.cu --- onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu index ea5cac9e30387..09e7d61b71db9 100644 --- a/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu +++ b/onnxruntime/contrib_ops/rocm/bert/multihead_attention.cu @@ -122,9 +122,8 @@ Status MultiHeadAttention::ComputeInternal(OpKernelContext* context) const { query, key, value, bias, key_padding_mask, relative_position_bias, past_key, past_value, past_seq_len, - &attn, - num_heads_, false, /*is_unidirectional_*/ - mask_filter_value_, scale_, + &attn, num_heads_, + mask_filter_value_, scale_, false, /*is_unidirectional_*/ past_present_share_buffer_, false, device_prop.maxThreadsPerBlock)); if (attn_type_ == kDecoderMaskedMultiHeadAttention && attn.sequence_length != 1) {