diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 8091fd4cfc2a3..1667686c1dadf 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -2216,7 +2216,7 @@ struct ShapeInferContext { size_t GetInputCount() const { return input_shapes_.size(); } - Status SetOutputShape(size_t indice, const Shape& shape); + Status SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); int64_t GetAttrInt(const char* attr_name); diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index a732bf169dc7a..a575c22f8775c 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -1996,9 +1996,10 @@ inline ShapeInferContext::ShapeInferContext(const OrtApi* ort_api, } } -inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape) { +inline Status ShapeInferContext::SetOutputShape(size_t indice, const Shape& shape, ONNXTensorElementDataType type) { OrtTensorTypeAndShapeInfo* info = {}; ORT_CXX_RETURN_ON_API_FAIL(ort_api_->CreateTensorTypeAndShapeInfo(&info)); + ORT_CXX_RETURN_ON_API_FAIL(ort_api_->SetTensorElementType(info, type)); using InfoPtr = std::unique_ptr>; diff --git a/onnxruntime/core/session/custom_ops.cc b/onnxruntime/core/session/custom_ops.cc index d0c46142ac060..f8afd5aaf9310 100644 --- a/onnxruntime/core/session/custom_ops.cc +++ b/onnxruntime/core/session/custom_ops.cc @@ -105,6 +105,7 @@ struct OrtShapeInferContext { } } ONNX_NAMESPACE::updateOutputShape(ctx_, index, shape_proto); + ONNX_NAMESPACE::updateOutputElemType(ctx_, index, info->type); return onnxruntime::Status::OK(); }