diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 8c81b97e552e3..be2e4a042c826 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -69,7 +69,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, std::vector dims; std::vector output_padding{0, 0}; if (helper.HasAttr("output_shape")) { - // Default value of 'output_shape' will be ignore as we already check if it's existed. + // Default value of 'output_shape' will be ignored as we already check if it existed. dims = helper.Get("output_shape", std::vector{-1, -1}); // Extract the height and width. std::vector output_shape; @@ -77,10 +77,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, output_shape = {dims[0], 1}; } else if (dims.size() == 2 && !is_conv1d) { output_shape = dims; - } else if (dims.size() == 3 && is_conv1d) { // ConvTranspose 1d - output_shape = {dims[2], 1}; - } else if (dims.size() == 4 && !is_conv1d) { - output_shape = {dims[2], dims[3]}; } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape"); } @@ -286,13 +282,9 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N std::vector new_shape = GetVecUint32FromVecInt64(input_shape); input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); - weight_shape.push_back(1); - if (strides.size() == 1) { - strides.push_back(1); - } - if (dilations.size() == 1) { - dilations.push_back(1); - } + weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed. + strides.resize(2, 1); // Ensure 2D by appending 1's if needed. + dilations.resize(2, 1); // Ensure 2D by appending 1's if needed. if (pads.size() == 2) { pads.insert(pads.begin() + 1, 0); pads.push_back(0); @@ -314,7 +306,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N options.set("filterLayout", emscripten::val("ihwo")); } } - } else { // ConvTranspose + } else { // ConvTranspose if (is_nhwc) { options.set("inputLayout", emscripten::val("nhwc")); options.set("filterLayout", emscripten::val("ohwi"));