diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 3b6cfb6904491..691463ebd2259 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -733,7 +733,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetArg(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; @@ -857,7 +857,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - return GetArg(impl, invoker, params); + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } @@ -875,7 +875,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { params->attention->sequence_length != params->attention->total_sequence_length, "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); - return GetArg(impl, invoker, params); + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } @@ -890,17 +890,22 @@ GemmSoftmaxGemmPermuteTunableOp::GemmSoftmaxGemmPermuteTunableOp() { return GemmSoftmaxGemmPermuteGenericPipeline::Run(params, false); }); -#define GET_CK_IMPLEMENTATION(USE_BIAS, USE_MASK) \ - for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { \ - this->RegisterOp(std::move(op)); \ +#ifdef USE_COMPOSABLE_KERNEL + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); } -#ifdef USE_COMPOSABLE_KERNEL - GET_CK_IMPLEMENTATION(false, false); - GET_CK_IMPLEMENTATION(true, false); - GET_CK_IMPLEMENTATION(false, true); - GET_CK_IMPLEMENTATION(true, true); + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } + + for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps()) { + this->RegisterOp(std::move(op)); + } #endif } diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 2e88eff33ee1e..2b506d95adc86 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -172,8 +172,7 @@ def _test_gemm_softmax_gemm_permute( pre_softmax_attn_scores = pre_softmax_attn_scores * scale if attn_bias is not None: pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias - converted_mask = None - causal_mask = None + correct_causal_mask = np.full((seqlen, total_seqlen), 1) if attn_mask is not None: filter_value = -10000.0