Skip to content

Commit

Permalink
Lintrunner pass
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Oct 23, 2023
1 parent 487a27a commit 1524da4
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 27 deletions.
11 changes: 5 additions & 6 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,12 +598,11 @@ typedef struct OrtTensorRTProviderOptions {
* \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
*/
typedef struct OrtMIGraphXProviderOptions {

int device_id; // hip device id.
int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
int device_id; // hip device id.
int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true
int migraphx_use_native_calibration_table; // MIGraphx INT8 cal table. Default 0 = false, noznero = true
const char* migraphx_int8_calibration_table_name; // MIGraphx INT8 calibration table name
} OrtMIGraphXProviderOptions;

/** \brief OpenVINO Provider Options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,7 @@ std::shared_ptr<KernelRegistry> MIGraphXExecutionProvider::GetKernelRegistry() c
}

MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info)
: IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT,
info.device_id), true}, device_id_(info.device_id) {
: IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, OrtDevice(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, info.device_id), true}, device_id_(info.device_id) {
InitProviderOrtApi();
// Set GPU device to be used
HIP_CALL_THROW(hipSetDevice(device_id_));
Expand All @@ -123,7 +122,7 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv

if (int8_enable_) {
const std::string int8_calibration_cache_name_env =
onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName);
onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8CalibrationTableName);
if (!int8_calibration_cache_name_env.empty()) {
int8_calibration_cache_name_ = int8_calibration_cache_name_env;
}
Expand All @@ -134,10 +133,10 @@ MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProv
}

const std::string int8_use_native_migraphx_calibration_table_env =
onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable);
onnxruntime::GetEnvironmentVar(migraphx_env_vars::kINT8UseNativeMIGraphXCalibrationTable);
if (!int8_use_native_migraphx_calibration_table_env.empty()) {
int8_use_native_migraphx_calibration_table_ =
(std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true);
(std::stoi(int8_use_native_migraphx_calibration_table_env) == 0 ? false : true);
}
}

Expand Down Expand Up @@ -934,7 +933,8 @@ GetUnsupportedNodeIndices(const GraphViewer& graph_viewer,
&graph_viewer](const onnxruntime::NodeArg& node_arg, bool is_input) {
if(is_input && graph_viewer.GetAllInitializedTensors().count(node_arg.Name())) {
mgx_required_initializers.insert(node_arg.Name());
} }, true);
} },
true);
} else {
unsupported_nodes_idx.push_back(node_idx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ float ConvertSinglePrecisionIEEE754ToFloat(uint64_t input) {
bool ReadDynamicRange(const std::string file_name,
const bool is_calibration_table,
std::unordered_map<std::string,
float>& dynamic_range_map) {
float>& dynamic_range_map) {
std::ifstream infile(file_name, std::ios::binary | std::ios::in);
if (!infile) {
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ struct MIGraphX_Provider : Provider {
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name == nullptr ?
"" : options.migraphx_int8_calibration_table_name;
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name == nullptr ? "" : options.migraphx_int8_calibration_table_name;
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
return std::make_shared<MIGraphXProviderFactory>(info);
}
Expand Down
28 changes: 16 additions & 12 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,10 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
#endif
} else if (type == kMIGraphXExecutionProvider) {
#ifdef USE_MIGRAPHX
std::string calibration_table;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
std::string calibration_table;
auto it = provider_options_map.find(type);
if (it != provider_options_map.end()) {
OrtMIGraphXProviderOptions params{
0,
0,
0,
Expand All @@ -744,7 +744,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be \
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be \
'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
Expand All @@ -753,15 +754,17 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (option.second == "False" || option.second == "false") {
params.migraphx_int8_enable = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be \
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be \
'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a \
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a \
file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
Expand All @@ -770,20 +773,21 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (option.second == "False" || option.second == "false") {
params.migraphx_use_native_calibration_table = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be \
ORT_THROW(
"[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be \
'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
return migraphx_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
return migraphx_provider_factory->CreateProvider();
}
}
Expand Down Expand Up @@ -816,8 +820,8 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
LOGS_DEFAULT(WARNING) << "Failed to create "
<< type
<< ". Please reference
https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements
to ensure all dependencies are met.";
https : // onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html#requirements
to ensure all dependencies are met.";
#endif
} else if (type == kRocmExecutionProvider) {
#ifdef USE_ROCM
Expand Down

0 comments on commit 1524da4

Please sign in to comment.