diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 2e2a903da27cb..cef50163f68b0 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -4600,6 +4600,16 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + + /** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object + * + * \param[in] info OrtKernelInfo instance + * \param[in] mem_type OrtMemType object + * \param[out] out A pointer to OrtAllocator + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); }; /* @@ -4697,6 +4707,13 @@ struct OrtCustomOp { // Get start range int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op); int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op); + + // Get the inplace_map that defines which output can reuse which input + // Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays + // when return, output (*output_index)[i] may reuse the input (*input_index[i]). + // The return value is the size of these 2 arrays. + // Callers are responsible to delete these 2 arrays after use. + size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index); }; /* diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 9a09c74cd0b3a..6e9d68d259a5d 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -736,6 +736,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf }); } +ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) { + return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr { + onnxruntime::AllocatorPtr allocator = reinterpret_cast(info)->GetAllocator(mem_type); + if (!allocator) { + return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available"); + } + auto p = std::make_unique(std::move(allocator)); + *out = p.release(); + return nullptr; + }); +} + ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) { if (count_or_bytes == 0) { *out = nullptr; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index 273f94ae5decc..270b3490689c4 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2726,6 +2726,7 @@ static constexpr OrtApi ort_api_1_to_18 = { &OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2, &OrtApis::SessionOptionsAppendExecutionProvider_VitisAI, &OrtApis::KernelContext_GetScratchBuffer, + &OrtApis::KernelInfoGetAllocator, }; // OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase. diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index adb89f7f85444..3591c96234aa3 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -515,4 +515,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi _In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys); ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out); + +ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out); } // namespace OrtApis diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 453b5fdd360bf..91453102d406f 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -4007,3 +4007,34 @@ TEST(CApiTest, RunAsyncFail) { Ort::RunOptions run_options; EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception); } + +struct MockGQA : public OrtCustomOp { + MockGQA() { + OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) { + size_t ret = 2; + *input_index = static_cast(malloc(ret * sizeof(int))); + (*input_index)[0] = 3; + (*input_index)[1] = 4; + *output_index = static_cast(malloc(ret * sizeof(int))); + (*output_index)[0] = 1; + (*output_index)[1] = 2; + return ret; + }; + } +}; + +TEST(CApiTest, OrtCustomOp_GetInPlace) { + MockGQA mock_gqa; + int* input_index = nullptr; + int* output_index = nullptr; + size_t len = mock_gqa.GetMayInplace(&input_index, &output_index); + ASSERT_NE(input_index, nullptr); + ASSERT_NE(output_index, nullptr); + ASSERT_EQ(input_index[0], 3); + ASSERT_EQ(input_index[1], 4); + ASSERT_EQ(output_index[0], 1); + ASSERT_EQ(output_index[1], 2); + ASSERT_EQ(len, static_cast(2)); + free(input_index); + free(output_index); +}