Skip to content

Commit

Permalink
Fix a memleak in RunAsync python (#17326)
Browse files Browse the repository at this point in the history
Release ort value outputs that are created and released from
ort::run(...).

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Aug 30, 2023
1 parent 081c069 commit 6c39641
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 6 deletions.
8 changes: 6 additions & 2 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4333,8 +4333,12 @@ struct OrtApi {
* \param[in] input_len Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[in] output_names_len Number of elements in the output_names and outputs array
* \param[out] output Array of OrtValue* owned by customers, size to output_names_len. It could simply be an array of nullptr
* The array will be passed back to run_async_callback
* \param[out] output OrtValue* array of size output_names_len.
* On calling RunAsync, output[i] could either be a null or a pointer to a preallocated OrtValue.
* Later, the output array will be passed to run_async_callback with all null(s) filled with valid
* OrtValue pointer(s) allocated by onnxruntime.
* NOTE: it is customer's duty to finally release the output array and each of its member,
* regardless of whether the member (OrtValue*) is allocated by onnxruntime or preallocated by the customer.
* \param[in] run_async_callback Callback function on model run completion
* \param[in] user_data User data that pass back to run_async_callback
*/
Expand Down
10 changes: 7 additions & 3 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1073,11 +1073,15 @@ struct SessionImpl : ConstSessionImpl<T> {
*
* \param[in] run_options
* \param[in] input_names Array of null terminated UTF8 encoded strings of the input names
* \param[in] input_values Array of ::OrtValue%s of the input values
* \param[in] input_values Array of Value objects of length input_count
* \param[in] input_count Number of elements in the input_names and inputs arrays
* \param[in] output_names Array of null terminated UTF8 encoded strings of the output names
* \param[out] output_values Array of ::OrtValue%s owned by customers, size to output_count. It could simply be an array of nullptr
* The array will be passed back to the callback
* \param[out] output_values Array of provided Values to be filled with outputs.
* On calling RunAsync, output_values[i] could either be initialized by a null pointer or a preallocated OrtValue*.
* Later, on invoking the callback, each output_values[i] of null will be filled with an OrtValue* allocated by onnxruntime.
* Then, an OrtValue** pointer will be casted from output_values, and pass to the callback.
* NOTE: it is customer's duty to finally release output_values and each of its member,
* regardless of whether the member (Ort::Value) is allocated by onnxruntime or preallocated by the customer.
* \param[in] output_count Number of elements in the output_names and outputs array
* \param[in] callback Callback function on model run completion
* \param[in] user_data User data that pass back to the callback
Expand Down
12 changes: 11 additions & 1 deletion onnxruntime/python/onnxruntime_pybind_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace onnxruntime {
#endif // _MSC_VER

#include <iterator>
#include <algorithm>

#if defined(_MSC_VER)
#pragma warning(disable : 4267 4996 4503 4003)
Expand Down Expand Up @@ -85,7 +86,7 @@ struct AsyncResource {
std::vector<std::string> feed_names;
std::vector<const char*> feed_names_raw;

std::vector<OrtValue*> fetches_raw;
std::vector<OrtValue*> fetches_raw; // will be released during destruction

std::vector<std::string> fetch_names;
std::vector<const char*> fetch_names_raw;
Expand All @@ -106,6 +107,15 @@ struct AsyncResource {
fetch_names.reserve(sz);
fetch_names_raw.reserve(sz);
}

~AsyncResource() {
std::for_each(fetches_raw.begin(), fetches_raw.end(), [](const OrtValue* fetch) {
if (fetch) {
std::unique_ptr<const OrtValue> fetch_recycler(fetch);
}
});
fetches_raw.clear();
}
};

void AsyncCallback(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr ort_status) {
Expand Down

0 comments on commit 6c39641

Please sign in to comment.