From a6ebeebc0eff4301ad99efd06d1cfae57ac2795a Mon Sep 17 00:00:00 2001 From: guyang3532 Date: Tue, 17 Dec 2024 08:06:57 +0000 Subject: [PATCH] refine ep plugin c++ wrapper --- .../core/session/onnxruntime_c_api_ep.h | 13 +- .../core/session/onnxruntime_cxx_api_ep.h | 50 +++++-- .../core/session/onnxruntime_cxx_inline_ep.h | 131 +++++++++++++----- .../core/session/onnxruntime_c_api_ep.cc | 15 +- onnxruntime/core/session/ort_apis_ep.h | 4 +- .../tensorRTEp/tensorrt_execution_provider.cc | 4 +- 6 files changed, 167 insertions(+), 50 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h index 77da384ac6671..8a59eaebfd210 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -389,13 +389,20 @@ ORT_API2_STATUS(OrtGraph_ReleaseGraph, const OrtGraph* graph); /** \brief Release the graph viewer instance. * * NOTE!!: Invoke this function after the use of OrtGraph_GetSubGraph. As OrtGraph_GetSubGraph allocates model instead of - * graph, this API releases graph's owning_model explicitly which in turn will release the graph - * (because graph is hosted in an unique_ptr in Model class) + * graph, should set release_model to true, then this API will releases graph's owning_model explicitly which in turn will + * release the graph (because graph is hosted in an unique_ptr in Model class) * * \param[in] graph The graph to release + * \param[in] release_model If true, release the model as well, otherwise only release the graph viewer instance + * + */ +ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model); + +/** \brief Release the graph viewer array. * + * NOTE!!: Invoke this function after the use of OrtNode_GetSubgraphs. */ -ORT_API2_STATUS(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +ORT_API2_STATUS(OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_array, size_t num_graphs); /** \brief Check are two graph actually pointing to the same graph. * diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h index 335eb88a060b4..7d206d16dbb3d 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api_ep.h @@ -8,6 +8,19 @@ namespace Ort { namespace PluginEP { +using VoidPtr = std::unique_ptr>; + +struct TensorRef { +explicit TensorRef(OrtTensorRef*); +~TensorRef(); +const std::vector GetShape(); +const ONNXTensorElementDataType GetTensorElementType(); +const char* GetData(); +size_t GetDataLen(); +private: +OrtTensorRef* tensor_; +}; + struct ValueInfoRef { explicit ValueInfoRef(OrtValueInfoRef*); ~ValueInfoRef(); @@ -18,8 +31,17 @@ OrtValueInfoRef* value_info_; }; struct Graph { -explicit Graph(const OrtGraphViewer*); -const OrtGraphViewer* GetGraph() { return graph_; } +explicit Graph(const OrtGraph*); +const OrtGraph* GetGraph() { return graph_; } +void DumpOnnxModel(const std::filesystem::path& onnx_model_path); +private: +const OrtGraph* graph_; +}; +using GraphPtr = std::unique_ptr>; + +struct GraphViewer { +explicit GraphViewer(const OrtGraphViewer*); +const OrtGraphViewer* GetGraphViewer() { return graph_; } const char* GetName(); bool IsConstantInitializer(const char* name, bool check_outer_scope); const std::vector GetNodesIndexInTopologicalOrder(int execution_order); @@ -37,17 +59,25 @@ int MaxNodeIndex(); size_t GetOutputSize(); std::string GetIthOutputName(size_t i); int32_t GetIthOutputElemType(size_t i); -// std::shared_ptr GetInitializerTensor(const char* initializer_name); +std::shared_ptr GetInitializerTensor(const char* initializer_name); std::shared_ptr GetValueInfo(const char* name); -// void SerializeToArray(void** data, size_t* data_size); -// void DumpOnnxModel(const std::filesystem::path& onnx_model_path); -// CreateOrUpdateEpCtxGraph(); -std::shared_ptr GetSubGraph(std::vector node_indices); -// bool IsSameGraph(const Graph& other); +std::pair SerializeToArray(); +GraphPtr CreateOrUpdateEpCtxGraph(const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num); +GraphViewerPtr GetSubGraph(std::vector node_indices); +bool IsSameGraph(GraphViewer& other); private: const OrtGraphViewer* graph_; }; +using GraphViewerPtr = std::unique_ptr>; struct Node { explicit Node(const OrtNode*); @@ -74,12 +104,10 @@ int GetAttributeStringSize(std::string attribute_name); int64_t GetAttributeIthInt(std::string attribute_name, size_t i); float GetAttributeIthFloat(std::string attribute_name, size_t i); const std::string GetAttributeIthStr(std::string attribute_name, size_t i); -// GetAttributeIthStrWithSize const std::string GetAttributeStr(std::string attribute_name); -// GetAttributeStrWithSize int64_t GetAttributeInt(std::string attribute_name); float GetAttributeFloat(std::string attribute_name); -// GetSubgraphs +// TODO: add GetSubgraphs wrapper here private: const OrtNode* node_; }; diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h index 21f2834fb20d6..8c5e9b7978ca9 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline_ep.h @@ -8,6 +8,29 @@ namespace PluginEP { static const OrtGraphApi* ort_graph_api = GetApi().GetGraphApi(ORT_API_VERSION); +inline TensorRef::TensorRef(OrtTensorRef* tensor) : tensor_(tensor) {} + +inline TensorRef::~TensorRef() { + ort_graph_api->OrtGraph_ReleaseInitializerTensor(tensor_); +} + +inline const std::vector TensorRef::GetShape() { + std::vector shape(tensor_->shape, tensor_->shape + tensor_->shape_len); + return shape; +} + +inline const ONNXTensorElementDataType TensorRef::GetTensorElementType() { + return tensor_->data_type; +} + +inline const char* TensorRef::GetData() { + return tensor_->data; +} + +inline size_t TensorRef::GetDataLen() { + return tensor_->data_len; +} + inline ValueInfoRef::ValueInfoRef(OrtValueInfoRef* value_info) : value_info_(value_info) {} inline ValueInfoRef::~ValueInfoRef() { @@ -23,46 +46,52 @@ inline const ONNXTensorElementDataType ValueInfoRef::GetTensorElementType() { return value_info_->data_type; } -inline Graph::Graph(const OrtGraphViewer* graph) : graph_(graph) {} +inline Graph::Graph(const OrtGraph* graph) : graph_(graph) {} + +inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) { + ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_, onnx_model_path.c_str())); +} -inline const char* Graph::GetName() { +inline GraphViewer::GraphViewer(const OrtGraphViewer* graph) : graph_(graph) {} + +inline const char* GraphViewer::GetName() { const char* graph_name = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetName(graph_, &graph_name)); return graph_name; } -inline bool Graph::IsConstantInitializer(const char* name, bool check_outer_scope) { +inline bool GraphViewer::IsConstantInitializer(const char* name, bool check_outer_scope) { bool is_initializer = false; ThrowOnError(ort_graph_api->OrtGraph_IsConstantInitializer(graph_, name, check_outer_scope, &is_initializer)); return is_initializer; } -inline const std::vector Graph::GetNodesIndexInTopologicalOrder(int execution_order) { +inline const std::vector GraphViewer::GetNodesIndexInTopologicalOrder(int execution_order) { const size_t* nodes_index = nullptr; size_t nodes_count = 0; - ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, nodes_index, nodes_count)); + ThrowOnError(ort_graph_api->OrtGraph_GetNodesIndexInTopologicalOrder(graph_, execution_order, &nodes_index, &nodes_count)); return std::vector(nodes_index, nodes_index + nodes_count); } -inline bool Graph::IsSubgraph() { +inline bool GraphViewer::IsSubgraph() { bool is_subgraph = false; ThrowOnError(ort_graph_api->OrtGraph_IsSubgraph(graph_, &is_subgraph)); return is_subgraph; } -inline std::shared_ptr Graph::GetParenNode() { +inline std::shared_ptr GraphViewer::GetParenNode() { const OrtNode* parent_node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetParenNode(graph_, &parent_node)); - return std::make_shared(parent_node); + return std::make_shared(parent_node); } -inline std::filesystem::path Graph::GetModelPath() { +inline std::filesystem::path GraphViewer::GetModelPath() { const void* model_path = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetModelPath(graph_, &model_path)); return *reinterpret_cast(model_path); } -inline std::vector Graph::GetRequiredInputs() { +inline std::vector GraphViewer::GetRequiredInputs() { const char** required_inputs = nullptr; size_t required_inputs_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetRequiredInputs(graph_, &required_inputs, &required_inputs_count)); @@ -78,7 +107,7 @@ inline std::vector Graph::GetRequiredInputs() { return ret; } -inline std::vector Graph::GetAllInputs() { +inline std::vector GraphViewer::GetAllInputs() { const char** all_inputs = nullptr; size_t all_inputs_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetAllInputs(graph_, &all_inputs, &all_inputs_count)); @@ -94,7 +123,7 @@ inline std::vector Graph::GetAllInputs() { return ret; } -inline std::vector Graph::GetAllInitializers() { +inline std::vector GraphViewer::GetAllInitializers() { const char** all_initializers = nullptr; size_t all_initializers_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetAllInitializers(graph_, &all_initializers, &all_initializers_count)); @@ -110,13 +139,13 @@ inline std::vector Graph::GetAllInitializers() { return ret; } -inline Ort::PluginEP::Node Graph::GetOrtNode(size_t node_index) { +inline Ort::PluginEP::Node GraphViewer::GetOrtNode(size_t node_index) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetOrtNode(graph_, node_index, &node)); return Ort::PluginEP::Node(node); } -inline std::vector Graph::GetNodesConsumingInput(const char* input_name) { +inline std::vector GraphViewer::GetNodesConsumingInput(const char* input_name) { const OrtNode** consumers = nullptr; size_t consumer_count = 0; ThrowOnError(ort_graph_api->OrtGraph_GetNodesConsumingInput(graph_, input_name, &consumers, &consumer_count)); @@ -129,64 +158,102 @@ inline std::vector Graph::GetNodesConsumingInput(const char return ret; } -inline Ort::PluginEP::Node Graph::GetNodeProducingOutput(const char* output_name) { +inline Ort::PluginEP::Node GraphViewer::GetNodeProducingOutput(const char* output_name) { const OrtNode* node = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetNodeProducingOutput(graph_, output_name, &node)); return Ort::PluginEP::Node(node); } -inline int Graph::NumberOfNodes() { +inline int GraphViewer::NumberOfNodes() { int num_nodes = 0; ThrowOnError(ort_graph_api->OrtGraph_NumberOfNodes(graph_, &num_nodes)); return num_nodes; } -inline int Graph::MaxNodeIndex() { +inline int GraphViewer::MaxNodeIndex() { int max_node_index = 0; ThrowOnError(ort_graph_api->OrtGraph_MaxNodeIndex(graph_, &max_node_index)); return max_node_index; } -inline size_t Graph::GetOutputSize() { +inline size_t GraphViewer::GetOutputSize() { size_t output_size = 0; ThrowOnError(ort_graph_api->OrtGraph_GetOutputSize(graph_, &output_size)); return output_size; } -inline std::string Graph::GetIthOutputName(size_t i) { +inline std::string GraphViewer::GetIthOutputName(size_t i) { const char* output_name = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputName(graph_, i, &output_name)); return std::string(output_name); } -inline int32_t Graph::GetIthOutputElemType(size_t i) { +inline int32_t GraphViewer::GetIthOutputElemType(size_t i) { int32_t elem_type = 0; ThrowOnError(ort_graph_api->OrtGraph_GetIthOutputElemType(graph_, i, &elem_type)); return elem_type; } -inline std::shared_ptr Graph::GetValueInfo(const char* name) { +inline std::shared_ptr GraphViewer::GetInitializerTensor(const char* initializer_name) { + OrtTensorRef* tensor = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_GetInitializerTensor(graph_, initializer_name, &tensor)); + return std::make_shared(tensor); +} + +inline std::shared_ptr GraphViewer::GetValueInfo(const char* name) { OrtValueInfoRef* value_info = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetValueInfo(graph_, name, &value_info)); return std::make_shared(value_info); } -// inline void Graph::DumpOnnxModel(const std::filesystem::path& onnx_model_path) { -// ThrowOnError(ort_graph_api->OrtGraph_DumpOnnxModel(graph_->GetGraph(), onnx_model_path.c_str())); -// } +inline std::pair GraphViewer::SerializeToArray() { + void* serialized_data = nullptr; + size_t serialized_data_len = 0; + ThrowOnError(ort_graph_api->OrtGraph_SerializeToArray(graph_, &serialized_data, &serialized_data_len)); + return std::make_pair(VoidPtr(serialized_data, [](void* ptr) { ort_graph_api->OrtFreeMem(ptr); }), serialized_data_len); +} + +inline GraphPtr GraphViewer::CreateOrUpdateEpCtxGraph(const char* node_name, + const int64_t main_context, + const int64_t embed_mode, + const char* cache_path, + char* cache_data, + size_t size, + const char* const* extra_attr_keys, + const char* const* extra_attr_values, + size_t extra_attr_num) { + OrtGraph* graph = nullptr; + ThrowOnError(ort_graph_api->OrtGraph_CreateOrUpdateEpCtxGraph(graph_, + node_name, + main_context, + embed_mode, + cache_path, + cache_data, + size, + extra_attr_keys, + extra_attr_values, + extra_attr_num, + &graph)); + auto release_fn = [](Graph* graph) { + ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraph(graph->GetGraph())); + }; + return std::unique_ptr(new Graph(graph), release_fn); +} -inline std::shared_ptr Graph::GetSubGraph(std::vector node_indices) { +inline GraphViewerPtr GraphViewer::GetSubGraph(std::vector node_indices) { const OrtGraphViewer* subgraph = nullptr; ThrowOnError(ort_graph_api->OrtGraph_GetSubGraph(graph_, node_indices.size(), node_indices.data(), &subgraph)); - // TODO:yang if should release subgraph in the decstructor of Graph? - return std::make_shared(subgraph); + auto release_fn = [](GraphViewer* graph) { + ThrowOnError(ort_graph_api->OrtGraph_ReleaseGraphViewer(graph->GetGraphViewer(), true)); + }; + return std::unique_ptr(new GraphViewer(subgraph), release_fn); } -// inline bool Graph::IsSameGraph(const Graph& other) { -// bool is_same = false; -// ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraph(), &is_same)); -// return is_same; -// } +inline bool GraphViewer::IsSameGraph(GraphViewer& other) { + bool is_same = false; + ThrowOnError(ort_graph_api->OrtGraph_IsSameGraph(graph_, other.GetGraphViewer(), &is_same)); + return is_same; +} inline Node::Node(const OrtNode* node) : node_(node) {} diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc index b52c1cae9046e..96dbd78438f56 100644 --- a/onnxruntime/core/session/onnxruntime_c_api_ep.cc +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -775,15 +775,25 @@ ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraph, const OrtGraph* ort_gra return nullptr; } -ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph) { +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model) { if (graph) { const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); - delete &(graph_viewer->GetGraph()).GetModel(); + if (release_model) { + delete &(graph_viewer->GetGraph()).GetModel(); + } delete graph_viewer; } return nullptr; } +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_viewers, size_t num_graphs) { + for (size_t i = 0; i < num_graphs; i++) { + OrtGraph_ReleaseGraphViewer(graph_viewers[i], false); + } + delete[] graph_viewers; + return nullptr; +} + ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, bool* is_same) { const ::onnxruntime::GraphViewer* graph_viewer1 = reinterpret_cast(graph1); const ::onnxruntime::GraphViewer* graph_viewer2 = reinterpret_cast(graph2); @@ -1021,6 +1031,7 @@ static constexpr OrtGraphApi ort_graph_api = { &OrtGraphApis::OrtGraph_GetSubGraph, &OrtGraphApis::OrtGraph_ReleaseGraph, &OrtGraphApis::OrtGraph_ReleaseGraphViewer, + &OrtGraphApis::OrtGraph_ReleaseGraphViewerArray, &OrtGraphApis::OrtGraph_IsSameGraph, &OrtGraphApis::OrtNode_GetName, &OrtGraphApis::OrtNode_GetDescription, diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h index dd6ae746db564..33c940242cc61 100644 --- a/onnxruntime/core/session/ort_apis_ep.h +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -70,7 +70,9 @@ ORT_API_STATUS_IMPL(OrtGraph_GetSubGraph, const OrtGraphViewer* graph, const int ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraph, const OrtGraph* graph); -ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph); +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewer, const OrtGraphViewer* graph, bool release_model); + +ORT_API_STATUS_IMPL(OrtGraph_ReleaseGraphViewerArray, const OrtGraphViewer** graph_viewers, size_t num_graphs); ORT_API_STATUS_IMPL(OrtGraph_IsSameGraph, const OrtGraphViewer* graph1, const OrtGraphViewer* graph2, _Out_ bool* is_same); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 76bcd5350b0d1..365ee37f22fcd 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1399,6 +1399,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); if (!all_subgraphs_are_supported) { // if not all its subgraphs are supported, we need to exclude this control flow op continue; @@ -1499,6 +1500,7 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const break; } + graph_api_->OrtGraph_ReleaseGraphViewerArray(subgraphs, subgraph_count); } if (all_subgraphs_are_supported) { @@ -3750,7 +3752,7 @@ SubGraphCollection_t TensorrtExecutionProvider::GetSupportedList(SubGraphCollect } nodes_list_output.push_back(next_nodes_list[i]); } - graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer); + graph_api_->OrtGraph_ReleaseGraphViewer(sub_graph_viewer, true); } } }