Skip to content

Commit

Permalink
add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp (#20145)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp 


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Add API function GetAliasMap and ReleaseAliasMap in OrtCustomOp
  • Loading branch information
jslhcl authored Mar 29, 2024
1 parent 8396845 commit 604b284
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 0 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4738,6 +4738,10 @@ struct OrtCustomOp {
// Release the pointer input_index and output_index allocated from GetMayInplace() function.
// If GetMayInplace() is defined, this function MUST be defined as well.
void(ORT_API_CALL* ReleaseMayInplace)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);

// Same as GetMayInplace() and ReleaseMayInplace()
size_t(ORT_API_CALL* GetAliasMap)(_Out_ int** input_index, _Out_ int** output_index);
void(ORT_API_CALL* ReleaseAliasMap)(_Frees_ptr_opt_ int* input_index, _Frees_ptr_opt_ int* output_index);
};

/*
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,8 @@ struct CustomOpBase : OrtCustomOp {

OrtCustomOp::GetMayInplace = nullptr;
OrtCustomOp::ReleaseMayInplace = nullptr;
OrtCustomOp::GetAliasMap = nullptr;
OrtCustomOp::ReleaseAliasMap = nullptr;
}

// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
Expand Down
2 changes: 2 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -865,6 +865,8 @@ struct OrtLiteCustomOp : public OrtCustomOp {

OrtCustomOp::GetMayInplace = {};
OrtCustomOp::ReleaseMayInplace = {};
OrtCustomOp::GetAliasMap = {};
OrtCustomOp::ReleaseAliasMap = {};
}

const std::string op_name_;
Expand Down
10 changes: 10 additions & 0 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -875,6 +875,16 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
}
}

if (op->version >= 18 && op->GetAliasMap != nullptr) {
int* input_index = nullptr;
int* output_index = nullptr;
size_t len = op->GetAliasMap(&input_index, &output_index);
if (len > 0) {
for (size_t i = 0; i < len; i++) def_builder.Alias(input_index[i], output_index[i]);
op->ReleaseAliasMap(input_index, output_index);
}
}

KernelCreateFn kernel_create_fn = [op](FuncManager&, const OpKernelInfo& info,
std::unique_ptr<OpKernel>& out) -> Status {
out = std::make_unique<CustomOpKernel>(info, *op);
Expand Down
25 changes: 25 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4039,6 +4039,20 @@ struct MockGQA : public OrtCustomOp {
free(input_index);
free(output_index);
};
OrtCustomOp::GetAliasMap = [](int** input_index, int** output_index) {
size_t ret = 2;
*input_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*input_index)[0] = 5;
(*input_index)[1] = 6;
*output_index = static_cast<int*>(malloc(ret * sizeof(int)));
(*output_index)[0] = 7;
(*output_index)[1] = 8;
return ret;
};
OrtCustomOp::ReleaseAliasMap = [](int* input_index, int* output_index) {
free(input_index);
free(output_index);
};
}
};

Expand All @@ -4055,4 +4069,15 @@ TEST(CApiTest, OrtCustomOp_GetInPlace) {
ASSERT_EQ(output_index[1], 2);
ASSERT_EQ(len, static_cast<size_t>(2));
mock_gqa.ReleaseMayInplace(input_index, output_index);

input_index = output_index = nullptr;
len = mock_gqa.GetAliasMap(&input_index, &output_index);
ASSERT_NE(input_index, nullptr);
ASSERT_NE(output_index, nullptr);
ASSERT_EQ(input_index[0], 5);
ASSERT_EQ(input_index[1], 6);
ASSERT_EQ(output_index[0], 7);
ASSERT_EQ(output_index[1], 8);
ASSERT_EQ(len, static_cast<size_t>(2));
mock_gqa.ReleaseAliasMap(input_index, output_index);
}

0 comments on commit 604b284

Please sign in to comment.