Skip to content

Commit

Permalink
Add support for MIGraphX Exhaustive tune flag in MIGraphX EP (#46)
Browse files Browse the repository at this point in the history
* Add support for MIGraphX Exhaustive tune flag in MIGraphX EP

Enable exhaustive tune by either python interface of environment env in bash

* Apply lintrunner pass

* Fix compile errors.

* Lintrunner pass
  • Loading branch information
TedThemistokleous authored and Ted Themistokleous committed Aug 2, 2024
1 parent d0a6f57 commit 843fead
Show file tree
Hide file tree
Showing 8 changed files with 40 additions and 2 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -621,6 +621,7 @@ typedef struct OrtMIGraphXProviderOptions {
const char* migraphx_save_model_path; // migraphx model path name
int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true
const char* migraphx_load_model_path; // migraphx model path name
int migraphx_exhaustive_tune; // migraphx tuned compile Default 0 = false, nonzero = true
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true);
}

// Allow for exhaustive tune during compile
const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune);
if (!exhaustive_tune_env.empty()) {
exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true);
}

metadef_id_generator_ = ModelMetadefIdGenerator::Create();

LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: "
Expand All @@ -188,6 +194,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
<< ", migraphx_int8_enable: " << int8_enable_
<< ", migraphx_int8_enable: " << int8_enable_
<< ", dump_model_ops: " << dump_model_ops_
<< ", exhaustive_tune: " << exhaustive_tune_
<< ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_
<< ", int8_calibration_cache_available: " << int8_calibration_cache_available_
<< ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_
Expand Down Expand Up @@ -1191,6 +1198,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&

migraphx::compile_options co;
co.set_fast_math(false);
co.set_exhaustive_tune_flag(false);
if (exhaustive_tune_) {
co.set_exhaustive_tune_flag(true);
}
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
prog.compile(t_, co);
LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl;
Expand Down Expand Up @@ -1350,6 +1361,10 @@ Status MIGraphXExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>&
LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl;
migraphx::compile_options co;
co.set_fast_math(false);
co.set_exhaustive_tune_flag(false);
if (exhaustive_tune_) {
co.set_exhaustive_tune_flag(true);
}
prog.compile(t, co);

save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL";
static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH";
static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL";
static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH";
static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE";

}; // namespace migraphx_env_vars

Expand All @@ -49,6 +50,7 @@ struct MIGraphXFuncState {
bool load_compiled_mode = false;
std::string load_compiled_path;
bool dump_model_ops = false;
bool exhaustive_tune = false;
};

// Logical device representation.
Expand Down Expand Up @@ -100,6 +102,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider {
migraphx::target t_;
OrtMutex mgx_mu_;
hipStream_t stream_ = nullptr;
bool exhaustive_tune_ = false;

std::unordered_map<std::string, migraphx::program> map_progs_;
std::unordered_map<std::string, std::string> map_onnx_string_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kSaveCompiledModel = "migx_save_compiled_model";
constexpr const char* kSaveModelPath = "migx_save_model_name";
constexpr const char* kLoadCompiledModel = "migx_load_compiled_model";
constexpr const char* kLoadModelPath = "migx_load_model_name";
constexpr const char* kExhaustiveTune = "migx_exhaustive_tune";

} // namespace provider_option_names
} // namespace migraphx
Expand All @@ -45,6 +46,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable)
.AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model)
.AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune)
.Parse(options));

return info;
Expand All @@ -57,6 +59,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.exhaustive_tune)},
};
return options;
}
Expand All @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)},
{migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)},
{migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)},
{migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ struct MIGraphXExecutionProviderInfo {
std::string save_model_file{"./compiled_model.mxr"};
bool load_compiled_model{true};
std::string load_model_file{"./compiled_model.mxr"};
bool exhaustive_tune{false};

static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider {
info.device_id = static_cast<OrtDevice::DeviceId>(options.device_id);
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.exhaustive_tune = options.migraphx_exhaustive_tune;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = "";
if (options.migraphx_int8_calibration_table_name != nullptr) {
Expand All @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider {
migx_options.device_id = internal_options.device_id;
migx_options.migraphx_fp16_enable = internal_options.fp16_enable;
migx_options.migraphx_int8_enable = internal_options.int8_enable;
migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune;

char* dest = nullptr;
auto str_size = internal_options.int8_calibration_table_name.size();
Expand Down
13 changes: 12 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
"./compiled_model.mxr",
1};
for (auto option : it->second) {
if (option.first == "device_id") {
if (!option.second.empty()) {
Expand Down Expand Up @@ -929,6 +930,16 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
"[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a "
"file name i.e. 'compiled_model.mxr'.\n");
}
} else if (option.first == "migraphx_exhaustive_tune") {
if (option.second == "True" || option.second == "true") {
params.migraphx_exhaustive_tune = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_exhaustive_tune = false;
} else {
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be"
" 'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
Expand Down
3 changes: 2 additions & 1 deletion onnxruntime/test/util/default_providers.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ std::unique_ptr<IExecutionProvider> DefaultMIGraphXExecutionProvider() {
1,
"./compiled_model.mxr",
1,
"./compiled_model.mxr"};
"./compiled_model.mxr",
1};
return MIGraphXProviderFactoryCreator::Create(&params)->CreateProvider();
#else
return nullptr;
Expand Down

0 comments on commit 843fead

Please sign in to comment.