Skip to content

Commit

Permalink
Move up members in Lite Custom Op hierarchy for possible memleaks. (#…
Browse files Browse the repository at this point in the history
…18478)

Move data member in LiteOpFunc to its parent to avoid possible mem
leaks.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Nov 18, 2023
1 parent 9364c05 commit 53917a3
Showing 1 changed file with 30 additions and 17 deletions.
47 changes: 30 additions & 17 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase {

using Variadic = TensorArray;

/*
Note:
OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.
2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
hence memory could still be recycled properly.
Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
*/
struct OrtLiteCustomOp : public OrtCustomOp {
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
Expand Down Expand Up @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {

OrtLiteCustomOp(const char* op_name,
const char* execution_provider,
int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
start_ver_(start_ver),
end_ver_(end_ver) {
ShapeInferFn shape_infer_fn,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
shape_infer_fn_(shape_infer_fn),
start_ver_(start_ver),
end_ver_(end_ver) {
OrtCustomOp::version = ORT_API_VERSION;

OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
Expand Down Expand Up @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp {
std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;

ShapeInferFn shape_infer_fn_ = {};

int start_ver_ = 1;
int end_ver_ = MAX_CUSTOM_OP_END_VER;

void* compute_fn_ = {};
void* compute_fn_return_status_ = {};
};

//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
Expand Down Expand Up @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
ComputeFn compute_fn,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_(compute_fn),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
compute_fn_ = reinterpret_cast<void*>(compute_fn);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
Expand All @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_ = static_cast<const MyType*>(this_)->compute_fn_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
ComputeFnReturnStatus compute_fn_return_status,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_return_status_(compute_fn_return_status),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {
compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
Expand All @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_return_status_ = static_cast<const MyType*>(this_)->compute_fn_return_status_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
};
}
}

ComputeFn compute_fn_ = {};
ComputeFnReturnStatus compute_fn_return_status_ = {};
ShapeInferFn shape_infer_fn_ = {};
}; // struct OrtLiteCustomFunc

/////////////////////////// OrtLiteCustomStruct ///////////////////////////
Expand Down Expand Up @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
OrtLiteCustomStruct(const char* op_name,
const char* execution_provider,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {
SetCompute(&CustomOp::Compute);

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
Expand Down

0 comments on commit 53917a3

Please sign in to comment.