Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move up members in Lite Custom Op hierarchy for possible memleaks. #18478

Merged
merged 4 commits into from
Nov 18, 2023
Merged
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,15 @@

using Variadic = TensorArray;

/*
Note:
OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core.
The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so:
1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy.

Check notice on line 406 in include/onnxruntime/core/session/onnxruntime_lite_custom_op.h

View workflow job for this annotation

GitHub Actions / misspell

[misspell] include/onnxruntime/core/session/onnxruntime_lite_custom_op.h#L406

"hierachy" is a misspelling of "hierarchy"
Raw output
./include/onnxruntime/core/session/onnxruntime_lite_custom_op.h:406:102: "hierachy" is a misspelling of "hierarchy"
2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp,
memory could still be recycled properly.
Finally, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety.
*/
struct OrtLiteCustomOp : public OrtCustomOp {
using ConstOptionalFloatTensor = std::optional<const Custom::Tensor<float>&>;
using OptionalFloatTensor = std::optional<Custom::Tensor<float>>;
Expand Down Expand Up @@ -774,10 +783,13 @@

OrtLiteCustomOp(const char* op_name,
const char* execution_provider,
int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
start_ver_(start_ver),
end_ver_(end_ver) {
ShapeInferFn shape_infer_fn,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
shape_infer_fn_(shape_infer_fn),
start_ver_(start_ver),
end_ver_(end_ver) {
OrtCustomOp::version = ORT_API_VERSION;

OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
Expand Down Expand Up @@ -858,8 +870,13 @@
std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;

ShapeInferFn shape_infer_fn_ = {};

int start_ver_ = 1;
int end_ver_ = MAX_CUSTOM_OP_END_VER;

void* compute_fn_ = {};
void* compute_fn_return_status_ = {};
};

//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
Expand Down Expand Up @@ -891,9 +908,8 @@
ComputeFn compute_fn,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_(compute_fn),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {

Check warning on line 911 in include/onnxruntime/core/session/onnxruntime_lite_custom_op.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_lite_custom_op.h#L911

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_lite_custom_op.h:911:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
compute_fn_ = reinterpret_cast<void*>(compute_fn);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
Expand All @@ -905,7 +921,8 @@

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_ = static_cast<const MyType*>(this_)->compute_fn_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_ = reinterpret_cast<ComputeFn>(me->compute_fn_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -931,9 +948,8 @@
ComputeFnReturnStatus compute_fn_return_status,
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_return_status_(compute_fn_return_status),
shape_infer_fn_(shape_infer_fn) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) {

Check warning on line 951 in include/onnxruntime/core/session/onnxruntime_lite_custom_op.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_lite_custom_op.h#L951

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_lite_custom_op.h:951:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
compute_fn_return_status_ = reinterpret_cast<void*>(compute_fn_return_status);
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
Expand All @@ -945,7 +961,8 @@

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
auto kernel = std::make_unique<Kernel>();
kernel->compute_fn_return_status_ = static_cast<const MyType*>(this_)->compute_fn_return_status_;
auto me = static_cast<const MyType*>(this_);
kernel->compute_fn_return_status_ = reinterpret_cast<ComputeFnReturnStatus>(me->compute_fn_return_status_);
Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_));
Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_));
auto self = static_cast<const OrtLiteCustomFunc*>(this_);
Expand All @@ -965,10 +982,6 @@
};
}
}

ComputeFn compute_fn_ = {};
ComputeFnReturnStatus compute_fn_return_status_ = {};
ShapeInferFn shape_infer_fn_ = {};
}; // struct OrtLiteCustomFunc

/////////////////////////// OrtLiteCustomStruct ///////////////////////////
Expand Down Expand Up @@ -1007,7 +1020,7 @@
OrtLiteCustomStruct(const char* op_name,
const char* execution_provider,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) {

Check warning on line 1023 in include/onnxruntime/core/session/onnxruntime_lite_custom_op.h

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] include/onnxruntime/core/session/onnxruntime_lite_custom_op.h#L1023

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
include/onnxruntime/core/session/onnxruntime_lite_custom_op.h:1023:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
SetCompute(&CustomOp::Compute);

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
Expand Down
Loading