Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN] Improve data type check of slice op #22988

Merged
merged 3 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 22 additions & 5 deletions onnxruntime/core/providers/webnn/builders/helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,14 +177,31 @@
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
return false;

return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
webnn_input_output_name, onnx_input_output_name, logger);
}

bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger) {

Check warning on line 190 in onnxruntime/core/providers/webnn/builders/helper.cc

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.cc:190: Add #include <string> for string [build/include_what_you_use] [4]
if (wnn_limits[webnn_op_type].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
return false;
}
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
<< webnn_input_output_name << "]";
return false;
}
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
LOGS(logger, VERBOSE) << "[" << onnx_op_type
<< "] " << onnx_input_output_name
<< " type: [" << onnx_data_type
<< "] is not supported for now";
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
return false;
}

return true;
}

Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,13 @@
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
const std::string& webnn_op_type,
const int32_t onnx_data_type,
const emscripten::val& wnn_limits,
const std::string& webnn_input_output_name,
const std::string& onnx_input_output_name,
const logging::Logger& logger);

Check warning on line 348 in onnxruntime/core/providers/webnn/builders/helper.h

View workflow job for this annotation

GitHub Actions / Optional Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <string> for string [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/webnn/builders/helper.h:348: Add #include <string> for string [build/include_what_you_use] [4]

bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,
std::vector<int64_t>& shape_b,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node&
bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
if (!HasSupportedInputs(node, wnn_limits, logger))
if (!HasSupportedInputs(initializers, node, wnn_limits, logger))
return false;

if (!HasSupportedOutputs(node, wnn_limits, logger))
Expand All @@ -41,7 +41,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons
return IsOpSupportedImpl(initializers, node, device_type, logger);
}

bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits,
bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]");
for (const auto* input : node.InputDefs()) {
Expand All @@ -50,10 +50,10 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val&
}
}

return HasSupportedInputsImpl(node, wnn_limits, logger);
return HasSupportedInputsImpl(initializers, node, wnn_limits, logger);
}

bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node,
bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
// We only check the type of input 0 by default, specific op builder can override this.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class BaseOpBuilder : public IOpBuilder {
return true;
}

virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const;
Expand All @@ -53,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder {

private:
bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const;
bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers
return true;
}

bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder {

// Operator support related.
private:
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
}

// Operator support related.
bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -55,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
return Status::OK();
}

bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
Expand Down Expand Up @@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return true;
}

bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // input data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Helper functions, thanks for DML EP's OperatorHelper.
Expand Down Expand Up @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}

bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();

const auto& op_type = node.OpType();
Expand Down Expand Up @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten:
return false;
} else if (recognized_operator_type == RecognizedOperatorType::Pairwise) {
// Map to WebNN's gemm or matmul
return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger);
} else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) {
return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger);
} else {
return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger);
return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;

// Operator support related.
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde

// Operator support related.

bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial
return true;
}

bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& data = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder {
// Operator support related.
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ
return true;
}

bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input = *node.InputDefs()[0];
const auto& indices = *node.InputDefs()[1];
const auto& op_type = node.OpType();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
};

// Add operator related.
Expand Down Expand Up @@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer
return true;
}

bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input0_type; // A data type
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder {
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const override;
};
Expand Down Expand Up @@ -187,8 +187,8 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c
return true;
}

bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits,
const logging::Logger& logger) const {
bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
const auto& input_defs = node.InputDefs();
const auto& op_type = node.OpType();
int32_t input_X_type = 0; // input data type
Expand Down
Loading
Loading