From 53917a33536ab8873264c55c9cac8d91d5a8d040 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Sat, 18 Nov 2023 15:00:54 -0800 Subject: [PATCH] Move up members in Lite Custom Op hierarchy for possible memleaks. (#18478) Move data member in LiteOpFunc to its parent to avoid possible mem leaks. --------- Co-authored-by: Randy Shuai --- .../core/session/onnxruntime_lite_custom_op.h | 47 ++++++++++++------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 443710884743a..0c0af16d4e20c 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -399,6 +399,15 @@ struct TensorArray : public ArgBase { using Variadic = TensorArray; +/* +Note: +OrtLiteCustomOp inherits from OrtCustomOp to bridge tween a custom func/struct and ort core. +The lifetime of an OrtLiteCustomOp instance is managed by customer code, not ort, so: +1. DO NOT cast OrtLiteCustomOp to OrtCustomOp and release since there is no virtual destructor in the hierachy. +2. OrtLiteCustomFunc and OrtLiteCustomStruct, as two sub-structs, can be released in form of OrtLiteCustomOp since all members are kept in the OrtLiteCustomOp, + hence memory could still be recycled properly. +Further, OrtCustomOp is a c struct bearing no v-table, so offspring structs are by design to be of zero virtual functions to maintain cast safety. +*/ struct OrtLiteCustomOp : public OrtCustomOp { using ConstOptionalFloatTensor = std::optional&>; using OptionalFloatTensor = std::optional>; @@ -774,10 +783,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtLiteCustomOp(const char* op_name, const char* execution_provider, - int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), - execution_provider_(execution_provider), - start_ver_(start_ver), - end_ver_(end_ver) { + ShapeInferFn shape_infer_fn, + int start_ver = 1, + int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name), + execution_provider_(execution_provider), + shape_infer_fn_(shape_infer_fn), + start_ver_(start_ver), + end_ver_(end_ver) { OrtCustomOp::version = ORT_API_VERSION; OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast(op)->op_name_.c_str(); }; @@ -858,8 +870,13 @@ struct OrtLiteCustomOp : public OrtCustomOp { std::vector input_types_; std::vector output_types_; + ShapeInferFn shape_infer_fn_ = {}; + int start_ver_ = 1; int end_ver_ = MAX_CUSTOM_OP_END_VER; + + void* compute_fn_ = {}; + void* compute_fn_return_status_ = {}; }; //////////////////////////// OrtLiteCustomFunc //////////////////////////////// @@ -891,9 +908,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFn compute_fn, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_(compute_fn), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_ = reinterpret_cast(compute_fn); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) { @@ -905,7 +921,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_ = static_cast(this_)->compute_fn_; + auto me = static_cast(this_); + kernel->compute_fn_ = reinterpret_cast(me->compute_fn_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -931,9 +948,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { ComputeFnReturnStatus compute_fn_return_status, ShapeInferFn shape_infer_fn = {}, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver), - compute_fn_return_status_(compute_fn_return_status), - shape_infer_fn_(shape_infer_fn) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, shape_infer_fn, start_ver, end_ver) { + compute_fn_return_status_ = reinterpret_cast(compute_fn_return_status); ParseArgs(input_types_, output_types_); OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr { @@ -945,7 +961,8 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) { auto kernel = std::make_unique(); - kernel->compute_fn_return_status_ = static_cast(this_)->compute_fn_return_status_; + auto me = static_cast(this_); + kernel->compute_fn_return_status_ = reinterpret_cast(me->compute_fn_return_status_); Ort::ThrowOnError(ort_api->KernelInfo_GetInputCount(info, &kernel->num_input_)); Ort::ThrowOnError(ort_api->KernelInfo_GetOutputCount(info, &kernel->num_output_)); auto self = static_cast(this_); @@ -965,10 +982,6 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp { }; } } - - ComputeFn compute_fn_ = {}; - ComputeFnReturnStatus compute_fn_return_status_ = {}; - ShapeInferFn shape_infer_fn_ = {}; }; // struct OrtLiteCustomFunc /////////////////////////// OrtLiteCustomStruct /////////////////////////// @@ -1007,7 +1020,7 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp { OrtLiteCustomStruct(const char* op_name, const char* execution_provider, int start_ver = 1, - int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) { + int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, {}, start_ver, end_ver) { SetCompute(&CustomOp::Compute); OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {