Skip to content

Commit

Permalink
Add 2 C API for ort extension (#19808)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add 2 C API for ORT extension:
- KernelInfo_GetAllocator
- OrtCustomOp::GetMayInplace


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add 2 C API for ORT extension project, which will leverage these 2 APIs
for GroupQueryAttention custom op.
  • Loading branch information
jslhcl authored Mar 14, 2024
1 parent 409b811 commit 966fa74
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 0 deletions.
17 changes: 17 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4600,6 +4600,16 @@ struct OrtApi {
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);

/** \brief Get allocator from KernelInfo for a specific memory type. Please use C API ReleaseAllocator to release out object
*
* \param[in] info OrtKernelInfo instance
* \param[in] mem_type OrtMemType object
* \param[out] out A pointer to OrtAllocator
*
* \snippet{doc} snippets.dox OrtStatus Return Value
*/
ORT_API2_STATUS(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
};

/*
Expand Down Expand Up @@ -4697,6 +4707,13 @@ struct OrtCustomOp {
// Get start range
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);

// Get the inplace_map that defines which output can reuse which input
// Callers will provide 2 raw int* and pass in their address, this function will fill these 2 arrays
// when return, output (*output_index)[i] may reuse the input (*input_index[i]).
// The return value is the size of these 2 arrays.
// Callers are responsible to delete these 2 arrays after use.
size_t(ORT_API_CALL* GetMayInplace)(_Out_ int** input_index, _Out_ int** output_index);
};

/*
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,18 @@ ORT_API_STATUS_IMPL(OrtApis::KernelInfo_GetLogger, _In_ const OrtKernelInfo* inf
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out) {
return ExecuteIfCustomOpsApiEnabled([&]() -> OrtStatusPtr {
onnxruntime::AllocatorPtr allocator = reinterpret_cast<const onnxruntime::OpKernelInfo*>(info)->GetAllocator(mem_type);
if (!allocator) {
return OrtApis::CreateStatus(ORT_INVALID_ARGUMENT, "No requested allocator available");
}
auto p = std::make_unique<onnxruntime::OrtAllocatorImplWrappingIAllocator>(std::move(allocator));
*out = p.release();
return nullptr;
});
}

ORT_API_STATUS_IMPL(OrtApis::KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out) {
if (count_or_bytes == 0) {
*out = nullptr;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/onnxruntime_c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2726,6 +2726,7 @@ static constexpr OrtApi ort_api_1_to_18 = {
&OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO_V2,
&OrtApis::SessionOptionsAppendExecutionProvider_VitisAI,
&OrtApis::KernelContext_GetScratchBuffer,
&OrtApis::KernelInfoGetAllocator,
};

// OrtApiBase can never change as there is no way to know what version of OrtApiBase is returned by OrtGetApiBase.
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/session/ort_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,4 +515,6 @@ ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_VitisAI, _In_ OrtSessi
_In_reads_(num_keys) const char* const* provider_options_values, _In_ size_t num_keys);

ORT_API_STATUS_IMPL(KernelContext_GetScratchBuffer, _In_ const OrtKernelContext* context, _In_ const OrtMemoryInfo* mem_info, _In_ size_t count_or_bytes, _Outptr_ void** out);

ORT_API_STATUS_IMPL(KernelInfoGetAllocator, _In_ const OrtKernelInfo* info, _In_ OrtMemType mem_type, _Outptr_ OrtAllocator** out);
} // namespace OrtApis
31 changes: 31 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4007,3 +4007,34 @@ TEST(CApiTest, RunAsyncFail) {
Ort::RunOptions run_options;
EXPECT_THROW(session.RunAsync(run_options, input_names, input_tensors, 1, output_names, output_values, 1, CallbackFail, nullptr), std::exception);
}

struct MockGQA : public OrtCustomOp {
MockGQA() {
OrtCustomOp::GetMayInplace = [](int** input_index, int** output_index) {
size_t ret = 2;
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*input_index)[0] = 3;
(*input_index)[1] = 4;
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*output_index)[0] = 1;
(*output_index)[1] = 2;
return ret;
};
}
};

TEST(CApiTest, OrtCustomOp_GetInPlace) {
MockGQA mock_gqa;
int* input_index = nullptr;
int* output_index = nullptr;
size_t len = mock_gqa.GetMayInplace(&input_index, &output_index);
ASSERT_NE(input_index, nullptr);
ASSERT_NE(output_index, nullptr);
ASSERT_EQ(input_index[0], 3);
ASSERT_EQ(input_index[1], 4);
ASSERT_EQ(output_index[0], 1);
ASSERT_EQ(output_index[1], 2);
ASSERT_EQ(len, static_cast<size_t>(2));
free(input_index);
free(output_index);
}

0 comments on commit 966fa74

Please sign in to comment.