Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
PeixuanZuo committed Feb 1, 2024
1 parent bc5ed7a commit b7cb71b
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ class GemmSoftmaxGemmPermuteTunableOp : public tunable::TunableOp<GemmSoftmaxGem
#ifdef USE_COMPOSABLE_KERNEL
template <typename U, typename V, typename T, bool USE_BIAS, bool USE_MASK>
auto GetArg(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params) {
auto GetArgAndRunInvoker(const U& impl, const V& invoker, const GemmSoftmaxGemmPermuteParams<T>* params) {
using Nop = ck::tensor_operation::element_wise::PassThrough;
using Acc0ElementOp = internal::PreSoftmaxAttentionScoreOp;
Expand Down Expand Up @@ -857,7 +857,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->attention->is_unidirectional, "unidirectional attention is not supported with MaskingSpecMaskDisabled");
return GetArg<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}
Expand All @@ -875,7 +875,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
params->attention->sequence_length != params->attention->total_sequence_length,
"seqence_length != total_seqence_length is not supported with MaskingSpecMaskOutUpperTriangle");
return GetArg<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
return GetArgAndRunInvoker<decltype(impl), decltype(invoker), T, USE_BIAS, USE_MASK>(impl, invoker, params);
};
ret.emplace_back(std::make_pair(std::move(type_string), std::move(op)));
}
Expand All @@ -890,17 +890,22 @@ GemmSoftmaxGemmPermuteTunableOp<T>::GemmSoftmaxGemmPermuteTunableOp() {
return GemmSoftmaxGemmPermuteGenericPipeline<T>::Run(params, false);
});
#define GET_CK_IMPLEMENTATION(USE_BIAS, USE_MASK) \
for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps<T, USE_BIAS, USE_MASK>()) { \
this->RegisterOp(std::move(op)); \
#ifdef USE_COMPOSABLE_KERNEL
for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps<T, /*USE_BIAS=*/false, /*USE_MASK=*/false>()) {
this->RegisterOp(std::move(op));
}
#ifdef USE_COMPOSABLE_KERNEL
GET_CK_IMPLEMENTATION(false, false);
GET_CK_IMPLEMENTATION(true, false);
GET_CK_IMPLEMENTATION(false, true);
GET_CK_IMPLEMENTATION(true, true);
for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps<T, /*USE_BIAS=*/true, /*USE_MASK=*/false>()) {
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps<T, /*USE_BIAS=*/false, /*USE_MASK=*/true>()) {
this->RegisterOp(std::move(op));
}
for (auto&& [_, op] : GetCKGemmSoftmaxGemmPermuteTypeStringAndOps<T, /*USE_BIAS=*/true, /*USE_MASK=*/true>()) {
this->RegisterOp(std::move(op));
}
#endif
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,7 @@ def _test_gemm_softmax_gemm_permute(
pre_softmax_attn_scores = pre_softmax_attn_scores * scale
if attn_bias is not None:
pre_softmax_attn_scores = pre_softmax_attn_scores + attn_bias
converted_mask = None
causal_mask = None

correct_causal_mask = np.full((seqlen, total_seqlen), 1)
if attn_mask is not None:
filter_value = -10000.0
Expand Down

0 comments on commit b7cb71b

Please sign in to comment.