Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refine ep plugin c++ wrapper #23131

Merged
merged 1 commit into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_c_api_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -389,13 +389,20 @@ ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph);
/** \brief Release the graph viewer instance.
*
* NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of
* graph, this API releases graph's owning_model explicitly which in turn will release the graph
* (because graph is hosted in an unique_ptr in Model class)
* graph, should set release_model to true, then this API will releases graph's owning_model explicitly which in turn will
* release the graph (because graph is hosted in an unique_ptr in Model class)
*
* \param[in] graph The graph to release
* \param[in] release_model If true, release the model as well, otherwise only release the graph viewer instance
*
*/
ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model);

/** \brief Release the graph viewer array.
*
* NOTE!!: Invoke this function after the use of OrtNode_GetSubgraphs.
*/
ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph);
ORT_API2_STATUS(OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_array, size_t num_graphs);

/** \brief Check are two graph actually pointing to the same graph.
*
Expand Down
50 changes: 39 additions & 11 deletions include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@
namespace Ort {
namespace PluginEP {

using VoidPtr = std::unique_ptr<void, std::function<void(void*)>>;

struct TensorRef {
explicit TensorRef(OrtTensorRef*);
~TensorRef();
const std::vector<int64_t> GetShape();
const ONNXTensorElementDataType GetTensorElementType();
const char* GetData();
size_t GetDataLen();
private:

Check warning on line 20 in include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 private: should be indented +1 space inside struct TensorRef [whitespace/indent] [3] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h:20: private: should be indented +1 space inside struct TensorRef [whitespace/indent] [3]
OrtTensorRef* tensor_;
};

struct ValueInfoRef {
explicit ValueInfoRef(OrtValueInfoRef*);
~ValueInfoRef();
Expand All @@ -18,8 +31,17 @@
};

struct Graph {
explicit Graph(const OrtGraphViewer*);
const OrtGraphViewer* GetGraph() { return graph_; }
explicit Graph(const OrtGraph*);
const OrtGraph* GetGraph() { return graph_; }
void DumpOnnxModel(const std::filesystem::path& onnx_model_path);
private:

Check warning on line 37 in include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 private: should be indented +1 space inside struct Graph [whitespace/indent] [3] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h:37: private: should be indented +1 space inside struct Graph [whitespace/indent] [3]
const OrtGraph* graph_;
};
using GraphPtr = std::unique_ptr<PluginEP::Graph, std::function<void(PluginEP::Graph*)>>;

struct GraphViewer {
explicit GraphViewer(const OrtGraphViewer*);
const OrtGraphViewer* GetGraphViewer() { return graph_; }
const char* GetName();
bool IsConstantInitializer(const char* name, bool check_outer_scope);
const std::vector<size_t> GetNodesIndexInTopologicalOrder(int execution_order);
Expand All @@ -37,17 +59,25 @@
size_t GetOutputSize();
std::string GetIthOutputName(size_t i);
int32_t GetIthOutputElemType(size_t i);
// std::shared_ptr<TensorRef> GetInitializerTensor(const char* initializer_name);
std::shared_ptr<TensorRef> GetInitializerTensor(const char* initializer_name);
std::shared_ptr<ValueInfoRef> GetValueInfo(const char* name);
// void SerializeToArray(void** data, size_t* data_size);
// void DumpOnnxModel(const std::filesystem::path& onnx_model_path);
// CreateOrUpdateEpCtxGraph();
std::shared_ptr<Graph> GetSubGraph(std::vector<size_t> node_indices);
// bool IsSameGraph(const Graph& other);
std::pair<VoidPtr, size_t> SerializeToArray();

Check warning on line 64 in include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <utility> for pair<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h:64: Add #include <utility> for pair<> [build/include_what_you_use] [4]
GraphPtr CreateOrUpdateEpCtxGraph(const char* node_name,
const int64_t main_context,
const int64_t embed_mode,
const char* cache_path,
char* cache_data,
size_t size,
const char* const* extra_attr_keys,
const char* const* extra_attr_values,
size_t extra_attr_num);
GraphViewerPtr GetSubGraph(std::vector<size_t> node_indices);
bool IsSameGraph(GraphViewer& other);

private:
const OrtGraphViewer* graph_;
};
using GraphViewerPtr = std::unique_ptr<PluginEP::GraphViewer, std::function<void(PluginEP::GraphViewer*)>>;

Check warning on line 80 in include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h:80: Add #include <memory> for unique_ptr<> [build/include_what_you_use] [4]

struct Node {
explicit Node(const OrtNode*);
Expand All @@ -74,12 +104,10 @@
int64_t GetAttributeIthInt(std::string attribute_name, size_t i);
float GetAttributeIthFloat(std::string attribute_name, size_t i);
const std::string GetAttributeIthStr(std::string attribute_name, size_t i);
// GetAttributeIthStrWithSize
const std::string GetAttributeStr(std::string attribute_name);
// GetAttributeStrWithSize
int64_t GetAttributeInt(std::string attribute_name);
float GetAttributeFloat(std::string attribute_name);
// GetSubgraphs
// TODO: add GetSubgraphs wrapper here

Check warning on line 110 in include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h:110: Missing username in TODO; it should look like "// TODO(my_username): Stuff." [readability/todo] [2]
private:
const OrtNode* node_;
};
Expand Down
131 changes: 99 additions & 32 deletions include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,29 @@

static const OrtGraphApi* ort_graph_api = GetApi().GetGraphApi(ORT_API_VERSION);

inline TensorRef::TensorRef(OrtTensorRef* tensor) : tensor_(tensor) {}

inline TensorRef::~TensorRef() {
ort_graph_api->OrtGraph_ReleaseInitializerTensor(tensor_);
}

inline const std::vector<int64_t> TensorRef::GetShape() {
std::vector<int64_t> shape(tensor_->shape, tensor_->shape + tensor_->shape_len);
return shape;
}

inline const ONNXTensorElementDataType TensorRef::GetTensorElementType() {
return tensor_->data_type;
}

inline const char* TensorRef::GetData() {
return tensor_->data;
}

inline size_t TensorRef::GetDataLen() {
return tensor_->data_len;
}

inline ValueInfoRef::ValueInfoRef(OrtValueInfoRef* value_info) : value_info_(value_info) {}

inline ValueInfoRef::~ValueInfoRef() {
Expand All @@ -23,46 +46,52 @@
return value_info_->data_type;
}

inline Graph::Graph(const OrtGraphViewer* graph) : graph_(graph) {}
inline Graph::Graph(const OrtGraph* graph) : graph_(graph) {}

inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) {
ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_, onnx_model_path.c_str()));
}

inline const char* Graph::GetName() {
inline GraphViewer::GraphViewer(const OrtGraphViewer* graph) : graph_(graph) {}

inline const char* GraphViewer::GetName() {
const char* graph_name = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetName(graph_, &graph_name));
return graph_name;
}

inline bool Graph::IsConstantInitializer(const char* name, bool check_outer_scope) {
inline bool GraphViewer::IsConstantInitializer(const char* name, bool check_outer_scope) {
bool is_initializer = false;
ThrowOnError(ort_graph_api->OrtGraph_IsConstantInitializer(graph_, name, check_outer_scope, &is_initializer));
return is_initializer;
}

inline const std::vector<size_t> Graph::GetNodesIndexInTopologicalOrder(int execution_order) {
inline const std::vector<size_t> GraphViewer::GetNodesIndexInTopologicalOrder(int execution_order) {
const size_t* nodes_index = nullptr;
size_t nodes_count = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, nodes_index, nodes_count));
ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, &nodes_index, &nodes_count));

Check warning on line 72 in include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h:72: Lines should be <= 120 characters long [whitespace/line_length] [2]
return std::vector<size_t>(nodes_index, nodes_index + nodes_count);
}

inline bool Graph::IsSubgraph() {
inline bool GraphViewer::IsSubgraph() {
bool is_subgraph = false;
ThrowOnError(ort_graph_api->OrtGraph_IsSubgraph(graph_, &is_subgraph));
return is_subgraph;
}

inline std::shared_ptr<Node> Graph::GetParenNode() {
inline std::shared_ptr<PluginEP::Node> GraphViewer::GetParenNode() {
const OrtNode* parent_node = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetParenNode(graph_, &parent_node));
return std::make_shared<Node>(parent_node);
return std::make_shared<PluginEP::Node>(parent_node);
}

inline std::filesystem::path Graph::GetModelPath() {
inline std::filesystem::path GraphViewer::GetModelPath() {
const void* model_path = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetModelPath(graph_, &model_path));
return *reinterpret_cast<const std::filesystem::path*>(model_path);
}

inline std::vector<std::string> Graph::GetRequiredInputs() {
inline std::vector<std::string> GraphViewer::GetRequiredInputs() {
const char** required_inputs = nullptr;
size_t required_inputs_count = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetRequiredInputs(graph_, &required_inputs, &required_inputs_count));
Expand All @@ -78,7 +107,7 @@
return ret;
}

inline std::vector<std::string> Graph::GetAllInputs() {
inline std::vector<std::string> GraphViewer::GetAllInputs() {
const char** all_inputs = nullptr;
size_t all_inputs_count = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetAllInputs(graph_, &all_inputs, &all_inputs_count));
Expand All @@ -94,7 +123,7 @@
return ret;
}

inline std::vector<std::string> Graph::GetAllInitializers() {
inline std::vector<std::string> GraphViewer::GetAllInitializers() {
const char** all_initializers = nullptr;
size_t all_initializers_count = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetAllInitializers(graph_, &all_initializers, &all_initializers_count));
Expand All @@ -110,13 +139,13 @@
return ret;
}

inline Ort::PluginEP::Node Graph::GetOrtNode(size_t node_index) {
inline Ort::PluginEP::Node GraphViewer::GetOrtNode(size_t node_index) {
const OrtNode* node = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetOrtNode(graph_, node_index, &node));
return Ort::PluginEP::Node(node);
}

inline std::vector<Ort::PluginEP::Node> Graph::GetNodesConsumingInput(const char* input_name) {
inline std::vector<Ort::PluginEP::Node> GraphViewer::GetNodesConsumingInput(const char* input_name) {
const OrtNode** consumers = nullptr;
size_t consumer_count = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetNodesConsumingInput(graph_, input_name, &consumers, &consumer_count));
Expand All @@ -129,64 +158,102 @@
return ret;
}

inline Ort::PluginEP::Node Graph::GetNodeProducingOutput(const char* output_name) {
inline Ort::PluginEP::Node GraphViewer::GetNodeProducingOutput(const char* output_name) {
const OrtNode* node = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetNodeProducingOutput(graph_, output_name, &node));
return Ort::PluginEP::Node(node);
}

inline int Graph::NumberOfNodes() {
inline int GraphViewer::NumberOfNodes() {
int num_nodes = 0;
ThrowOnError(ort_graph_api->OrtGraph_NumberOfNodes(graph_, &num_nodes));
return num_nodes;
}

inline int Graph::MaxNodeIndex() {
inline int GraphViewer::MaxNodeIndex() {
int max_node_index = 0;
ThrowOnError(ort_graph_api->OrtGraph_MaxNodeIndex(graph_, &max_node_index));
return max_node_index;
}

inline size_t Graph::GetOutputSize() {
inline size_t GraphViewer::GetOutputSize() {
size_t output_size = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetOutputSize(graph_, &output_size));
return output_size;
}

inline std::string Graph::GetIthOutputName(size_t i) {
inline std::string GraphViewer::GetIthOutputName(size_t i) {
const char* output_name = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputName(graph_, i, &output_name));
return std::string(output_name);
}

inline int32_t Graph::GetIthOutputElemType(size_t i) {
inline int32_t GraphViewer::GetIthOutputElemType(size_t i) {
int32_t elem_type = 0;
ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputElemType(graph_, i, &elem_type));
return elem_type;
}

inline std::shared_ptr<ValueInfoRef> Graph::GetValueInfo(const char* name) {
inline std::shared_ptr<TensorRef> GraphViewer::GetInitializerTensor(const char* initializer_name) {
OrtTensorRef* tensor = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetInitializerTensor(graph_, initializer_name, &tensor));
return std::make_shared<TensorRef>(tensor);
}

inline std::shared_ptr<ValueInfoRef> GraphViewer::GetValueInfo(const char* name) {
OrtValueInfoRef* value_info = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetValueInfo(graph_, name, &value_info));
return std::make_shared<ValueInfoRef>(value_info);
}

// inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) {
// ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_->GetGraph(), onnx_model_path.c_str()));
// }
inline std::pair<VoidPtr, size_t> GraphViewer::SerializeToArray() {
void* serialized_data = nullptr;
size_t serialized_data_len = 0;
ThrowOnError(ort_graph_api->OrtGraph_SerializeToArray(graph_, &serialized_data, &serialized_data_len));
return std::make_pair(VoidPtr(serialized_data, [](void* ptr) { ort_graph_api->OrtFreeMem(ptr); }), serialized_data_len);

Check warning on line 213 in include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h:213: Lines should be <= 120 characters long [whitespace/line_length] [2]
}

inline GraphPtr GraphViewer::CreateOrUpdateEpCtxGraph(const char* node_name,
const int64_t main_context,

Check warning on line 217 in include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h:217: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const int64_t embed_mode,

Check warning on line 218 in include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h:218: Do not indent within a namespace. [whitespace/indent_namespace] [4]
const char* cache_path,

Check warning on line 219 in include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Do not indent within a namespace. [whitespace/indent_namespace] [4] Raw Output: include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h:219: Do not indent within a namespace. [whitespace/indent_namespace] [4]
char* cache_data,
size_t size,
const char* const* extra_attr_keys,
const char* const* extra_attr_values,
size_t extra_attr_num) {
OrtGraph* graph = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_CreateOrUpdateEpCtxGraph(graph_,
node_name,
main_context,
embed_mode,
cache_path,
cache_data,
size,
extra_attr_keys,
extra_attr_values,
extra_attr_num,
&graph));
auto release_fn = [](Graph* graph) {
ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraph(graph->GetGraph()));
};
return std::unique_ptr<Graph, decltype(release_fn)>(new Graph(graph), release_fn);
}

inline std::shared_ptr<Graph> Graph::GetSubGraph(std::vector<size_t> node_indices) {
inline GraphViewerPtr GraphViewer::GetSubGraph(std::vector<size_t> node_indices) {
const OrtGraphViewer* subgraph = nullptr;
ThrowOnError(ort_graph_api->OrtGraph_GetSubGraph(graph_, node_indices.size(), node_indices.data(), &subgraph));
// TODO:yang if should release subgraph in the decstructor of Graph?
return std::make_shared<Graph>(subgraph);
auto release_fn = [](GraphViewer* graph) {
ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraphViewer(graph->GetGraphViewer(), true));
};
return std::unique_ptr<GraphViewer, decltype(release_fn)>(new GraphViewer(subgraph), release_fn);
}

// inline bool Graph::IsSameGraph(const Graph& other) {
// bool is_same = false;
// ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraph(), &is_same));
// return is_same;
// }
inline bool GraphViewer::IsSameGraph(GraphViewer& other) {
bool is_same = false;
ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraphViewer(), &is_same));
return is_same;
}

inline Node::Node(const OrtNode* node) : node_(node) {}

Expand Down
15 changes: 13 additions & 2 deletions onnxruntime/core/session/onnxruntime_c_api_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -775,15 +775,25 @@
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) {
ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model) {
if (graph) {
const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
delete &(graph_viewer->GetGraph()).GetModel();
if (release_model) {
delete &(graph_viewer->GetGraph()).GetModel();
}
delete graph_viewer;
}
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_viewers, size_t num_graphs) {
for (size_t i = 0; i < num_graphs; i++) {
OrtGraph_ReleaseGraphViewer(graph_viewers[i], false);

Check warning

Code scanning / PREfast

Return value ignored: 'OrtGraphApis::OrtGraph_ReleaseGraphViewer'. Warning

Return value ignored: 'OrtGraphApis::OrtGraph_ReleaseGraphViewer'.
}
delete[] graph_viewers;

Check warning

Code scanning / PREfast

Avoid calling new and delete explicitly, use std::make_unique instead (r.11). Warning

Avoid calling new and delete explicitly, use std::make_unique instead (r.11).
return nullptr;
}

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, bool* is_same) {
const ::onnxruntime::GraphViewer* graph_viewer1 = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph1);
const ::onnxruntime::GraphViewer* graph_viewer2 = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph2);
Expand Down Expand Up @@ -1021,6 +1031,7 @@
&OrtGraphApis::OrtGraph_GetSubGraph,
&OrtGraphApis::OrtGraph_ReleaseGraph,
&OrtGraphApis::OrtGraph_ReleaseGraphViewer,
&OrtGraphApis::OrtGraph_ReleaseGraphViewerArray,
&OrtGraphApis::OrtGraph_IsSameGraph,
&OrtGraphApis::OrtNode_GetName,
&OrtGraphApis::OrtNode_GetDescription,
Expand Down
Loading
Loading