From b7cb71b1a125e39e1daadd5dab698a0e13607993 Mon Sep 17 00:00:00 2001 From: peixuanzuo Date: Wed, 27 Dec 2023 08:33:29 +0000 Subject: [PATCH] update --- ...ed_gemm_softmax_gemm_permute_pipelines.cuh | 27 +++++++++++-------- .../kernels/gemm_softmax_gemm_permute_test.py | 3 +-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 3b6cfb6904491..691463ebd2259 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -733,7 +733,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetArg(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; @@ -857,7 +857,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - return GetArg(impl, invoker, params); + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } @@ -875,7 +875,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { params->attention->sequence_length != params->attention->total_sequence_length, "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - return GetArg(impl, invoker, params); + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } @@ -890,17 +890,22 @@ GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); }); -#define GET_CK_IMPLEMENTATION(USE_BIAS, USE_MASK) \ - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { \ - this->RegisterOp(std::move(op)); \ +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); } -#ifdef USE_COMPOSABLE_KERNEL - GET_CK_IMPLEMENTATION(false, false); - GET_CK_IMPLEMENTATION(true, false); - GET_CK_IMPLEMENTATION(false, true); - GET_CK_IMPLEMENTATION(true, true); + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } #endif } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 2e88eff33ee1e..2b506d95adc86 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -172,8 +172,7 @@ def _test_gemm_softmax_gemm_permute( pre_softmax_attn_scores = pre_softmax_attn_scores * scale if attn_bias is not None: pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias - converted_mask = None - causal_mask = None + correct_causal_mask = np.full((seqlen, total_seqlen), 1) if attn_mask is not None: filter_value = -10000.0