From 224a2fe4a9c22b0c7f5a368b8230f59226a73e35 Mon Sep 17 00:00:00 2001 From: cao lei Date: Fri, 29 Mar 2024 13:49:56 -0700 Subject: [PATCH] add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp (#20145) ### Description Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp ### Motivation and Context Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp --- .../core/session/onnxruntime_c_api.h | 4 +++ .../core/session/onnxruntime_cxx_api.h | 2 ++ .../core/session/onnxruntime_lite_custom_op.h | 2 ++ onnxruntime/core/session/custom_ops.cc | 10 ++++++++ onnxruntime/test/shared_lib/test_inference.cc | 25 +++++++++++++++++++ 5 files changed, 43 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index c4fb0d3a83a67..e40c375cab119 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4738,6 +4738,10 @@ struct OrtCustomOp { // Release the pointer input_index and output_index allocated from GetMayInplace() function. // If GetMayInplace() is defined, this function MUST be defined as well. void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); + + // Same as GetMayInplace() and ReleaseMayInplace() + size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index); + void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index); }; /* diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 5f2e0a470a133..fd0e3490426a7 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2304,6 +2304,8 @@ struct CustomOpBase : OrtCustomOp { OrtCustomOp::GetMayInplace = nullptr; OrtCustomOp::ReleaseMayInplace = nullptr; + OrtCustomOp::GetAliasMap = nullptr; + OrtCustomOp::ReleaseAliasMap = nullptr; } // Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider diff --git a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h index 896893e986e05..ee60f25da115e 100644 --- a/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h +++ b/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h @@ -865,6 +865,8 @@ struct OrtLiteCustomOp : public OrtCustomOp { OrtCustomOp::GetMayInplace = {}; OrtCustomOp::ReleaseMayInplace = {}; + OrtCustomOp::GetAliasMap = {}; + OrtCustomOp::ReleaseAliasMap = {}; } const std::string op_name_; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index cc23d0822c36e..d0c46142ac060 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -875,6 +875,16 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust } } + if (op->version >= 18 && op->GetAliasMap != nullptr) { + int* input_index = nullptr; + int* output_index = nullptr; + size_t len = op->GetAliasMap(&input_index, &output_index); + if (len > 0) { + for (size_t i = 0; i < len; i++) def_builder.Alias(input_index[i], output_index[i]); + op->ReleaseAliasMap(input_index, output_index); + } + } + KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info, std::unique_ptr& out) -> Status { out = std::make_unique(info, *op); diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 5dd5fabd26fb4..a7ce8127a7f50 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4039,6 +4039,20 @@ struct MockGQA : public OrtCustomOp { free(input_index); free(output_index); }; + OrtCustomOp::GetAliasMap = [](int** input_index, int** output_index) { + size_t ret = 2; + *input_index = static_cast(malloc(ret * sizeof(int))); + (*input_index)[0] = 5; + (*input_index)[1] = 6; + *output_index = static_cast(malloc(ret * sizeof(int))); + (*output_index)[0] = 7; + (*output_index)[1] = 8; + return ret; + }; + OrtCustomOp::ReleaseAliasMap = [](int* input_index, int* output_index) { + free(input_index); + free(output_index); + }; } }; @@ -4055,4 +4069,15 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) { ASSERT_EQ(output_index[1], 2); ASSERT_EQ(len, static_cast(2)); mock_gqa.ReleaseMayInplace(input_index, output_index); + + input_index = output_index = nullptr; + len = mock_gqa.GetAliasMap(&input_index, &output_index); + ASSERT_NE(input_index, nullptr); + ASSERT_NE(output_index, nullptr); + ASSERT_EQ(input_index[0], 5); + ASSERT_EQ(input_index[1], 6); + ASSERT_EQ(output_index[0], 7); + ASSERT_EQ(output_index[1], 8); + ASSERT_EQ(len, static_cast(2)); + mock_gqa.ReleaseAliasMap(input_index, output_index); }