Skip to content

Commit

Permalink
Fixing linting/formatting and more issues.
Browse files Browse the repository at this point in the history
  • Loading branch information
TedThemistokleous committed Oct 20, 2023
1 parent ea8fa7d commit 35225d1
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 51 deletions.
1 change: 1 addition & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,7 @@ typedef struct OrtTensorRTProviderOptions {
* \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX
*/
typedef struct OrtMIGraphXProviderOptions {

int device_id; // hip device id.
int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true
int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true
Expand Down
191 changes: 162 additions & 29 deletions onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "migraphx_execution_provider_info.h"

#include <map>
#include <unordered_map>
#include "migraphx_inc.h"
// TODO: find a better way to share this
// #include "core/providers/cuda/rocm_stream_handle.h"
Expand All @@ -18,12 +19,12 @@
namespace onnxruntime {

namespace migraphx_env_vars {
static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE";
static const std::string kINT8Enable = "ORT_MIGRAPHX_INT8_ENABLE";
static const std::string dumpModelOps = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const std::string kINT8CalibrationTableName = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const std::string kCachePath = "ORT_MIGRAPHX_CACHE_PATH";
static const std::string kINT8UseNativeMIGraphXCalibrationTable = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
static const char kFP16Enable[] = "ORT_MIGRAPHX_FP16_ENABLE";
static const char kINT8Enable[] = "ORT_MIGRAPHX_INT8_ENABLE";
static const char string dumpModelOps[] = "ORT_MIGRAPHX_DUMP_MODEL_OPS";
static const char kINT8CalibrationTableName[] = "ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME";
static const char kCachePath[] = "ORT_MIGRAPHX_CACHE_PATH";
static const char kINT8UseNativeMIGraphXCalibrationTable[] = "ORT_MIGRAPHX_INT8_USE_NATIVE_CALIBRATION_TABLE";
}; // namespace migraphx_env_vars

// Information to construct kernel function state.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#pragma once

#include <limits>
#include <string>

#include "core/framework/ortdevice.h"
#include "core/framework/provider_options.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
#include <string>
#include <iostream>
#include <filesystem>
#include <memory>
#include "flatbuffers/idl.h"
#include "ort_trt_int8_cal_table.fbs.h"
#include "core/providers/migraphx/ort_trt_int8_cal_table.fbs.h"
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/execution_provider.h"
#include "core/common/path_string.h"
Expand Down Expand Up @@ -112,7 +113,10 @@ bool canEvalShapeGeneral(const GraphViewer& graph, const Node* node, std::vector
return true;
}

bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector<std::size_t> indices, std::vector<NodeIndex>& input_nodes) {
bool canEvalNodeArgument(const GraphViewer& graph,
const Node* node,
std::vector<std::size_t> indices,
std::vector<NodeIndex>& input_nodes) {
input_nodes.clear();
std::vector<const Node*> in_nodes;
for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) {
Expand Down Expand Up @@ -148,7 +152,7 @@ bool canEvalNodeArgument(const GraphViewer& graph, const Node* node, std::vector
return true;
}

float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) {
float ConvertSinglePrecisionIEEE754ToFloat(uint64_t input) {
int s = (input >> 31) & 0x01;
int e = ((input & 0x7f800000) >> 23) - 127;
int p = -1;
Expand Down Expand Up @@ -180,7 +184,10 @@ float ConvertSinglePrecisionIEEE754ToFloat(unsigned long input) {
* Taken from the tensorRT EP to allow MIGraphX EP to reuse calibration tables for existing models
*
*/
bool ReadDynamicRange(const std::string file_name, const bool is_calibration_table, std::unordered_map<std::string, float>& dynamic_range_map) {
bool ReadDynamicRange(const std::string file_name,
const bool is_calibration_table,
std::unordered_map<std::string,
float>& dynamic_range_map) {
std::ifstream infile(file_name, std::ios::binary | std::ios::in);
if (!infile) {
return false;
Expand All @@ -202,7 +209,7 @@ bool ReadDynamicRange(const std::string file_name, const bool is_calibration_tab
std::getline(in_line, str, delim);
std::string tensor_name = str;
std::getline(in_line, str, delim);
unsigned long scale_int = std::strtoul(str.c_str(), nullptr, 16);
uint64_t scale_int = std::strtoul(str.c_str(), nullptr, 16);
float scale_float = ConvertSinglePrecisionIEEE754ToFloat(scale_int);
float dynamic_range = scale_float * 127.0f;
dynamic_range_map[tensor_name] = dynamic_range;
Expand All @@ -219,7 +226,7 @@ bool ReadDynamicRange(const std::string file_name, const bool is_calibration_tab
std::unique_ptr<char[]> data{new char[length]};
infile.read((char*)data.get(), length);
infile.close();
auto flat_table = flatbuffers::GetRoot<CalTableFlatBuffers::TrtTable>((const uint8_t*)data.get());
auto flat_table = flatbuffers::GetRoot<CalTableFlatBuffers::TrtTable>(reinterpret_cast<char*>(data.get()));
auto flat_dict = flat_table->dict();
for (size_t i = 0, end = flat_dict->size(); i < end; ++i) {
flatbuffers::uoffset_t idx = static_cast<flatbuffers::uoffset_t>(i);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ struct MIGraphX_Provider : Provider {
info.target_device = "gpu";
info.fp16_enable = options.migraphx_fp16_enable;
info.int8_enable = options.migraphx_int8_enable;
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name == nullptr ? "" : options.migraphx_int8_calibration_table_name;
info.int8_calibration_table_name = options.migraphx_int8_calibration_table_name == nullptr ?
"" : options.migraphx_int8_calibration_table_name;
info.int8_use_native_calibration_table = options.migraphx_use_native_calibration_table != 0;
return std::make_shared<MIGraphXProviderFactory>(info);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
// automatically generated by the FlatBuffers compiler, do not modify

#ifndef FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
#define FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
#ifndef ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
#define ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_

#include "flatbuffers/flatbuffers.h"
#include <vector>

namespace CalTableFlatBuffers {

Expand Down Expand Up @@ -141,4 +142,4 @@ inline flatbuffers::Offset<TrtTable> CreateTrtTableDirect(

} // namespace CalTableFlatBuffers

#endif // FLATBUFFERS_GENERATED_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
#endif // ONNXRUNTIME_CORE_PROVIDERS_MIGRAPHX_ORTTRTINT8CALTABLE_CALTABLEFLATBUFFERS_H_
18 changes: 12 additions & 6 deletions onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -744,40 +744,46 @@ std::unique_ptr<IExecutionProvider> CreateExecutionProviderInstance(
} else if (option.second == "False" || option.second == "false") {
params.migraphx_fp16_enable = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be 'True' or 'False'. Default value is 'False'.\n");
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'trt_fp16_enable' should be
'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_enable") {
if (option.second == "True" || option.second == "true") {
params.migraphx_int8_enable = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_int8_enable = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be 'True' or 'False'. Default value is 'False'.\n");
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_enable' should be
'True' or 'False'. Default value is 'False'.\n");
}
} else if (option.first == "migraphx_int8_calibration_table_name") {
if (!option.second.empty()) {
calibration_table = option.second;
params.migraphx_int8_calibration_table_name = calibration_table.c_str();
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a file name i.e. 'cal_table'.\n");
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_calibration_table_name' should be a
file name i.e. 'cal_table'.\n");
}
} else if (option.first == "migraphx_use_native_calibration_table") {
if (option.second == "True" || option.second == "true") {
params.migraphx_use_native_calibration_table = true;
} else if (option.second == "False" || option.second == "false") {
params.migraphx_use_native_calibration_table = false;
} else {
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be 'True' or 'False'. Default value is 'False'.\n");
ORT_THROW("[ERROR] [MIGraphX] The value for the key 'migx_int8_use_native_calibration_table' should be
'True' or 'False'. Default value is 'False'.\n");
}
} else {
ORT_THROW("Invalid MIGraphX EP option: ", option.first);
}
}
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(&params)) {
return migraphx_provider_factory->CreateProvider();
}
} else {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory = onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
if (std::shared_ptr<IExecutionProviderFactory> migraphx_provider_factory =
onnxruntime::MIGraphXProviderFactoryCreator::Create(cuda_device_id)) {
return migraphx_provider_factory->CreateProvider();
}
}
Expand Down

0 comments on commit 35225d1

Please sign in to comment.