From d2fab1e128a99b56a81b4fce6a651378e0209370 Mon Sep 17 00:00:00 2001 From: zesongw Date: Wed, 20 Mar 2024 11:02:54 +0800 Subject: [PATCH 1/3] [WebNN EP] Support Split before opset13 --- .../webnn/builders/impl/split_op_builder.cc | 25 +++++++++---------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 9819e4ce7ac5b..e809e5200da01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -28,8 +28,6 @@ class SplitOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - - int GetMinSupportedOpSet(const Node& node) const override; }; // Add operator related. @@ -57,12 +55,18 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, axis = SafeInt(HandleNegativeAxis(axis, rank)); options.set("axis", axis); - if (!GetTensorName(input_defs, 1).empty()) { - // Inputs contains optional 'split' input - std::vector splits; - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& split_tensor = *initializers.at(input_defs[1]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split."); + // Inputs contains optional 'split' input or attribute. + if (!GetTensorName(input_defs, 1).empty() || node.SinceVersion() < 13) { + std::vector splits; + // Before opset13, split is an attribute. + if (node.SinceVersion() < 13) { + ORT_RETURN_IF_NOT(helper.HasAttr("split"), "Cannot get split."); + splits = helper.Get("split", std::vector{}); + } else { + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& split_tensor = *initializers.at(input_defs[1]->Name()); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split."); + } output_array = model_builder.GetBuilder().call("split", input, emscripten::val::array(splits), @@ -112,11 +116,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related. -int SplitOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { - // Since opset 13, Split has optional 'split' input. - return 13; -} - bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, From 085318c8beec55c61fb07fddec0b1451687e5dcd Mon Sep 17 00:00:00 2001 From: zesongw Date: Fri, 22 Mar 2024 15:27:27 +0800 Subject: [PATCH 2/3] Resolve comment --- .../webnn/builders/impl/split_op_builder.cc | 109 ++++++++---------- 1 file changed, 45 insertions(+), 64 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index e809e5200da01..9c03b6dd51aee 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -55,59 +55,35 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, axis = SafeInt(HandleNegativeAxis(axis, rank)); options.set("axis", axis); - // Inputs contains optional 'split' input or attribute. - if (!GetTensorName(input_defs, 1).empty() || node.SinceVersion() < 13) { - std::vector splits; - // Before opset13, split is an attribute. - if (node.SinceVersion() < 13) { - ORT_RETURN_IF_NOT(helper.HasAttr("split"), "Cannot get split."); - splits = helper.Get("split", std::vector{}); - } else { - const auto& initializers(model_builder.GetInitializerTensors()); - const auto& split_tensor = *initializers.at(input_defs[1]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split."); - } - output_array = model_builder.GetBuilder().call("split", - input, - emscripten::val::array(splits), - options); - ORT_RETURN_IF_NOT(output_array["length"].as() == static_cast(splits.size()), - "The size of outputs must be equal to the size of 'split' input."); + uint32_t split_count = 0; + std::vector splits = helper.Get("split", std::vector{}); + + // Read either the split count or explicit split lengths from the various attributes over opset versions. + if (helper.HasAttr("num_outputs")) { + split_count = helper.Get("num_outputs", 0); + } else if (GetTensorName(input_defs, 1).size()) { + const auto& initializers(model_builder.GetInitializerTensors()); + const auto& split_tensor = *initializers.at(input_defs[1]->Name()); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split."); + } else if (!helper.HasAttr("split")) { + split_count = node.OutputDefs().size(); + } + + // Check that the splits evenly divide. + if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) { + // Divide inputs into variable size outputs: + splits.insert(splits.end(), split_count - 1, gsl::narrow(input_shape[axis]) / split_count); + splits.insert(splits.end(), gsl::narrow(input_shape[axis]) % split_count); + } + + if (splits.empty()) { + output_array = model_builder.GetBuilder().call( + "split", input, split_count, options); } else { - if (helper.HasAttr("num_outputs")) { - const int32_t num_outputs = helper.Get("num_outputs", 1); - ORT_RETURN_IF_NOT(num_outputs > 0, "The 'num_outputs' must be a positive integer."); - if (input_shape[axis] % num_outputs == 0) { - // The 'num_outputs' evenly divide the dim value at 'axis' specified. - output_array = model_builder.GetBuilder().call("split", - input, - num_outputs, - options); - } else { - std::vector mapping_split; - mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs); - mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs); - std::vector converted_splits = GetVecUint32FromVecInt64(mapping_split); - output_array = model_builder.GetBuilder().call("split", - input, - emscripten::val::array(converted_splits), - options); - } - ORT_RETURN_IF_NOT(output_array["length"].as() == num_outputs, - "The size of outputs must be equal to 'num_outputs'."); - } else { - // w/o 'split' input for opset 13 - // Refer to https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989 - // split input stream equally across output streams. - const auto& output_defs = node.OutputDefs(); - const size_t output_count = output_defs.size(); - output_array = model_builder.GetBuilder().call("split", - input, static_cast(output_count), - options); - ORT_RETURN_IF_NOT(output_array["length"].as() == output_count, - "The size of outputs must be equal to the count of output nodes."); - } + output_array = model_builder.GetBuilder().call( + "split", input, emscripten::val::array(splits), options); } + for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); } @@ -131,6 +107,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, NodeAttrHelper helper(node); int32_t axis = helper.Get("axis", 0); axis = SafeInt(HandleNegativeAxis(axis, rank)); + std::vector split = helper.Get("split", std::vector{}); const std::string split_name = GetTensorName(input_defs, 1); // Inputs contain optional 'split' input. @@ -140,7 +117,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } // Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified. - std::vector split; const auto& split_tensor = *initializers.at(input_defs[1]->Name()); if (split_tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64."; @@ -150,18 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, LOGS(logger, VERBOSE) << "Cannot get split."; return false; } - int64_t sum = 0; - for (size_t i = 0; i < split.size(); i++) { - if (split[i] < 0) { - LOGS(logger, VERBOSE) << "Value of split should be greater than or equal to 0."; - return false; - } - sum += split[i]; - } - if (sum != input_shape[axis]) { - LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; - return false; - } } else { if (helper.HasAttr("num_outputs")) { // Split has 'num_outputs' attribute when opset is 18. @@ -178,6 +142,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } } + + if (split.size()) { + int64_t sum = 0; + // TODO: Allow 0 size dimensions. + // https://github.com/webmachinelearning/webnn/issues/391 + for (size_t i = 0; i < split.size(); i++) { + if (split[i] <= 0) { + LOGS(logger, VERBOSE) << "Value of split should be greater than 0."; + return false; + } + sum += split[i]; + } + if (sum != input_shape[axis]) { + LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; + return false; + } + } return true; } From 43f0bd836e85cb5afe8dc230d1e69c1aadf23ac7 Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 26 Mar 2024 13:22:18 +0800 Subject: [PATCH 3/3] Resolve comment --- .../providers/webnn/builders/impl/split_op_builder.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 9c03b6dd51aee..ea3b8ef384ddc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -143,16 +143,16 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } - if (split.size()) { + if (!split.empty()) { int64_t sum = 0; // TODO: Allow 0 size dimensions. // https://github.com/webmachinelearning/webnn/issues/391 - for (size_t i = 0; i < split.size(); i++) { - if (split[i] <= 0) { + for (uint32_t split_value : split) { + if (split_value <= 0) { LOGS(logger, VERBOSE) << "Value of split should be greater than 0."; return false; } - sum += split[i]; + sum += split_value; } if (sum != input_shape[axis]) { LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";