Skip to content

Commit

Permalink
[WebNN EP] Support Split before opset13 (microsoft#19988)
Browse files Browse the repository at this point in the history
### Description
Support Split before opset13, where the `split` is an attribute.



### Motivation and Context
Support more models which use the earlier opset.
  • Loading branch information
zesongw authored and Ted Themistokleous committed May 7, 2024
1 parent 0a36e1d commit 5ad41c1
Showing 1 changed file with 43 additions and 63 deletions.
106 changes: 43 additions & 63 deletions onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class SplitOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;

int GetMinSupportedOpSet(const Node& node) const override;
};

// Add operator related.
Expand Down Expand Up @@ -57,53 +55,35 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
options.set("axis", axis);

if (!GetTensorName(input_defs, 1).empty()) {
// Inputs contains optional 'split' input
std::vector<int32_t> splits;
uint32_t split_count = 0;
std::vector<uint32_t> splits = helper.Get("split", std::vector<uint32_t>{});

// Read either the split count or explicit split lengths from the various attributes over opset versions.
if (helper.HasAttr("num_outputs")) {
split_count = helper.Get("num_outputs", 0);
} else if (GetTensorName(input_defs, 1).size()) {
const auto& initializers(model_builder.GetInitializerTensors());
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split.");
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
emscripten::val::array(splits),
options);
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == static_cast<int32_t>(splits.size()),
"The size of outputs must be equal to the size of 'split' input.");
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split.");
} else if (!helper.HasAttr("split")) {
split_count = node.OutputDefs().size();
}

// Check that the splits evenly divide.
if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) {
// Divide inputs into variable size outputs:
splits.insert(splits.end(), split_count - 1, gsl::narrow<uint32_t>(input_shape[axis]) / split_count);
splits.insert(splits.end(), gsl::narrow<uint32_t>(input_shape[axis]) % split_count);
}

if (splits.empty()) {
output_array = model_builder.GetBuilder().call<emscripten::val>(
"split", input, split_count, options);
} else {
if (helper.HasAttr("num_outputs")) {
const int32_t num_outputs = helper.Get("num_outputs", 1);
ORT_RETURN_IF_NOT(num_outputs > 0, "The 'num_outputs' must be a positive integer.");
if (input_shape[axis] % num_outputs == 0) {
// The 'num_outputs' evenly divide the dim value at 'axis' specified.
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
num_outputs,
options);
} else {
std::vector<int64_t> mapping_split;
mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs);
mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs);
std::vector<uint32_t> converted_splits = GetVecUint32FromVecInt64(mapping_split);
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
emscripten::val::array(converted_splits),
options);
}
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == num_outputs,
"The size of outputs must be equal to 'num_outputs'.");
} else {
// w/o 'split' input for opset 13
// Refer to https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989
// split input stream equally across output streams.
const auto& output_defs = node.OutputDefs();
const size_t output_count = output_defs.size();
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input, static_cast<int32_t>(output_count),
options);
ORT_RETURN_IF_NOT(output_array["length"].as<size_t>() == output_count,
"The size of outputs must be equal to the count of output nodes.");
}
output_array = model_builder.GetBuilder().call<emscripten::val>(
"split", input, emscripten::val::array(splits), options);
}

for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
}
Expand All @@ -112,11 +92,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Operator support related.

int SplitOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
// Since opset 13, Split has optional 'split' input.
return 13;
}

bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
Expand All @@ -132,6 +107,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 0);
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
std::vector<uint32_t> split = helper.Get("split", std::vector<uint32_t>{});

const std::string split_name = GetTensorName(input_defs, 1);
// Inputs contain optional 'split' input.
Expand All @@ -141,7 +117,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return false;
}
// Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified.
std::vector<int64_t> split;
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
if (split_tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64.";
Expand All @@ -151,18 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
LOGS(logger, VERBOSE) << "Cannot get split.";
return false;
}
int64_t sum = 0;
for (size_t i = 0; i < split.size(); i++) {
if (split[i] < 0) {
LOGS(logger, VERBOSE) << "Value of split should be greater than or equal to 0.";
return false;
}
sum += split[i];
}
if (sum != input_shape[axis]) {
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
return false;
}
} else {
if (helper.HasAttr("num_outputs")) {
// Split has 'num_outputs' attribute when opset is 18.
Expand All @@ -179,6 +142,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
}
}
}

if (!split.empty()) {
int64_t sum = 0;
// TODO: Allow 0 size dimensions.
// https://github.com/webmachinelearning/webnn/issues/391
for (uint32_t split_value : split) {
if (split_value <= 0) {
LOGS(logger, VERBOSE) << "Value of split should be greater than 0.";
return false;
}
sum += split_value;
}
if (sum != input_shape[axis]) {
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
return false;
}
}
return true;
}

Expand Down

0 comments on commit 5ad41c1

Please sign in to comment.