diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index aed640ea195d7..2935e116906f5 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -29,6 +29,7 @@ enum OperatorStatus : int; using DataType = const std::string*; using DataTypeSet = std::unordered_set; using TypeConstraintMap = std::unordered_map>; + } // namespace ONNX_NAMESPACE namespace onnxruntime { @@ -566,7 +567,7 @@ struct ProviderHost { virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0; virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0; - virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0; + virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0; virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0; virtual const std::string& OpSchema__inputs__GetName(const ONNX_NAMESPACE::OpSchema* p, const int i) = 0; virtual const std::string& OpSchema__inputs__GetTypeStr(const ONNX_NAMESPACE::OpSchema* p, const int i) = 0; diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index ea5687f2691b1..1082c0be79937 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -13,7 +13,14 @@ void register_xir_ops(const std::vector& domains) { for (auto domain : domains) { for (auto op : domain->custom_ops_) { if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) == nullptr) { - Provider_GetHost()->RegisterSchema(domain->domain_, op); + auto name = op->GetName(op); + if ((std::string)name == "super_layer") { + Provider_GetHost()->RegisterSchema(domain->domain_, op, 1); + } else if ((std::string)name == "FixNeuron") { + Provider_GetHost()->RegisterSchema(domain->domain_, op, 2); + } else { + Provider_GetHost()->RegisterSchema(domain->domain_, op, 3); + } } } } diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 9348c6b3cb1f5..80da485e64b28 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -732,13 +732,85 @@ struct ProviderHostImpl : ProviderHost { return elemType; } - void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override { + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_output = ctx.getNumOutputs(); + if (num_output == 1) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + if (data_type == nullptr) { + std::cerr << "Custom op is missing `data_type` attr." << std::endl; + return; + } + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + } + } else { + for (auto idx = 0u; idx < num_output; idx++) { + auto* shape = ctx.getAttribute("shape_" + std::to_string(idx)); + auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx)); + if (shape == nullptr || data_type == nullptr) { + // this output is optional + } else { + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType); + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } + } + } + } + + static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0); + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); + } + + static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_inputs = ctx.getNumInputs(); + + // Run inferencing on the subgraph + auto* graphInferencer = ctx.getGraphAttributeInferencer("body"); + + std::vector input_data; + std::vector subgraph_input_types; + for (size_t i = 0; i < num_inputs; ++i) { + input_data.push_back(ctx.getInputData(i)); + subgraph_input_types.push_back(ctx.getInputType(i)); + } + + auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data); + for (size_t i = 0, end = output_types.size(); i < end; ++i) { + *ctx.getOutputType(i) = *output_types[i]; + } + } + void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override { auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance(); const auto& domain_to_version_map = domain_instance.Map(); if (domain_to_version_map.find(domain) == domain_to_version_map.end()) { domain_instance.AddDomainToVersion(domain, 1, 1000); } auto schema = CreateSchema(domain, {op}); + switch (type) { + case 1: + schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference); + break; + case 2: + schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference); + break; + case 3: + schema.TypeAndShapeInferenceFunction(xir_shape_infer); + break; + default: + break; + } ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION); } const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) override {