diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 234574503c4b2..4674db42fb1c9 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -621,6 +621,7 @@ typedef struct OrtMIGraphXProviderOptions { const char* migraphx_save_model_path; // migraphx model path name int migraphx_load_compiled_model; // migraphx int8 cal table. Default 0 = false, noznero = true const char* migraphx_load_model_path; // migraphx model path name + bool migraphx_exhaustive_tune; // migraphx tuned compile Default = false } OrtMIGraphXProviderOptions; /** \brief OpenVINO Provider Options diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index be9f1bd681883..90dfa49c73c9a 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -182,6 +182,12 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv dump_model_ops_ = (std::stoi(dump_model_ops_env) == 0 ? false : true); } + // Allow for exhaustive tune during compile + const std::string exhaustive_tune_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kExhaustiveTune); + if (!exhaustive_tune_env.empty()) { + exhaustive_tune_ = (std::stoi(exhaustive_tune_env) == 0 ? false : true); + } + metadef_id_generator_ = ModelMetadefIdGenerator::Create(); LOGS_DEFAULT(VERBOSE) << "[MIGraphX EP] MIGraphX provider options: " @@ -190,6 +196,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv << ", migraphx_int8_enable: " << int8_enable_ << ", migraphx_int8_enable: " << int8_enable_ << ", dump_model_ops: " << dump_model_ops_ + << ", exhaustive_tune: " << exhaustive_tune_ << ", migraphx_int8_calibration_cache_name: " << int8_calibration_cache_name_ << ", int8_calibration_cache_available: " << int8_calibration_cache_available_ << ", use_native_migraphx_calibration_table: " << int8_use_native_migraphx_calibration_table_ @@ -1181,6 +1188,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& migraphx::compile_options co; co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune_); LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; prog.compile(t_, co); LOGS_DEFAULT(INFO) << "Model Compile: Complete" << std::endl; @@ -1345,6 +1353,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& LOGS_DEFAULT(INFO) << "Model Compile: Begin" << std::endl; migraphx::compile_options co; co.set_fast_math(false); + co.set_exhaustive_tune_flag(exhaustive_tune_); prog.compile(t, co); save_compiled_model(prog, mgx_state->save_compiled_mode, mgx_state->save_compiled_path); diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 21b582de8f86e..21679d1f6f151 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -26,6 +26,7 @@ static const char kSaveCompiledModel[] = "ORT_MIGRAPHX_SAVE_COMPILED_MODEL"; static const char kSavedModelPath[] = "ORT_MIGRAPHX_SAVE_COMPILE_PATH"; static const char kLoadCompiledModel[] = "ORT_MIGRAPHX_LOAD_COMPILED_MODEL"; static const char kLoadModelPath[] = "ORT_MIGRAPHX_LOAD_COMPILE_PATH"; +static const char kExhaustiveTune[] = "ORT_MIGRAPHX_EXHAUSTIVE_TUNE"; }; // namespace migraphx_env_vars @@ -50,6 +51,7 @@ struct MIGraphXFuncState { bool load_compiled_mode = false; std::string load_compiled_path; bool dump_model_ops = false; + bool exhaustive_tune = false; }; // Logical device representation. @@ -101,6 +103,7 @@ class MIGraphXExecutionProvider : public IExecutionProvider { migraphx::target t_; OrtMutex mgx_mu_; hipStream_t stream_ = nullptr; + bool exhaustive_tune_ = false; mutable std::filesystem::path model_path_; std::unordered_map map_progs_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc index 2a135b7324f3a..1f9a47d3ad87d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -21,6 +21,7 @@ constexpr const char* kSaveCompiledModel = "migx_save_compiled_model"; constexpr const char* kSaveModelPath = "migx_save_model_name"; constexpr const char* kLoadCompiledModel = "migx_load_compiled_model"; constexpr const char* kLoadModelPath = "migx_load_model_name"; +constexpr const char* kExhaustiveTune = "migx_exhaustive_tune"; } // namespace provider_option_names } // namespace migraphx @@ -45,6 +46,7 @@ MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) .AddAssignmentToReference(migraphx::provider_option_names::kSaveCompiledModel, info.save_compiled_model) .AddAssignmentToReference(migraphx::provider_option_names::kLoadCompiledModel, info.load_compiled_model) + .AddAssignmentToReference(migraphx::provider_option_names::kExhaustiveTune, info.exhaustive_tune) .Parse(options)); return info; @@ -57,6 +59,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXE {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.load_compiled_model)}, + {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.exhaustive_tune)}, }; return options; } @@ -68,6 +71,7 @@ ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGrap {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)}, {migraphx::provider_option_names::kSaveCompiledModel, MakeStringWithClassicLocale(info.migraphx_save_compiled_model)}, {migraphx::provider_option_names::kLoadCompiledModel, MakeStringWithClassicLocale(info.migraphx_load_compiled_model)}, + {migraphx::provider_option_names::kExhaustiveTune, MakeStringWithClassicLocale(info.migraphx_exhaustive_tune)}, }; return options; } diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h index 68d5d9af98ea4..b8bf86580f03d 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -23,6 +23,7 @@ struct MIGraphXExecutionProviderInfo { std::string save_model_file{"./compiled_model.mxr"}; bool load_compiled_model{true}; std::string load_model_file{"./compiled_model.mxr"}; + bool exhaustive_tune{false}; static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 6d199930116e8..7b192b657b7cc 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -60,6 +60,7 @@ struct MIGraphX_Provider : Provider { info.device_id = static_cast(options.device_id); info.target_device = "gpu"; info.fp16_enable = options.migraphx_fp16_enable; + info.exhaustive_tune = options.migraphx_exhaustive_tune; info.int8_enable = options.migraphx_int8_enable; info.int8_calibration_table_name = ""; if (options.migraphx_int8_calibration_table_name != nullptr) { @@ -85,6 +86,7 @@ struct MIGraphX_Provider : Provider { migx_options.device_id = internal_options.device_id; migx_options.migraphx_fp16_enable = internal_options.fp16_enable; migx_options.migraphx_int8_enable = internal_options.int8_enable; + migx_options.migraphx_exhaustive_tune = internal_options.exhaustive_tune; char* dest = nullptr; auto str_size = internal_options.int8_calibration_table_name.size(); diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index ffcd339c0ca3a..47b8d75f22aea 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -844,7 +844,8 @@ std::unique_ptr CreateExecutionProviderInstance( 1, "./compiled_model.mxr", 1, - "./compiled_model.mxr"}; + "./compiled_model.mxr", + 1}; for (auto option : it->second) { if (option.first == "device_id") { if (!option.second.empty()) { @@ -929,6 +930,16 @@ std::unique_ptr CreateExecutionProviderInstance( "[ERROR] [MIGraphX] The value for the key 'migx_load_model_name' should be a " "file name i.e. 'compiled_model.mxr'.\n"); } + } else if (option.first == "migraphx_exhaustive_tune") { + if (option.second == "True" || option.second == "true") { + params.migraphx_exhaustive_tune = true; + } else if (option.second == "False" || option.second == "false") { + params.migraphx_exhaustive_tune = false; + } else { + ORT_THROW( + "[ERROR] [MIGraphX] The value for the key 'migraphx_exhaustive_tune' should be" + " 'True' or 'False'. Default value is 'False'.\n"); + } } else { ORT_THROW("Invalid MIGraphX EP option: ", option.first); } diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 312aa86277994..1feba20e32bbb 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -80,7 +80,8 @@ std::unique_ptr DefaultMIGraphXExecutionProvider() { 1, "./compiled_model.mxr", 1, - "./compiled_model.mxr"}; + "./compiled_model.mxr", + 1}; return MIGraphXProviderFactoryCreator::Create(¶ms)->CreateProvider(); #else return nullptr;