From a4079443a963603e1feb22c12fda309aeb3279f0 Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 26 Sep 2024 23:42:11 +0000 Subject: [PATCH] sample code to separate graph C API to different files --- .../core/session/onnxruntime_c_api.h | 5 +++++ .../core/session/onnxruntime_c_api_ep.h | 7 +++++++ onnxruntime/core/session/onnxruntime_c_api.cc | 8 ++++++++ .../core/session/onnxruntime_c_api_ep.cc | 18 ++++++++++++++++++ onnxruntime/core/session/ort_apis.h | 2 ++ onnxruntime/core/session/ort_apis_ep.h | 6 ++++++ samples/c_test/test.cpp | 12 ++++++++---- .../tensorRTEp/tensorrt_execution_provider.cc | 3 +++ .../tensorRTEp/tensorrt_execution_provider.h | 2 +- 9 files changed, 58 insertions(+), 5 deletions(-) create mode 100644 include/onnxruntime/core/session/onnxruntime_c_api_ep.h create mode 100644 onnxruntime/core/session/onnxruntime_c_api_ep.cc create mode 100644 onnxruntime/core/session/ort_apis_ep.h diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index e6c0ea485d38f..b504dddcf62cf 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -673,6 +673,9 @@ typedef struct OrtApi OrtApi; struct OrtTrainingApi; typedef struct OrtTrainingApi OrtTrainingApi; +struct OrtGraphApi; +typedef struct OrtGraphApi OrtGraphApi; + /** \brief The helper interface to get the right version of OrtApi * * Get a pointer to this structure through ::OrtGetApiBase @@ -4876,6 +4879,8 @@ struct OrtApi { ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_CLASS_RELEASE(TypeConstraints); + + const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION; }; // struct OrtApi /* diff --git a/include/onnxruntime/core/session/onnxruntime_c_api_ep.h b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h new file mode 100644 index 0000000000000..20a4860cab163 --- /dev/null +++ b/include/onnxruntime/core/session/onnxruntime_c_api_ep.h @@ -0,0 +1,7 @@ +#pragma once +#include "onnxruntime_c_api.h" + +struct OrtGraphApi { +ORT_API2_STATUS(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); +}; +typedef struct OrtGraphApi OrtGraphApi; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 9b2d50718677c..7836c8024e76e 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -44,6 +44,8 @@ #include "core/framework/provider_factory_adapter.h" #include "core/framework/kernel_registry.h" #include "core/framework/ort_type_constraints.h" +#include "onnxruntime_c_api_ep.h" +#include "ort_apis_ep.h" #ifdef USE_CUDA #include "core/providers/cuda/cuda_provider_factory.h" @@ -2906,6 +2908,11 @@ ORT_API(void, OrtApis::ReleaseTypeConstraints, OrtTypeConstraints* type_constrai delete type_constraints; } +ORT_API(const OrtGraphApi*, OrtApis::GetGraphApi, uint32_t version) { + //if (version >= xx && version <= ORT_API_VERSION) + return OrtGraphApis::GetGraphApi(version); +} + static constexpr OrtApiBase ort_api_base = { &OrtApis::GetApi, &OrtApis::GetVersionString}; @@ -3345,6 +3352,7 @@ static constexpr OrtApi ort_api_1_to_19 = { &OrtApis::CreateOrtTypeConstraints, &OrtApis::AddTypeConstraint, &OrtApis::ReleaseTypeConstraints, + &OrtApis::GetGraphApi, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/onnxruntime_c_api_ep.cc b/onnxruntime/core/session/onnxruntime_c_api_ep.cc new file mode 100644 index 0000000000000..5c89348188faf --- /dev/null +++ b/onnxruntime/core/session/onnxruntime_c_api_ep.cc @@ -0,0 +1,18 @@ +#include "core/session/onnxruntime_c_api_ep.h" +#include "ort_apis_ep.h" +#include "core/graph/graph_viewer.h" + +ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out) { + const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast(graph); + *out = graph_viewer->NumberOfNodes(); + return nullptr; +} + +static constexpr OrtGraphApi ort_graph_api = { + &OrtGraphApis::OrtGraph_PlaceHolder, +}; + +ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi, uint32_t) { + // No constraints on the API version yet. + return &ort_graph_api; +} diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 0b16aae0fa253..bf149d4daca4d 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -644,4 +644,6 @@ ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type); ORT_API(void, ReleaseTypeConstraints, _In_ OrtTypeConstraints* type_constraints); + +ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); } // namespace OrtApis diff --git a/onnxruntime/core/session/ort_apis_ep.h b/onnxruntime/core/session/ort_apis_ep.h new file mode 100644 index 0000000000000..ef0af223504f5 --- /dev/null +++ b/onnxruntime/core/session/ort_apis_ep.h @@ -0,0 +1,6 @@ +#pragma once + +namespace OrtGraphApis { +ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version); +ORT_API_STATUS_IMPL(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out); +} diff --git a/samples/c_test/test.cpp b/samples/c_test/test.cpp index c89e209b56d09..362c329039581 100644 --- a/samples/c_test/test.cpp +++ b/samples/c_test/test.cpp @@ -196,9 +196,10 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) { // for (size_t i = 0; i < 4; i++) std::cout<CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + if (!strcmp(model, "tyolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session)); + else if (!strcmp(model, "yolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/yolov3/yolov3.onnx", so, &session)); OrtMemoryInfo* memory_info = nullptr; THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info)); @@ -212,6 +213,9 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) { THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[0])); float input2[2] = {375, 500}; + if (!strcmp(model, "yolo")) { + input2[0] = 506, input2[1] = 640; + } const size_t input2_len = 8; // 2 * sizeof(float) const int64_t input2_shape[] = {1, 2}; THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1])); @@ -255,8 +259,8 @@ int main(int argc, char *argv[]) { RunResnet18v1_7(g_ort, p_env, so); } else if (!strcmp(argv[2], "rcnn")) { RunFastRcnn(g_ort, p_env, so); - } else if (!strcmp(argv[2], "tyolo")) { - RunTinyYolov3(p_env, so); + } else if (!strcmp(argv[2], "tyolo") || !strcmp(argv[2], "yolo")) { + RunTinyYolov3(p_env, so, argv[2]); } g_ort->ReleaseEnv(p_env); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.cc b/samples/tensorRTEp/tensorrt_execution_provider.cc index 28b15b3f73514..9b676a0aee76b 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.cc +++ b/samples/tensorRTEp/tensorrt_execution_provider.cc @@ -1280,6 +1280,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) { const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION); const TensorrtExecutionProvider* p = static_cast(this_); + const OrtGraphApi* g_ort_graph_api = api->GetGraphApi(ORT_API_VERSION); + int num_nodes = 0; + g_ort_graph_api->OrtGraph_PlaceHolder(graph, &num_nodes); // Get ModelPath const std::filesystem::path* model_path = nullptr; api->OrtGraph_GetModelPath(graph, (const void**)&model_path); diff --git a/samples/tensorRTEp/tensorrt_execution_provider.h b/samples/tensorRTEp/tensorrt_execution_provider.h index f35c8d4316e84..56feecea84e4e 100644 --- a/samples/tensorRTEp/tensorrt_execution_provider.h +++ b/samples/tensorRTEp/tensorrt_execution_provider.h @@ -2,7 +2,7 @@ #include #include #include -#include "core/session/onnxruntime_c_api.h" +#include "core/session/onnxruntime_c_api_ep.h" #include "core/framework/provider_options.h" #include "tensorrt_execution_provider_info.h" #include "nv_includes.h"