Skip to content

Commit

Permalink
[WebNN EP] Fixed bug in ConvTranspose (microsoft#21569)
Browse files Browse the repository at this point in the history
The constraint of ConvTranspose was placed in wrong place.
  • Loading branch information
Honry authored Jul 31, 2024
1 parent c5f8389 commit a3883af
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class ConvOpBuilder : public BaseOpBuilder {
// Operator support related.
private:
bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node,
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
const WebnnDeviceType device_type, const logging::Logger& logger) const override;
bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType /* device_type */,
const logging::Logger& logger) const override;
};
Expand Down Expand Up @@ -378,6 +378,22 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
return false;
}

// WebNN CPU backend (TFLite) only supports default dilations and group.
// https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040
if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") {
NodeAttrHelper helper(node);
const auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
const auto group = helper.Get("group", 1);
if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) {
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1.";
return false;
}
if (group != 1) {
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1.";
return false;
}
}

return true;
}

Expand Down Expand Up @@ -427,22 +443,6 @@ bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTy
return false;
}

// WebNN CPU backend (TFLite) only supports default dilations and group.
// https://source.chromium.org/chromium/chromium/src/+/main:services/webnn/tflite/graph_builder_tflite.cc;l=1040
if (device_type == WebnnDeviceType::CPU && op_type == "ConvTranspose") {
NodeAttrHelper helper(node);
const auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
const auto group = helper.Get("group", 1);
if (dilations[0] != 1 || (dilations.size() > 1 && dilations[1] != 1)) {
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default dilation 1.";
return false;
}
if (group != 1) {
LOGS(logger, VERBOSE) << op_type << " for WebNN CPU backend only supports default group 1.";
return false;
}
}

return true;
}

Expand Down

0 comments on commit a3883af

Please sign in to comment.