Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tunable verbose log #17328

Merged
merged 1 commit into from
Oct 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,7 @@ auto GetCKGemmSoftmaxGemmPermuteTypeStringAndOps() {
Nop{});

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");

if constexpr (USE_MASK) {
ORT_RETURN_IF_ERROR(GemmSoftmaxGemmPermuteTunableOp<T>::LaunchConvertToFilledMaskValue(params));
Expand Down
8 changes: 4 additions & 4 deletions onnxruntime/contrib_ops/rocm/bert/gemm_fast_gelu_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() {
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero || params->bias == nullptr,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr", params->Signature());
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias != nullptr");

auto nop = Nop{};
auto addfastgelu = AddFastGelu{};
Expand All @@ -67,7 +67,7 @@ auto GetCKGemmAddFastGeluTypeStringAndOps() {
params->lda, params->ldb, std::array<ck::index_t, 1>{0}, params->ldc,
nop, nop, addfastgelu);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down Expand Up @@ -95,7 +95,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() {

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero || params->bias != nullptr,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr", params->Signature());
impl->GetTypeString(), " only supports alpha == 1 and beta == 0 and bias == nullptr");

auto nop = Nop{};
auto fastgelu = FastGelu{};
Expand All @@ -108,7 +108,7 @@ auto GetCKGemmFastGeluTypeStringAndOps() {
params->ldc,
nop, nop, fastgelu);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/contrib_ops/rocm/diffusion/group_norm_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ auto GetCKGroupNormNHWCTypeStringAndOps() {
nullptr,
activation);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down
39 changes: 21 additions & 18 deletions onnxruntime/core/framework/tunable.h
Original file line number Diff line number Diff line change
Expand Up @@ -232,14 +232,15 @@ class TunableOp {
return timer.Duration() / num_iter;
}

static bool IsSupported(Op<ParamsT>& op, const ParamsT* param) {
Status status = op.IsSupported(param);
// Filter all Status, only OK and TUNABLE_OP_UNSUPPORTED is left, other error status will be thrown, and to be
// processed by onnxruntime. We return Status to avoid the construction of op and params signature string.
static Status IsSupported(Op<ParamsT>& op, const ParamsT* params) {
Status status = op.IsSupported(params);
if (status.Category() == common::StatusCategory::NONE && status.Code() == common::StatusCode::INVALID_ARGUMENT) {
LOGS_DEFAULT(VERBOSE) << "unsupported reason: " << status.ErrorMessage();
return false;
return status;
}
ORT_THROW_IF_ERROR(status);
return true;
return status;
}

protected:
Expand All @@ -250,40 +251,42 @@ class TunableOp {
int FindFastestImpl(const ParamsT* params, const std::vector<Op<ParamsT>>& candidates) {
ITuningContext* ctx = params->TuningContext();
auto op_sig = Signature();
auto param_sig = params->Signature();
LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ')';
auto min_time = std::numeric_limits<double>::infinity();
auto params_sig = params->Signature();
LOGS_DEFAULT(VERBOSE) << "finding fastest for " << op_sig << '(' << params_sig << ')';
auto min_duration_ms = std::numeric_limits<double>::infinity();
int id = -1;

constexpr const int max_tuning_iter = 100;
constexpr const int approx_num_iter = 3;

for (size_t i = 0; i < candidates.size(); i++) {
auto& candidate = const_cast<Op<ParamsT>&>(candidates[i]);
if (!IsSupported(candidate, params)) {
LOGS_DEFAULT(VERBOSE) << "FindFastestImpl found unsupported " << op_sig << '(' << param_sig << ") id=" << i;
auto status = IsSupported(candidate, params);
if (!status.IsOK()) {
LOGS_DEFAULT(VERBOSE) << "├──unsupported id=" << i << ", " << op_sig << '(' << params_sig << ")";
LOGS_DEFAULT(VERBOSE) << "│ reason: " << status.ErrorMessage();
mindest marked this conversation as resolved.
Show resolved Hide resolved
continue;
}

WarmUp(candidate, params);

auto approx_duration = Profile(candidate, params, approx_num_iter);
if (approx_duration > 2 * min_time) {
LOGS_DEFAULT(VERBOSE) << "FindFastestImpl skip slow instance " << op_sig << '(' << param_sig << ") id=" << i;
if (approx_duration > 2 * min_duration_ms) {
LOGS_DEFAULT(VERBOSE) << "├──skip slow instance id=" << i;
continue;
}
int tuning_iter = std::max(1, int(std::min(double(max_tuning_iter), ctx->GetMaxTuningDurationMs() / approx_duration)));

LOGS_DEFAULT(VERBOSE) << "FindFastestImpl run instance " << op_sig << '(' << param_sig << ") id=" << i << " " << tuning_iter << " times.";

auto time = Profile(candidate, params, tuning_iter);
if (time < min_time) {
min_time = time;
auto duration_ms = Profile(candidate, params, tuning_iter);
if (duration_ms < min_duration_ms) {
LOGS_DEFAULT(VERBOSE) << "├──found better instance, new best id=" << i << ", old id=" << id << ". "
<< duration_ms << "ms, " << tuning_iter << " iters.";
min_duration_ms = duration_ms;
id = static_cast<int>(i);
}
}
ORT_ENFORCE(id >= 0, "Could not find viable op");
LOGS_DEFAULT(VERBOSE) << "FindFastestImpl for " << op_sig << '(' << param_sig << ") found fastest with id=" << id;
LOGS_DEFAULT(VERBOSE) << "└──found fastest with id=" << id << " for " << op_sig << '(' << params_sig << ")";
std::this_thread::sleep_for(std::chrono::milliseconds(50));
return id;
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/math/softmax_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ auto GetCKSoftmaxTypeStringAndOps() {
auto arg = impl->MakeArgumentPointer(in_lengths, in_strides, reduce_dims, alpha, beta,
params->input, params->output, nop, nop);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/rocm/tunable/gemm_ck.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ auto GetCKGemmTypeStringAndOps() {
params->lda, params->ldb, params->ldc,
nop, nop, nop);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down Expand Up @@ -164,7 +164,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() {
auto zero = ToHipType<T>::FromFloat(0.0f);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
params->alpha != one || params->beta != zero,
impl->GetTypeString(), " only supports alpha == 1 and beta == 0", params->Signature());
impl->GetTypeString(), " only supports alpha == 1 and beta == 0");

auto nop = Nop{};
auto arg = impl->MakeArgumentPointer(params->a, params->b, params->c,
Expand All @@ -174,7 +174,7 @@ auto GetCKStridedBatchedGemmTypeStringAndOps() {
params->batch,
nop, nop, nop);
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(!impl->IsSupportedArgument(arg.get()),
impl->GetTypeString(), " does not support ", params->Signature());
impl->GetTypeString(), " does not support the params");
invoker->Run(arg.get(), StreamConfig{params->StreamHandle()});
return Status::OK();
};
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != HIPBLAS_STATUS_SUCCESS,
"[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported (", params->Signature(), ")");
"[hipBLASLt] Solution #", i, " failed: algo ", algo_index, " not supported");

IAllocatorUniquePtr<void> workspace_buffer;
if (workspace_size > 0) {
Expand Down
9 changes: 3 additions & 6 deletions onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ auto GetRocBlasGemmTypeStringAndOps() {

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success,
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status),
" (", params->Signature(), ")");
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -238,8 +237,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() {

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success,
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status),
" (", params->Signature(), ")");
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -308,8 +306,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() {

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success,
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status),
" (", params->Signature(), ")");
"[rocBLAS] Solution #", i, " (original ", solution, ") failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down
Loading