From 88c0cc4006dd757f3c2ca893401c90a0ab993db2 Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 9 Jan 2024 14:02:44 +0800 Subject: [PATCH] [WebNN EP] Update WebNN normalization ops (#18817) Use batchNormalization, layerNormalization and instanceNormalization instead of meanVarianceNormalization to implement normalization Ops. The spec of meanVarianceNormalization has been deleted. Remove groupNormalization. --- .../core/providers/webnn/builders/helper.h | 7 +- .../builders/impl/normalization_op_builder.cc | 141 +++++++----------- .../webnn/builders/op_builder_factory.cc | 1 - 3 files changed, 57 insertions(+), 92 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 8b8b85339a87c..5aec81af15761 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -139,7 +139,7 @@ static const InlinedHashMap op_map = { {"ArgMax", {"argMax", false}}, {"ArgMin", {"argMin", false}}, {"AveragePool", {"averagePool2d", true}}, - {"BatchNormalization", {"meanVarianceNormalization", false}}, + {"BatchNormalization", {"batchNormalization", false}}, {"Cast", {"cast", false}}, {"Ceil", {"ceil", true}}, {"Clip", {"clamp", true}}, @@ -162,12 +162,11 @@ static const InlinedHashMap op_map = { {"GlobalLpPool", {"l2Pool2d", false}}, {"Greater", {"greater", false}}, {"GreaterOrEqual", {"greaterOrEqual", false}}, - {"GroupNormalization", {"meanVarianceNormalization", false}}, {"HardSigmoid", {"hardSigmoid", false}}, {"HardSwish", {"hardSwish", true}}, {"Identity", {"identity", false}}, - {"InstanceNormalization", {"meanVarianceNormalization", false}}, - {"LayerNormalization", {"meanVarianceNormalization", false}}, + {"InstanceNormalization", {"instanceNormalization", false}}, + {"LayerNormalization", {"layerNormalization", false}}, {"LeakyRelu", {"leakyRelu", true}}, {"Less", {"lesser", false}}, {"LessOrEqual", {"lesserOrEqual", false}}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 756a838cc0c3e..4d2470dfe7deb 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -27,8 +27,6 @@ class NormalizationOpBuilder : public BaseOpBuilder { const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; }; -// All normalization are based on layout NCHW. -// TODO: add support for NHWC. Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const { @@ -61,49 +59,13 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape."); } - std::vector new_scale_shape; - if (scale_size < rank) { - if (op_type == "BatchNormalization") { - scale_shape.insert(scale_shape.begin(), 1); - scale_shape.insert(scale_shape.end(), rank - 2, 1); - } else if (op_type == "LayerNormalization") { - // Align right with leading ones. - scale_shape.insert(scale_shape.begin(), rank - scale_size, 1); - } else if (op_type == "InstanceNormalization") { - // Insert ones before and after the channel dimension. - scale_shape.insert(scale_shape.begin(), 1); - ORT_RETURN_IF(scale_size != 1 || rank < 2, - "The scale size should be 1 and rank should be at least 2 for InstanceNorm."); - scale_shape.insert(scale_shape.end(), rank - scale_size - 1, 1); - } else if (op_type == "GroupNormalization") { - // The input will be reshaped to 3D later. So just insert ones before the channel and after. - scale_shape.insert(scale_shape.begin(), 1); - scale_shape.insert(scale_shape.end(), 1); - } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); - } + emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name()); + options.set("scale", scale); - std::transform(scale_shape.cbegin(), scale_shape.cend(), - std::back_inserter(new_scale_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name()); - emscripten::val reshape_output_scale = - model_builder.GetBuilder().call("reshape", reshape_scale, emscripten::val::array(new_scale_shape)); - options.set("scale", reshape_output_scale); - - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name()); - emscripten::val reshape_output_bias = - model_builder.GetBuilder().call("reshape", reshape_bias, emscripten::val::array(new_scale_shape)); - options.set("bias", reshape_output_bias); - } - } else { - options.set("scale", model_builder.GetOperand(input_defs[1]->Name())); - if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { - // Bias input exists, and bias's shape is the same as scale's shape. - options.set("bias", model_builder.GetOperand(input_defs[2]->Name())); - } + if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) { + // Bias input exists, and bias's shape is the same as scale's shape. + emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name()); + options.set("bias", bias); } NodeAttrHelper helper(node); @@ -114,56 +76,62 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs."); emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name()); - // Enlarge 1-D mean and variance to new scale shape. - emscripten::val reshape_mean = - model_builder.GetBuilder().call("reshape", mean, emscripten::val::array(new_scale_shape)); - emscripten::val reshape_variance = - model_builder.GetBuilder().call("reshape", variance, emscripten::val::array(new_scale_shape)); - - std::vector axes = {0}; - for (uint32_t i = 2; i < rank; i++) { - axes.push_back(i); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("axis", rank - 1); } - - options.set("axes", emscripten::val::array(axes)); - options.set("mean", reshape_mean); - options.set("variance", reshape_variance); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); + output = model_builder.GetBuilder().call("batchNormalization", input, mean, variance, options); } else if (op_type == "LayerNormalization") { int64_t axis = helper.Get("axis", -1); axis = HandleNegativeAxis(axis, rank); std::vector axes(rank - SafeInt(axis)); - std::iota(axes.begin(), axes.end(), axis); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) { + std::iota(axes.begin(), axes.end(), axis - 1); + } else { + std::iota(axes.begin(), axes.end(), axis); + } options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); + output = model_builder.GetBuilder().call("layerNormalization", input, options); } else if (op_type == "InstanceNormalization") { - std::vector axes; - for (uint32_t i = 2; i < rank; i++) { - axes.emplace_back(i); + // WebNN spec only supports 4D input for instanceNormalization. + // Supports 3D input by prepending 1 size dimension. + // For models with dimensions greater than 4, they will be reshaped into 4D. + constexpr size_t webnn_shape_rank = 4; + if (input_shape.size() != webnn_shape_rank) { + std::vector new_shape; + new_shape.reserve(std::max(input_shape.size(), webnn_shape_rank)); + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + + size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3; + ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank; + auto insertion_point = new_shape.begin() + insertion_offset; + if (input_shape.size() < webnn_shape_rank) { + // Pad the shape with extra 1's to satisfy WebNN v1's rank requirements. + new_shape.insert(insertion_point, -excess_rank, 1); + } else { + // Fold the extra range to fit within WebNN v1's rank requirements. + uint32_t sum = std::accumulate( + insertion_point, insertion_point + excess_rank + 1, 1, std::multiplies()); + new_shape.erase(insertion_point, insertion_point + excess_rank); + *insertion_point = sum; + } + input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + } + + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + options.set("layout", emscripten::val("nhwc")); + } + output = model_builder.GetBuilder().call("instanceNormalization", input, options); + // Reshape back to the original output shape for 3D input. + if (input_shape.size() != 4) { + std::vector output_shape; + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(output_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + output = model_builder.GetBuilder().call( + "reshape", output, emscripten::val::array(output_shape)); } - options.set("axes", emscripten::val::array(axes)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", input, options); - } else if (op_type == "GroupNormalization") { - ORT_RETURN_IF_NOT(helper.HasAttr("num_groups"), "GroupNormalization num_group must be provided."); - int32_t group_count = helper.Get("num_groups", -1); - std::vector orig_shape, new_shape; - std::transform(input_shape.cbegin(), input_shape.cend(), - std::back_inserter(orig_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - // Add N and Group. - ORT_RETURN_IF_NOT(rank >= 2, "Input for GroupNormalization cannot be a scalar or 1D"); - new_shape.emplace_back(SafeInt(input_shape[0])); - new_shape.emplace_back(SafeInt(group_count)); - - ORT_RETURN_IF_NOT(group_count > 0 && input_shape[1] % group_count == 0, - "GroupNormalization num_group must be divisible by group."); - new_shape.emplace_back(SafeInt(std::reduce(input_shape.begin() + 2, input_shape.end(), - input_shape[1] / group_count, std::multiplies()))); - // Input will be reshaped to (N, group count, channels per group x D1 x D2 ... Dn) and recovered after normalization. - options.set("axes", emscripten::val::array(std::vector{2})); - output = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); - output = model_builder.GetBuilder().call("meanVarianceNormalization", output, options); - output = model_builder.GetBuilder().call("reshape", output, emscripten::val::array(orig_shape)); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); } @@ -214,7 +182,6 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat constexpr static std::string_view op_types[] = { "BatchNormalization", - "GroupNormalization", "InstanceNormalization", "LayerNormalization", }; diff --git a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc index 463317a4dafda..613771eda71fe 100644 --- a/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc +++ b/onnxruntime/core/providers/webnn/builders/op_builder_factory.cc @@ -111,7 +111,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() { { // Normalization CreateNormalizationOpBuilder("BatchNormalization", op_registrations); - CreateNormalizationOpBuilder("GroupNormalization", op_registrations); CreateNormalizationOpBuilder("InstanceNormalization", op_registrations); CreateNormalizationOpBuilder("LayerNormalization", op_registrations); }