Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Ignore empty optional input tensor #19235

Merged
merged 3 commits into from
Jan 31, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ std::string GetShapeString(std::vector<T>& shape) {
return shape_info.str();
}

inline std::string GetTensorName(const ConstPointerContainer<std::vector<NodeArg *>>& input_defs, const size_t index) {
Honry marked this conversation as resolved.
Show resolved Hide resolved
return (input_defs.size() > index) ? std::string(input_defs[index]->Name()) : "";
}

inline std::vector<uint32_t> GetVecUint32FromVecInt64(const std::vector<int64_t>& int64_vec) {
std::vector<uint32_t> uint32_vec;
uint32_vec.reserve(int64_vec.size());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,10 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return false;
}
for (size_t i = 1; i < input_defs.size(); i++) {
if (!Contains(initializers, input_defs[i]->Name())) {
LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] must be known as initializer";
// Optional tensors (constant_value, axes) can be indicated by an empty name, just ignore it.
const std::string input_name = GetTensorName(input_defs, i);
if (!input_name.empty() && !Contains(initializers, input_name)) {
LOGS(logger, VERBOSE) << "Input [" << input_name << "] must be known as initializer";
return false;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,9 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ
return false;

const auto& op_type = node.OpType();
const std::string axes_name = GetTensorName(input_defs, 1);
// If the optional input 'axes' is provided, it must be an initializer.
if (input_defs.size() > 1 && !Contains(initializers, input_defs[1]->Name())) {
if (!axes_name.empty() && !Contains(initializers, axes_name)) {
LOGS(logger, VERBOSE) << "Input axes of " << op_type << " must be a constant";
return false;
}
Expand Down
32 changes: 20 additions & 12 deletions onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,16 +120,17 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
std::vector<float> scales_hw;
std::vector<int32_t> sizes_hw;
std::vector<int32_t> axes;
std::string scales_name = GetTensorName(input_defs, 2);
const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
if (input_defs.size() == 3) { // Use scales.
if (!scales_name.empty()) { // Use scales.
ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales");
if (is_nhwc) {
scales_hw = {scales[1], scales[2]};
} else {
scales_hw = {scales[2], scales[3]};
}
options.set("scales", emscripten::val::array(scales_hw));
} else { // We already checked number of inputs in IsOpSupportedImpl.
} else { // Use sizes, we already checked inputs in IsOpSupportedImpl.
std::vector<int64_t> output_sizes;
ORT_RETURN_IF_NOT(GetResizeOutputSizes(initializers, node, output_sizes, logger),
"Error getting resize output_sizes");
Expand Down Expand Up @@ -203,26 +204,31 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
}

{ // scales and sizes (if present) must be initializers.
if (input_defs.size() < 3) {
LOGS(logger, VERBOSE) << "Input scales or sizes of Resize must be known";
return false;
}
const std::string scales_name = GetTensorName(input_defs, 2);
const std::string sizes_name = GetTensorName(input_defs, 3);

// scales
if (input_defs.size() == 3 && !Contains(initializers, input_defs[2]->Name())) {
// scales (scales may be empty tensor)
bool has_scales = !scales_name.empty();
if ((has_scales && !Contains(initializers, scales_name)) || (!has_scales && node.SinceVersion() == 11)) {
LOGS(logger, VERBOSE) << "Input scales of Resize must be known";
return false;
}

// sizes
if (input_defs.size() > 3 && !Contains(initializers, input_defs[3]->Name())) {
// sizes (sizes may be empty tensor)
bool has_sizes = !sizes_name.empty();
if (has_sizes && !Contains(initializers, sizes_name)) {
LOGS(logger, VERBOSE) << "Input sizes of Resize must be known";
return false;
}

if (has_scales && has_sizes) {
LOGS(logger, VERBOSE) << "Only one of 'scales' and 'sizes' can be specified";
return false;
}

const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain;
// We want to check if the scales or sizes are not trying to resize on N/C channels here.
if (input_defs.size() == 3) { // We are using scales.
if (has_scales) { // We are using scales.
std::vector<float> scales;
if (!GetResizeScales(initializers, node, scales, logger))
return false;
Expand Down Expand Up @@ -251,7 +257,9 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
LOGS(logger, VERBOSE) << "Resize: scale_w: " << scale_w << " is not a whole number";
return false;
}
} else {
}

if (has_sizes) {
// We are using sizes.
std::vector<int64_t> output_sizes;
if (!GetResizeOutputSizes(initializers, node, output_sizes, logger))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,10 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,

// Inputs: starts, ends, axes, and steps must be constant initializers if present.
for (size_t i = 1; i < input_defs.size(); i++) {
if (!Contains(initializers, input_defs[i]->Name())) {
LOGS(logger, VERBOSE) << "Input [" << input_defs[i]->Name() << "] of " << op_type
// Optional tensors (axes, steps) can be indicated by an empty name, just ignore it.
const std::string input_name = GetTensorName(input_defs, i);
if (!input_name.empty() && !Contains(initializers, input_name)) {
LOGS(logger, VERBOSE) << "Input [" << input_name << "] of " << op_type
<< " [" << name << "] must be known as initializer";
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,9 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
int32_t axis = helper.Get("axis", 0);
axis = SafeInt<int32_t>(HandleNegativeAxis(axis, rank));

if (input_defs.size() == 2) {
// Inputs contains optional 'split' input
const auto& split_name = input_defs[1]->Name();
const std::string split_name = GetTensorName(input_defs, 1);
// Inputs contain optional 'split' input.
if (!split_name.empty()) {
if (!Contains(initializers, split_name)) {
LOGS(logger, VERBOSE) << "The split must be a constant initializer.";
return false;
Expand Down Expand Up @@ -166,7 +166,7 @@ bool SplitOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
LOGS(logger, VERBOSE) << "Sum of the split's values must be equal to the dim value at 'axis' specified.";
return false;
}
} else if (input_defs.size() == 1) {
} else {
if (helper.HasAttr("num_outputs")) {
// Split has 'num_outputs' attribute when opset is 18.
const int32_t num_outputs = helper.Get("num_outputs", 1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ bool SqueezeUnsqueezeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& in

// Squeeze/Unsqueeze opset 13 uses input 1 as axes, it needs to be an initializer.
if (node.SinceVersion() >= 13) {
if (input_defs.size() > 1) {
const auto& axes_name = input_defs[1]->Name();
const std::string axes_name = GetTensorName(input_defs, 1);
if (!axes_name.empty()) {
if (!Contains(initializers, axes_name)) {
LOGS(logger, ERROR) << "Input axes of " << op_type << " is not present and constant";
return false;
Expand Down
Loading