Skip to content

Commit

Permalink
Address @fdwr's comment
Browse files Browse the repository at this point in the history
  • Loading branch information
Honry committed Jan 22, 2024
1 parent e4eb3d9 commit 165a896
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,18 +69,14 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
std::vector<int64_t> dims;
std::vector<int64_t> output_padding{0, 0};
if (helper.HasAttr("output_shape")) {
// Default value of 'output_shape' will be ignore as we already check if it's existed.
// Default value of 'output_shape' will be ignored as we already check if it existed.
dims = helper.Get("output_shape", std::vector<int64_t>{-1, -1});
// Extract the height and width.
std::vector<int64_t> output_shape;
if (dims.size() == 1 && is_conv1d) { // ConvTranspose 1d
output_shape = {dims[0], 1};
} else if (dims.size() == 2 && !is_conv1d) {
output_shape = dims;
} else if (dims.size() == 3 && is_conv1d) { // ConvTranspose 1d
output_shape = {dims[2], 1};
} else if (dims.size() == 4 && !is_conv1d) {
output_shape = {dims[2], dims[3]};
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Invalid output shape");
}
Expand Down Expand Up @@ -286,13 +282,9 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(input_shape);
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));

weight_shape.push_back(1);
if (strides.size() == 1) {
strides.push_back(1);
}
if (dilations.size() == 1) {
dilations.push_back(1);
}
weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed.
strides.resize(2, 1); // Ensure 2D by appending 1's if needed.
dilations.resize(2, 1); // Ensure 2D by appending 1's if needed.
if (pads.size() == 2) {
pads.insert(pads.begin() + 1, 0);
pads.push_back(0);
Expand All @@ -314,7 +306,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
options.set("filterLayout", emscripten::val("ihwo"));
}
}
} else { // ConvTranspose
} else { // ConvTranspose
if (is_nhwc) {
options.set("inputLayout", emscripten::val("nhwc"));
options.set("filterLayout", emscripten::val("ohwi"));
Expand Down

0 comments on commit 165a896

Please sign in to comment.