From d1d40fbafdb2935f6f646d7ab595a061ce3bc1b1 Mon Sep 17 00:00:00 2001 From: mingyueliuh <131847423+mingyueliuh@users.noreply.github.com> Date: Sun, 18 Aug 2024 16:51:25 -0400 Subject: [PATCH] [VitisAI][Fix] ShapeInferContext GetAttrxxxs support empty value (#21471) ### Description Bug fix for the ShapeInferContext GetAttrxxxs APIs. Node attribute maybe is empty. ### Motivation and Context If the attr value is empty, the expected result through the interface is empty , but currently, it returns a meaningless {0}. --------- Co-authored-by: mingyue Co-authored-by: Liu Minyue --- .../core/session/onnxruntime_cxx_inline.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 9b9dd81a749c0..d3a8cade4d28f 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -2044,6 +2044,9 @@ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_n int64_t i = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, &i, sizeof(i), &out); if (status) { size_t num_i = out / sizeof(int64_t); @@ -2051,6 +2054,9 @@ inline ShapeInferContext::Ints ShapeInferContext::GetAttrInts(const char* attr_n Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_INTS, ints.data(), out, &out)); return ints; } else { + if (out == 0u) { + return {}; + } return {i}; } } @@ -2068,6 +2074,9 @@ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* at float f = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, &f, sizeof(f), &out); if (status) { size_t num_f = out / sizeof(float); @@ -2075,6 +2084,9 @@ inline ShapeInferContext::Floats ShapeInferContext::GetAttrFloats(const char* at Ort::ThrowOnError(ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_FLOATS, floats.data(), out, &out)); return floats; } else { + if (out == 0u) { + return {}; + } return {f}; } } @@ -2099,6 +2111,9 @@ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* char c = {}; size_t out = {}; // first call to get the bytes needed + // 1. A status == nullptr means that ReadOpAttr was successful. A status != nullptr means failure. + // 2. The ReadOpAttr function should normally be called twice: once to get the needed buffer size (returns a status != nullptr), and a second time to actually read the ints (returns status == null on success). + // 3. This code tries a subtle optimization in the first call to ReadOpAttr. It passes in a buffer (&i) of size 1 just in case there is only 1 int. In this case, status == nullptr and we need to return {i}. auto status = ort_api_->ReadOpAttr(attr, ORT_OP_ATTR_STRINGS, &c, sizeof(char), &out); if (status) { std::vector chars(out, '\0'); @@ -2115,6 +2130,9 @@ inline ShapeInferContext::Strings ShapeInferContext::GetAttrStrings(const char* } return strings; } else { + if (out == 0u) { + return {}; + } return {std::string{c}}; } }