From 1bc215e1d1c1e3509a1dd0bc413b1537563dedb5 Mon Sep 17 00:00:00 2001 From: Yiming Hu Date: Thu, 21 Sep 2023 19:22:28 -0700 Subject: [PATCH] [VITISAI] add float16 and bfloat16 support (#17438) ### Description Add float16 and bfloat16 data type support for VitisAI ep ### Motivation and Context The VitisAI ep has added the bfloat datatype support. So we would like to register the datatype from onnxruntime side to enable them. --------- Signed-off-by: Yiming Hu --- onnxruntime/core/providers/vitisai/README.md | 2 +- onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/onnxruntime/core/providers/vitisai/README.md b/onnxruntime/core/providers/vitisai/README.md index 15e0c804489c5..6ddb58b8d96ae 100644 --- a/onnxruntime/core/providers/vitisai/README.md +++ b/onnxruntime/core/providers/vitisai/README.md @@ -1,4 +1,4 @@ -VitsAI Execution Prividers +VitisAI Execution Provider ============================ diff --git a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc index 544e18350635d..ee8dfc6d03d12 100644 --- a/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc +++ b/onnxruntime/core/providers/vitisai/imp/register_xir_ops.cc @@ -34,9 +34,12 @@ static void xir_shape_infer(ONNX_NAMESPACE::InferenceContext& ctx) { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::INT64); } else if (data_type->s() == "int1") { updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BOOL); + } else if (data_type->s() == "bfloat16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::BFLOAT16); + } else if (data_type->s() == "float16") { + updateOutputElemType(ctx, 0, ONNX_NAMESPACE::TensorProto::FLOAT16); } else { - std::cerr << "not supported data_type " << data_type->s(); - abort(); + vai_assert(false, ", not supported data_type: " + data_type->s()); } if (shape != nullptr) { for (auto i = 0; i < shape->ints_size(); ++i) {