diff --git a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc index 9819e4ce7ac5b..ea3b8ef384ddc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc @@ -28,8 +28,6 @@ class SplitOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - - int GetMinSupportedOpSet(const Node& node) const override; }; // Add operator related. @@ -57,53 +55,35 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, axis = SafeInt(HandleNegativeAxis(axis, rank)); options.set("axis", axis); - if (!GetTensorName(input_defs, 1).empty()) { - // Inputs contains optional 'split' input - std::vector splits; + uint32_t split_count = 0; + std::vector splits = helper.Get("split", std::vector{}); + + // Read either the split count or explicit split lengths from the various attributes over opset versions. + if (helper.HasAttr("num_outputs")) { + split_count = helper.Get("num_outputs", 0); + } else if (GetTensorName(input_defs, 1).size()) { const auto& initializers(model_builder.GetInitializerTensors()); const auto& split_tensor = *initializers.at(input_defs[1]->Name()); - ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split."); - output_array = model_builder.GetBuilder().call("split", - input, - emscripten::val::array(splits), - options); - ORT_RETURN_IF_NOT(output_array["length"].as() == static_cast(splits.size()), - "The size of outputs must be equal to the size of 'split' input."); + ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split."); + } else if (!helper.HasAttr("split")) { + split_count = node.OutputDefs().size(); + } + + // Check that the splits evenly divide. + if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) { + // Divide inputs into variable size outputs: + splits.insert(splits.end(), split_count - 1, gsl::narrow(input_shape[axis]) / split_count); + splits.insert(splits.end(), gsl::narrow(input_shape[axis]) % split_count); + } + + if (splits.empty()) { + output_array = model_builder.GetBuilder().call( + "split", input, split_count, options); } else { - if (helper.HasAttr("num_outputs")) { - const int32_t num_outputs = helper.Get("num_outputs", 1); - ORT_RETURN_IF_NOT(num_outputs > 0, "The 'num_outputs' must be a positive integer."); - if (input_shape[axis] % num_outputs == 0) { - // The 'num_outputs' evenly divide the dim value at 'axis' specified. - output_array = model_builder.GetBuilder().call("split", - input, - num_outputs, - options); - } else { - std::vector mapping_split; - mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs); - mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs); - std::vector converted_splits = GetVecUint32FromVecInt64(mapping_split); - output_array = model_builder.GetBuilder().call("split", - input, - emscripten::val::array(converted_splits), - options); - } - ORT_RETURN_IF_NOT(output_array["length"].as() == num_outputs, - "The size of outputs must be equal to 'num_outputs'."); - } else { - // w/o 'split' input for opset 13 - // Refer to https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989 - // split input stream equally across output streams. - const auto& output_defs = node.OutputDefs(); - const size_t output_count = output_defs.size(); - output_array = model_builder.GetBuilder().call("split", - input, static_cast(output_count), - options); - ORT_RETURN_IF_NOT(output_array["length"].as() == output_count, - "The size of outputs must be equal to the count of output nodes."); - } + output_array = model_builder.GetBuilder().call( + "split", input, emscripten::val::array(splits), options); } + for (size_t i = 0, count = output_array["length"].as(); i < count; i++) { model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i])); } @@ -112,11 +92,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, // Operator support related. -int SplitOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const { - // Since opset 13, Split has optional 'split' input. - return 13; -} - bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, @@ -132,6 +107,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, NodeAttrHelper helper(node); int32_t axis = helper.Get("axis", 0); axis = SafeInt(HandleNegativeAxis(axis, rank)); + std::vector split = helper.Get("split", std::vector{}); const std::string split_name = GetTensorName(input_defs, 1); // Inputs contain optional 'split' input. @@ -141,7 +117,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } // Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified. - std::vector split; const auto& split_tensor = *initializers.at(input_defs[1]->Name()); if (split_tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) { LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64."; @@ -151,18 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, LOGS(logger, VERBOSE) << "Cannot get split."; return false; } - int64_t sum = 0; - for (size_t i = 0; i < split.size(); i++) { - if (split[i] < 0) { - LOGS(logger, VERBOSE) << "Value of split should be greater than or equal to 0."; - return false; - } - sum += split[i]; - } - if (sum != input_shape[axis]) { - LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; - return false; - } } else { if (helper.HasAttr("num_outputs")) { // Split has 'num_outputs' attribute when opset is 18. @@ -179,6 +142,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, } } } + + if (!split.empty()) { + int64_t sum = 0; + // TODO: Allow 0 size dimensions. + // https://github.com/webmachinelearning/webnn/issues/391 + for (uint32_t split_value : split) { + if (split_value <= 0) { + LOGS(logger, VERBOSE) << "Value of split should be greater than 0."; + return false; + } + sum += split_value; + } + if (sum != input_shape[axis]) { + LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified."; + return false; + } + } return true; }