diff --git a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc index 1ccc9e150da96..bde2291c6809b 100644 --- a/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/coreml/builders/impl/split_op_builder.cc @@ -55,7 +55,7 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); std::vector data_shape; - ORT_RETURN_IF_NOT(GetStaticShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); + ORT_RETURN_IF_NOT(GetShape(*node.InputDefs()[0], data_shape, logger), "Failed to get input shape."); NodeAttrHelper helper(node); const auto axis = helper.Get("axis", 0); @@ -71,26 +71,26 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // if "split" is explicitly provided as an input const auto& split_tensor = *model_builder.GetInitializerTensors().at(input_defs[1]->Name()); Initializer unpacked_tensor(split_tensor); - auto split_span = unpacked_tensor.DataAsSpan(); + auto split_span = unpacked_tensor.DataAsSpan(); auto split_sizes = split_span.size(); - num_outputs = SafeInt(split_sizes); + num_outputs = narrow(split_sizes); for (size_t i = 0; i < split_sizes; i++) { - coreml_splitnd->add_splitsizes(SafeInt(split_span[i])); + coreml_splitnd->add_splitsizes(split_span[i]); } } else if (node.SinceVersion() < 18) { - num_outputs = node.OutputDefs().size(); + num_outputs = narrow(node.OutputDefs().size()); coreml_splitnd->set_numsplits(num_outputs); } else { - num_outputs = SafeInt(helper.Get("num_outputs", 2)); + num_outputs = static_cast(helper.Get("num_outputs").value()); auto split_dim_size = data_shape[HandleNegativeAxis(axis, data_shape.size())]; - uint64_t chunk_size = SafeInt((split_dim_size + num_outputs - 1) / num_outputs); + uint64_t chunk_size = narrow((split_dim_size + num_outputs - 1) / num_outputs); uint64_t remainder = split_dim_size % chunk_size; if (remainder) { // uneven auto split_sizes = std::vector(num_outputs, chunk_size); split_sizes.back() = remainder; for (size_t i = 0; i < split_sizes.size(); i++) { - coreml_splitnd->add_splitsizes(SafeInt(split_sizes[i])); + coreml_splitnd->add_splitsizes(split_sizes[i]); } } else { // even @@ -120,10 +120,9 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar NodeAttrHelper helper(node); const auto axis = helper.Get("axis", 0); - const auto num_outputs = helper.Get("num_outputs", -1); std::vector input_shape; - if (!GetStaticShape(*input_defs[0], input_shape, logger)) + if (!GetShape(*input_defs[0], input_shape, logger)) return false; const auto split_dims_at_axis = input_shape[HandleNegativeAxis(axis, input_shape.size())]; @@ -154,15 +153,20 @@ bool SplitOpBuilder::IsOpSupportedImpl(const Node& node, const OpBuilderInputPar } } else { if (node.SinceVersion() >= 18) { - if (num_outputs < 2) { + const auto num_outputs = helper.Get("num_outputs"); + if (!num_outputs.has_value()) { + LOGS(logger, VERBOSE) << "No 'num_outputs' provided. For split 18+, num_outputs is a required attribute."; + return false; + } + if (num_outputs.value() < 2) { LOGS(logger, VERBOSE) << "Invalid num_outputs. The value can not be lower than 1.\n" - << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs; + << "CoreML SplitND requires at least 2 outputs. num_outputs: " << num_outputs.value(); return false; } - if (num_outputs != static_cast(node.OutputDefs().size()) || num_outputs > split_dims_at_axis) { + if (num_outputs.value() != static_cast(node.OutputDefs().size()) || num_outputs.value() > split_dims_at_axis) { LOGS(logger, VERBOSE) << "Invalid num_outputs provided.\n." << "The value should be smaller or equal to the size of dimension being split. num_outputs: " - << num_outputs; + << num_outputs.value(); return false; } } diff --git a/onnxruntime/core/providers/shared/utils/utils.cc b/onnxruntime/core/providers/shared/utils/utils.cc index 6b1207d3d16f0..60133cd4fe3f8 100644 --- a/onnxruntime/core/providers/shared/utils/utils.cc +++ b/onnxruntime/core/providers/shared/utils/utils.cc @@ -166,6 +166,13 @@ std::vector NodeAttrHelper::Get(const std::string& key, const std::vector return std::vector{source.cbegin(), source.cend()}; } +std::optional NodeAttrHelper::Get(const std::string& key) const { + if (!HasAttr(key)) + return std::nullopt; + + return SafeInt(node_attributes_.at(key).f()); +} + bool NodeAttrHelper::HasAttr(const std::string& key) const { return Contains(node_attributes_, key); } diff --git a/onnxruntime/core/providers/shared/utils/utils.h b/onnxruntime/core/providers/shared/utils/utils.h index db07938c1897e..0cd4bb62552d8 100644 --- a/onnxruntime/core/providers/shared/utils/utils.h +++ b/onnxruntime/core/providers/shared/utils/utils.h @@ -57,6 +57,8 @@ class NodeAttrHelper { uint32_t Get(const std::string& key, uint32_t def_val) const; std::vector Get(const std::string& key, const std::vector& def_val) const; + std::optional Get(const std::string& key) const; + bool HasAttr(const std::string& key) const; private: diff --git a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc index af58f0770ca5a..f0f7acd5841d6 100644 --- a/onnxruntime/test/providers/cpu/tensor/split_op_test.cc +++ b/onnxruntime/test/providers/cpu/tensor/split_op_test.cc @@ -169,11 +169,11 @@ TEST(SplitOperatorTest, Axis0UnequalSplitFloat) { {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); -// TensorRT parser: Assertion failed: axis != BATCH_DIM -#ifdef USE_COREML - RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif + + // TensorRT parser: Assertion failed: axis != BATCH_DIM RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); + // CoreML EP, etc. requires split to be an input. Same applies to below sets of tests. + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); } TEST(SplitOperatorTest, Axis0UnequalSplitString) { @@ -195,10 +195,8 @@ TEST(SplitOperatorTest, Axis0UnequalSplitString) { {"c", "d", "e", "f", "g", "h"}}); -// TensorRT parser: Assertion failed: axis != BATCH_DIM -#ifdef USE_COREML + // TensorRT parser: Assertion failed: axis != BATCH_DIM RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -218,9 +216,7 @@ TEST(SplitOperatorTest, Axis1EqualSplitFloat) { outputs.push_back({{2, 2}, {3.f, 4.f, 7.f, 8.f}}); -#ifdef USE_COREML RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -240,9 +236,8 @@ TEST(SplitOperatorTest, Axis1EqualSplitString) { outputs.push_back({{2, 2}, {"c", "d", "g", "h"}}); -#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -264,9 +259,8 @@ TEST(SplitOperatorTest, Axis1UnequalSplitFloat) { outputs.push_back({{2, 1}, {4.f, 8.f}}); -#ifdef USE_COREML + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -288,9 +282,8 @@ TEST(SplitOperatorTest, Axis1UnequalSplitString) { outputs.push_back({{2, 1}, {"d", "h"}}); -#ifdef USE_COREML + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -332,9 +325,8 @@ TEST(SplitOperatorTest, Axis2EqualSplit) { 17.f, 18.f, 23.f, 24.f}}); -#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -366,9 +358,8 @@ TEST(SplitOperatorTest, Axis2UnequalSplit) { 16.f, 17.f, 18.f, 22.f, 23.f, 24.f}}); -#ifdef USE_COREML + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -401,9 +392,8 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionEqually) { 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); -#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -429,9 +419,8 @@ TEST(SplitOperatorTest, Axis1SplitMiddleDimensionUnequally) { 21.f, 22.f, 23.f, 24.f, 25.f, 26.f, 27.f, 28.f, 29.f, 30.f, 31.f, 32.f}}); -#ifdef USE_COREML + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}); } @@ -451,9 +440,8 @@ TEST(SplitOperatorTest, NegativeAxis) { outputs.push_back({{2, 2}, {3.f, 4.f, 7.f, 8.f}}); -#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}, false, true); -#endif RunTest(axis, {}, input, outputs, {kTensorrtExecutionProvider}); } @@ -469,9 +457,8 @@ TEST(SplitOperatorTest, InvalidAxis) { 7.f, 8.f}}; outputs.push_back({{1}, {0.f}}); -#ifdef USE_COREML + RunTest(axis, {}, input, outputs, {}, true, true, -1, true, "Invalid value of attribute 'axis'"); -#endif RunTest(axis, {}, input, outputs, {}, true, false, -1, true, "Invalid value of attribute 'axis'"); } @@ -491,10 +478,9 @@ TEST(SplitOperatorTest, SplitAttributeSumTooSmall) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{2, 2}, {3.f, 4.f, 5.f, 6.f}}); -#ifdef USE_COREML + RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true, "[ShapeInferenceError] Mismatch between the sum of 'split'"); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true, "[ShapeInferenceError] Mismatch between the sum of 'split'"); // TensorRT parser: Assertion failed: axis != BATCH_DIM } @@ -514,10 +500,8 @@ TEST(SplitOperatorTest, InvalidValueInSplitAttribute) { outputs.push_back({{1, 2}, {1.f, 2.f}}); outputs.push_back({{3, 2}, {3.f, 4.f, 5.f, 6.f, 7.f, 8.f}}); -#ifdef USE_COREML RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, true, -1, true, "[ShapeInferenceError] Mismatch between number of splits"); -#endif RunTest(axis, splits, input, outputs, {kTensorrtExecutionProvider}, true, false, -1, true, "[ShapeInferenceError] Mismatch between number of splits"); // TensorRT parser: Assertion failed: axis != BATCH_DIM }