Skip to content

Commit

Permalink
[WebNN EP] Remove NHWC preferred layout (#21570)
Browse files Browse the repository at this point in the history
Currently WebNN CPU backend has supported NCHW layout in Chromium, we
can now drop NHWC preferred layout for CPU backend in WebNN EP to
simplify the code.
  • Loading branch information
Honry authored Aug 28, 2024
1 parent bf8855b commit 5911422
Show file tree
Hide file tree
Showing 10 changed files with 39 additions and 296 deletions.
21 changes: 9 additions & 12 deletions onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ common::Status ComputeConvPads(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
bool use_nchw) {
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
std::vector<int64_t>& pads_out) {
const int64_t input_size_y = input_shape[2];
const int64_t input_size_x = input_shape[3];
const int64_t stride_y = onnx_strides[0];
const int64_t stride_x = onnx_strides[1];
const int64_t dilation_y = onnx_dilations[0];
Expand Down Expand Up @@ -54,16 +53,15 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
bool use_nchw) {
std::vector<int64_t>& pads_out) {
if (AutoPadType::SAME_UPPER == auto_pad_type) {
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
onnx_pads, onnx_strides, onnx_dilations,
AutoPadType::SAME_UPPER, pads_out, use_nchw));
AutoPadType::SAME_UPPER, pads_out));
} else {
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
onnx_pads, onnx_strides, onnx_dilations,
AutoPadType::SAME_LOWER, pads_out, use_nchw));
AutoPadType::SAME_LOWER, pads_out));
}
return Status::OK();
}
Expand Down Expand Up @@ -111,10 +109,9 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
const std::vector<int64_t>& onnx_output_padding,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
std::vector<int64_t>& output_shape_out,
bool use_nchw) {
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
std::vector<int64_t>& output_shape_out) {
const int64_t input_size_y = input_shape[2];
const int64_t input_size_x = input_shape[3];
const int64_t stride_y = onnx_strides[0];
const int64_t stride_x = onnx_strides[1];
const int64_t dilation_y = onnx_dilations[0];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
const std::vector<int64_t>& onnx_strides,
const std::vector<int64_t>& onnx_dilations,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
bool use_nchw) ORT_MUST_USE_RESULT;
std::vector<int64_t>& pads_out) ORT_MUST_USE_RESULT;

// Compute pads and output shape for ConvTranspose.
common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t> input_shape,
Expand All @@ -34,8 +33,7 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
const std::vector<int64_t>& onnx_output_padding,
AutoPadType auto_pad_type,
std::vector<int64_t>& pads_out,
std::vector<int64_t>& output_shape_out,
bool use_nchw) ORT_MUST_USE_RESULT;
std::vector<int64_t>& output_shape_out) ORT_MUST_USE_RESULT;

} // namespace webnn
} // namespace onnxruntime
171 changes: 12 additions & 159 deletions onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@ namespace webnn {

class ConvOpBuilder : public BaseOpBuilder {
// Add operator related.
public:
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;

private:
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
Expand All @@ -33,13 +30,6 @@ class ConvOpBuilder : public BaseOpBuilder {
const logging::Logger& logger) const override;
};

void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
// skip the weight for conv as we need to transpose for preferred layout NHWC.
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W
}
}

// Helper functions
common::Status SetConvBaseOptions(ModelBuilder& model_builder,
const Node& node, emscripten::val& options,
Expand All @@ -48,7 +38,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
const std::vector<int64_t>& strides,
const std::vector<int64_t>& dilations,
std::vector<int64_t>& pads,
const bool is_nhwc,
const bool is_conv1d,
const logging::Logger& logger) {
NodeAttrHelper helper(node);
Expand All @@ -61,7 +50,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
// Calculate explicit padding for autoPad.
if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc));
pads, strides, dilations, auto_pad_type, pads_out));
pads = pads_out;
}
} else if (node.OpType() == "ConvTranspose") {
Expand All @@ -82,7 +71,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
// Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER.
ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3],
pads, strides, dilations, output_padding,
auto_pad_type, pads_out, output_shape, !is_nhwc));
auto_pad_type, pads_out, output_shape));

if (output_shape[0] != -1 && output_shape[1] != -1) {
options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape)));
Expand Down Expand Up @@ -111,89 +100,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
return Status::OK();
}

// Both depthwise Conv and ConvTranspose share the same logic to add the layout.
Status AddInitializerInNewLayout(ModelBuilder& model_builder,
const std::string& name,
bool is_conv,
bool is_conv1d) {
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
auto data_type = tensor.data_type();

const auto& shape = tensor.dims();
std::vector<uint32_t> dims = GetVecUint32FromVecInt64(std::vector<int64_t>(std::begin(shape), std::end(shape)));

if (is_conv1d) {
// Support conv1d by prepending a 1 size dimension.
dims.push_back(1);
}

const uint8_t* src = nullptr;
Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath());
src = unpacked_tensor.DataAsByteSpan().data();
const auto out_t = dims[0], in_t = dims[1],
h_t = dims[2], w_t = dims[3];
std::vector<uint32_t> dest_shape;
if (is_conv == 1)
dest_shape = {out_t, h_t, w_t, in_t}; // L_0231
else
dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight

SafeInt<size_t> num_elements = SafeInt<size_t>(Product(dest_shape));

size_t element_size{0};
switch (data_type) {
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
element_size = sizeof(uint8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
element_size = sizeof(int8_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
element_size = sizeof(uint16_t);
break;
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
element_size = sizeof(float);
break;
default:
break;
}
std::unique_ptr<uint8_t[]> buffer_holder(new uint8_t[element_size * num_elements]);
uint8_t* buffer = buffer_holder.get();

for (uint32_t out = 0; out < out_t; out++) {
for (uint32_t in = 0; in < in_t; in++) {
for (uint32_t h = 0; h < h_t; h++) {
for (uint32_t w = 0; w < w_t; w++) {
auto onnx_idx = out * in_t * h_t * w_t +
in * h_t * w_t +
h * w_t +
w;

uint32_t nnapi_idx;
if (is_conv == 1) { // L_0231
nnapi_idx = out * h_t * w_t * in_t +
h * w_t * in_t +
w * in_t +
in;
} else { // L_1230 for depthwise conv weight
nnapi_idx = in * h_t * w_t * out_t +
h * w_t * out_t +
w * out_t +
out;
}

for (size_t i = 0; i < element_size; i++) {
buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i];
}
}
}
}
}
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size,
dest_shape, data_type));
return Status::OK();
}

// Add operator related.

Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
Expand All @@ -203,7 +109,6 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
const auto& op_type = node.OpType();
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
emscripten::val output = emscripten::val::object();
const auto& initializers(model_builder.GetInitializerTensors());

std::vector<int64_t> input_shape;
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape");
Expand All @@ -216,19 +121,11 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
auto pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});

const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3;
const bool is_constant_weight = Contains(initializers, weight_name);
// Support conv1d by prepending a 1 or 2 size dimensions.
if (is_conv1d) {
// Reshape input.
if (is_nhwc) {
// For NHWC preferred layout, the input has been transposed.
// For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2.
input_shape.insert(input_shape.begin() + 2, 1);
} else {
input_shape.push_back(1);
}
input_shape.push_back(1);
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(input_shape);
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));

Expand All @@ -244,63 +141,19 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
emscripten::val options = emscripten::val::object();
options.set("label", node.Name());
ORT_RETURN_IF_ERROR(SetConvBaseOptions(
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
bool depthwise = false;
if (op_type == "Conv" || op_type == "ConvInteger") {
int groups = options["groups"].as<int>();
if (is_nhwc) {
depthwise = (groups == input_shape[3] && groups != 1);
options.set("inputLayout", emscripten::val("nhwc"));
if (is_constant_weight) {
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d));
}
if (!depthwise) {
options.set("filterLayout", emscripten::val("ohwi"));
} else {
options.set("filterLayout", emscripten::val("ihwo"));
}
}
} else { // ConvTranspose
if (is_nhwc) {
options.set("inputLayout", emscripten::val("nhwc"));
options.set("filterLayout", emscripten::val("ohwi"));
if (is_constant_weight) {
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d));
}
}
}

model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger));
emscripten::val filter = model_builder.GetOperand(weight_name);

if (is_conv1d) {
// Reshape weight to 4D for conv1d.
if (!is_nhwc || !is_constant_weight) {
// The weight_shape has been appended 1's, reshape weight operand.
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
filter,
emscripten::val::array(new_shape),
reshape_options);
}
}

emscripten::val transpose_options = emscripten::val::object();
if (is_nhwc && !is_constant_weight) {
// For NHWC preferred layout, if the weight is input:
// - Transpose it from iohw -> ohwi for convTranspose.
// - Transpose it from oihw -> ihwo for depthwise conv.
// - Transpose it from oihw -> ohwi for conv.
std::vector<uint32_t> perm(4);
if (op_type == "ConvTranspose" || depthwise) {
perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight
} else {
perm = {0, 2, 3, 1}; // L_0231
}
transpose_options.set("permutation", emscripten::val::array(perm));
transpose_options.set("label", node.Name() + "_transpose_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("transpose", filter, transpose_options);
// The weight_shape has been appended 1's, reshape weight operand.
std::vector<uint32_t> new_shape = GetVecUint32FromVecInt64(weight_shape);
emscripten::val reshape_options = emscripten::val::object();
reshape_options.set("label", node.Name() + "_reshape_filter");
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
filter,
emscripten::val::array(new_shape),
reshape_options);
}

if (op_type == "Conv") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}

output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
Expand All @@ -104,9 +101,8 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });

size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3;
ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank;
auto insertion_point = new_shape.begin() + insertion_offset;
auto insertion_point = new_shape.begin() + 3;
if (input_shape.size() < webnn_shape_rank) {
// Pad the shape with extra 1's to satisfy WebNN v1's rank requirements.
new_shape.insert(insertion_point, -excess_rank, 1);
Expand All @@ -125,9 +121,6 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
reshape_input_options);
}

if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
// Reshape back to the original output shape for 3D input.
if (input_shape.size() != 4) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
options.set("strides", emscripten::val::array(strides));
const auto dilations = helper.Get("dilations", std::vector<int32_t>{1, 1});
options.set("dilations", emscripten::val::array(dilations));
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
} else {
options.set("layout", emscripten::val("nchw"));
}
options.set("layout", emscripten::val("nchw"));

// Add Padding.
// Usually using autopadding is more efficient than using explicit padding.
Expand All @@ -93,8 +89,7 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
helper.Get("strides", std::vector<int64_t>{1, 1}),
helper.Get("dilations", std::vector<int64_t>{1, 1}),
auto_pad_type,
pads_out,
model_builder.GetPreferredLayout() == DataLayout::NCHW));
pads_out));
pads = GetVecUint32FromVecInt64(pads_out);
}
// Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width],
Expand Down
Loading

0 comments on commit 5911422

Please sign in to comment.