Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rachguo authored and rachguo committed Nov 22, 2023
1 parent 7a1408f commit bb23a14
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
coreml_splitnd->set_numsplits(num_outputs);
} else {
// note: for opset 18+ 'num_outputs' is a required attribute
num_outputs = narrow<uint64_t>(helper.Get("num_outputs").value());
num_outputs = narrow<uint64_t>(helper.GetInt("num_outputs").value());
// note: checked in IsOpSupportedImpl that ensures the dim value at splitting axis exists
auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())];
uint64_t chunk_size = narrow<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
Expand Down Expand Up @@ -159,7 +159,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
}
} else {
if (node.SinceVersion() >= 18) {
const auto num_outputs = helper.Get("num_outputs");
const auto num_outputs = helper.GetInt("num_outputs");
if (!num_outputs.has_value()) {
LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
return false;
Expand Down
5 changes: 2 additions & 3 deletions onnxruntime/core/providers/shared/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,11 +166,10 @@ std::vector<float> NodeAttrHelper::Get(const std::string& key, const std::vector
return std::vector<float>{source.cbegin(), source.cend()};
}

std::optional<int32_t> NodeAttrHelper::Get(const std::string& key) const {
std::optional<int64_t> NodeAttrHelper::GetInt(const std::string& key) const {
if (!HasAttr(key))
return std::nullopt;

return SafeInt<int32_t>(node_attributes_.at(key).f());
return node_attributes_.at(key).i();
}

bool NodeAttrHelper::HasAttr(const std::string& key) const {
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/shared/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ class NodeAttrHelper {
uint32_t Get(const std::string& key, uint32_t def_val) const;
std::vector<uint32_t> Get(const std::string& key, const std::vector<uint32_t>& def_val) const;

std::optional<int32_t> Get(const std::string& key) const;
std::optional<int64_t> GetInt(const std::string& key) const;

bool HasAttr(const std::string& key) const;

Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ void SplitTestAxis0EqualSplit(bool use_opset_13 = false) {
// TensorRT parser: Assertion failed: axis != BATCH_DIM
{kTensorrtExecutionProvider}, // is_tensorrt_supported
false, // expect_failure
use_opset_13 = true); // split_as_input
true); // split_as_input
#endif
}

Expand Down

0 comments on commit bb23a14

Please sign in to comment.