From a3883af7bfede84315abaa94fbc4cc2a0d2b02a3 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 1 Aug 2024 05:39:21 +0800 Subject: [PATCH] [WebNN EP] Fixed bug in ConvTranspose (#21569) The constraint of ConvTranspose was placed in wrong place. --- .../webnn/builders/impl/conv_op_builder.cc | 34 +++++++++---------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 22049d2519712..76a8a178678df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder { // Operator support related. private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, - const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + const WebnnDeviceType device_type, const logging::Logger& logger) const override; bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; @@ -378,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return false; } + // WebNN CPU backend (TFLite) only supports default dilations and group. + // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 + if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { + NodeAttrHelper helper(node); + const auto dilations = helper.Get("dilations", std::vector{1, 1}); + const auto group = helper.Get("group", 1); + if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; + return false; + } + if (group != 1) { + LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; + return false; + } + } + return true; } @@ -427,22 +443,6 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy return false; } - // WebNN CPU backend (TFLite) only supports default dilations and group. - // https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040 - if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") { - NodeAttrHelper helper(node); - const auto dilations = helper.Get("dilations", std::vector{1, 1}); - const auto group = helper.Get("group", 1); - if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1."; - return false; - } - if (group != 1) { - LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1."; - return false; - } - } - return true; }