Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] CK implementation support causal mask #18943

Merged
merged 5 commits into from
Feb 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@

using PassThrough = ck::tensor_operation::element_wise::PassThrough;

using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute; // the interface
using ck::tensor_operation::device::DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle; // the implementation

static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
Expand Down Expand Up @@ -141,6 +141,35 @@
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskDisabled>();

template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

// fp16, biased, non-masked
template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

// fp16, biased, fp16 masked, basically, two bias
template <>
std::vector<std::unique_ptr<DeviceBatchedGemmSoftmaxGemmPermute<

Check warning on line 165 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh#L165

Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh:165:  Add #include <memory> for unique_ptr<>  [build/include_what_you_use] [4]

Check warning on line 165 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh#L165

Add #include <vector> for vector<> [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_ck_impl/impl.cuh:165:  Add #include <vector> for vector<>  [build/include_what_you_use] [4]
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>();

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using NonBiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<NonBiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<F16>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,27 @@ GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
return instances;
}

using BiasedNonmaskedCausal = DeviceBatchedGemmSoftmaxGemmPermute<
2, 1, 1, 1, 1,
F16, F16, F16, F16, ck::Tuple<F16, F16>, ck::Tuple<>,
PassThrough, PassThrough, PreSoftmaxAttentionScoreOp, PassThrough, PassThrough,
MaskingSpecialization::MaskOutUpperTriangle>;

template <>
std::vector<std::unique_ptr<BiasedNonmaskedCausal>>
GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp, MaskingSpecialization::MaskOutUpperTriangle>() {
std::vector<std::unique_ptr<BiasedNonmaskedCausal>> instances;
ck::tensor_operation::device::instance::add_device_operation_instances(
instances,
device_batched_gemm_softmax_gemm_permute_instances<
2, 1, 1, 1, 1,
F16, ck::Tuple<F16, F16>, F32, PreSoftmaxAttentionScoreOp,
MaskingSpecialization::MaskOutUpperTriangle>{});

return instances;
}

} // namespace internal
} // namespace rocm
} // namespace contrib
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -732,122 +732,154 @@

#ifdef USE_COMPOSABLE_KERNEL

template <typename T, bool USE_BIAS, bool USE_MASK>
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
template <typename U, typename V, typename T, bool USE_BIAS, bool USE_MASK>
auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params) {
constexpr const int kNumBiasBuffer = static_cast<int>(USE_BIAS) + static_cast<int>(USE_MASK);

using Nop = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
"attention mode is not supported, got ", params->attention->mode);
if constexpr (USE_BIAS) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer == nullptr, "biased version only support input with bias");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer != nullptr, "non-biased version only support input without bias");
}
if constexpr (USE_MASK) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
"mask type is not supported, got ", params->attention->mask_type);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer == nullptr, "masked version only support input with mask");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
}

auto attn = params->attention;
const int& G0 = attn->batch_size;
const int& G1 = attn->num_heads;
const int& M = attn->sequence_length;
const int& N = attn->total_sequence_length;
const int& K = attn->head_size;
const int& O = attn->v_head_size;
{
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
}

auto [qs, ks, vs] = GetQkvStrides(attn);
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213

std::array<void*, kNumBiasBuffer> bias_buffers{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
if constexpr (USE_BIAS) {
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
bias_strides[0] = {G1 * M * N, M * N, N, 1};
}
if constexpr (USE_MASK) {
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
if (params->mask_index_dims.size() == 2) { // [B,T]
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else {
ORT_ENFORCE(false, "Unreachable");
}
}

auto arg = impl->MakeArgumentPointer(
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
bias_buffers, // Gemm1 bias, as attention mask
{}, // Gemm2 bias
q_buffer_lengths, q_buffer_strides,
k_buffer_lengths, k_buffer_strides,
v_buffer_lengths, v_buffer_strides,
out_buffer_lengths, out_buffer_strides,
bias_lengths, bias_strides,
{},
{},
Nop{},
Nop{},
Acc0ElementOp{params->scale},
Nop{},
Nop{});

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support the params");

if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
}

invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
}

template <typename T, bool USE_BIAS, bool USE_MASK>
auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
using CKDataType = typename CKDataTypeAdaptor<T>::type;
using D0DataType = typename ck::detail::tuple_concat<
std::conditional_t<USE_BIAS, ck::Tuple<CKDataType>, ck::Tuple<>>,
std::conditional_t<USE_MASK, ck::Tuple<CKDataType>, ck::Tuple<>>>::type;

constexpr static auto MaskingSpec =
constexpr static auto MaskingSpecMaskDisabled =
ck::tensor_operation::device::MaskingSpecialization::MaskDisabled;
constexpr static auto MaskingSpecMaskOutUpperTriangle =
ck::tensor_operation::device::MaskingSpecialization::MaskOutUpperTriangle;

std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>>

Check warning on line 847 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh#L847

Add #include <string> for string [build/include_what_you_use] [4]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh:847:  Add #include <string> for string  [build/include_what_you_use] [4]
ret;

std::vector<std::pair<std::string, Op<GemmSoftmaxGemmPermuteParams<T>>>> ret;
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpec>()) {
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskDisabled>()) {
auto type_string = impl->GetTypeString();

auto invoker = impl->MakeInvokerPointer();
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMode(params->attention),
"attention mode is not supported, got ", params->attention->mode);
if constexpr (USE_BIAS) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer == nullptr, "biased version only support input with bias");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->bias_buffer != nullptr, "non-biased version only support input without bias");
}
if constexpr (USE_MASK) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!GemmSoftmaxGemmPermuteTunableOp<T>::IsSupportedMaskType(params->attention),
"mask type is not supported, got ", params->attention->mask_type);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer == nullptr, "masked version only support input with mask");
} else {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->mask_index_buffer != nullptr, "non-masked version only support input without mask");
}
params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled");

Check warning on line 858 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh#L858

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh:858:  Lines should be <= 120 characters long  [whitespace/line_length] [2]

auto attn = params->attention;
const int& G0 = attn->batch_size;
const int& G1 = attn->num_heads;
const int& M = attn->sequence_length;
const int& N = attn->total_sequence_length;
const int& K = attn->head_size;
const int& O = attn->v_head_size;
{
auto [m, n, k, o, batch] = params->GetGemmsMNKOBatch();
ORT_ENFORCE(M == m && N == n && K == k && O == o && G0 * G1 == batch, "semantic mismatch");
}
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}

auto [qs, ks, vs] = GetQkvStrides(attn);
std::vector<ck::index_t> q_buffer_lengths = {G0, G1, M, K};
std::vector<ck::index_t> q_buffer_strides = qs.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> k_buffer_lengths = {G0, G1, N, K};
std::vector<ck::index_t> k_buffer_strides = ks.template ForBNSHCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> v_buffer_lengths = {G0, G1, O, N};
std::vector<ck::index_t> v_buffer_strides = vs.template ForBNHSCoord<std::vector<ck::index_t>>();
std::vector<ck::index_t> out_buffer_lengths = {G0, G1, M, O};
std::vector<ck::index_t> out_buffer_strides = {M * G1 * O, O, G1 * O, 1}; // permute 0213

std::array<void*, kNumBiasBuffer> bias_buffers{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_lengths{};
std::array<std::vector<ck::index_t>, kNumBiasBuffer> bias_strides{};
if constexpr (USE_BIAS) {
bias_buffers[0] = const_cast<T*>(params->bias_buffer);
bias_lengths[0] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
bias_strides[0] = {G1 * M * N, M * N, N, 1};
}
if constexpr (USE_MASK) {
bias_buffers[kNumBiasBuffer - 1] = params->workspace_buffer;
bias_lengths[kNumBiasBuffer - 1] = {G0, G1, M, N}; // BN(G0*G1), S(M), T(N)
if (params->mask_index_dims.size() == 2) { // [B,T]
bias_strides[kNumBiasBuffer - 1] = {N, 0, 0, 1};
} else if (params->mask_index_dims.size() == 3) { // [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else if (params->mask_index_dims.size() == 4) { // [B,1,max_seq_len,max_seq_len] -->convert--> [B,S,T]
bias_strides[kNumBiasBuffer - 1] = {M * N, 0, N, 1};
} else {
ORT_ENFORCE(false, "Unreachable");
}
}
for (auto&& impl : internal::GetDeviceBatchedGemmSoftmaxGemmPermuteInstances<
CKDataType, D0DataType, internal::F32, internal::PreSoftmaxAttentionScoreOp, MaskingSpecMaskOutUpperTriangle>()) {

Check warning on line 866 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh#L866

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh:866:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
auto type_string = impl->GetTypeString();

auto arg = impl->MakeArgumentPointer(
params->q_buffer, params->k_buffer, params->v_buffer, params->out_buffer,
bias_buffers, // Gemm1 bias, as attention mask
{}, // Gemm2 bias
q_buffer_lengths, q_buffer_strides,
k_buffer_lengths, k_buffer_strides,
v_buffer_lengths, v_buffer_strides,
out_buffer_lengths, out_buffer_strides,
bias_lengths, bias_strides,
{},
{},
Nop{},
Nop{},
Acc0ElementOp{params->scale},
Nop{},
Nop{});

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support the params");

if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
}
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
auto invoker = impl->MakeInvokerPointer();
auto op = [impl = std::move(impl), invoker = std::move(invoker)](
const GemmSoftmaxGemmPermuteParams<T>* params) -> Status {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
!params->attention->is_unidirectional, "bidirectional attention is not supported with MaskingSpecMaskOutUpperTriangle");

Check warning on line 873 in onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh#L873

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/contrib_ops/rocm/bert/batched_gemm_softmax_gemm_permute_pipelines.cuh:873:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->attention->sequence_length != params->attention->total_sequence_length,
"seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle");

return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}

return ret;
}
#endif // USE_COMPOSABLE_KERNEL
Expand Down
Loading
Loading