Skip to content

Commit

Permalink
sample code to separate graph C API to different files
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Sep 26, 2024
1 parent 1d7b2df commit a407944
Show file tree
Hide file tree
Showing 9 changed files with 58 additions and 5 deletions.
5 changes: 5 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,9 @@ typedef struct OrtApi OrtApi;
struct OrtTrainingApi;
typedef struct OrtTrainingApi OrtTrainingApi;

struct OrtGraphApi;
typedef struct OrtGraphApi OrtGraphApi;

/** \brief The helper interface to get the right version of OrtApi
*
* Get a pointer to this structure through ::OrtGetApiBase
Expand Down Expand Up @@ -4876,6 +4879,8 @@ struct OrtApi {
ORT_API2_STATUS(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type);

ORT_CLASS_RELEASE(TypeConstraints);

const OrtGraphApi*(ORT_API_CALL* GetGraphApi)(uint32_t version)NO_EXCEPTION;
}; // struct OrtApi

/*
Expand Down
7 changes: 7 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api_ep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#pragma once

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.
#include "onnxruntime_c_api.h"

struct OrtGraphApi {
ORT_API2_STATUS(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out);
};
typedef struct OrtGraphApi OrtGraphApi;
8 changes: 8 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#include "core/framework/provider_factory_adapter.h"
#include "core/framework/kernel_registry.h"
#include "core/framework/ort_type_constraints.h"
#include "onnxruntime_c_api_ep.h"
#include "ort_apis_ep.h"

#ifdef USE_CUDA
#include "core/providers/cuda/cuda_provider_factory.h"
Expand Down Expand Up @@ -2906,6 +2908,11 @@ ORT_API(void, OrtApis::ReleaseTypeConstraints, OrtTypeConstraints* type_constrai
delete type_constraints;
}

ORT_API(const OrtGraphApi*, OrtApis::GetGraphApi, uint32_t version) {
//if (version >= xx && version <= ORT_API_VERSION)
return OrtGraphApis::GetGraphApi(version);
}

static constexpr OrtApiBase ort_api_base = {
&OrtApis::GetApi,
&OrtApis::GetVersionString};
Expand Down Expand Up @@ -3345,6 +3352,7 @@ static constexpr OrtApi ort_api_1_to_19 = {
&OrtApis::CreateOrtTypeConstraints,
&OrtApis::AddTypeConstraint,
&OrtApis::ReleaseTypeConstraints,
&OrtApis::GetGraphApi,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
18 changes: 18 additions & 0 deletions onnxruntime/core/session/onnxruntime_c_api_ep.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include "core/session/onnxruntime_c_api_ep.h"
#include "ort_apis_ep.h"
#include "core/graph/graph_viewer.h"

ORT_API_STATUS_IMPL(OrtGraphApis::OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out) {
const ::onnxruntime::GraphViewer* graph_viewer = reinterpret_cast<const ::onnxruntime::GraphViewer*>(graph);
*out = graph_viewer->NumberOfNodes();
return nullptr;
}

static constexpr OrtGraphApi ort_graph_api = {
&OrtGraphApis::OrtGraph_PlaceHolder,
};

ORT_API(const OrtGraphApi*, OrtGraphApis::GetGraphApi, uint32_t) {
// No constraints on the API version yet.
return &ort_graph_api;
}
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,4 +644,6 @@ ORT_API_STATUS_IMPL(CreateOrtTypeConstraints, _Outptr_ OrtTypeConstraints** type
ORT_API_STATUS_IMPL(AddTypeConstraint, _In_ OrtTypeConstraints* type_constraints, _In_ const char* type_symbol, ONNXTensorElementDataType type);

ORT_API(void, ReleaseTypeConstraints, _In_ OrtTypeConstraints* type_constraints);

ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version);
} // namespace OrtApis
6 changes: 6 additions & 0 deletions onnxruntime/core/session/ort_apis_ep.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
#pragma once

Check warning

Code scanning / lintrunner

CLANGFORMAT/format Warning

See https://clang.llvm.org/docs/ClangFormat.html.
Run lintrunner -a to apply this patch.

namespace OrtGraphApis {
ORT_API(const OrtGraphApi*, GetGraphApi, uint32_t version);
ORT_API_STATUS_IMPL(OrtGraph_PlaceHolder, const OrtGraphViewer* graph, _Out_ int* out);
}
12 changes: 8 additions & 4 deletions samples/c_test/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,10 @@ void RunFastRcnn(const OrtApi* g_ort, OrtEnv* p_env, OrtSessionOptions* so) {
// for (size_t i = 0; i < 4; i++) std::cout<<output_tensor_data[i]<<" \n";
}

void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) {
void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so, const char* model) {
OrtSession* session = nullptr;
THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session));
if (!strcmp(model, "tyolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/tinyyolov3/yolov3-tiny.onnx", so, &session));
else if (!strcmp(model, "yolo")) THROW_ON_ERROR(g_ort->CreateSession(p_env, "/home/leca/models/yolov3/yolov3.onnx", so, &session));

OrtMemoryInfo* memory_info = nullptr;
THROW_ON_ERROR(g_ort->CreateCpuMemoryInfo(OrtArenaAllocator, OrtMemTypeDefault, &memory_info));
Expand All @@ -212,6 +213,9 @@ void RunTinyYolov3(OrtEnv* p_env, OrtSessionOptions* so) {
THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input_data, input_len, input_shape, sizeof(input_shape)/sizeof(input_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[0]));

float input2[2] = {375, 500};
if (!strcmp(model, "yolo")) {
input2[0] = 506, input2[1] = 640;
}
const size_t input2_len = 8; // 2 * sizeof(float)
const int64_t input2_shape[] = {1, 2};
THROW_ON_ERROR(g_ort->CreateTensorWithDataAsOrtValue(memory_info, input2, input2_len, input2_shape, sizeof(input2_shape)/sizeof(input2_shape[0]), ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_tensors[1]));
Expand Down Expand Up @@ -255,8 +259,8 @@ int main(int argc, char *argv[]) {
RunResnet18v1_7(g_ort, p_env, so);
} else if (!strcmp(argv[2], "rcnn")) {
RunFastRcnn(g_ort, p_env, so);
} else if (!strcmp(argv[2], "tyolo")) {
RunTinyYolov3(p_env, so);
} else if (!strcmp(argv[2], "tyolo") || !strcmp(argv[2], "yolo")) {
RunTinyYolov3(p_env, so, argv[2]);
}

g_ort->ReleaseEnv(p_env);
Expand Down
3 changes: 3 additions & 0 deletions samples/tensorRTEp/tensorrt_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1280,6 +1280,9 @@ TensorrtExecutionProvider::TensorrtExecutionProvider(const char* ep_type, const
OrtExecutionProvider::GetCapability = [](const OrtExecutionProvider* this_, const OrtGraphViewer* graph, size_t* cnt, OrtIndexedSubGraph*** indexed_sub_graph) {
const OrtApi* api = OrtGetApiBase()->GetApi(ORT_API_VERSION);
const TensorrtExecutionProvider* p = static_cast<const TensorrtExecutionProvider*>(this_);
const OrtGraphApi* g_ort_graph_api = api->GetGraphApi(ORT_API_VERSION);
int num_nodes = 0;
g_ort_graph_api->OrtGraph_PlaceHolder(graph, &num_nodes);
// Get ModelPath
const std::filesystem::path* model_path = nullptr;
api->OrtGraph_GetModelPath(graph, (const void**)&model_path);
Expand Down
2 changes: 1 addition & 1 deletion samples/tensorRTEp/tensorrt_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#include <ctime>
#include <string>
#include <unordered_set>
#include "core/session/onnxruntime_c_api.h"
#include "core/session/onnxruntime_c_api_ep.h"
#include "core/framework/provider_options.h"
#include "tensorrt_execution_provider_info.h"
#include "nv_includes.h"
Expand Down

0 comments on commit a407944

Please sign in to comment.