diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 1bb013d0cdc10..b53e70926cd5d 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -602,9 +602,7 @@ struct ProviderHostImpl : ProviderHost { const ONNX_NAMESPACE::ValueInfoProto& ValueInfoProtos__operator_array(const ONNX_NAMESPACE::ValueInfoProtos* p, int index) override { return (*p)[index]; } - static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { - auto* shape = ctx.getAttribute("shape"); - auto* data_type = ctx.getAttribute("data_type"); + static int32_t convert_elem_type(const ONNX_NAMESPACE::AttributeProto* data_type) { int32_t elemType = 0; if (data_type->s() == "float32") { elemType = ONNX_NAMESPACE::TensorProto_DataType_FLOAT; @@ -650,17 +648,43 @@ struct ProviderHostImpl : ProviderHost { elemType = ONNX_NAMESPACE::TensorProto_DataType_UINT4; } else if (data_type->s() == "int4") { elemType = ONNX_NAMESPACE::TensorProto_DataType_INT4; - } else { - return; } - ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); - if (shape != nullptr) { - for (auto i = 0; i < shape->ints_size(); ++i) { - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + return elemType; + } + + static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { + auto num_output = ctx.getNumOutputs(); + if (num_output == 1) { + auto* shape = ctx.getAttribute("shape"); + auto* data_type = ctx.getAttribute("data_type"); + if (data_type == nullptr) { + std::cerr << "Custom op is missing `data_type` attr." << std::endl; + return; + } + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, 0, elemType); + if (shape != nullptr) { + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } else { + // set scalar type. + ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); } } else { - // set scalar type. - ONNX_NAMESPACE::getOutputShape(ctx, 0, ONNX_NAMESPACE::TypeProto::kTensorType)->clear_dim(); + for (auto idx = 0u; idx < num_output; idx++) { + auto* shape = ctx.getAttribute("shape_" + std::to_string(idx)); + auto* data_type = ctx.getAttribute("data_type_" + std::to_string(idx)); + if (shape == nullptr || data_type == nullptr) { + // this output is optional + } else { + int32_t elemType = convert_elem_type(data_type); + ONNX_NAMESPACE::updateOutputElemType(ctx, idx, elemType); + for (auto i = 0; i < shape->ints_size(); ++i) { + ONNX_NAMESPACE::getOutputShape(ctx, idx, ONNX_NAMESPACE::TypeProto::kTensorType)->add_dim()->set_dim_value(shape->ints(i)); + } + } + } } }