diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh index 0599318a4022d..be8508670e4b1 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh @@ -31,7 +31,7 @@ using MaskingSpecialization = ck::tensor_operation::device::MaskingSpecializatio using PassThrough = ck::tensor_operation::element_wise::PassThrough; -using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface +using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; @@ -141,6 +141,35 @@ std::vector, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>(); +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, non-masked +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + +// fp16, biased, fp16 masked, basically, two bias +template <> +std::vector, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>>> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>(); + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu index 181e47f012c99..2e32a6594d164 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu index 1577bdf397fa5..91da8d9e1f9a8 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu index 14de59234356b..b08123be18977 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl_fp16_biased_biased.cu @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< return instances; } +using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute< + 2, 1, 1, 1, 1, + F16, F16, F16, F16, ck::Tuple, ck::Tuple<>, + PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough, + MaskingSpecialization::MaskOutUpperTriangle>; + +template <> +std::vector> +GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() { + std::vector> instances; + ck::tensor_operation::device::instance::add_device_operation_instances( + instances, + device_batched_gemm_softmax_gemm_permute_instances< + 2, 1, 1, 1, 1, + F16, ck::Tuple, F32, PreSoftmaxAttentionScoreOp, + MaskingSpecialization::MaskOutUpperTriangle>{}); + + return instances; +} + } // namespace internal } // namespace rocm } // namespace contrib diff --git a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh index 78983ac95e672..54dda4bfa6d2c 100644 --- a/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh +++ b/onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh @@ -732,122 +732,154 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp -auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { +template +auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams* params) { constexpr const int kNumBiasBuffer = static_cast(USE_BIAS) + static_cast(USE_MASK); using Nop = ck::tensor_operation::element_wise::PassThrough; using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp; + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), + "attention mode is not supported, got ", params->attention->mode); + if constexpr (USE_BIAS) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer == nullptr, "biased version only support input with bias"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->bias_buffer != nullptr, "non-biased version only support input without bias"); + } + if constexpr (USE_MASK) { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), + "mask type is not supported, got ", params->attention->mask_type); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer == nullptr, "masked version only support input with mask"); + } else { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); + } + + auto attn = params->attention; + const int& G0 = attn->batch_size; + const int& G1 = attn->num_heads; + const int& M = attn->sequence_length; + const int& N = attn->total_sequence_length; + const int& K = attn->head_size; + const int& O = attn->v_head_size; + { + auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); + ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); + } + + auto [qs, ks, vs] = GetQkvStrides(attn); + std::vector q_buffer_lengths = {G0, G1, M, K}; + std::vector q_buffer_strides = qs.template ForBNSHCoord>(); + std::vector k_buffer_lengths = {G0, G1, N, K}; + std::vector k_buffer_strides = ks.template ForBNSHCoord>(); + std::vector v_buffer_lengths = {G0, G1, O, N}; + std::vector v_buffer_strides = vs.template ForBNHSCoord>(); + std::vector out_buffer_lengths = {G0, G1, M, O}; + std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 + + std::array bias_buffers{}; + std::array, kNumBiasBuffer> bias_lengths{}; + std::array, kNumBiasBuffer> bias_strides{}; + if constexpr (USE_BIAS) { + bias_buffers[0] = const_cast(params->bias_buffer); + bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + bias_strides[0] = {G1 * M * N, M * N, N, 1}; + } + if constexpr (USE_MASK) { + bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; + bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) + if (params->mask_index_dims.size() == 2) { // [B,T] + bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; + } else if (params->mask_index_dims.size() == 3) { // [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] + bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; + } else { + ORT_ENFORCE(false, "Unreachable"); + } + } + + auto arg = impl->MakeArgumentPointer( + params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, + bias_buffers, // Gemm1 bias, as attention mask + {}, // Gemm2 bias + q_buffer_lengths, q_buffer_strides, + k_buffer_lengths, k_buffer_strides, + v_buffer_lengths, v_buffer_strides, + out_buffer_lengths, out_buffer_strides, + bias_lengths, bias_strides, + {}, + {}, + Nop{}, + Nop{}, + Acc0ElementOp{params->scale}, + Nop{}, + Nop{}); + + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), + impl->GetTypeString(), " does not support the params"); + + if constexpr (USE_MASK) { + ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); + } + + invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); + return Status::OK(); +} + +template +auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() { using CKDataType = typename CKDataTypeAdaptor::type; using D0DataType = typename ck::detail::tuple_concat< std::conditional_t, ck::Tuple<>>, std::conditional_t, ck::Tuple<>>>::type; - constexpr static auto MaskingSpec = + constexpr static auto MaskingSpecMaskDisabled = ck::tensor_operation::device::MaskingSpecialization::MaskDisabled; + constexpr static auto MaskingSpecMaskOutUpperTriangle = + ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle; + + std::vector>>> + ret; - std::vector>>> ret; for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< - CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) { + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) { auto type_string = impl->GetTypeString(); auto invoker = impl->MakeInvokerPointer(); auto op = [impl = std::move(impl), invoker = std::move(invoker)]( const GemmSoftmaxGemmPermuteParams* params) -> Status { TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMode(params->attention), - "attention mode is not supported, got ", params->attention->mode); - if constexpr (USE_BIAS) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer == nullptr, "biased version only support input with bias"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->bias_buffer != nullptr, "non-biased version only support input without bias"); - } - if constexpr (USE_MASK) { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - !GemmSoftmaxGemmPermuteTunableOp::IsSupportedMaskType(params->attention), - "mask type is not supported, got ", params->attention->mask_type); - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer == nullptr, "masked version only support input with mask"); - } else { - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( - params->mask_index_buffer != nullptr, "non-masked version only support input without mask"); - } + params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled"); - auto attn = params->attention; - const int& G0 = attn->batch_size; - const int& G1 = attn->num_heads; - const int& M = attn->sequence_length; - const int& N = attn->total_sequence_length; - const int& K = attn->head_size; - const int& O = attn->v_head_size; - { - auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch(); - ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch"); - } + return GetArgAndRunInvoker(impl, invoker, params); + }; + ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); + } - auto [qs, ks, vs] = GetQkvStrides(attn); - std::vector q_buffer_lengths = {G0, G1, M, K}; - std::vector q_buffer_strides = qs.template ForBNSHCoord>(); - std::vector k_buffer_lengths = {G0, G1, N, K}; - std::vector k_buffer_strides = ks.template ForBNSHCoord>(); - std::vector v_buffer_lengths = {G0, G1, O, N}; - std::vector v_buffer_strides = vs.template ForBNHSCoord>(); - std::vector out_buffer_lengths = {G0, G1, M, O}; - std::vector out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213 - - std::array bias_buffers{}; - std::array, kNumBiasBuffer> bias_lengths{}; - std::array, kNumBiasBuffer> bias_strides{}; - if constexpr (USE_BIAS) { - bias_buffers[0] = const_cast(params->bias_buffer); - bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - bias_strides[0] = {G1 * M * N, M * N, N, 1}; - } - if constexpr (USE_MASK) { - bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer; - bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N) - if (params->mask_index_dims.size() == 2) { // [B,T] - bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1}; - } else if (params->mask_index_dims.size() == 3) { // [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T] - bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1}; - } else { - ORT_ENFORCE(false, "Unreachable"); - } - } + for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances< + CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) { + auto type_string = impl->GetTypeString(); - auto arg = impl->MakeArgumentPointer( - params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer, - bias_buffers, // Gemm1 bias, as attention mask - {}, // Gemm2 bias - q_buffer_lengths, q_buffer_strides, - k_buffer_lengths, k_buffer_strides, - v_buffer_lengths, v_buffer_strides, - out_buffer_lengths, out_buffer_strides, - bias_lengths, bias_strides, - {}, - {}, - Nop{}, - Nop{}, - Acc0ElementOp{params->scale}, - Nop{}, - Nop{}); - - TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()), - impl->GetTypeString(), " does not support the params"); - - if constexpr (USE_MASK) { - ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp::LaunchConvertToFilledMaskValue(params)); - } - invoker->Run(arg.get(), StreamConfig{params->StreamHandle()}); - return Status::OK(); + auto invoker = impl->MakeInvokerPointer(); + auto op = [impl = std::move(impl), invoker = std::move(invoker)]( + const GemmSoftmaxGemmPermuteParams* params) -> Status { + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + !params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle"); + TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF( + params->attention->sequence_length != params->attention->total_sequence_length, + "seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle"); + + return GetArgAndRunInvoker(impl, invoker, params); }; ret.emplace_back(std::make_pair(std::move(type_string), std::move(op))); } + return ret; } #endif // USE_COMPOSABLE_KERNEL diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py index 6e1e431842a56..802d924c27b62 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py +++ b/onnxruntime/python/tools/kernel_explorer/kernels/gemm_softmax_gemm_permute_test.py @@ -44,6 +44,7 @@ def get_ck_binding_name(dtype, biased: bool, masked: bool): num_heads = [8, 12] head_sizes = [64] biaseds = [False, True] +causals = [False] mask_dims = [0, 2, 3, 4] @@ -81,8 +82,57 @@ def maybe_pack_q_k_v_bnsh_for_device_on_host(q, k, v, dtype, qkv_format): raise NotImplementedError +def _make_causal_mask( + seqence_length, + total_sequence_length, + dtype: np.dtype, +): + """ + Make causal mask used for Attention with attribute unidirectional == 1. + The mask is a upper triangular matrix with shape [sequence_length, total_sequence_length]. + Putting a 1 indicates that the token at this position should be masked. + For Example: + sequence_length = 5, total_sequence_length = 5, + mask: [[0. 1. 1. 1. 1.] + [0. 0. 1. 1. 1.] + [0. 0. 0. 1. 1.] + [0. 0. 0. 0. 1.] + [0. 0. 0. 0. 0.]] + seqence_length = 5, total_seqence_length = 3, + mask: [[1. 1. 1.] + [1. 1. 1.] + [0. 1. 1.] + [0. 0. 1.] + [0. 0. 0.]] + seqence_length = 5, total_seqence_length = 7, + mask: [[0. 0. 0. 1. 1. 1. 1.] + [0. 0. 0. 0. 1. 1. 1.] + [0. 0. 0. 0. 0. 1. 1.] + [0. 0. 0. 0. 0. 0. 1.] + [0. 0. 0. 0. 0. 0. 0.]] + """ + mask = np.full((seqence_length, seqence_length), 1) + mask_cond = np.arange(mask.shape[-1]) + mask = np.where(mask_cond < (mask_cond + 1).reshape(mask.shape[-1], 1), 0, mask) + + mask = mask.astype(dtype) + + if total_sequence_length - seqence_length > 0: + mask = np.concatenate( + [np.zeros((seqence_length, total_sequence_length - seqence_length), dtype=dtype), mask], axis=-1 + ) + + if total_sequence_length - seqence_length < 0: + mask = mask[:, -total_sequence_length:] + + correct_mask = np.full((seqence_length, total_sequence_length), 1) + for i in range(seqence_length): + correct_mask[i][:] = sum(mask[i]) != total_sequence_length + return mask, correct_mask + + def _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ): v_head_size = head_size q_shape = [batch, num_heads, seqlen, head_size] @@ -123,6 +173,8 @@ def _test_gemm_softmax_gemm_permute( pre_softmax_attn_scores = pre_softmax_attn_scores * scale if attn_bias is not None: pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias + + correct_causal_mask = np.full((seqlen, total_seqlen), 1) if attn_mask is not None: filter_value = -10000.0 if mask_dim == 4: @@ -131,7 +183,18 @@ def _test_gemm_softmax_gemm_permute( else: converted_mask = (1 - attn_mask.reshape(mask_shape_broadcasted)) * filter_value pre_softmax_attn_scores = pre_softmax_attn_scores + converted_mask + if causal: + filter_value = np.finfo(dtype).min + causal_mask, correct_causal_mask = _make_causal_mask(seqlen, total_seqlen, pre_softmax_attn_scores.dtype) + causal_mask = np.broadcast_to(causal_mask, pre_softmax_attn_scores.shape) * filter_value + pre_softmax_attn_scores = pre_softmax_attn_scores + causal_mask attn_scores = softmax(pre_softmax_attn_scores, axis=-1) + + # apply mask to attn_scores to correct softmax result, in c++ implementation, if all values in a row are masked, + # the softmax result in this row will be filled with 0. + correct_causal_mask = np.broadcast_to(correct_causal_mask, pre_softmax_attn_scores.shape) + attn_scores = attn_scores * correct_causal_mask + attn = matmul(attn_scores, v) ref = np.swapaxes(attn, 2, 1) # permute 0213 @@ -154,6 +217,7 @@ def _test_gemm_softmax_gemm_permute( head_size, mask_dim, scale, + causal, qkv_format, dev_q, dev_k, @@ -202,12 +266,26 @@ def _test_gemm_softmax_gemm_permute( @pytest.mark.parametrize("total_seqlen", total_seqlens) @pytest.mark.parametrize("seqlen", seqlens) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", ["float16", "float32"]) -def test_gemm_softmax_gemm_permute_generic(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_generic( + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +): f = getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -218,14 +296,26 @@ def test_gemm_softmax_gemm_permute_generic(dtype, batch, seqlen, total_seqlen, n @pytest.mark.parametrize("total_seqlen", [128]) @pytest.mark.parametrize("seqlen", [64]) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", ["float16", "float32"]) def test_gemm_softmax_gemm_permute_generic_nested_tunable( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim ): f = getattr(ke, "GemmSoftmaxGemmPermuteGenericNestedTunable_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -237,12 +327,24 @@ def test_gemm_softmax_gemm_permute_generic_nested_tunable( @pytest.mark.parametrize("total_seqlen", total_seqlens) @pytest.mark.parametrize("seqlen", seqlens) @pytest.mark.parametrize("batch", batches) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", dtypes) -def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim): f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -253,12 +355,26 @@ def test_gemm_softmax_gemm_permute_ck(dtype, batch, seqlen, total_seqlen, nhead, @pytest.mark.parametrize("total_seqlen", [128]) @pytest.mark.parametrize("seqlen", [64]) @pytest.mark.parametrize("batch", [16]) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("dtype", ["float16"]) -def test_gemm_softmax_gemm_permute_tunable(dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim): +def test_gemm_softmax_gemm_permute_tunable( + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim +): f = getattr(ke, "GemmSoftmaxGemmPermuteTunable_" + dtype_to_suffix(dtype)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, ke.qkv_format.Q_K_V_BNSH + f, + dtype, + batch, + seqlen, + total_seqlen, + nhead, + head_size, + biased, + mask_dim, + scale, + causal, + ke.qkv_format.Q_K_V_BNSH, ) @@ -278,16 +394,17 @@ def test_gemm_softmax_gemm_permute_tunable(dtype, batch, seqlen, total_seqlen, n @pytest.mark.skipif(not ke.is_composable_kernel_available(), reason="ck is not enabled") @pytest.mark.parametrize("mask_dim", [0], ids=get_mask_dim_id) @pytest.mark.parametrize("biased", [False], ids=get_biased_id) +@pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("batch, seqlen, total_seqlen, nhead, head_size, qkv_format_name", stabel_diffusion_configs) @pytest.mark.parametrize("dtype", dtypes) def test_gemm_softmax_gemm_permute_ck_sd( - dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, qkv_format_name + dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, causal, mask_dim, qkv_format_name ): qkv_format = getattr(ke.qkv_format, qkv_format_name) f = getattr(ke, get_ck_binding_name(dtype, biased, mask_dim != 0)) scale = 1.0 / np.sqrt(head_size) _test_gemm_softmax_gemm_permute( - f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, nhead, head_size, biased, mask_dim, scale, causal, qkv_format ) @@ -316,7 +433,7 @@ def report(self): def profile_gemm_softmax_gemm_permute_func( - f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format + f, dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format ): v_head_size = head_size q_shape = [batch, num_heads, seqlen, head_size] @@ -369,6 +486,7 @@ def profile_gemm_softmax_gemm_permute_func( head_size, mask_dim, scale, + causal, qkv_format, dev_q, dev_k, @@ -402,10 +520,10 @@ def profile_gemm_softmax_gemm_permute_func( def profile_with_args( - dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format, *, sort=False + dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, causal, mask_dim, scale, qkv_format, *, sort=False ): with ke.benchmark(sort): - args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, qkv_format) + args = (dtype, batch, seqlen, total_seqlen, num_heads, head_size, biased, mask_dim, scale, causal, qkv_format) if qkv_format == ke.qkv_format.Q_K_V_BNSH: profile_gemm_softmax_gemm_permute_func( getattr(ke, "GemmSoftmaxGemmPermuteGeneric_" + dtype_to_suffix(dtype)), *args @@ -429,6 +547,7 @@ def profile(): nhead, head_size, biased=False, + causal=False, mask_dim=0, qkv_format=getattr(ke.qkv_format, qkv_format_name), scale=0.125, @@ -436,7 +555,7 @@ def profile(): ) print() - for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, mask_dims): + for args in product(dtypes, batches, seqlens, total_seqlens, num_heads, head_sizes, biaseds, causals, mask_dims): profile_with_args(*args, qkv_format=ke.qkv_format.Q_K_V_BNSH, scale=0.125, sort=True) print() @@ -455,6 +574,7 @@ def profile(): group.add_argument("head_size", type=int) group.add_argument("biased", type=int, choices=[0, 1], default=0) group.add_argument("mask_dim", type=int, choices=[0, 2, 3, 4], default=2, help="0 for mask disabled") + group.add_argument("causal", type=int, choices=[0, 1], default=0) group.add_argument("--scale", type=float, default=None, help="default to 1.0/sqrt(head_size)") group.add_argument( "--qkv_format", @@ -471,6 +591,7 @@ def profile(): profile() else: args = parser.parse_args() + print(args) profile_with_args( args.dtype, args.batch, @@ -479,6 +600,7 @@ def profile(): args.num_heads, args.head_size, args.biased, + args.causal, args.mask_dim, args.scale, getattr(ke.qkv_format, args.qkv_format), diff --git a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu index 5e60bad776d4a..7068fc8fd0ebc 100644 --- a/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu +++ b/onnxruntime/python/tools/kernel_explorer/kernels/rocm/gemm_softmax_gemm_permute.cu @@ -28,6 +28,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -51,7 +52,7 @@ class IGemmSoftmaxGemmPermuteKernelExplorer : public IKernelExplorer { attn_.v_hidden_size = attn_.hidden_size; // Q,K,V hidden size must agree now attn_.v_head_size = attn_.head_size; // Q,K,V hidden size must agree now attn_.num_heads = num_heads; - attn_.is_unidirectional = false; + attn_.is_unidirectional = causal; attn_.past_present_share_buffer = false; attn_.do_rotary = false; attn_.mask_filter_value = -10000.0f; @@ -148,6 +149,7 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -156,7 +158,7 @@ class GemmSoftmaxGemmPermuteGeneric : public IGemmSoftmaxGemmPermuteKernelExplor std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_)); } @@ -187,6 +189,7 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -195,7 +198,7 @@ class GemmSoftmaxGemmPermuteGenericNestedTunable : public GemmSoftmaxGemmPermute std::optional& attn_mask, DeviceArray& out) : GemmSoftmaxGemmPermuteGeneric(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->params_.TuningContext()->EnableTunableOpAndTuning(); } @@ -214,6 +217,7 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -222,7 +226,7 @@ class GemmSoftmaxGemmPermuteCK : public IGemmSoftmaxGemmPermuteKernelExplorer std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(GemmSoftmaxGemmPermuteTunableOp::GetWorkspaceNumBytes(&this->attn_)); @@ -275,6 +279,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor int64_t head_size, int64_t mask_dim, double scale, + bool causal, contrib::AttentionQkvFormat qkv_format, DeviceArray& Q, std::optional& K, @@ -283,7 +288,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor std::optional& attn_mask, DeviceArray& out) : IGemmSoftmaxGemmPermuteKernelExplorer(batch, seqlen, total_seqlen, max_seqlen, - num_heads, head_size, mask_dim, scale, qkv_format, + num_heads, head_size, mask_dim, scale, causal, qkv_format, Q, K, V, attn_bias, attn_mask, out) { this->SetWorkspace(std::max( GemmSoftmaxGemmPermuteGenericPipeline::GetWorkspaceNumBytes(&this->attn_), @@ -311,7 +316,7 @@ class GemmSoftmaxGemmPermuteTunable : public IGemmSoftmaxGemmPermuteKernelExplor #define REGISTER_COMMON(name, type, ...) \ py::class_>(m, name) \ .def(py::init, int64_t, int64_t, int64_t, \ - float, contrib::AttentionQkvFormat, \ + float, bool, contrib::AttentionQkvFormat, \ DeviceArray&, \ std::optional&, \ std::optional&, \