Skip to content

Commit

Permalink
Versioning for custom op (microsoft#18088)
Browse files Browse the repository at this point in the history
Allow custom ops to have versions.

---------

Co-authored-by: Randy Shuai <[email protected]>
  • Loading branch information
RandySheriffH and RandyShuai authored Oct 31, 2023
1 parent 62c7894 commit 2b95e74
Show file tree
Hide file tree
Showing 8 changed files with 148 additions and 37 deletions.
4 changes: 4 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -4605,6 +4605,10 @@ struct OrtCustomOp {
OrtStatusPtr(ORT_API_CALL* KernelComputeV2)(_In_ void* op_kernel, _In_ OrtKernelContext* context);

OrtStatusPtr(ORT_API_CALL* InferOutputShapeFn)(_In_ const struct OrtCustomOp* op, _In_ OrtShapeInferContext*);

// Get start range
int(ORT_API_CALL* GetStartVersion)(_In_ const struct OrtCustomOp* op);
int(ORT_API_CALL* GetEndVersion)(_In_ const struct OrtCustomOp* op);
};

/*
Expand Down
13 changes: 13 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_cxx_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2228,6 +2228,8 @@ struct ShapeInferContext {

using ShapeInferFn = Ort::Status (*)(Ort::ShapeInferContext&);

#define MAX_CUSTOM_OP_END_VER (1UL << 31) - 1

template <typename TOp, typename TKernel, bool WithStatus = false>
struct CustomOpBase : OrtCustomOp {
CustomOpBase() {
Expand Down Expand Up @@ -2280,6 +2282,14 @@ struct CustomOpBase : OrtCustomOp {
}

SetShapeInferFn<TOp>(0);

OrtCustomOp::GetStartVersion = [](const OrtCustomOp* this_) {
return static_cast<const TOp*>(this_)->start_ver_;
};

OrtCustomOp::GetEndVersion = [](const OrtCustomOp* this_) {
return static_cast<const TOp*>(this_)->end_ver_;
};
}

// Default implementation of GetExecutionProviderType that returns nullptr to default to the CPU provider
Expand Down Expand Up @@ -2348,6 +2358,9 @@ struct CustomOpBase : OrtCustomOp {
protected:
// Helper function that returns a map of session config entries specified by CustomOpBase::GetSessionConfigKeys.
void GetSessionConfigs(std::unordered_map<std::string, std::string>& out, ConstSessionOptions options) const;

int start_ver_ = 1;
int end_ver_ = MAX_CUSTOM_OP_END_VER;
};

} // namespace Ort
Expand Down
59 changes: 43 additions & 16 deletions include/onnxruntime/core/session/onnxruntime_lite_custom_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -773,8 +773,11 @@ struct OrtLiteCustomOp : public OrtCustomOp {
PARSE_ARGS(Ort::Float8E5M2FNUZ_t, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ)

OrtLiteCustomOp(const char* op_name,
const char* execution_provider) : op_name_(op_name),
execution_provider_(execution_provider) {
const char* execution_provider,
int start_ver = 1, int end_ver = MAX_CUSTOM_OP_END_VER) : op_name_(op_name),
execution_provider_(execution_provider),
start_ver_(start_ver),
end_ver_(end_ver) {
OrtCustomOp::version = ORT_API_VERSION;

OrtCustomOp::GetName = [](const OrtCustomOp* op) { return static_cast<const OrtLiteCustomOp*>(op)->op_name_.c_str(); };
Expand Down Expand Up @@ -837,13 +840,26 @@ struct OrtLiteCustomOp : public OrtCustomOp {
OrtCustomOp::KernelCompute = {};

OrtCustomOp::InferOutputShapeFn = {};

OrtCustomOp::GetStartVersion = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->start_ver_;
};

OrtCustomOp::GetEndVersion = [](const OrtCustomOp* op) {
auto self = reinterpret_cast<const OrtLiteCustomOp*>(op);
return self->end_ver_;
};
}

const std::string op_name_;
const std::string execution_provider_;

std::vector<ONNXTensorElementDataType> input_types_;
std::vector<ONNXTensorElementDataType> output_types_;

int start_ver_ = 1;
int end_ver_ = MAX_CUSTOM_OP_END_VER;
};

//////////////////////////// OrtLiteCustomFunc ////////////////////////////////
Expand Down Expand Up @@ -873,9 +889,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
OrtLiteCustomFunc(const char* op_name,
const char* execution_provider,
ComputeFn compute_fn,
ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
compute_fn_(compute_fn),
shape_infer_fn_(shape_infer_fn) {
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_(compute_fn),
shape_infer_fn_(shape_infer_fn) {
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelCompute = [](void* op_kernel, OrtKernelContext* context) {
Expand Down Expand Up @@ -911,9 +929,11 @@ struct OrtLiteCustomFunc : public OrtLiteCustomOp {
OrtLiteCustomFunc(const char* op_name,
const char* execution_provider,
ComputeFnReturnStatus compute_fn_return_status,
ShapeInferFn shape_infer_fn = {}) : OrtLiteCustomOp(op_name, execution_provider),
compute_fn_return_status_(compute_fn_return_status),
shape_infer_fn_(shape_infer_fn) {
ShapeInferFn shape_infer_fn = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver),
compute_fn_return_status_(compute_fn_return_status),
shape_infer_fn_(shape_infer_fn) {
ParseArgs<Args...>(input_types_, output_types_);

OrtCustomOp::KernelComputeV2 = [](void* op_kernel, OrtKernelContext* context) -> OrtStatusPtr {
Expand Down Expand Up @@ -985,8 +1005,9 @@ struct OrtLiteCustomStruct : public OrtLiteCustomOp {
};

OrtLiteCustomStruct(const char* op_name,
const char* execution_provider) : OrtLiteCustomOp(op_name,
execution_provider) {
const char* execution_provider,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) : OrtLiteCustomOp(op_name, execution_provider, start_ver, end_ver) {
SetCompute(&CustomOp::Compute);

OrtCustomOp::CreateKernel = [](const OrtCustomOp* this_, const OrtApi* ort_api, const OrtKernelInfo* info) {
Expand Down Expand Up @@ -1049,25 +1070,31 @@ template <typename... Args>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider,
void (*custom_compute_fn)(Args...),
Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
Status (*shape_infer_fn)(ShapeInferContext&) = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) {
using LiteOp = OrtLiteCustomFunc<Args...>;
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn).release();
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn, shape_infer_fn, start_ver, end_ver).release();
}

template <typename... Args>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider,
Status (*custom_compute_fn_v2)(Args...),
Status (*shape_infer_fn)(ShapeInferContext&) = {}) {
Status (*shape_infer_fn)(ShapeInferContext&) = {},
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) {
using LiteOp = OrtLiteCustomFunc<Args...>;
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn).release();
return std::make_unique<LiteOp>(op_name, execution_provider, custom_compute_fn_v2, shape_infer_fn, start_ver, end_ver).release();
}

template <typename CustomOp>
OrtLiteCustomOp* CreateLiteCustomOp(const char* op_name,
const char* execution_provider) {
const char* execution_provider,
int start_ver = 1,
int end_ver = MAX_CUSTOM_OP_END_VER) {
using LiteOp = OrtLiteCustomStruct<CustomOp>;
return std::make_unique<LiteOp>(op_name, execution_provider).release();
return std::make_unique<LiteOp>(op_name, execution_provider, start_ver, end_ver).release();
}

} // namespace Custom
Expand Down
22 changes: 19 additions & 3 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#if !defined(ORT_MINIMAL_BUILD)
static constexpr uint32_t min_ort_version_with_optional_io_support = 8;
static constexpr uint32_t min_ort_version_with_variadic_io_support = 14;
static constexpr uint32_t min_ort_version_with_custom_version = 17;
#endif

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
Expand Down Expand Up @@ -698,8 +699,19 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust

KernelDefBuilder def_builder;
def_builder.SetName(op->GetName(op))
.SetDomain(domain)
.SinceVersion(1);
.SetDomain(domain);

if (op->version >= min_ort_version_with_custom_version) {
if (op->GetStartVersion && op->GetEndVersion) {
def_builder.SinceVersion(op->GetStartVersion(op), op->GetEndVersion(op));
} else if (op->GetStartVersion) {
def_builder.SinceVersion(op->GetStartVersion(op));
} else {
def_builder.SinceVersion(1);
}
} else {
def_builder.SinceVersion(1);
}

// GetInputMemoryType was introduced in ver 13. This check allows custom ops compiled using older versions
// to work with newer versions (> 12) of the ORT binary.
Expand Down Expand Up @@ -820,7 +832,11 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom
schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
}
schema.SetDomain(domain);
schema.SinceVersion(1);
if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) {
schema.SinceVersion(op->GetStartVersion(op));
} else {
schema.SinceVersion(1);
}
schema.AllowUncheckedAttributes();

if (op->version >= min_ort_version_with_shape_inference && op->InferOutputShapeFn) {
Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3323,6 +3323,22 @@ TEST(LiteCustomOpTest, CustomFunc) {
ASSERT_TRUE(floats_output[1] == 16);
}

TEST(LiteCustomOpTest, CustomFuncOpsetMismatch) {
Ort::SessionOptions session_options;
session_options.SetIntraOpNumThreads(1);
session_options.SetGraphOptimizationLevel(GraphOptimizationLevel::ORT_ENABLE_EXTENDED);
session_options.SetLogSeverityLevel(0);
#if defined(_WIN32)
session_options.RegisterCustomOpsLibrary(ORT_TSTR("custom_op_library.dll"));
#elif defined(__APPLE__)
session_options.RegisterCustomOpsLibrary(ORT_TSTR("libcustom_op_library.dylib"));
#else
session_options.RegisterCustomOpsLibrary(ORT_TSTR("./libcustom_op_library.so"));
#endif

EXPECT_THROW(Ort::Session(*ort_env, TSTR("testdata/fuse_select_filter_opset_8.onnx"), session_options), std::exception);
}

struct Merge {
Merge(const OrtApi* ort_api, const OrtKernelInfo* info) {
int64_t reverse;
Expand Down
37 changes: 21 additions & 16 deletions onnxruntime/test/testdata/custom_op_library/cpu/cpu_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -94,23 +94,28 @@ void Select(const Ort::Custom::Span<int32_t>& indices_in,
}
}

void Filter(const Ort::Custom::Tensor<float>& floats_in,
Ort::Custom::Tensor<float>& floats_out) {
const float* in = floats_in.Data();
auto in_len = floats_in.NumberOfElement();
struct Filter {
Filter(const OrtApi*, const OrtKernelInfo*) {}
Ort::Status Compute(const Ort::Custom::Tensor<float>& floats_in,
Ort::Custom::Tensor<float>& floats_out) {
const float* in = floats_in.Data();
auto in_len = floats_in.NumberOfElement();

std::vector<float> filter_floats;
for (int64_t i = 0; i < in_len; ++i) {
if (in[i] > 1.f) {
filter_floats.push_back(in[i]);
}
}

std::vector<float> filter_floats;
for (int64_t i = 0; i < in_len; ++i) {
if (in[i] > 1.f) {
filter_floats.push_back(in[i]);
float* out = static_cast<float*>(floats_out.Allocate({static_cast<int64_t>(filter_floats.size())}));
for (size_t j = 0; j < filter_floats.size(); ++j) {
out[j] = filter_floats[j];
}
}

float* out = static_cast<float*>(floats_out.Allocate({static_cast<int64_t>(filter_floats.size())}));
for (size_t j = 0; j < filter_floats.size(); ++j) {
out[j] = filter_floats[j];
return Ort::Status{nullptr};
}
}
};

void Box(const Ort::Custom::Tensor<float>* float_in_1,
const Ort::Custom::Tensor<float>* float_in_2,
Expand Down Expand Up @@ -293,9 +298,9 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
static const std::unique_ptr<OrtLiteCustomOp> c_CustomOpTwo{Ort::Custom::CreateLiteCustomOp("CustomOpTwo", "CPUExecutionProvider", KernelTwo)};
static const std::unique_ptr<OrtLiteCustomOp> c_MulTopOpFloat{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop<float>)};
static const std::unique_ptr<OrtLiteCustomOp> c_MulTopOpInt32{Ort::Custom::CreateLiteCustomOp("MulTop", "CPUExecutionProvider", MulTop<int32_t>)};
static const std::unique_ptr<OrtLiteCustomOp> c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse)};
static const std::unique_ptr<OrtLiteCustomOp> c_Fuse{Ort::Custom::CreateLiteCustomOp("Fuse", "CPUExecutionProvider", Fuse, {}, 10, 12)};
static const std::unique_ptr<OrtLiteCustomOp> c_Select{Ort::Custom::CreateLiteCustomOp("Select", "CPUExecutionProvider", Select)};
static const std::unique_ptr<OrtLiteCustomOp> c_Fill{Ort::Custom::CreateLiteCustomOp("Filter", "CPUExecutionProvider", Filter)};
static const std::unique_ptr<OrtLiteCustomOp> c_Filter{Ort::Custom::CreateLiteCustomOp<Filter>("Filter", "CPUExecutionProvider", 15, 17)};
static const std::unique_ptr<OrtLiteCustomOp> c_Box{Ort::Custom::CreateLiteCustomOp("Box", "CPUExecutionProvider", Box)};
static const std::unique_ptr<OrtLiteCustomOp> c_CopyTensorArrayAllVariadic{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayAllVariadic", "CPUExecutionProvider", CopyTensorArrayAllVariadic<float>)};
static const std::unique_ptr<OrtLiteCustomOp> c_CopyTensorArrayCombined{Ort::Custom::CreateLiteCustomOp("CopyTensorArrayCombined", "CPUExecutionProvider", CopyTensorArrayCombined<float>)};
Expand All @@ -314,7 +319,7 @@ void RegisterOps(Ort::CustomOpDomain& domain) {
domain.Add(c_MulTopOpInt32.get());
domain.Add(c_Fuse.get());
domain.Add(c_Select.get());
domain.Add(c_Fill.get());
domain.Add(c_Filter.get());
domain.Add(c_Box.get());
domain.Add(c_CopyTensorArrayAllVariadic.get());
domain.Add(c_CopyTensorArrayCombined.get());
Expand Down
5 changes: 3 additions & 2 deletions onnxruntime/test/testdata/fuse_select_filter.onnx
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
:�
 :�
P
vector_1
vector_2
Expand All @@ -25,4 +25,5 @@ N
���������b&
vector_filtered

���������B
���������B
v2
29 changes: 29 additions & 0 deletions onnxruntime/test/testdata/fuse_select_filter_opset_8.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
 :�
P
vector_1
vector_2
alpha vector_fused fuse_node"Fuse*
fuse_algo�:v2
4
indicesindices_selected select_node"Select:v2
N
vector_fused
indices_selectedvector_gathered gather_node"GatherElements
;
vector_gatheredvector_filtered filter_node"Filter:v2graphZ
vector_1

���������Z
vector_2

���������Z
alpha

���������Z
indices

���������b&
vector_filtered

���������B
v2

0 comments on commit 2b95e74

Please sign in to comment.