From bb23a1445e6d01f093b028144eb7b70a09bc6e71 Mon Sep 17 00:00:00 2001 From: rachguo Date: Wed, 22 Nov 2023 11:10:15 -0800 Subject: [PATCH] address pr comments --- .../core/providers/coreml/builders/impl/split_op_builder.cc | 4 ++-- onnxruntime/core/providers/shared/utils/utils.cc | 5 ++--- onnxruntime/core/providers/shared/utils/utils.h | 2 +- onnxruntime/test/providers/cpu/tensor/split_op_test.cc | 2 +- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 0f70258a99cfc..815f68128ffaf 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -82,7 +82,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, coreml_splitnd->set_numsplits(num_outputs); } else { // note: for opset 18+ 'num_outputs' is a required attribute - num_outputs = narrow(helper.Get("num_outputs").value()); + num_outputs = narrow(helper.GetInt("num_outputs").value()); // note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); @@ -159,7 +159,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - const auto num_outputs = helper.Get("num_outputs"); + const auto num_outputs = helper.GetInt("num_outputs"); if (!num_outputs.has_value()) { LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; return false; diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 60133cd4fe3f8..39ea4dd8412bb 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -166,11 +166,10 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vector return std::vector{source.cbegin(), source.cend()}; } -std::optional NodeAttrHelper::Get(const std::string& key) const { +std::optional NodeAttrHelper::GetInt(const std::string& key) const { if (!HasAttr(key)) return std::nullopt; - - return SafeInt(node_attributes_.at(key).f()); + return node_attributes_.at(key).i(); } bool NodeAttrHelper::HasAttr(const std::string& key) const { diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index 973ef0a1e6ad2..1e93f040711df 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -58,7 +58,7 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; - std::optional Get(const std::string& key) const; + std::optional GetInt(const std::string& key) const; bool HasAttr(const std::string& key) const; diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index 6d8efbad015fa..7ed08bdc84c35 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -127,7 +127,7 @@ void SplitTestAxis0EqualSplit(bool use_opset_13 = false) { // TensorRT parser: Assertion failed: axis != BATCH_DIM {kTensorrtExecutionProvider}, // is_tensorrt_supported false, // expect_failure - use_opset_13 = true); // split_as_input + true); // split_as_input #endif }