From c63dd0234b4e0236b24fabdca005bbeb75ff4eb9 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Sat, 14 Sep 2024 12:36:20 +0800 Subject: [PATCH] [WebNN EP] Use opSupportLimits to dynamically check data type support (#22025) - Remove hard code data type checks and use WebNN's opSupportLimits instead - Add HasSupportedOutputsImpl for output data type validation - Get preferred layout info from opSupportLimits - Move Not op to logical_op_builder.cc because it should be there. This avoid the inconsistent input names in `unary_op_builder.cc`. --- .../core/providers/webnn/builders/helper.cc | 61 ++++++++++++++++--- .../core/providers/webnn/builders/helper.h | 43 +++++++++---- .../builders/impl/activation_op_builder.cc | 40 ------------ .../builders/impl/argmax_min_op_builder.cc | 27 -------- .../webnn/builders/impl/base_op_builder.cc | 52 ++++++++-------- .../webnn/builders/impl/base_op_builder.h | 9 ++- .../webnn/builders/impl/binary_op_builder.cc | 36 +++-------- .../webnn/builders/impl/cast_op_builder.cc | 32 +++++----- .../webnn/builders/impl/clip_op_builder.cc | 29 --------- .../webnn/builders/impl/concat_op_builder.cc | 28 +++++++++ .../webnn/builders/impl/conv_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gather_op_builder.cc | 26 +++----- .../webnn/builders/impl/gemm_op_builder.cc | 35 +++-------- .../webnn/builders/impl/gru_op_builder.cc | 40 ++++-------- .../webnn/builders/impl/logical_op_builder.cc | 42 +++++++------ .../webnn/builders/impl/max_min_op_builder.cc | 29 ++++----- .../builders/impl/normalization_op_builder.cc | 35 ++++------- .../webnn/builders/impl/pad_op_builder.cc | 27 -------- .../builders/impl/reduction_op_builder.cc | 52 ---------------- .../webnn/builders/impl/resize_op_builder.cc | 26 -------- .../webnn/builders/impl/shape_op_builder.cc | 27 -------- .../webnn/builders/impl/slice_op_builder.cc | 26 -------- .../webnn/builders/impl/softmax_op_builder.cc | 26 -------- .../webnn/builders/impl/ternary_op_builder.cc | 23 ++----- .../builders/impl/transpose_op_builder.cc | 27 -------- .../webnn/builders/impl/unary_op_builder.cc | 43 ------------- .../providers/webnn/builders/model_builder.cc | 7 ++- .../providers/webnn/builders/model_builder.h | 5 +- .../providers/webnn/builders/op_builder.h | 3 +- .../webnn/builders/op_builder_factory.cc | 2 +- .../webnn/webnn_execution_provider.cc | 22 ++++--- .../webnn/webnn_execution_provider.h | 1 + 32 files changed, 281 insertions(+), 635 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index d3c1d06818db2..c4a633fcc92bb 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -45,12 +45,12 @@ bool GetShape(const NodeArg& node_arg, std::vector& shape, const loggin return true; } -bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, - const WebnnDeviceType device_type, const logging::Logger& logger) { +bool IsNodeSupported(const Node& node, const GraphViewer& graph_viewer, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { const auto& op_builders = GetOpBuilders(); if (Contains(op_builders, node.OpType())) { const auto* op_builder = op_builders.at(node.OpType()); - return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, logger); + return op_builder->IsOpSupported(graph_viewer.GetAllInitializedTensors(), node, device_type, wnn_limits, logger); } else { return false; } @@ -86,6 +86,7 @@ bool IsInputSupported(const NodeArg& input, const std::string& parent_name, cons std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger) { std::vector> supported_node_groups; @@ -105,7 +106,7 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v // Firstly check if platform supports the WebNN op. if (CheckSingleOp(node->OpType(), wnn_builder, device_type)) { LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() << "] is supported by browser"; - supported = IsNodeSupported(*node, graph_viewer, device_type, logger); + supported = IsNodeSupported(*node, graph_viewer, device_type, wnn_limits, logger); } LOGS(logger, VERBOSE) << "Operator type: [" << node->OpType() @@ -130,10 +131,54 @@ std::vector> GetSupportedNodes(const GraphViewer& graph_v return supported_node_groups; } -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types) { - return std::find(supported_data_types.begin(), supported_data_types.end(), data_type) != - supported_data_types.end(); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger) { + for (size_t i = 1; i < input_types.size(); i++) { + if (input_types[0] != input_types[i]) { + LOGS(logger, VERBOSE) << "[" << op_type + << "] Input data types should be the same, but [" + << input_types[0] << "] does not match " + << input_types[i] << "]."; + return false; + } + } + return true; +} + +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types) { + auto it = onnx_to_webnn_data_type_map.find(static_cast(onnx_data_type)); + if (it == onnx_to_webnn_data_type_map.end()) + return false; + + std::string webnn_data_type = it->second; + + // Check if WebNN supports the data type. + emscripten::val is_supported = webnn_supported_data_types.call("includes", + emscripten::val(webnn_data_type)); + return is_supported.as(); +} + +// Check if the input or output data type of ONNX node is supported by the WebNN operator. +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + std::string webnn_op_type; + if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) + return false; + + if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type + << "] " << onnx_input_output_name + << " type: [" << onnx_data_type + << "] is not supported for now"; + return false; + } + + return true; } bool GetBidirectionalBroadcastShape(std::vector& shape_a, diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index b51092619db22..257fcff9ef50c 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -148,6 +148,7 @@ bool IsInputSupported(const NodeArg& node_arg, const std::string& parent_name, c std::vector> GetSupportedNodes(const GraphViewer& graph_viewer, const emscripten::val& wnn_builder, const WebnnDeviceType device_type, + const emscripten::val& wnn_limits, const logging::Logger& logger); static const InlinedHashMap op_map = { {"Abs", "abs"}, @@ -250,20 +251,38 @@ inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn return true; } -static const std::unordered_set webnn_supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, +inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) { + auto it = op_map.find(op_type); + // Returns false if the op_type is not listed in the op_map. + if (it == op_map.end()) { + return false; + } + webnn_op_type = it->second; + return true; +} + +static const InlinedHashMap onnx_to_webnn_data_type_map = { + {ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT8, "int8"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT8, "uint8"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, "float16"}, + {ONNX_NAMESPACE::TensorProto_DataType_FLOAT, "float32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT32, "int32"}, + {ONNX_NAMESPACE::TensorProto_DataType_INT64, "int64"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT32, "uint32"}, + {ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"}, }; -bool IsSupportedDataType(const int32_t data_type, - const std::unordered_set& supported_data_types); +bool AreInputDataTypesSame(const std::string& op_type, + gsl::span input_types, + const logging::Logger& logger); +bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types); +bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc index 626aaf5c71b74..781ddcb896155 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/activation_op_builder.cc @@ -21,8 +21,6 @@ class ActivationOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,44 +92,6 @@ bool ActivationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initi return true; } -bool ActivationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - // WebNN relu op supports float32, float16, int32, int8 input data types. - if (op_type == "Relu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend does not support int32 data type for relu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { // Others only support float32 and float16. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateActivationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 05f3a742a3775..d61ae1a1f6be7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -22,8 +22,6 @@ class ArgMaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -77,31 +75,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ArgMaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int64, uint64 input data types for argMax and argMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateArgMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index fa535889299ea..8da255a288f17 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -38,9 +38,9 @@ bool HasExternalInitializer(const InitializedTensorSet& initializers, const Node Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { ORT_RETURN_IF_NOT( - IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), logger), - "Unsupported operator ", - node.OpType()); + IsOpSupported(model_builder.GetInitializerTensors(), node, model_builder.GetWebnnDeviceType(), + model_builder.GetOpSupportLimits(), logger), + "Unsupported operator ", node.OpType()); ORT_RETURN_IF_ERROR(AddToModelBuilderImpl(model_builder, node, logger)); LOGS(logger, VERBOSE) << "Operator name: [" << node.Name() << "] type: [" << node.OpType() << "] was added"; @@ -50,8 +50,12 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& // Operator support related. bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, device_type, logger)) + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + if (!HasSupportedInputs(node, wnn_limits, logger)) + return false; + + if (!HasSupportedOutputsImpl(node, wnn_limits, logger)) return false; // We do not support external initializers for now. @@ -64,7 +68,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, +bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -73,39 +77,33 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const WebnnDeviceType d } } - // WebNN CPU backend (TFLite) will enable float16 input data type soon, - // temporarily fallback float16 input data type for WebNN CPU. - if (device_type == WebnnDeviceType::CPU) { - const auto& input = *node.InputDefs()[0]; - - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - if (input_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) - return false; - } - - return HasSupportedInputsImpl(node, device_type, logger); + return HasSupportedInputsImpl(node, wnn_limits, logger); } bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, - const WebnnDeviceType /* device_type */, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. const auto& input = *node.InputDefs()[0]; - + const auto& op_type = node.OpType(); int32_t input_type; if (!GetType(input, input_type, logger)) return false; - if (!IsSupportedDataType(input_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Input type: [" << input_type - << "] is not supported for now"; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger); +} + +bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node, + const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + // We only check the type of output 0 by default, specific op builder can override this. + const auto& output = *node.OutputDefs()[0]; + const auto& op_type = node.OpType(); + int32_t output_type; + if (!GetType(output, output_type, logger)) return false; - } - return true; + return IsDataTypeSupportedByOp(op_type, output_type, wnn_limits, "output", "Output", logger); } bool BaseOpBuilder::HasSupportedOpSet(const Node& node, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 85e38b668cee4..584455f62cb4e 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -28,7 +28,8 @@ class BaseOpBuilder : public IOpBuilder { // Operator support related. public: bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; protected: virtual bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& /* node */, @@ -36,8 +37,10 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const; // ONNX Runtime only *guarantees* support for models stamped // with opset version 7 or above for opset domain 'ai.onnx'. @@ -50,7 +53,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const; + bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index 555de68cd60fe..af82a01b14de5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,7 +22,7 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -86,7 +86,7 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -97,36 +97,14 @@ bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevice !GetType(*input_defs[1], input1_type, logger)) return false; - std::unordered_set supported_data_types; - // WebNN prelu op only supports float32, float16, int32, int8 input data types. - if (op_type == "Prelu") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - // WebNN CPU backend doesn't support int32 for prelu. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - } - } else { - supported_data_types = webnn_supported_data_types; - } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; - } - - return true; + std::string webnn_input_name = op_type == "PRelu" ? "input" : "a"; + std::string onnx_input_name = op_type == "PRelu" || op_type == "Pow" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, webnn_input_name, onnx_input_name, logger); } void CreateBinaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index a08e1681a8464..3c4fc822f3d01 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -80,26 +80,22 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. +bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input_type; -bool CastOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType device_type, - const logging::Logger& logger) const { - NodeAttrHelper helper(node); - // Check cast output type. - const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); - - // WebNN CPU backend doesn't support casting to uint64 data type. - if (device_type == WebnnDeviceType::CPU && to_type == ONNX_NAMESPACE::TensorProto_DataType_UINT64) { - LOGS(logger, VERBOSE) << "Cast to uint64 is not supported for WebNN CPU backend."; + if (!GetType(*input_defs[0], input_type, logger)) return false; - } - if (!IsSupportedDataType(to_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "WebNN doesn't support casting to type " << to_type << "."; + + if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger)) return false; - } - return true; + NodeAttrHelper helper(node); + // Check cast to type. + const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED); + return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger); } void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc index b5c3206072d50..374143c886849 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/clip_op_builder.cc @@ -25,8 +25,6 @@ class ClipOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -94,33 +92,6 @@ bool ClipOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, }; } -bool ClipOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support int32, uint32, int64, uint64 input data types for clamp. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_INT64); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateClipOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index dedc76b80e978..48dd6f3beb020 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -19,6 +19,10 @@ class ConcatOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; + + // Operator support related. + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; // Add operator related. @@ -52,6 +56,30 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } +bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type; + + if (!GetType(*input_defs[0], input0_type, logger)) + return false; + + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } + + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "inputs", "inputs", logger); +} + void CreateConcatOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 76a8a178678df..35498c2e9b8b7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,7 +29,7 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -397,7 +397,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -415,35 +415,18 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Conv" || op_type == "ConvTranspose") { - // WebNN conv2d and convTranspose2d only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "ConvInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateConvOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index 23233539d34c7..ae9fe3e3f3bd1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,7 +22,7 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -69,29 +69,19 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; + const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); int32_t input_type; - if (!GetType(input, input_type, logger)) + int32_t indices_type; + if (!GetType(input, input_type, logger) || + !GetType(indices, indices_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for gather. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger) && + IsDataTypeSupportedByOp(op_type, indices_type, wnn_limits, "indices", "indices", logger); } void CreateGatherOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index bd452b118fe3e..30e024792ed42 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,7 +25,7 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -215,7 +215,7 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -233,35 +233,18 @@ bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - std::unordered_set supported_data_types; - if (op_type == "Gemm" || op_type == "MatMul") { - // WebNN gemm and matmul only support float32 and float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else if (op_type == "MatMulInteger") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_INT8, - ONNX_NAMESPACE::TensorProto_DataType_UINT8, - }; + InlinedVector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input3) { + input_types.push_back(input3_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "A", logger); } void CreateGemmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 23cc7f1b11459..c92fe7366d494 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -208,37 +208,21 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp return false; } - std::unordered_set supported_data_types; - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend only support float32 input data type. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - }; - } else if (device_type == WebnnDeviceType::GPU) { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); } - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + if (has_input4) { + input_types.push_back(input4_type); } - - if (input0_type != input1_type || - input0_type != input2_type || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type) || - (has_input5 && input0_type != input5_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input5) { + input_types.push_back(input5_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateGruOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index 23f3a938fee5e..ea7f70b4598e6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,7 +21,7 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -29,9 +29,14 @@ class LogicalOpBuilder : public BaseOpBuilder { Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& /* logger */) const { + const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); - emscripten::val input0 = model_builder.GetOperand(node.InputDefs()[0]->Name()); - emscripten::val input1 = model_builder.GetOperand(node.InputDefs()[1]->Name()); + emscripten::val input0 = model_builder.GetOperand(input_defs[0]->Name()); + emscripten::val input1 = emscripten::val::undefined(); + if (input_defs.size() > 1) { + input1 = model_builder.GetOperand(input_defs[1]->Name()); + } + emscripten::val output = emscripten::val::object(); emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); @@ -45,6 +50,8 @@ Status LogicalOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons output = model_builder.GetBuilder().call("lesser", input0, input1, options); } else if (op_type == "LessOrEqual") { output = model_builder.GetBuilder().call("lesserOrEqual", input0, input1, options); + } else if (op_type == "Not") { + output = model_builder.GetBuilder().call("logicalNot", input0, options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "LogicalOpBuilder::AddToModelBuilderImpl, unknown op: ", op_type); @@ -61,7 +68,7 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali const auto& name = node.Name(); const auto& op_type = node.OpType(); const auto& input_defs = node.InputDefs(); - if (input_defs.size() < 2) { + if (input_defs.size() < 2 && op_type != "Not") { LOGS(logger, VERBOSE) << op_type << " [" << name << "] requires at least 2 inputs, actual: " << input_defs.size(); return false; @@ -69,31 +76,27 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) - return false; - - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; + if (!GetType(*input_defs[0], input0_type, logger)) return false; - } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + if (op_type != "Not") { + if (!GetType(*input_defs[1], input1_type, logger)) + return false; + std::array input_types{input0_type, input1_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + std::string onnx_input_name = op_type == "Not" ? "X" : "A"; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", onnx_input_name, logger); } void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { @@ -107,6 +110,7 @@ void CreateLogicalOpBuilder(const std::string& op_type, OpBuilderRegistrations& "GreaterOrEqual", "Less", "LessOrEqual", + "Not", }; op_registrations.builders.push_back(std::make_unique()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index 5d88afda7b6a7..e111ca412c6e9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,7 +22,7 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -87,31 +87,28 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; - int32_t input1_type; - if (!GetType(*input_defs[0], input0_type, logger) || - !GetType(*input_defs[1], input1_type, logger)) + if (!GetType(*input_defs[0], input0_type, logger)) return false; - if (!IsSupportedDataType(input0_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; - } + for (size_t i = 1; i < input_defs.size(); i++) { + int32_t input_type; + if (!GetType(*input_defs[i], input_type, logger)) { + return false; + } - if (input0_type != input1_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; - return false; + std::array input_types{input0_type, input_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "a", "data_0", logger); } void CreateMaxMinOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 4d068baf35e72..a3c6b8fdcea9b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,7 +25,7 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -182,7 +182,7 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -203,30 +203,21 @@ bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const Webn return false; } - // WebNN batchNormalization, instanceNormalization, layerNormalization - // only support float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input0_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input0_type - << "] is not supported for now"; - return false; + std::vector input_types = {input0_type, input1_type}; + if (has_input2) { + input_types.push_back(input2_type); } - - if (input0_type != input1_type || - (has_input2 && input0_type != input2_type) || - (has_input3 && input0_type != input3_type) || - (has_input4 && input0_type != input4_type)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input data types should be the same."; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input4) { + input_types.push_back(input4_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); } void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc index 071155a2fb372..d8373a45e4423 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc @@ -28,8 +28,6 @@ class PadOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -196,31 +194,6 @@ bool PadOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } // namespace webnn -bool PadOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for pad. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreatePadOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc index 3e6d4d9820e9a..93ad933d71c34 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reduction_op_builder.cc @@ -31,8 +31,6 @@ class ReductionOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -147,56 +145,6 @@ bool ReductionOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializ return true; } -bool ReductionOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "ReduceL1" || op_type == "ReduceProd" || - op_type == "ReduceSum" || op_type == "ReduceSumSquare") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_UINT32, - ONNX_NAMESPACE::TensorProto_DataType_INT64, - ONNX_NAMESPACE::TensorProto_DataType_UINT64, - }; - - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend doesn't support uint32 and uint64 for reduceL1, - // reduceProd, reduceSum and reduceSumSquare. - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } else if (op_type == "ReduceL2" || op_type == "ReduceLogSum" || - op_type == "ReduceLogSumExp" || op_type == "ReduceMean") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } else { // ReduceMax and ReduceMin - supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 for reduceMax and reduceMin. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateReductionOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 2218c858951d3..9dc79f4f52f46 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -35,8 +35,6 @@ class ResizeOpBuilder : public BaseOpBuilder { // Resize opset 10- is very different than Resize opset 11+, with many key attributes missing. // We only support Resize opset 11+ here. int GetMinSupportedOpSet(const Node& /* node */) const override { return 11; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Helper functions @@ -275,30 +273,6 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool ResizeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN resample2d op only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateResizeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc index 0eb7dafdffe4d..6b56d2c740f40 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/shape_op_builder.cc @@ -18,11 +18,6 @@ class ShapeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - - // Operator support related. - private: - bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const override; }; Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -69,28 +64,6 @@ Status ShapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -// Operator support related. - -bool ShapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, - const Node& node, - const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input_defs = node.InputDefs(); - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - int32_t output_type = ONNX_NAMESPACE::TensorProto_DataType_INT64; - if (!IsSupportedDataType(output_type, webnn_supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << node.OpType() - << "] Output type: [" << output_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateShapeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index bef13841c646c..3f0d633ac888b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -29,8 +29,6 @@ class SliceOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -166,30 +164,6 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool SliceOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 input data type for slice. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc index 798cfabae65db..b1b737b114998 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/softmax_op_builder.cc @@ -24,8 +24,6 @@ class SoftmaxOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; Status SoftmaxOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -63,30 +61,6 @@ bool SoftmaxOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool SoftmaxOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - // WebNN softmax only supports float32 and float16 input data types. - std::unordered_set supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateSoftmaxOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 2ed8330bf25be..4b6cf312074ba 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,7 +18,7 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -46,7 +46,7 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -59,27 +59,14 @@ bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDevic !GetType(*input_defs[2], input2_type, logger)) return false; - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint64 X, Y data type for where. - if (device_type == WebnnDeviceType::CPU && op_type == "Where") { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } // ONNX's condition data type is bool which is same as WebNN. // Only need to check X, Y data types. - if (!IsSupportedDataType(input1_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input1_type - << "] is not supported for now"; - return false; - } - - if (input1_type != input2_type) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input X, Y data types should be the same."; + std::array input_types{input1_type, input2_type}; + if (!AreInputDataTypesSame(op_type, input_types, logger)) { return false; } - return true; + return IsDataTypeSupportedByOp(op_type, input1_type, wnn_limits, "trueValue", "X", logger); } void CreateTernaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc index 03c88ad9db88a..3a5e39f7f7a56 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/transpose_op_builder.cc @@ -18,8 +18,6 @@ class TransposeOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const override; }; // Add operator related. @@ -50,31 +48,6 @@ Status TransposeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool TransposeOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types = webnn_supported_data_types; - // WebNN CPU backend doesn't support uint32, uint64 input data types for transpose. - if (device_type == WebnnDeviceType::CPU) { - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT32); - supported_data_types.erase(ONNX_NAMESPACE::TensorProto_DataType_UINT64); - } - - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateTransposeOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc index 061404c8a9ce0..8e64e98445f03 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/unary_op_builder.cc @@ -18,8 +18,6 @@ class UnaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const override; }; // Add operator related. @@ -51,8 +49,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const output = model_builder.GetBuilder().call("log", input, options); } else if (op_type == "Neg") { output = model_builder.GetBuilder().call("neg", input, options); - } else if (op_type == "Not") { - output = model_builder.GetBuilder().call("logicalNot", input, options); } else if (op_type == "Reciprocal") { output = model_builder.GetBuilder().call("reciprocal", input, options); } else if (op_type == "Sin") { @@ -70,44 +66,6 @@ Status UnaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const return Status::OK(); } -bool UnaryOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, - const logging::Logger& logger) const { - const auto& input = *node.InputDefs()[0]; - const auto& op_type = node.OpType(); - int32_t input_type; - if (!GetType(input, input_type, logger)) - return false; - - std::unordered_set supported_data_types; - if (op_type == "Identity") { - supported_data_types = webnn_supported_data_types; - } else if (op_type == "Abs" || op_type == "Neg") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - ONNX_NAMESPACE::TensorProto_DataType_INT32, - ONNX_NAMESPACE::TensorProto_DataType_INT8, - }; - } else if (op_type == "Not") { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_BOOL, - }; - } else { // Others only support float32, float16 input data types. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - if (!IsSupportedDataType(input_type, supported_data_types)) { - LOGS(logger, VERBOSE) << "[" << op_type - << "] Input type: [" << input_type - << "] is not supported for now"; - return false; - } - - return true; -} - void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { if (op_registrations.op_builder_map.count(op_type) > 0) return; @@ -123,7 +81,6 @@ void CreateUnaryOpBuilder(const std::string& op_type, OpBuilderRegistrations& op "Identity", "Log", "Neg", - "Not", "Reciprocal", "Sin", "Sqrt", diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 44bec1fb6fd48..b58bf8233692e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -21,12 +21,13 @@ namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type) + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), preferred_layout_(preferred_layout), - wnn_device_type_(wnn_device_type) { + wnn_device_type_(wnn_device_type), + wnn_limits_(wnn_limits) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. wnn_builder_ = emscripten::val::global("MLGraphBuilder").new_(context); @@ -102,7 +103,7 @@ Status ModelBuilder::RegisterInitializers() { desc.set("dimensions", emscripten::val::array(dims)); auto data_type = tensor.data_type(); emscripten::val operand = emscripten::val::object(); - if (IsSupportedDataType(data_type, webnn_supported_data_types)) { + if (IsSupportedDataType(data_type, wnn_limits_["constant"]["dataTypes"])) { ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); auto num_elements = SafeInt(Product(shape)); emscripten::val view = emscripten::val::undefined(); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index 2d686070cdcc1..256337baeba7e 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -23,7 +23,7 @@ class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, const emscripten::val& context, const DataLayout preferred_layout, - const WebnnDeviceType wnn_device_type); + const WebnnDeviceType wnn_device_type, const emscripten::val& wnn_limits); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -35,6 +35,8 @@ class ModelBuilder { const emscripten::val& GetBuilder() const { return wnn_builder_; } const emscripten::val& GetContext() const { return wnn_context_; } const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } + const emscripten::val& GetOpSupportLimits() const { return wnn_limits_; } + void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); // Use the buffers to persist WebNN allocated data like transposed weight. @@ -66,6 +68,7 @@ class ModelBuilder { emscripten::val wnn_builder_ = emscripten::val::undefined(); DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; + emscripten::val wnn_limits_ = emscripten::val::undefined(); InlinedHashMap wnn_operands_; std::vector input_names_; std::vector output_names_; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder.h b/onnxruntime/core/providers/webnn/builders/op_builder.h index 6ecc5d1068963..bb69a6a545597 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/op_builder.h @@ -29,7 +29,8 @@ class IOpBuilder { public: // Check if an operator is supported. virtual bool IsOpSupported(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType device_type, const logging::Logger& logger) const = 0; + const WebnnDeviceType device_type, const emscripten::val& wnn_limits, + const logging::Logger& logger) const = 0; }; } // namespace webnn diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 01761290f07e3..3dc1c7966ae41 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -25,7 +25,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateUnaryOpBuilder("Identity", op_registrations); CreateUnaryOpBuilder("Log", op_registrations); CreateUnaryOpBuilder("Neg", op_registrations); - CreateUnaryOpBuilder("Not", op_registrations); CreateUnaryOpBuilder("Reciprocal", op_registrations); CreateUnaryOpBuilder("Sin", op_registrations); CreateUnaryOpBuilder("Sqrt", op_registrations); @@ -118,6 +117,7 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { CreateLogicalOpBuilder("GreaterOrEqual", op_registrations); CreateLogicalOpBuilder("Less", op_registrations); CreateLogicalOpBuilder("LessOrEqual", op_registrations); + CreateLogicalOpBuilder("Not", op_registrations); } { // Max/Min diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index b918daf838c99..b729623c5d3d8 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -21,10 +21,8 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { - preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { - preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -38,6 +36,17 @@ WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_f if (!wnn_context_.as()) { ORT_THROW("Failed to create WebNN context."); } + + // Retrieve the level of support for different WebNN operators. + // This varies across implementations and is obtained via the WebNN's opSupportLimits() function. + // https://www.w3.org/TR/webnn/#api-mlcontext-opsupportlimits + wnn_limits_ = wnn_context_.call("opSupportLimits"); + + if (wnn_limits_["preferredInputLayout"].as().compare("nhwc") == 0) { + preferred_layout_ = DataLayout::NHWC; + } else { + preferred_layout_ = DataLayout::NCHW; + } } WebNNExecutionProvider::~WebNNExecutionProvider() {} @@ -82,7 +91,7 @@ WebNNExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_view ORT_THROW("Failed to create WebNN builder."); } - const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, logger); + const auto node_groups = webnn::GetSupportedNodes(graph_viewer, wnn_builder, wnn_device_type_, wnn_limits_, logger); wnn_builder = emscripten::val::undefined(); if (node_groups.empty()) { @@ -213,7 +222,7 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); @@ -295,11 +304,6 @@ common::Status WebNNExecutionProvider::Compile(const std::vector