From dabb5644e725267898efd61e18b085439fa5b6eb Mon Sep 17 00:00:00 2001 From: Dmitri Smirnov Date: Mon, 26 Feb 2024 17:56:12 -0800 Subject: [PATCH] Pass schema to InferOutputTypes. This helps detect optional input/outputs. Quit inference on variadic inputs/outputs. --- .../platform/EigenNonBlockingThreadPool.h | 3 + onnxruntime/core/session/custom_ops.cc | 166 ++++++++++++------ 2 files changed, 120 insertions(+), 49 deletions(-) diff --git a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h index 7f0046d137a64..9c288c9fbdafc 100644 --- a/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h +++ b/include/onnxruntime/core/platform/EigenNonBlockingThreadPool.h @@ -634,9 +634,12 @@ class RunQueue { // position, these conditions would be indistinguishable); (2) obtain // consistent snapshot of front_/back_ for Size operation using the // modification counters. +#pragma warning(push) +#pragma warning(disable: 4324) ORT_ALIGN_TO_AVOID_FALSE_SHARING std::atomic front_; ORT_ALIGN_TO_AVOID_FALSE_SHARING std::atomic back_; ORT_ALIGN_TO_AVOID_FALSE_SHARING Elem array_[kSize]; +#pragma warning(pop) // SizeOrNotEmpty returns current queue size; if NeedSizeEstimate is false, // only whether the size is 0 is guaranteed to be correct. diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index 0da67a967119d..613d437f254bf 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -620,55 +620,122 @@ Status IsCompatible(const ONNX_NAMESPACE::OpSchema& schema, const OrtCustomOp* o return Status::OK(); } -void InferOutputTypes(const InlinedVector& kernel_defs, +namespace { +struct CustomOpSchemaAndKernels { + ONNX_NAMESPACE::OpSchema schema; + InlinedVector kernel_defs; +}; +} // namespace + +void InferOutputTypes(const CustomOpSchemaAndKernels& custom_op_schema_and_kernels, ONNX_NAMESPACE::InferenceContext& infer_ctx) { - for (const auto& kernel_def : kernel_defs) { + const auto& schema = custom_op_schema_and_kernels.schema; + const auto& kernel_defs = custom_op_schema_and_kernels.kernel_defs; + const auto& inputs = schema.inputs(); + const auto node_input_num = infer_ctx.getNumInputs(); + + const KernelDef* def_selected = nullptr; + bool variadic_input = false; + int32_t output_propagate{0}; + + for (size_t kernel_index = 0; + kernel_index < kernel_defs.size() && def_selected == nullptr; + ++kernel_index) { + const auto* kernel_def = kernel_defs[kernel_index]; const auto& type_constraints = kernel_def->TypeConstraints(); - auto num_inputs = infer_ctx.getNumInputs(); - bool matched = true; - ONNXTensorElementDataType undef = ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED; - // first, make sure there is a constraint for every input - for (size_t i = 0; i < num_inputs && matched; ++i) { - auto input_name = "Input" + std::to_string(i); - auto input_type = infer_ctx.getInputType(i); - if (input_type) { - auto elem_type = static_cast(input_type->tensor_type().elem_type()); - auto tc_iter = type_constraints.find(input_name); - if (tc_iter != type_constraints.end()) { - if (tc_iter->second.size() > 1) { - undef = elem_type; - } else if (tc_iter->second.size() != 1 || tc_iter->second[0] != DataTypeImpl::TensorTypeFromONNXEnum(elem_type)) { - matched = false; - } - } else { - matched = false; + def_selected = kernel_def; + + for (size_t i = 0; i < node_input_num; ++i) { + const auto input_type = infer_ctx.getInputType(i); + + const size_t schema_input_index = (i < inputs.size()) ? i : inputs.size() - 1; + const auto& param = inputs[schema_input_index]; + const auto& input_name = param.GetName(); + if (input_type == nullptr) { + if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional) + continue; + + ORT_THROW("[CustomOP type inferencing error]: kernel Input: ", input_name, + " is absent, but not optional. Op : ", schema.Name()); + } + + variadic_input |= (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + + auto hit = type_constraints.find(input_name); + if (hit != type_constraints.end()) { + const auto& types = hit->second; + // For custom ops kernel constraints are never empty + assert(!types.empty()); + if (!std::any_of(types.cbegin(), types.cend(), + [input_type](const DataTypeImpl* type) { + return type->IsCompatible(*input_type); + })) { + def_selected = nullptr; + variadic_input = false; + output_propagate = 0; + break; } + + if (types.size() > 1) { + output_propagate = input_type->tensor_type().elem_type(); + } + } else { - matched = false; - } - } // for - // next, ensure that there is a constraint for every output - auto num_outputs = infer_ctx.getNumOutputs(); - for (size_t i = 0; i < num_outputs && matched; i++) { - auto output_name = "Output" + std::to_string(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter == type_constraints.end() || tc_iter->second.size() < 1) { - matched = false; + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for input: ", + input_name, " Op: ", schema.Name()); } } - if (matched) { - for (size_t i = 0; i < num_outputs; i++) { - auto output_name = "Output" + std::to_string(i); - auto output_type = infer_ctx.getOutputType(i); - auto tc_iter = type_constraints.find(output_name); - if (tc_iter->second.size() > 1) { - output_type->mutable_tensor_type()->set_elem_type(undef); - } else { - output_type->mutable_tensor_type()->set_elem_type(tc_iter->second[0]->GetTypeProto()->tensor_type().elem_type()); - } - } + } + + if (def_selected == nullptr) { + ORT_THROW("[CustomOP type inferencing error]: no kernel def matches node inputs for Op: ", schema.Name()); + } + + if (variadic_input) { + return; + } + + const auto& outputs = custom_op_schema_and_kernels.schema.outputs(); + const auto node_output_num = infer_ctx.getNumOutputs(); + const auto& selected_type_constraints = def_selected->TypeConstraints(); + + for (size_t i = 0; i < node_output_num; ++i) { + auto output_type = infer_ctx.getOutputType(i); + // Account for variadic outputs + const size_t schema_output_index = (i < outputs.size()) ? i : outputs.size() - 1; + const auto& param = outputs[schema_output_index]; + const auto& output_name = param.GetName(); + + if (output_type == nullptr) { + if (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Optional) + continue; + + ORT_THROW("[CustomOP type inferencing error]: kernel Output: ", i, + " is absent, but not optional. Op : ", schema.Name()); + } + + const bool variadic_output = (param.GetOption() == ONNX_NAMESPACE::OpSchema::FormalParameterOption::Variadic); + + if (variadic_output) { break; } + + auto hit = selected_type_constraints.find(output_name); + if (hit != selected_type_constraints.end()) { + const auto& types = hit->second; + assert(!types.empty()); + + if (types.size() == 1) { + // Use the constraint type + output_type->mutable_tensor_type()->set_elem_type( + types[0]->GetTypeProto()->tensor_type().elem_type()); + } else { + output_type->mutable_tensor_type()->set_elem_type(output_propagate); + } + } else { + ORT_THROW("[CustomOP type inferencing error]: no type constraint found for output: ", + output_name, " Op: ", schema.Name()); + } } } #endif @@ -679,9 +746,6 @@ common::Status CreateCustomRegistry(gsl::span op_domai for (const auto& domain : op_domains) { #if !defined(ORT_MINIMAL_BUILD) - std::unordered_map schema_map; - std::unordered_map> kernel_def_map; - // Domain is not empty - add it to the DomainToVersion ONNX map // If domain is empty, it is assumed to be part of the ONNX domain if (!domain->domain_.empty()) { @@ -695,6 +759,9 @@ common::Status CreateCustomRegistry(gsl::span op_domai } } + std::unordered_map schema_map; + std::unordered_map> kernel_def_map; + for (const auto* op : domain->custom_ops_) { // define kernel auto kernel_create_info = CreateKernelCreateInfo(domain->domain_, op); @@ -711,11 +778,12 @@ common::Status CreateCustomRegistry(gsl::span op_domai } std::vector schemas; - for (auto schema_iter : schema_map) { - schemas.push_back(schema_iter.second); - InlinedVector kernel_defs = std::move(kernel_def_map[schema_iter.first]); - ONNX_NAMESPACE::InferenceFunction infer_fn = [kernel_defs](ONNX_NAMESPACE::InferenceContext& infer_ctx) { - InferOutputTypes(kernel_defs, infer_ctx); + for (const auto& [name, schema] : schema_map) { + schemas.push_back(schema); + ONNX_NAMESPACE::InferenceFunction infer_fn = [custom_op_schema_and_kernels = + CustomOpSchemaAndKernels{schema, std::move(kernel_def_map[name])}]( + ONNX_NAMESPACE::InferenceContext& infer_ctx) { + InferOutputTypes(custom_op_schema_and_kernels, infer_ctx); }; schemas.back().TypeAndShapeInferenceFunction(infer_fn); }