Skip to content

Commit

Permalink
Pass schema to InferOutputTypes.
Browse files Browse the repository at this point in the history
This helps detect optional input/outputs.
Quit inference on variadic inputs/outputs.
  • Loading branch information
yuslepukhin committed Feb 27, 2024
1 parent baeece4 commit dabb564
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,12 @@ class RunQueue {
// position, these conditions would be indistinguishable); (2) obtain
// consistent snapshot of front_/back_ for Size operation using the
// modification counters.
#pragma warning(push)
#pragma warning(disable: 4324)
ORT_ALIGN_TO_AVOID_FALSE_SHARING std::atomic<unsigned> front_;
ORT_ALIGN_TO_AVOID_FALSE_SHARING std::atomic<unsigned> back_;
ORT_ALIGN_TO_AVOID_FALSE_SHARING Elem array_[kSize];
#pragma warning(pop)

// SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false,
// only whether the size is 0 is guaranteed to be correct.
Expand Down
166 changes: 117 additions & 49 deletions onnxruntime/core/session/custom_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,55 +620,122 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o
return Status::OK();
}

void InferOutputTypes(const InlinedVector<const KernelDef*>& kernel_defs,
namespace {
struct CustomOpSchemaAndKernels {
ONNX_NAMESPACE::OpSchema schema;
InlinedVector<const KernelDef*> kernel_defs;
};
} // namespace

void InferOutputTypes(const CustomOpSchemaAndKernels& custom_op_schema_and_kernels,
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
for (const auto& kernel_def : kernel_defs) {
const auto& schema = custom_op_schema_and_kernels.schema;
const auto& kernel_defs = custom_op_schema_and_kernels.kernel_defs;
const auto& inputs = schema.inputs();
const auto node_input_num = infer_ctx.getNumInputs();

const KernelDef* def_selected = nullptr;
bool variadic_input = false;
int32_t output_propagate{0};

for (size_t kernel_index = 0;
kernel_index < kernel_defs.size() && def_selected == nullptr;
++kernel_index) {
const auto* kernel_def = kernel_defs[kernel_index];
const auto& type_constraints = kernel_def->TypeConstraints();
auto num_inputs = infer_ctx.getNumInputs();
bool matched = true;
ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED;
// first, make sure there is a constraint for every input
for (size_t i = 0; i < num_inputs && matched; ++i) {
auto input_name = "Input" + std::to_string(i);
auto input_type = infer_ctx.getInputType(i);
if (input_type) {
auto elem_type = static_cast<ONNXTensorElementDataType>(input_type->tensor_type().elem_type());
auto tc_iter = type_constraints.find(input_name);
if (tc_iter != type_constraints.end()) {
if (tc_iter->second.size() > 1) {
undef = elem_type;
} else if (tc_iter->second.size() != 1 || tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) {
matched = false;
}
} else {
matched = false;
def_selected = kernel_def;

for (size_t i = 0; i < node_input_num; ++i) {
const auto input_type = infer_ctx.getInputType(i);

const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1;
const auto& param = inputs[schema_input_index];
const auto& input_name = param.GetName();
if (input_type == nullptr) {
if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional)
continue;

ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name,
" is absent, but not optional. Op : ", schema.Name());
}

variadic_input |= (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);

auto hit = type_constraints.find(input_name);
if (hit != type_constraints.end()) {
const auto& types = hit->second;
// For custom ops kernel constraints are never empty
assert(!types.empty());
if (!std::any_of(types.cbegin(), types.cend(),
[input_type](const DataTypeImpl* type) {
return type->IsCompatible(*input_type);
})) {
def_selected = nullptr;
variadic_input = false;
output_propagate = 0;
break;
}

if (types.size() > 1) {
output_propagate = input_type->tensor_type().elem_type();
}

} else {
matched = false;
}
} // for
// next, ensure that there is a constraint for every output
auto num_outputs = infer_ctx.getNumOutputs();
for (size_t i = 0; i < num_outputs && matched; i++) {
auto output_name = "Output" + std::to_string(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) {
matched = false;
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ",
input_name, " Op: ", schema.Name());
}
}
if (matched) {
for (size_t i = 0; i < num_outputs; i++) {
auto output_name = "Output" + std::to_string(i);
auto output_type = infer_ctx.getOutputType(i);
auto tc_iter = type_constraints.find(output_name);
if (tc_iter->second.size() > 1) {
output_type->mutable_tensor_type()->set_elem_type(undef);
} else {
output_type->mutable_tensor_type()->set_elem_type(tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type());
}
}
}

if (def_selected == nullptr) {
ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name());
}

if (variadic_input) {
return;
}

const auto& outputs = custom_op_schema_and_kernels.schema.outputs();
const auto node_output_num = infer_ctx.getNumOutputs();
const auto& selected_type_constraints = def_selected->TypeConstraints();

for (size_t i = 0; i < node_output_num; ++i) {
auto output_type = infer_ctx.getOutputType(i);
// Account for variadic outputs
const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1;
const auto& param = outputs[schema_output_index];
const auto& output_name = param.GetName();

if (output_type == nullptr) {
if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional)
continue;

ORT_THROW("[CustomOP type inferencing error]: kernel Output: ", i,
" is absent, but not optional. Op : ", schema.Name());
}

const bool variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic);

if (variadic_output) {
break;
}

auto hit = selected_type_constraints.find(output_name);
if (hit != selected_type_constraints.end()) {
const auto& types = hit->second;
assert(!types.empty());

if (types.size() == 1) {
// Use the constraint type
output_type->mutable_tensor_type()->set_elem_type(
types[0]->GetTypeProto()->tensor_type().elem_type());
} else {
output_type->mutable_tensor_type()->set_elem_type(output_propagate);
}
} else {
ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ",
output_name, " Op: ", schema.Name());
}
}
}
#endif
Expand All @@ -679,9 +746,6 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai

for (const auto& domain : op_domains) {
#if !defined(ORT_MINIMAL_BUILD)
std::unordered_map<std::string, ONNX_NAMESPACE::OpSchema> schema_map;
std::unordered_map<std::string, InlinedVector<const KernelDef*>> kernel_def_map;

// Domain is not empty - add it to the DomainToVersion ONNX map
// If domain is empty, it is assumed to be part of the ONNX domain
if (!domain->domain_.empty()) {
Expand All @@ -695,6 +759,9 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
}
}

std::unordered_map<std::string, ONNX_NAMESPACE::OpSchema> schema_map;
std::unordered_map<std::string, InlinedVector<const KernelDef*>> kernel_def_map;

for (const auto* op : domain->custom_ops_) {
// define kernel
auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op);
Expand All @@ -711,11 +778,12 @@ common::Status CreateCustomRegistry(gsl::span<OrtCustomOpDomain* const> op_domai
}

std::vector<ONNX_NAMESPACE::OpSchema> schemas;
for (auto schema_iter : schema_map) {
schemas.push_back(schema_iter.second);
InlinedVector<const KernelDef*> kernel_defs = std::move(kernel_def_map[schema_iter.first]);
ONNX_NAMESPACE::InferenceFunction infer_fn = [kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(kernel_defs, infer_ctx);
for (const auto& [name, schema] : schema_map) {
schemas.push_back(schema);
ONNX_NAMESPACE::InferenceFunction infer_fn = [custom_op_schema_and_kernels =
CustomOpSchemaAndKernels{schema, std::move(kernel_def_map[name])}](
ONNX_NAMESPACE::InferenceContext& infer_ctx) {
InferOutputTypes(custom_op_schema_and_kernels, infer_ctx);
};
schemas.back().TypeAndShapeInferenceFunction(infer_fn);
}
Expand Down

0 comments on commit dabb564

Please sign in to comment.