From 6c39641ea2248f616f1da306fe7b2d4798b321a8 Mon Sep 17 00:00:00 2001 From: RandySheriffH <48490400+RandySheriffH@users.noreply.github.com> Date: Wed, 30 Aug 2023 12:54:17 -0700 Subject: [PATCH] Fix a memleak in RunAsync python (#17326) Release ort value outputs that are created and released from ort::run(...). --------- Co-authored-by: Randy Shuai --- include/onnxruntime/core/session/onnxruntime_c_api.h | 8 ++++++-- .../onnxruntime/core/session/onnxruntime_cxx_api.h | 10 +++++++--- onnxruntime/python/onnxruntime_pybind_state.cc | 12 +++++++++++- 3 files changed, 24 insertions(+), 6 deletions(-) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index bc7792ba4366b..456a11603de65 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4333,8 +4333,12 @@ struct OrtApi { * \param[in] input_len Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names * \param[in] output_names_len Number of elements in the output_names and outputs array - * \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr - * The array will be passed back to run_async_callback + * \param[out] output OrtValue* array of size output_names_len. + * On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue. + * Later, the output array will be passed to run_async_callback with all null(s) filled with valid + * OrtValue pointer(s) allocated by onnxruntime. + * NOTE: it is customer's duty to finally release the output array and each of its member, + * regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer. * \param[in] run_async_callback Callback function on model run completion * \param[in] user_data User data that pass back to run_async_callback */ diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index b9b6676c0072d..47356c3fe3608 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -1073,11 +1073,15 @@ struct SessionImpl : ConstSessionImpl { * * \param[in] run_options * \param[in] input_names Array of null terminated UTF8 encoded strings of the input names - * \param[in] input_values Array of ::OrtValue%s of the input values + * \param[in] input_values Array of Value objects of length input_count * \param[in] input_count Number of elements in the input_names and inputs arrays * \param[in] output_names Array of null terminated UTF8 encoded strings of the output names - * \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr - * The array will be passed back to the callback + * \param[out] output_values Array of provided Values to be filled with outputs. + * On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*. + * Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime. + * Then, an OrtValue** pointer will be casted from output_values, and pass to the callback. + * NOTE: it is customer's duty to finally release output_values and each of its member, + * regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer. * \param[in] output_count Number of elements in the output_names and outputs array * \param[in] callback Callback function on model run completion * \param[in] user_data User data that pass back to the callback diff --git a/onnxruntime/python/onnxruntime_pybind_state.cc b/onnxruntime/python/onnxruntime_pybind_state.cc index 5ac20739c486e..82d119894a5d8 100644 --- a/onnxruntime/python/onnxruntime_pybind_state.cc +++ b/onnxruntime/python/onnxruntime_pybind_state.cc @@ -53,6 +53,7 @@ namespace onnxruntime { #endif // _MSC_VER #include +#include #if defined(_MSC_VER) #pragma warning(disable : 4267 4996 4503 4003) @@ -85,7 +86,7 @@ struct AsyncResource { std::vector feed_names; std::vector feed_names_raw; - std::vector fetches_raw; + std::vector fetches_raw; // will be released during destruction std::vector fetch_names; std::vector fetch_names_raw; @@ -106,6 +107,15 @@ struct AsyncResource { fetch_names.reserve(sz); fetch_names_raw.reserve(sz); } + + ~AsyncResource() { + std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) { + if (fetch) { + std::unique_ptr fetch_recycler(fetch); + } + }); + fetches_raw.clear(); + } }; void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {