Skip to content

Commit

Permalink
Fix schema type constraint for custom operators (#17497)
Browse files Browse the repository at this point in the history
### Description
onnxruntime may raise an error "type inference failed" but when a custom
operator sets IsHomogeneous to false in its schema. This change make
sure that TypeInferenceFunction and schema type constraints are aligned
to prevent that from happening.

---------

Co-authored-by: Xavier Dupre <[email protected]@orttrainingdev9.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
Co-authored-by: Scott McKay <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2024
1 parent 011b562 commit 889b1ef
Show file tree
Hide file tree
Showing 12 changed files with 478 additions and 49 deletions.
32 changes: 32 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1662,6 +1662,38 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUI
${ONNXRUNTIME_CUSTOM_OP_GET_CONST_INPUT_TEST_LIB_LINK_FLAG})
endif()

if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND (NOT onnxruntime_MINIMAL_BUILD OR onnxruntime_MINIMAL_BUILD_CUSTOM_OPS))

file(GLOB_RECURSE custom_op_local_function_test_library_src
"${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.cc"
"${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.h"
"${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.cc"
"${TEST_SRC_DIR}/testdata/custom_op_local_function/dummy_gemm.h"
)

onnxruntime_add_shared_library_module(custom_op_local_function ${custom_op_local_function_test_library_src})

onnxruntime_add_include_to_target(custom_op_local_function onnxruntime_common GTest::gtest GTest::gmock)
target_include_directories(custom_op_local_function PRIVATE ${REPO_ROOT}/include/onnxruntime/core/session
${REPO_ROOT}/include/onnxruntime/core/common)

if(UNIX)
if (APPLE)
set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG "-Xlinker -dead_strip")
else()
string(CONCAT ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG
"-Xlinker --version-script=${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.lds "
"-Xlinker --no-undefined -Xlinker --gc-sections -z noexecstack")
endif()
else()
set(ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG
"-DEF:${TEST_SRC_DIR}/testdata/custom_op_local_function/custom_op_local_function.def")
endif()

set_property(TARGET custom_op_local_function APPEND_STRING PROPERTY LINK_FLAGS
${ONNXRUNTIME_CUSTOM_OP_lOCAL_FUNCTION_TEST_LIB_LINK_FLAG})
endif()

if (onnxruntime_BUILD_SHARED_LIB AND NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten" AND NOT onnxruntime_MINIMAL_BUILD)
set (onnxruntime_logging_apis_test_SRC
${ONNXRUNTIME_LOGGING_APIS_TEST_SRC_DIR}/test_logging_apis.cc)
Expand Down
8 changes: 7 additions & 1 deletion onnxruntime/core/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2367,8 +2367,14 @@ Status Graph::InferAndVerifyTypeMatch(Node& node, const OpSchema& op, const Reso
inferred_type = existing_type;
} else {
// This should not happen: indicates incompleteness in ONNX inference.
std::stringstream ss;
ss << "index=" << operand_index;
for (auto it = op_formal_parameter.GetTypes().begin(); it != op_formal_parameter.GetTypes().end(); ++it) {
ss << "," << *(*it);
}
Status status(ONNXRUNTIME, onnxruntime::common::StatusCode::FAIL,
"Node (" + node_name + ") output arg (" + output_def->Name() + ") type inference failed");
"Node (" + node_name + ") Op (" + node.OpType() + ") output arg (" +
output_def->Name() + ") type inference failed, inferred types: " + ss.str());
return status;
}

Expand Down
145 changes: 97 additions & 48 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#pragma warning(disable : 4267)
#endif

#include <string>
#include <type_traits>
#include <unordered_map>
#include <unordered_set>

#include "core/common/gsl.h"
#include "core/framework/data_types.h"
Expand Down Expand Up @@ -755,66 +758,96 @@ KernelCreateInfo CreateKernelCreateInfo(const std::string& domain, const OrtCust
return KernelCreateInfo(def_builder.Build(), kernel_create_fn);
}

ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustomOp* op) {
const size_t input_count = op->GetInputTypeCount(op);
const size_t output_count = op->GetOutputTypeCount(op);
ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const std::vector<const OrtCustomOp*>& ops) {
// The function registers the first schema assuming all the other one are the same except the types constraints.
ORT_ENFORCE(ops.size() > 0, "No kernels to registers.");
int undefined = 0;

// Creation of the schema for the first kernel in ops.
const OrtCustomOp* op = *ops.begin();
ONNX_NAMESPACE::OpSchema schema(op->GetName(op), "custom op registered at runtime", 0);

for (size_t i = 0; i < input_count; i++) {
auto create_type_constraint = [&ops, &schema, &undefined](const OrtCustomOp* op, int count, int i, bool is_input) {
onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
bool is_homogeneous = true;
int min_arity = 1;

// The OrtCustomOp interface did not support the methods to query input/output characteristics before
// ORT API version 8. So, query the relevant methods ONLY from API version 8 onwards.
if (op->version >= min_ort_version_with_optional_io_support) {
const auto characteristic = op->GetInputCharacteristic(op, i);
const auto characteristic = is_input ? op->GetInputCharacteristic(op, i) : op->GetOutputCharacteristic(op, i);

// Support for optional and variadic inputs/output was added in versions 8 and 14, respectively.
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
option = onnx::OpSchema::FormalParameterOption::Optional;
} else if ((op->version >= min_ort_version_with_variadic_io_support) &&
(characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC)) {
ORT_ENFORCE(i == input_count - 1, "Only the last input to a custom op may be marked variadic.");
ORT_ENFORCE(i == count - 1, "Only the last ", (is_input ? "input" : "output"),
" to a custom op may be marked variadic.");
option = onnx::OpSchema::FormalParameterOption::Variadic;
min_arity = op->GetVariadicInputMinArity(op);
is_homogeneous = static_cast<bool>(op->GetVariadicInputHomogeneity(op));
min_arity = is_input ? op->GetVariadicInputMinArity(op) : op->GetVariadicOutputMinArity(op);
is_homogeneous = static_cast<bool>(is_input
? op->GetVariadicInputHomogeneity(op)
: op->GetVariadicOutputHomogeneity(op));
}
}

const auto type = op->GetInputType(op, i);
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
undefined++;
// The loop goes through all operators sharing the same schema to build
// the minimal type constraints for all of them. All kernels must have
// the same number of inputs / outputs among themselves to be able to build
// the type constraints. Any kind of incompatibility between a schema and
// a kernel is checked by method IsCompatible once the schema is created
// by this method.
std::unordered_set<ONNXTensorElementDataType> all_types;
for (auto o : ops) {
ORT_ENFORCE(static_cast<size_t>(i) != (is_input ? o->GetInputTypeCount(o) : o->GetOutputTypeCount(o)),
"Another version of operator '", schema.Name(),
"'has a different number of ", (is_input ? "inputs" : "outputs"),
". onnxruntime allows the overloading of an operator "
"if all versions have the same number of declared ",
(is_input ? "inputs" : "outputs"), ".");
const auto type = is_input ? o->GetInputType(o, i) : o->GetOutputType(o, i);
if (type == ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED) {
// If 'type' is undefined, all types are allowed regardless of what other versions of the same operator
// define. In that case, all_types is cleared, that's the convention used by the code following this loop
// to declare all types as possible types.
all_types.clear();
break;
}
all_types.insert(type);
}
std::string input_name = "Input" + std::to_string(i);
schema.Input(gsl::narrow_cast<int>(i), input_name, "", input_name, option, is_homogeneous, min_arity);
// support all types as input here in schema, and handle the type inference in TypeShapeInference func
schema.TypeConstraint(input_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
}

for (size_t i = 0; i < output_count; i++) {
onnx::OpSchema::FormalParameterOption option = onnx::OpSchema::FormalParameterOption::Single;
bool is_homogeneous = true;
int min_arity = 1;

// The OrtCustomOp interface did not support the methods to query input/output characteristics before
// ORT API version 8. So, query the relevant methods ONLY from API version 8 onwards.
if (op->version >= min_ort_version_with_optional_io_support) {
const auto characteristic = op->GetOutputCharacteristic(op, i);
std::string prefix = is_input ? "Input" : "Output";
std::string name = prefix + std::to_string(i);
if (is_input) {
schema.Input(gsl::narrow_cast<int>(i), name, "", name, option, is_homogeneous, min_arity);
} else {
schema.Output(gsl::narrow_cast<int>(i), name, "", name, option, is_homogeneous, min_arity);
}

// Support for optional and variadic inputs/output was added in versions 8 and 14, respectively.
if (characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_OPTIONAL) {
option = onnx::OpSchema::FormalParameterOption::Optional;
} else if ((op->version >= min_ort_version_with_variadic_io_support) &&
(characteristic == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_VARIADIC)) {
ORT_ENFORCE(i == output_count - 1, "Only the last output to a custom op may be marked variadic.");
option = onnx::OpSchema::FormalParameterOption::Variadic;
min_arity = op->GetVariadicOutputMinArity(op);
is_homogeneous = static_cast<bool>(op->GetVariadicOutputHomogeneity(op));
if (!all_types.empty()) {
// all_types is not empty then only the types in this container are allowed of this input.
std::vector<std::string> types;
for (auto type : all_types) {
const ONNX_NAMESPACE::TypeProto* type_proto =
DataTypeImpl::TensorTypeFromONNXEnum(static_cast<int>(type))->GetTypeProto();
types.push_back(*ONNX_NAMESPACE::Utils::DataTypeUtils::ToType(*type_proto));
}
schema.TypeConstraint(name, types, "defined list of types");
} else {
// all_types is empty. As mentioned in the previous loop, all types are allowed.
schema.TypeConstraint(name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
undefined++;
}
};

const size_t input_count = op->GetInputTypeCount(op);
for (size_t i = 0; i < input_count; i++) {
create_type_constraint(op, static_cast<int>(input_count), static_cast<int>(i), true);
}

const size_t output_count = op->GetOutputTypeCount(op);
for (size_t i = 0; i < output_count; i++) {
const auto type = op->GetOutputType(op, i);
if (ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED == type) {
if (op->GetOutputCharacteristic(op, i) == OrtCustomOpInputOutputCharacteristic::INPUT_OUTPUT_REQUIRED) {
Expand All @@ -826,11 +859,9 @@ ONNX_NAMESPACE::OpSchema CreateSchema(const std::string& domain, const OrtCustom
"cannot be inferred without which model loading cannot proceed.");
}
}
std::string output_name = "Output" + std::to_string(i);
schema.Output(gsl::narrow_cast<int>(i), output_name, "", output_name, option, is_homogeneous, min_arity);
// support all types as input here in schema, and handle the type inference in TypeShapeInference func
schema.TypeConstraint(output_name, DataTypeImpl::ToString(SUPPORTED_TENSOR_TYPES), "all types");
create_type_constraint(op, static_cast<int>(output_count), static_cast<int>(i), false);
}

schema.SetDomain(domain);
if (op->version >= min_ort_version_with_custom_version && op->GetStartVersion) {
schema.SinceVersion(op->GetStartVersion(op));
Expand Down Expand Up @@ -905,7 +936,7 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same homogeneity");
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicInputMinArity(op),
ORT_RETURN_IF_NOT(formal_parameter.GetMinArity() == op->GetVariadicOutputMinArity(op),
"custom op schemas mismatch, expecting ", i + 1,
i == 0 ? "st" : (i == 1 ? "nd" : "th"),
" output to keep same arity");
Expand Down Expand Up @@ -994,18 +1025,36 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
}
}

// domain_kernels aggregate all custom operator per names.
std::unordered_map<std::string, std::vector<const OrtCustomOp*>> domain_kernels;
for (const auto* op : domain->custom_ops_) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
// define schema
auto schema_map_iter = schema_map.find(op->GetName(op));
if (schema_map_iter == schema_map.end()) {
auto schema = CreateSchema(domain->domain_, op);
schema_map.emplace(schema.Name(), schema);
auto it = domain_kernels.find(op->GetName(op));
if (it == domain_kernels.end()) {
domain_kernels[op->GetName(op)] = {op};
} else {
ORT_RETURN_IF_ERROR(IsCompatible(schema_map_iter->second, op));
domain_kernels[op->GetName(op)].push_back(op);
}
}

// Creation of the schemas, one per unique name.
for (auto& [name, ops] : domain_kernels) {
auto schema = CreateSchema(domain->domain_, ops);
// schema.Name() is equal to ops[0]->GetName(ops[0]) and op->GetName(op) is the value
// used as a key for dictionary domain_kernels, therefore name == schema.Name().
schema_map.emplace(schema.Name(), schema);

// This loops checks that all custom operators sharing the same name are compatible with the defined schema.
for (const auto* op : ops) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
kernel_def_map[op->GetName(op)].push_back(kernel_create_info.kernel_def.get());
ORT_RETURN_IF_ERROR(output->RegisterCustomKernel(kernel_create_info));
// If IsCompatible returns false, then all custom operators named
// 'op->GetName(op)' are not compatible among themselves.
// They should have the same number of inputs and outputs, the same characteristics,
// (optional, ...). Only the type can change.
ORT_RETURN_IF_ERROR(IsCompatible(schema, op));
}
}

Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/test/shared_lib/test_inference.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1264,6 +1264,49 @@ TEST(CApiTest, test_custom_op_get_const_input) {
}
#endif

#if !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)
#if defined(__ANDROID__)
// Disable on android because custom op libraries are not copied to the emulator.
TEST(CApiTest, DISABLED_test_custom_op_local_function) {
#else
TEST(CApiTest, test_custom_op_local_function) {
#endif // defined(__ANDROID__)
const auto* model_path = TSTR("testdata/custom_op_local_function/custom_ops_type_inference_fails_0.onnx");

Ort::MemoryInfo info("Cpu", OrtDeviceAllocator, 0, OrtMemTypeDefault);
std::vector<Ort::Value> ort_inputs;
std::vector<const char*> input_names;

// input 0 (float type)
input_names.emplace_back("X");
std::vector<float> input_0_data = {1.0f, 2.0f, 3.0f, 4.0f};
std::vector<int64_t> input_0_dims = {2, 2};
ort_inputs.emplace_back(
Ort::Value::CreateTensor<float>(info, const_cast<float*>(input_0_data.data()),
input_0_data.size(), input_0_dims.data(), input_0_dims.size()));
const char* output_name = "Y";

const ORTCHAR_T* lib_name;
#if defined(_WIN32)
lib_name = ORT_TSTR("custom_op_local_function.dll");
#elif defined(__APPLE__)
lib_name = ORT_TSTR("libcustom_op_local_function.dylib");
#else
lib_name = ORT_TSTR("./libcustom_op_local_function.so");
#endif

Ort::SessionOptions session_opts;

session_opts.RegisterCustomOpsLibrary(lib_name);

Ort::Session session(*ort_env, model_path, session_opts);
auto default_allocator = std::make_unique<MockedOrtAllocator>();

session.Run(Ort::RunOptions{}, input_names.data(), ort_inputs.data(), ort_inputs.size(),
&output_name, 1);
}
#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS)

#if defined(USE_OPENVINO) && (!defined(ORT_MINIMAL_BUILD) || defined(ORT_MINIMAL_BUILD_CUSTOM_OPS))
TEST(CApiTest, test_custom_op_openvino_wrapper_library) {
// Tests a custom operator that wraps an OpenVINO MNIST model (.xml and .bin files serialized into node attributes).
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "custom_op_local_function.h"

#include <cmath>
#include <mutex>
#include <utility>
#include <vector>

#include "core/common/common.h"
#include "core/framework/ortdevice.h"
#include "core/framework/ortmemoryinfo.h"
#include "dummy_gemm.h"

static const char* c_OpDomain = "onnx_extented.ortops.tutorial.cpu";

static void AddOrtCustomOpDomainToContainer(Ort::CustomOpDomain&& domain) {
static std::vector<Ort::CustomOpDomain> ort_custom_op_domain_container;
static std::mutex ort_custom_op_domain_mutex;
std::lock_guard<std::mutex> lock(ort_custom_op_domain_mutex);
ort_custom_op_domain_container.push_back(std::move(domain));
}

OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options,
const OrtApiBase* api_base) {
Ort::InitApi(api_base->GetApi(ORT_API_VERSION));
Ort::UnownedSessionOptions session_options(options);

// An instance remaining available until onnxruntime unload the library.
static Cpu::CustomGemmOp c_CustomGemmFloat(
"CustomGemmFloat", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
false);
static Cpu::CustomGemmOp c_CustomGemmFloat8E4M3FN(
"CustomGemmFloat8E4M3FN", ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN,
ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT,
false);
OrtStatus* result = nullptr;

ORT_TRY {
Ort::CustomOpDomain domain{c_OpDomain};

domain.Add(&c_CustomGemmFloat);
domain.Add(&c_CustomGemmFloat8E4M3FN);

session_options.Add(domain);
AddOrtCustomOpDomainToContainer(std::move(domain));
}
ORT_CATCH(const std::exception& e) {
ORT_HANDLE_EXCEPTION([&]() {
Ort::Status status{e};
result = status.release();
});
}

return result;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
LIBRARY "custom_op_local_function.dll"
EXPORTS
RegisterCustomOps @1
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once
#include "onnxruntime_c_api.h"

#ifdef __cplusplus
extern "C" {
#endif

ORT_EXPORT OrtStatus* ORT_API_CALL RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);

#ifdef __cplusplus
}
#endif
Loading

0 comments on commit 889b1ef

Please sign in to comment.