Skip to content

Commit

Permalink
address pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rachguo authored and rachguo committed Nov 21, 2023
1 parent 0d96ebe commit d8d5b29
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 45 deletions.
32 changes: 18 additions & 14 deletions onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const auto& input_defs = node.InputDefs();

std::vector<int64_t> data_shape;
ORT_RETURN_IF_NOT(GetStaticShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape.");
ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape.");

NodeAttrHelper helper(node);
const auto axis = helper.Get("axis", 0);
Expand All @@ -71,26 +71,26 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
// if "split" is explicitly provided as an input
const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name());
Initializer unpacked_tensor(split_tensor);
auto split_span = unpacked_tensor.DataAsSpan<int64_t>();
auto split_span = unpacked_tensor.DataAsSpan<uint64_t>();
auto split_sizes = split_span.size();
num_outputs = SafeInt<uint64_t>(split_sizes);
num_outputs = narrow<uint64_t>(split_sizes);
for (size_t i = 0; i < split_sizes; i++) {
coreml_splitnd->add_splitsizes(SafeInt<uint64_t>(split_span[i]));
coreml_splitnd->add_splitsizes(split_span[i]);
}
} else if (node.SinceVersion() < 18) {
num_outputs = node.OutputDefs().size();
num_outputs = narrow<uint64_t>(node.OutputDefs().size());
coreml_splitnd->set_numsplits(num_outputs);
} else {
num_outputs = SafeInt<uint64_t>(helper.Get("num_outputs", 2));
num_outputs = static_cast<uint64_t>(helper.Get("num_outputs").value());
auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())];
uint64_t chunk_size = SafeInt<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
uint64_t chunk_size = narrow<uint64_t>((split_dim_size + num_outputs - 1) / num_outputs);
uint64_t remainder = split_dim_size % chunk_size;
if (remainder) {
// uneven
auto split_sizes = std::vector<uint64_t>(num_outputs, chunk_size);
split_sizes.back() = remainder;
for (size_t i = 0; i < split_sizes.size(); i++) {
coreml_splitnd->add_splitsizes(SafeInt<uint64_t>(split_sizes[i]));
coreml_splitnd->add_splitsizes(split_sizes[i]);
}
} else {
// even
Expand Down Expand Up @@ -120,10 +120,9 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar

NodeAttrHelper helper(node);
const auto axis = helper.Get("axis", 0);
const auto num_outputs = helper.Get("num_outputs", -1);

std::vector<int64_t> input_shape;
if (!GetStaticShape(*input_defs[0], input_shape, logger))
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())];
Expand Down Expand Up @@ -154,15 +153,20 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar
}
} else {
if (node.SinceVersion() >= 18) {
if (num_outputs < 2) {
const auto num_outputs = helper.Get("num_outputs");
if (!num_outputs.has_value()) {
LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute.";
return false;
}
if (num_outputs.value() < 2) {
LOGS(logger, VERBOSE) << "Invalid num_outputs. The value can not be lower than 1.\n"
<< "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs;
<< "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value();
return false;
}
if (num_outputs != static_cast<int32_t>(node.OutputDefs().size()) || num_outputs > split_dims_at_axis) {
if (num_outputs.value() != static_cast<int32_t>(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) {

Check warning on line 166 in onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc#L166

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc:166:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n."
<< "The value should be smaller or equal to the size of dimension being split. num_outputs: "

Check warning on line 168 in onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc

View workflow job for this annotation

GitHub Actions / cpplint

[cpplint] onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc#L168

Lines should be <= 120 characters long [whitespace/line_length] [2]
Raw output
onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc:168:  Lines should be <= 120 characters long  [whitespace/line_length] [2]
<< num_outputs;
<< num_outputs.value();
return false;
}
}
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/shared/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ std::vector<float> NodeAttrHelper::Get(const std::string& key, const std::vector
return std::vector<float>{source.cbegin(), source.cend()};
}

std::optional<int32_t> NodeAttrHelper::Get(const std::string& key) const {
if (!HasAttr(key))
return std::nullopt;

return SafeInt<int32_t>(node_attributes_.at(key).f());
}

bool NodeAttrHelper::HasAttr(const std::string& key) const {
return Contains(node_attributes_, key);
}
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/shared/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ class NodeAttrHelper {
uint32_t Get(const std::string& key, uint32_t def_val) const;
std::vector<uint32_t> Get(const std::string& key, const std::vector<uint32_t>& def_val) const;

std::optional<int32_t> Get(const std::string& key) const;

bool HasAttr(const std::string& key) const;

private:
Expand Down
46 changes: 15 additions & 31 deletions onnxruntime/test/providers/cpu/tensor/split_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,11 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) {
{3.f, 4.f,
5.f, 6.f,
7.f, 8.f}});
// TensorRT parser: Assertion failed: axis != BATCH_DIM
#ifdef USE_COREML
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif

// TensorRT parser: Assertion failed: axis != BATCH_DIM
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
// CoreML EP, etc. requires split to be an input. Same applies to below sets of tests.
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
}

TEST(SplitOperatorTest, Axis0UnequalSplitString) {
Expand All @@ -195,10 +195,8 @@ TEST(SplitOperatorTest, Axis0UnequalSplitString) {
{"c", "d",
"e", "f",
"g", "h"}});
// TensorRT parser: Assertion failed: axis != BATCH_DIM
#ifdef USE_COREML
// TensorRT parser: Assertion failed: axis != BATCH_DIM
RunTest<std::string>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<std::string>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -218,9 +216,7 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) {
outputs.push_back({{2, 2},
{3.f, 4.f,
7.f, 8.f}});
#ifdef USE_COREML
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -240,9 +236,8 @@ TEST(SplitOperatorTest, Axis1EqualSplitString) {
outputs.push_back({{2, 2},
{"c", "d",
"g", "h"}});
#ifdef USE_COREML

RunTest<std::string>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<std::string>(axis, {}, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -264,9 +259,8 @@ TEST(SplitOperatorTest, Axis1UnequalSplitFloat) {
outputs.push_back({{2, 1},
{4.f,
8.f}});
#ifdef USE_COREML

RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -288,9 +282,8 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) {
outputs.push_back({{2, 1},
{"d",
"h"}});
#ifdef USE_COREML

RunTest<std::string>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<std::string>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
}

Expand Down Expand Up @@ -332,9 +325,8 @@ TEST(SplitOperatorTest, Axis2EqualSplit) {

17.f, 18.f,
23.f, 24.f}});
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider});
}

Expand Down Expand Up @@ -366,9 +358,8 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) {

16.f, 17.f, 18.f,
22.f, 23.f, 24.f}});
#ifdef USE_COREML

RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
}

Expand Down Expand Up @@ -401,9 +392,8 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) {

25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f}});
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -429,9 +419,8 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) {
21.f, 22.f, 23.f, 24.f,
25.f, 26.f, 27.f, 28.f,
29.f, 30.f, 31.f, 32.f}});
#ifdef USE_COREML

RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -451,9 +440,8 @@ TEST(SplitOperatorTest, NegativeAxis) {
outputs.push_back({{2, 2},
{3.f, 4.f,
7.f, 8.f}});
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true);
#endif
RunTest<float>(axis, {}, input, outputs, {kTensorrtExecutionProvider});
}

Expand All @@ -469,9 +457,8 @@ TEST(SplitOperatorTest, InvalidAxis) {
7.f, 8.f}};

outputs.push_back({{1}, {0.f}});
#ifdef USE_COREML

RunTest<float>(axis, {}, input, outputs, {}, true, true, -1, true, "Invalid value of attribute 'axis'");
#endif
RunTest<float>(axis, {}, input, outputs, {}, true, false, -1, true, "Invalid value of attribute 'axis'");
}

Expand All @@ -491,10 +478,9 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) {

outputs.push_back({{1, 2}, {1.f, 2.f}});
outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}});
#ifdef USE_COREML

RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true,
"[ShapeInferenceError] Mismatch between the sum of 'split'");
#endif
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true,
"[ShapeInferenceError] Mismatch between the sum of 'split'"); // TensorRT parser: Assertion failed: axis != BATCH_DIM
}
Expand All @@ -514,10 +500,8 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) {
outputs.push_back({{1, 2}, {1.f, 2.f}});
outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}});

#ifdef USE_COREML
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true,
"[ShapeInferenceError] Mismatch between number of splits");
#endif
RunTest<float>(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true,
"[ShapeInferenceError] Mismatch between number of splits"); // TensorRT parser: Assertion failed: axis != BATCH_DIM
}
Expand Down

0 comments on commit d8d5b29

Please sign in to comment.