Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhenze Wang committed Aug 15, 2024
1 parent 7e08ac5 commit 63e8638
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ enum OperatorStatus : int;
using DataType = const std::string*;
using DataTypeSet = std::unordered_set<DataType>;
using TypeConstraintMap = std::unordered_map<std::string, std::pair<DataTypeSet, std::string>>;

} // namespace ONNX_NAMESPACE

namespace onnxruntime {
Expand Down Expand Up @@ -566,7 +567,7 @@ struct ProviderHost {
virtual int FunctionProto__metadata_props_size(const ONNX_NAMESPACE::FunctionProto* p) = 0;
virtual ONNX_NAMESPACE::StringStringEntryProto* FunctionProto__add_metadata_props(ONNX_NAMESPACE::FunctionProto* p) = 0;

virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op) = 0;
virtual void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) = 0;
virtual const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) = 0;
virtual const std::string& OpSchema__inputs__GetName(const ONNX_NAMESPACE::OpSchema* p, const int i) = 0;
virtual const std::string& OpSchema__inputs__GetTypeStr(const ONNX_NAMESPACE::OpSchema* p, const int i) = 0;
Expand Down
9 changes: 8 additions & 1 deletion onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@ void register_xir_ops(const std::vector<OrtCustomOpDomain*>& domains) {
for (auto domain : domains) {
for (auto op : domain->custom_ops_) {
if (Provider_GetHost()->GetSchema(op->GetName(op), op->GetStartVersion(op), domain->domain_) == nullptr) {
Provider_GetHost()->RegisterSchema(domain->domain_, op);
auto name = op->GetName(op);
if ((std::string)name == "super_layer") {
Provider_GetHost()->RegisterSchema(domain->domain_, op, 1);
} else if ((std::string)name == "FixNeuron") {
Provider_GetHost()->RegisterSchema(domain->domain_, op, 2);
} else {
Provider_GetHost()->RegisterSchema(domain->domain_, op, 3);
}
}
}
}
Expand Down
74 changes: 73 additions & 1 deletion onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -732,13 +732,85 @@ struct ProviderHostImpl : ProviderHost {
return elemType;
}

void RegisterSchema(const std::string& domain, const OrtCustomOp* op) override {
static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) {
auto num_output = ctx.getNumOutputs();
if (num_output == 1) {
auto* shape = ctx.getAttribute("shape");
auto* data_type = ctx.getAttribute("data_type");
if (data_type == nullptr) {
std::cerr << "Custom op is missing `data_type` attr." << std::endl;
return;
}
int32_t elemType = convert_elem_type(data_type);
ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType);
if (shape != nullptr) {
for (auto i = 0; i < shape->ints_size(); ++i) {
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
}
} else {
// set scalar type.
ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim();
}
} else {
for (auto idx = 0u; idx < num_output; idx++) {
auto* shape = ctx.getAttribute("shape_" + std::to_string(idx));
auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx));
if (shape == nullptr || data_type == nullptr) {
// this output is optional
} else {
int32_t elemType = convert_elem_type(data_type);
ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType);
for (auto i = 0; i < shape->ints_size(); ++i) {
ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i));
}
}
}
}
}

static void xir_fixneuron_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
}

static void xir_subgraph_shape_inference(ONNX_NAMESPACE::InferenceContext& ctx) {
auto num_inputs = ctx.getNumInputs();

// Run inferencing on the subgraph
auto* graphInferencer = ctx.getGraphAttributeInferencer("body");

std::vector<const ONNX_NAMESPACE::TensorProto*> input_data;
std::vector<const ONNX_NAMESPACE::TypeProto*> subgraph_input_types;
for (size_t i = 0; i < num_inputs; ++i) {
input_data.push_back(ctx.getInputData(i));
subgraph_input_types.push_back(ctx.getInputType(i));
}

auto output_types = graphInferencer->doInferencing(subgraph_input_types, input_data);
for (size_t i = 0, end = output_types.size(); i < end; ++i) {
*ctx.getOutputType(i) = *output_types[i];
}
}
void RegisterSchema(const std::string& domain, const OrtCustomOp* op, int type) override {
auto& domain_instance = ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange::Instance();
const auto& domain_to_version_map = domain_instance.Map();
if (domain_to_version_map.find(domain) == domain_to_version_map.end()) {
domain_instance.AddDomainToVersion(domain, 1, 1000);
}
auto schema = CreateSchema(domain, {op});
switch (type) {
case 1:
schema.TypeAndShapeInferenceFunction(xir_subgraph_shape_inference);
break;
case 2:
schema.TypeAndShapeInferenceFunction(xir_fixneuron_shape_inference);
break;
case 3:
schema.TypeAndShapeInferenceFunction(xir_shape_infer);
break;
default:
break;
}
ONNX_NAMESPACE::RegisterSchema(schema, ORT_API_VERSION);
}
const ONNX_NAMESPACE::OpSchema* GetSchema(const std::string& name, const int maxInclusiveVersion, const std::string& domain) override {
Expand Down

0 comments on commit 63e8638

Please sign in to comment.