Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Support Split before opset13 #19988

Merged
merged 3 commits into from
Mar 26, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 43 additions & 63 deletions onnxruntime/core/providers/webnn/builders/impl/split_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@ class SplitOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;

int GetMinSupportedOpSet(const Node& node) const override;
};

// Add operator related.
Expand Down Expand Up @@ -57,53 +55,35 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
options.set("axis", axis);

if (!GetTensorName(input_defs, 1).empty()) {
// Inputs contains optional 'split' input
std::vector<int32_t> splits;
uint32_t split_count = 0;
std::vector<uint32_t> splits = helper.Get("split", std::vector<uint32_t>{});

// Read either the split count or explicit split lengths from the various attributes over opset versions.
if (helper.HasAttr("num_outputs")) {
split_count = helper.Get("num_outputs", 0);
} else if (GetTensorName(input_defs, 1).size()) {
const auto& initializers(model_builder.GetInitializerTensors());
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get split.");
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
emscripten::val::array(splits),
options);
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == static_cast<int32_t>(splits.size()),
"The size of outputs must be equal to the size of 'split' input.");
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(split_tensor, splits, logger), "Cannot get input for split.");
} else if (!helper.HasAttr("split")) {
split_count = node.OutputDefs().size();
}

// Check that the splits evenly divide.
if (split_count > 0 && splits.empty() && input_shape[axis] % split_count != 0) {
// Divide inputs into variable size outputs:
splits.insert(splits.end(), split_count - 1, gsl::narrow<uint32_t>(input_shape[axis]) / split_count);
splits.insert(splits.end(), gsl::narrow<uint32_t>(input_shape[axis]) % split_count);
}

if (splits.empty()) {
output_array = model_builder.GetBuilder().call<emscripten::val>(
"split", input, split_count, options);
} else {
if (helper.HasAttr("num_outputs")) {
const int32_t num_outputs = helper.Get("num_outputs", 1);
ORT_RETURN_IF_NOT(num_outputs > 0, "The 'num_outputs' must be a positive integer.");
if (input_shape[axis] % num_outputs == 0) {
// The 'num_outputs' evenly divide the dim value at 'axis' specified.
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
num_outputs,
options);
} else {
std::vector<int64_t> mapping_split;
mapping_split.insert(mapping_split.begin(), num_outputs - 1, input_shape[axis] / num_outputs);
mapping_split.insert(mapping_split.end(), input_shape[axis] % num_outputs);
std::vector<uint32_t> converted_splits = GetVecUint32FromVecInt64(mapping_split);
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input,
emscripten::val::array(converted_splits),
options);
}
ORT_RETURN_IF_NOT(output_array["length"].as<int32_t>() == num_outputs,
"The size of outputs must be equal to 'num_outputs'.");
} else {
// w/o 'split' input for opset 13
// Refer to https://github.com/microsoft/onnxruntime/blob/a7ad859e3ab60bddfcf2fefa96bfcb550f0fc04c/onnxruntime/core/providers/dml/OperatorAuthorHelper/OperatorHelper.cpp#L984-L989
// split input stream equally across output streams.
const auto& output_defs = node.OutputDefs();
const size_t output_count = output_defs.size();
output_array = model_builder.GetBuilder().call<emscripten::val>("split",
input, static_cast<int32_t>(output_count),
options);
ORT_RETURN_IF_NOT(output_array["length"].as<size_t>() == output_count,
"The size of outputs must be equal to the count of output nodes.");
}
output_array = model_builder.GetBuilder().call<emscripten::val>(
"split", input, emscripten::val::array(splits), options);
}

for (size_t i = 0, count = output_array["length"].as<size_t>(); i < count; i++) {
model_builder.AddOperand(node.OutputDefs()[i]->Name(), std::move(output_array[i]));
}
Expand All @@ -112,11 +92,6 @@ Status SplitOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,

// Operator support related.

int SplitOpBuilder::GetMinSupportedOpSet(const Node& /* node */) const {
// Since opset 13, Split has optional 'split' input.
return 13;
}

bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
const Node& node,
const WebnnDeviceType /* device_type */,
Expand All @@ -132,6 +107,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
NodeAttrHelper helper(node);
int32_t axis = helper.Get("axis", 0);
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));
std::vector<uint32_t> split = helper.Get("split", std::vector<uint32_t>{});

const std::string split_name = GetTensorName(input_defs, 1);
// Inputs contain optional 'split' input.
Expand All @@ -141,7 +117,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return false;
}
// Values should be >= 0. Sum of the values must be equal to the dim value at 'axis' specified.
std::vector<int64_t> split;
const auto& split_tensor = *initializers.at(input_defs[1]->Name());
if (split_tensor.data_type() != ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64) {
LOGS(logger, VERBOSE) << "The type of tensor's element data must be INT64.";
Expand All @@ -151,18 +126,6 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
LOGS(logger, VERBOSE) << "Cannot get split.";
return false;
}
int64_t sum = 0;
for (size_t i = 0; i < split.size(); i++) {
if (split[i] < 0) {
LOGS(logger, VERBOSE) << "Value of split should be greater than or equal to 0.";
return false;
}
sum += split[i];
}
if (sum != input_shape[axis]) {
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
return false;
}
} else {
if (helper.HasAttr("num_outputs")) {
// Split has 'num_outputs' attribute when opset is 18.
Expand All @@ -179,6 +142,23 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
}
}
}

if (split.size()) {
fdwr marked this conversation as resolved.
Show resolved Hide resolved
int64_t sum = 0;
// TODO: Allow 0 size dimensions.
// https://github.com/webmachinelearning/webnn/issues/391
for (size_t i = 0; i < split.size(); i++) {
fdwr marked this conversation as resolved.
Show resolved Hide resolved
if (split[i] <= 0) {
LOGS(logger, VERBOSE) << "Value of split should be greater than 0.";
return false;
}
sum += split[i];
}
if (sum != input_shape[axis]) {
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
return false;
}
}
return true;
}

Expand Down