Skip to content

Commit

Permalink
[WebNN EP] Update WebNN normalization ops (microsoft#18817)
Browse files Browse the repository at this point in the history
Use batchNormalization, layerNormalization and instanceNormalization
instead of meanVarianceNormalization to implement normalization Ops. The
spec of meanVarianceNormalization has been deleted.
Remove groupNormalization.
  • Loading branch information
zesongw authored and jslap-ubi committed Apr 5, 2024
1 parent 62fe34a commit 88c0cc4
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 92 deletions.
7 changes: 3 additions & 4 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"ArgMax", {"argMax", false}},
{"ArgMin", {"argMin", false}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"meanVarianceNormalization", false}},
{"BatchNormalization", {"batchNormalization", false}},
{"Cast", {"cast", false}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
Expand All @@ -162,12 +162,11 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"GlobalLpPool", {"l2Pool2d", false}},
{"Greater", {"greater", false}},
{"GreaterOrEqual", {"greaterOrEqual", false}},
{"GroupNormalization", {"meanVarianceNormalization", false}},
{"HardSigmoid", {"hardSigmoid", false}},
{"HardSwish", {"hardSwish", true}},
{"Identity", {"identity", false}},
{"InstanceNormalization", {"meanVarianceNormalization", false}},
{"LayerNormalization", {"meanVarianceNormalization", false}},
{"InstanceNormalization", {"instanceNormalization", false}},
{"LayerNormalization", {"layerNormalization", false}},
{"LeakyRelu", {"leakyRelu", true}},
{"Less", {"lesser", false}},
{"LessOrEqual", {"lesserOrEqual", false}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class NormalizationOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// All normalization are based on layout NCHW.
// TODO: add support for NHWC.
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -61,49 +59,13 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape.");
}

std::vector<uint32_t> new_scale_shape;
if (scale_size < rank) {
if (op_type == "BatchNormalization") {
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), rank - 2, 1);
} else if (op_type == "LayerNormalization") {
// Align right with leading ones.
scale_shape.insert(scale_shape.begin(), rank - scale_size, 1);
} else if (op_type == "InstanceNormalization") {
// Insert ones before and after the channel dimension.
scale_shape.insert(scale_shape.begin(), 1);
ORT_RETURN_IF(scale_size != 1 || rank < 2,
"The scale size should be 1 and rank should be at least 2 for InstanceNorm.");
scale_shape.insert(scale_shape.end(), rank - scale_size - 1, 1);
} else if (op_type == "GroupNormalization") {
// The input will be reshaped to 3D later. So just insert ones before the channel and after.
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), 1);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
options.set("scale", scale);

std::transform(scale_shape.cbegin(), scale_shape.cend(),
std::back_inserter(new_scale_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val reshape_output_scale =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_scale, emscripten::val::array(new_scale_shape));
options.set("scale", reshape_output_scale);

if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val reshape_output_bias =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_bias, emscripten::val::array(new_scale_shape));
options.set("bias", reshape_output_bias);
}
} else {
options.set("scale", model_builder.GetOperand(input_defs[1]->Name()));
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name());
options.set("bias", bias);
}

NodeAttrHelper helper(node);
Expand All @@ -114,56 +76,62 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
// Enlarge 1-D mean and variance to new scale shape.
emscripten::val reshape_mean =
model_builder.GetBuilder().call<emscripten::val>("reshape", mean, emscripten::val::array(new_scale_shape));
emscripten::val reshape_variance =
model_builder.GetBuilder().call<emscripten::val>("reshape", variance, emscripten::val::array(new_scale_shape));

std::vector<uint32_t> axes = {0};
for (uint32_t i = 2; i < rank; i++) {
axes.push_back(i);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}

options.set("axes", emscripten::val::array(axes));
options.set("mean", reshape_mean);
options.set("variance", reshape_variance);
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);
axis = HandleNegativeAxis(axis, rank);
std::vector<uint32_t> axes(rank - SafeInt<uint32_t>(axis));
std::iota(axes.begin(), axes.end(), axis);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) {
std::iota(axes.begin(), axes.end(), axis - 1);
} else {
std::iota(axes.begin(), axes.end(), axis);
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else if (op_type == "InstanceNormalization") {
std::vector<uint32_t> axes;
for (uint32_t i = 2; i < rank; i++) {
axes.emplace_back(i);
// WebNN spec only supports 4D input for instanceNormalization.
// Supports 3D input by prepending 1 size dimension.
// For models with dimensions greater than 4, they will be reshaped into 4D.
constexpr size_t webnn_shape_rank = 4;
if (input_shape.size() != webnn_shape_rank) {
std::vector<uint32_t> new_shape;
new_shape.reserve(std::max(input_shape.size(), webnn_shape_rank));
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });

size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3;
ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank;
auto insertion_point = new_shape.begin() + insertion_offset;
if (input_shape.size() < webnn_shape_rank) {
// Pad the shape with extra 1's to satisfy WebNN v1's rank requirements.
new_shape.insert(insertion_point, -excess_rank, 1);
} else {
// Fold the extra range to fit within WebNN v1's rank requirements.
uint32_t sum = std::accumulate(
insertion_point, insertion_point + excess_rank + 1, 1, std::multiplies<uint32_t>());
new_shape.erase(insertion_point, insertion_point + excess_rank);
*insertion_point = sum;
}
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
}

if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
// Reshape back to the original output shape for 3D input.
if (input_shape.size() != 4) {
std::vector<uint32_t> output_shape;
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(output_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", output, emscripten::val::array(output_shape));
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
} else if (op_type == "GroupNormalization") {
ORT_RETURN_IF_NOT(helper.HasAttr("num_groups"), "GroupNormalization num_group must be provided.");
int32_t group_count = helper.Get("num_groups", -1);
std::vector<uint32_t> orig_shape, new_shape;
std::transform(input_shape.cbegin(), input_shape.cend(),
std::back_inserter(orig_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
// Add N and Group.
ORT_RETURN_IF_NOT(rank >= 2, "Input for GroupNormalization cannot be a scalar or 1D");
new_shape.emplace_back(SafeInt<uint32_t>(input_shape[0]));
new_shape.emplace_back(SafeInt<uint32_t>(group_count));

ORT_RETURN_IF_NOT(group_count > 0 && input_shape[1] % group_count == 0,
"GroupNormalization num_group must be divisible by group.");
new_shape.emplace_back(SafeInt<uint32_t>(std::reduce(input_shape.begin() + 2, input_shape.end(),
input_shape[1] / group_count, std::multiplies<int64_t>())));
// Input will be reshaped to (N, group count, channels per group x D1 x D2 ... Dn) and recovered after normalization.
options.set("axes", emscripten::val::array(std::vector<uint32_t>{2}));
output = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", output, options);
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(orig_shape));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
Expand Down Expand Up @@ -214,7 +182,6 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat
constexpr static std::string_view op_types[] =
{
"BatchNormalization",
"GroupNormalization",
"InstanceNormalization",
"LayerNormalization",
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {

{ // Normalization
CreateNormalizationOpBuilder("BatchNormalization", op_registrations);
CreateNormalizationOpBuilder("GroupNormalization", op_registrations);
CreateNormalizationOpBuilder("InstanceNormalization", op_registrations);
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
}
Expand Down

0 comments on commit 88c0cc4

Please sign in to comment.