From 1e10c1ee4a8373cda4e57d706a1a21508ed7b0d6 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Thu, 29 Aug 2024 09:01:59 +0800 Subject: [PATCH] Revert "[WebNN EP] Remove NHWC preferred layout (#21570)" This reverts commit 59114227fd6cd86f7fa3477291fe0fcfae9afceb. --- .../webnn/builders/impl/builder_utils.cc | 21 ++- .../webnn/builders/impl/builder_utils.h | 6 +- .../webnn/builders/impl/conv_op_builder.cc | 171 ++++++++++++++++-- .../builders/impl/normalization_op_builder.cc | 9 +- .../webnn/builders/impl/pool_op_builder.cc | 9 +- .../webnn/builders/impl/resize_op_builder.cc | 34 +++- .../providers/webnn/builders/model_builder.cc | 62 ++++++- .../providers/webnn/builders/model_builder.h | 13 +- .../webnn/webnn_execution_provider.cc | 6 +- .../webnn/webnn_execution_provider.h | 4 +- 10 files changed, 296 insertions(+), 39 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc index 594e75042f2ae..113cc3df5438d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc @@ -19,9 +19,10 @@ common::Status ComputeConvPads(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out) { - const int64_t input_size_y = input_shape[2]; - const int64_t input_size_x = input_shape[3]; + std::vector& pads_out, + bool use_nchw) { + const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; + const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; @@ -53,15 +54,16 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out) { + std::vector& pads_out, + bool use_nchw) { if (AutoPadType::SAME_UPPER == auto_pad_type) { ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_UPPER, pads_out)); + AutoPadType::SAME_UPPER, pads_out, use_nchw)); } else { ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x, onnx_pads, onnx_strides, onnx_dilations, - AutoPadType::SAME_LOWER, pads_out)); + AutoPadType::SAME_LOWER, pads_out, use_nchw)); } return Status::OK(); } @@ -109,9 +111,10 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector const std::vector& onnx_output_padding, AutoPadType auto_pad_type, std::vector& pads_out, - std::vector& output_shape_out) { - const int64_t input_size_y = input_shape[2]; - const int64_t input_size_x = input_shape[3]; + std::vector& output_shape_out, + bool use_nchw) { + const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1]; + const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2]; const int64_t stride_y = onnx_strides[0]; const int64_t stride_x = onnx_strides[1]; const int64_t dilation_y = onnx_dilations[0]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h index f9f9746d6ed83..5a156c96c4852 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h +++ b/onnxruntime/core/providers/webnn/builders/impl/builder_utils.h @@ -21,7 +21,8 @@ common::Status HandleAutoPad(const std::vector input_shape, const std::vector& onnx_strides, const std::vector& onnx_dilations, AutoPadType auto_pad_type, - std::vector& pads_out) ORT_MUST_USE_RESULT; + std::vector& pads_out, + bool use_nchw) ORT_MUST_USE_RESULT; // Compute pads and output shape for ConvTranspose. common::Status ComputeConvTransposePadsAndOutputShape(const std::vector input_shape, @@ -33,7 +34,8 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector const std::vector& onnx_output_padding, AutoPadType auto_pad_type, std::vector& pads_out, - std::vector& output_shape_out) ORT_MUST_USE_RESULT; + std::vector& output_shape_out, + bool use_nchw) ORT_MUST_USE_RESULT; } // namespace webnn } // namespace onnxruntime diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 980c5dcd184c0..76a8a178678df 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -18,6 +18,9 @@ namespace webnn { class ConvOpBuilder : public BaseOpBuilder { // Add operator related. + public: + void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override; + private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; @@ -30,6 +33,13 @@ class ConvOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override; }; +void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { + // skip the weight for conv as we need to transpose for preferred layout NHWC. + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W + } +} + // Helper functions common::Status SetConvBaseOptions(ModelBuilder& model_builder, const Node& node, emscripten::val& options, @@ -38,6 +48,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, const std::vector& strides, const std::vector& dilations, std::vector& pads, + const bool is_nhwc, const bool is_conv1d, const logging::Logger& logger) { NodeAttrHelper helper(node); @@ -50,7 +61,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Calculate explicit padding for autoPad. if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) { ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3], - pads, strides, dilations, auto_pad_type, pads_out)); + pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc)); pads = pads_out; } } else if (node.OpType() == "ConvTranspose") { @@ -71,7 +82,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, // Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER. ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3], pads, strides, dilations, output_padding, - auto_pad_type, pads_out, output_shape)); + auto_pad_type, pads_out, output_shape, !is_nhwc)); if (output_shape[0] != -1 && output_shape[1] != -1) { options.set("outputSizes", emscripten::val::array(GetVecUint32FromVecInt64(output_shape))); @@ -100,6 +111,89 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder, return Status::OK(); } +// Both depthwise Conv and ConvTranspose share the same logic to add the layout. +Status AddInitializerInNewLayout(ModelBuilder& model_builder, + const std::string& name, + bool is_conv, + bool is_conv1d) { + const auto& tensor = *model_builder.GetInitializerTensors().at(name); + auto data_type = tensor.data_type(); + + const auto& shape = tensor.dims(); + std::vector dims = GetVecUint32FromVecInt64(std::vector(std::begin(shape), std::end(shape))); + + if (is_conv1d) { + // Support conv1d by prepending a 1 size dimension. + dims.push_back(1); + } + + const uint8_t* src = nullptr; + Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath()); + src = unpacked_tensor.DataAsByteSpan().data(); + const auto out_t = dims[0], in_t = dims[1], + h_t = dims[2], w_t = dims[3]; + std::vector dest_shape; + if (is_conv == 1) + dest_shape = {out_t, h_t, w_t, in_t}; // L_0231 + else + dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight + + SafeInt num_elements = SafeInt(Product(dest_shape)); + + size_t element_size{0}; + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + element_size = sizeof(uint8_t); + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + element_size = sizeof(int8_t); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + element_size = sizeof(uint16_t); + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + element_size = sizeof(float); + break; + default: + break; + } + std::unique_ptr buffer_holder(new uint8_t[element_size * num_elements]); + uint8_t* buffer = buffer_holder.get(); + + for (uint32_t out = 0; out < out_t; out++) { + for (uint32_t in = 0; in < in_t; in++) { + for (uint32_t h = 0; h < h_t; h++) { + for (uint32_t w = 0; w < w_t; w++) { + auto onnx_idx = out * in_t * h_t * w_t + + in * h_t * w_t + + h * w_t + + w; + + uint32_t nnapi_idx; + if (is_conv == 1) { // L_0231 + nnapi_idx = out * h_t * w_t * in_t + + h * w_t * in_t + + w * in_t + + in; + } else { // L_1230 for depthwise conv weight + nnapi_idx = in * h_t * w_t * out_t + + h * w_t * out_t + + w * out_t + + out; + } + + for (size_t i = 0; i < element_size; i++) { + buffer[element_size * nnapi_idx + i] = src[element_size * onnx_idx + i]; + } + } + } + } + } + ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size, + dest_shape, data_type)); + return Status::OK(); +} + // Add operator related. Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, @@ -109,6 +203,7 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N const auto& op_type = node.OpType(); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); emscripten::val output = emscripten::val::object(); + const auto& initializers(model_builder.GetInitializerTensors()); std::vector input_shape; ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input shape"); @@ -121,11 +216,19 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N auto dilations = helper.Get("dilations", std::vector{1, 1}); auto pads = helper.Get("pads", std::vector{0, 0, 0, 0}); + const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3; + const bool is_constant_weight = Contains(initializers, weight_name); // Support conv1d by prepending a 1 or 2 size dimensions. if (is_conv1d) { // Reshape input. - input_shape.push_back(1); + if (is_nhwc) { + // For NHWC preferred layout, the input has been transposed. + // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2. + input_shape.insert(input_shape.begin() + 2, 1); + } else { + input_shape.push_back(1); + } std::vector new_shape = GetVecUint32FromVecInt64(input_shape); input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); @@ -141,19 +244,63 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); ORT_RETURN_IF_ERROR(SetConvBaseOptions( - model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger)); + model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger)); + bool depthwise = false; + if (op_type == "Conv" || op_type == "ConvInteger") { + int groups = options["groups"].as(); + if (is_nhwc) { + depthwise = (groups == input_shape[3] && groups != 1); + options.set("inputLayout", emscripten::val("nhwc")); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d)); + } + if (!depthwise) { + options.set("filterLayout", emscripten::val("ohwi")); + } else { + options.set("filterLayout", emscripten::val("ihwo")); + } + } + } else { // ConvTranspose + if (is_nhwc) { + options.set("inputLayout", emscripten::val("nhwc")); + options.set("filterLayout", emscripten::val("ohwi")); + if (is_constant_weight) { + ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, true, is_conv1d)); + } + } + } + emscripten::val filter = model_builder.GetOperand(weight_name); if (is_conv1d) { // Reshape weight to 4D for conv1d. - // The weight_shape has been appended 1's, reshape weight operand. - std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); - emscripten::val reshape_options = emscripten::val::object(); - reshape_options.set("label", node.Name() + "_reshape_filter"); - filter = model_builder.GetBuilder().call("reshape", - filter, - emscripten::val::array(new_shape), - reshape_options); + if (!is_nhwc || !is_constant_weight) { + // The weight_shape has been appended 1's, reshape weight operand. + std::vector new_shape = GetVecUint32FromVecInt64(weight_shape); + emscripten::val reshape_options = emscripten::val::object(); + reshape_options.set("label", node.Name() + "_reshape_filter"); + filter = model_builder.GetBuilder().call("reshape", + filter, + emscripten::val::array(new_shape), + reshape_options); + } + } + + emscripten::val transpose_options = emscripten::val::object(); + if (is_nhwc && !is_constant_weight) { + // For NHWC preferred layout, if the weight is input: + // - Transpose it from iohw -> ohwi for convTranspose. + // - Transpose it from oihw -> ihwo for depthwise conv. + // - Transpose it from oihw -> ohwi for conv. + std::vector perm(4); + if (op_type == "ConvTranspose" || depthwise) { + perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight + } else { + perm = {0, 2, 3, 1}; // L_0231 + } + transpose_options.set("permutation", emscripten::val::array(perm)); + transpose_options.set("label", node.Name() + "_transpose_filter"); + filter = model_builder.GetBuilder().call("transpose", filter, transpose_options); } if (op_type == "Conv") { diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 347cd11898d25..4d068baf35e72 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -79,6 +79,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name()); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("axis", rank - 1); + } output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { @@ -101,8 +104,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder std::back_inserter(new_shape), [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3; ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank; - auto insertion_point = new_shape.begin() + 3; + auto insertion_point = new_shape.begin() + insertion_offset; if (input_shape.size() < webnn_shape_rank) { // Pad the shape with extra 1's to satisfy WebNN v1's rank requirements. new_shape.insert(insertion_point, -excess_rank, 1); @@ -121,6 +125,9 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder reshape_input_options); } + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("layout", emscripten::val("nhwc")); + } output = model_builder.GetBuilder().call("instanceNormalization", input, options); // Reshape back to the original output shape for 3D input. if (input_shape.size() != 4) { diff --git a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc index 09eb8e79ce1d3..0af62dacedbd5 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/pool_op_builder.cc @@ -70,7 +70,11 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, options.set("strides", emscripten::val::array(strides)); const auto dilations = helper.Get("dilations", std::vector{1, 1}); options.set("dilations", emscripten::val::array(dilations)); - options.set("layout", emscripten::val("nchw")); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("layout", emscripten::val("nhwc")); + } else { + options.set("layout", emscripten::val("nchw")); + } // Add Padding. // Usually using autopadding is more efficient than using explicit padding. @@ -89,7 +93,8 @@ Status PoolOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, helper.Get("strides", std::vector{1, 1}), helper.Get("dilations", std::vector{1, 1}), auto_pad_type, - pads_out)); + pads_out, + model_builder.GetPreferredLayout() == DataLayout::NCHW)); pads = GetVecUint32FromVecInt64(pads_out); } // Permute the ONNX's pads, which is [beginning_height, beginning_width, ending_height, ending_width], diff --git a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc index 0e211de5a3986..2218c858951d3 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc @@ -120,10 +120,18 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::vector scales; std::vector sizes; + std::vector scales_hw; + std::vector sizes_hw; + std::vector axes; std::string scales_name = GetTensorName(input_defs, 2); + const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC; if (!scales_name.empty()) { // Use scales. ORT_RETURN_IF_NOT(GetResizeScales(initializers, node, scales, logger), "Error getting resize scales"); - std::vector scales_hw = {scales[2], scales[3]}; + if (is_nhwc) { + scales_hw = {scales[1], scales[2]}; + } else { + scales_hw = {scales[2], scales[3]}; + } options.set("scales", emscripten::val::array(scales_hw)); } else { // Use sizes, we already checked inputs in IsOpSupportedImpl. std::vector output_sizes; @@ -132,11 +140,19 @@ Status ResizeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, std::transform(output_sizes.cbegin(), output_sizes.cend(), std::back_inserter(sizes), [](int64_t dim) -> int32_t { return SafeInt(dim); }); - std::vector sizes_hw = {sizes[2], sizes[3]}; + if (is_nhwc) { + sizes_hw = {sizes[1], sizes[2]}; + } else { + sizes_hw = {sizes[2], sizes[3]}; + } options.set("sizes", emscripten::val::array(sizes_hw)); } - std::vector axes = {2, 3}; + if (is_nhwc) { + axes = {1, 2}; + } else { + axes = {2, 3}; + } options.set("axes", emscripten::val::array(axes)); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); @@ -205,6 +221,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } + const bool is_nhwc = node.Domain() == kMSInternalNHWCDomain; // We want to check if the scales or sizes are not trying to resize on N/C channels here. if (has_scales) { // We are using scales. std::vector scales; @@ -212,7 +229,7 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; float scale_n = scales[0]; - float scale_c = scales[1]; + float scale_c = is_nhwc ? scales[3] : scales[1]; if (scale_n != 1.0f || scale_c != 1.0f) { LOGS(logger, VERBOSE) << "Scales of N/C channel should be 1" << "Resize of N/C channels are not supported" @@ -222,8 +239,8 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers // For now we only support upscale, so the scale_h and scale_w should be an integer >= 1. // TODO support ResizeBilinear. - float scale_h = scales[2]; - float scale_w = scales[3]; + float scale_h = is_nhwc ? scales[1] : scales[2]; + float scale_w = is_nhwc ? scales[2] : scales[3]; // Onnx spec requires scale to be a positive float, so we are not checking that here. if (roundf(scale_h) != scale_h) { @@ -244,11 +261,12 @@ bool ResizeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; auto output_size_n = output_sizes[0]; - if (output_size_n != input_shape[0] || output_sizes[1] != input_shape[1]) { + const int c_idx = is_nhwc ? 3 : 1; + if (output_size_n != input_shape[0] || output_sizes[c_idx] != input_shape[c_idx]) { LOGS(logger, VERBOSE) << "Output sizes of N/C chanel should match the input sizes, " << "Resize of N/C channels are not supported" << ", input_size_n, " << input_shape[0] << ", output_size_n, " << output_size_n - << ". input_size_c, " << input_shape[1] << ", output_size_c, " << output_sizes[1]; + << ". input_size_c, " << input_shape[c_idx] << ", output_size_c, " << output_sizes[c_idx]; return false; } } diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.cc b/onnxruntime/core/providers/webnn/builders/model_builder.cc index 02fb8e732b3c7..44bec1fb6fd48 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/model_builder.cc @@ -20,10 +20,12 @@ namespace onnxruntime { namespace webnn { ModelBuilder::ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const WebnnDeviceType wnn_device_type) + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type) : graph_viewer_(graph_viewer), logger_(logger), wnn_context_(context), + preferred_layout_(preferred_layout), wnn_device_type_(wnn_device_type) { // Create WebNN MLGraphBuilder for each ModelBuilder, because MLGraphBuilder.build() // is only allowed to be called once. @@ -252,6 +254,64 @@ Status ModelBuilder::AddOperations() { return Status::OK(); } +Status ModelBuilder::AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, const size_t size, + const std::vector shape, const int32_t data_type) { + auto persist_buffer = std::make_unique(size); + uint8_t* dest = persist_buffer.get(); + memcpy(dest, buffer, size); + emscripten::val view = emscripten::val::undefined(); + emscripten::val desc = emscripten::val::object(); + ORT_RETURN_IF_NOT(SetWebnnDataType(desc, data_type), "Unsupported data type"); + switch (data_type) { + case ONNX_NAMESPACE::TensorProto_DataType_BOOL: + case ONNX_NAMESPACE::TensorProto_DataType_UINT8: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint8_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT8: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int8_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint16_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(float), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT32: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int32_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_INT64: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(int64_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT32: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint32_t), + reinterpret_cast(dest))}; + break; + case ONNX_NAMESPACE::TensorProto_DataType_UINT64: + view = emscripten::val{emscripten::typed_memory_view(size / sizeof(uint64_t), + reinterpret_cast(dest))}; + break; + default: + break; + } + + desc.set("dimensions", emscripten::val::array(shape)); + emscripten::val operand = emscripten::val::object(); + // Wasm memory grow will cause all array buffers reallocation, which will be treated as detached + // buffers in JS side. Simply create a copy to fix it. + operand = wnn_builder_.call("constant", desc, view.call("slice")); + + AddOperand(name, operand); + mem_persist_buffers_.push_back(std::move(persist_buffer)); + return Status::OK(); +} + Status ModelBuilder::RegisterModelOutputs() { for (const auto* node_arg : graph_viewer_.GetOutputs()) { ORT_RETURN_IF_ERROR(RegisterModelInputOutput(*node_arg, false /* is_input */)); diff --git a/onnxruntime/core/providers/webnn/builders/model_builder.h b/onnxruntime/core/providers/webnn/builders/model_builder.h index a954daa855e4a..2d686070cdcc1 100644 --- a/onnxruntime/core/providers/webnn/builders/model_builder.h +++ b/onnxruntime/core/providers/webnn/builders/model_builder.h @@ -22,7 +22,8 @@ class IOpBuilder; class ModelBuilder { public: ModelBuilder(const GraphViewer& graph_viewer, const logging::Logger& logger, - const emscripten::val& context, const WebnnDeviceType wnn_device_type); + const emscripten::val& context, const DataLayout preferred_layout, + const WebnnDeviceType wnn_device_type); ~ModelBuilder() = default; Status Compile(std::unique_ptr& model) ORT_MUST_USE_RESULT; @@ -36,6 +37,15 @@ class ModelBuilder { const emscripten::val& GetOperand(const std::string& name) const { return wnn_operands_.at(name); } void AddOperand(const std::string& name, const emscripten::val& operand); const emscripten::val& GetZeroConstant(const std::string& data_type); + // Use the buffers to persist WebNN allocated data like transposed weight. + // It ensures the validity during inference session. + std::vector> mem_persist_buffers_; + // Add a constant operand (allocate persist buffer and move the ownership to mem_persist_buffers_). + Status AddOperandFromPersistMemoryBuffer( + const std::string& name, const void* buffer, + const size_t size, const std::vector shape, const int32_t data_type); + + DataLayout GetPreferredLayout() const { return preferred_layout_; } WebnnDeviceType GetWebnnDeviceType() const { return wnn_device_type_; } @@ -54,6 +64,7 @@ class ModelBuilder { emscripten::val wnn_context_ = emscripten::val::undefined(); emscripten::val wnn_builder_ = emscripten::val::undefined(); + DataLayout preferred_layout_; WebnnDeviceType wnn_device_type_; InlinedHashMap wnn_operands_; std::vector input_names_; diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc index a6fe00241e55f..b918daf838c99 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.cc +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.cc @@ -19,9 +19,12 @@ namespace onnxruntime { WebNNExecutionProvider::WebNNExecutionProvider(const std::string& webnn_device_flags) : IExecutionProvider{onnxruntime::kWebNNExecutionProvider} { + // WebNN EP uses NHWC layout for CPU XNNPACK backend and NCHW for GPU DML backend. if (webnn_device_flags.compare("cpu") == 0) { + preferred_layout_ = DataLayout::NHWC; wnn_device_type_ = webnn::WebnnDeviceType::CPU; } else { + preferred_layout_ = DataLayout::NCHW; if (webnn_device_flags.compare("gpu") == 0) { wnn_device_type_ = webnn::WebnnDeviceType::GPU; } else if (webnn_device_flags.compare("npu") == 0) { @@ -209,7 +212,8 @@ common::Status WebNNExecutionProvider::Compile(const std::vector model; ORT_RETURN_IF_ERROR(builder.Compile(model)); diff --git a/onnxruntime/core/providers/webnn/webnn_execution_provider.h b/onnxruntime/core/providers/webnn/webnn_execution_provider.h index 1fbc99098e30f..d8c1e90c86cdb 100644 --- a/onnxruntime/core/providers/webnn/webnn_execution_provider.h +++ b/onnxruntime/core/providers/webnn/webnn_execution_provider.h @@ -26,8 +26,7 @@ class WebNNExecutionProvider : public IExecutionProvider { GetCapability(const onnxruntime::GraphViewer& graph_viewer, const IKernelLookup& /*kernel_registries*/) const override; - // WebNN EP uses default NCHW layout for all backends. - DataLayout GetPreferredLayout() const override { return DataLayout::NCHW; } + DataLayout GetPreferredLayout() const override { return preferred_layout_; } // We implement the Compile that takes FusedNodeAndGraph instances. FusionStyle GetFusionStyle() const override { return FusionStyle::FilteredGraphViewer; } @@ -45,6 +44,7 @@ class WebNNExecutionProvider : public IExecutionProvider { private: emscripten::val wnn_context_ = emscripten::val::undefined(); + DataLayout preferred_layout_; webnn::WebnnDeviceType wnn_device_type_; InlinedHashMap> models_; ModelMetadefIdGenerator metadef_id_generator_;