Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Update WebNN normalization ops #18817

Merged
merged 6 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"ArgMax", {"argMax", false}},
{"ArgMin", {"argMin", false}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"meanVarianceNormalization", false}},
{"BatchNormalization", {"batchNormalization", false}},
{"Cast", {"cast", false}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
Expand All @@ -162,12 +162,11 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"GlobalLpPool", {"l2Pool2d", false}},
{"Greater", {"greater", false}},
{"GreaterOrEqual", {"greaterOrEqual", false}},
{"GroupNormalization", {"meanVarianceNormalization", false}},
{"HardSigmoid", {"hardSigmoid", false}},
{"HardSwish", {"hardSwish", true}},
{"Identity", {"identity", false}},
{"InstanceNormalization", {"meanVarianceNormalization", false}},
{"LayerNormalization", {"meanVarianceNormalization", false}},
{"InstanceNormalization", {"instanceNormalization", false}},
{"LayerNormalization", {"layerNormalization", false}},
{"LeakyRelu", {"leakyRelu", true}},
{"Less", {"lesser", false}},
{"LessOrEqual", {"lesserOrEqual", false}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class NormalizationOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// All normalization are based on layout NCHW.
// TODO: add support for NHWC.
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -61,49 +59,13 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape.");
}

std::vector<uint32_t> new_scale_shape;
if (scale_size < rank) {
if (op_type == "BatchNormalization") {
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), rank - 2, 1);
} else if (op_type == "LayerNormalization") {
// Align right with leading ones.
scale_shape.insert(scale_shape.begin(), rank - scale_size, 1);
} else if (op_type == "InstanceNormalization") {
// Insert ones before and after the channel dimension.
scale_shape.insert(scale_shape.begin(), 1);
ORT_RETURN_IF(scale_size != 1 || rank < 2,
"The scale size should be 1 and rank should be at least 2 for InstanceNorm.");
scale_shape.insert(scale_shape.end(), rank - scale_size - 1, 1);
} else if (op_type == "GroupNormalization") {
// The input will be reshaped to 3D later. So just insert ones before the channel and after.
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), 1);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
options.set("scale", scale);

std::transform(scale_shape.cbegin(), scale_shape.cend(),
std::back_inserter(new_scale_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val reshape_output_scale =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_scale, emscripten::val::array(new_scale_shape));
options.set("scale", reshape_output_scale);

if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val reshape_output_bias =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_bias, emscripten::val::array(new_scale_shape));
options.set("bias", reshape_output_bias);
}
} else {
options.set("scale", model_builder.GetOperand(input_defs[1]->Name()));
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name());
options.set("bias", bias);
}

NodeAttrHelper helper(node);
Expand All @@ -114,56 +76,62 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
// Enlarge 1-D mean and variance to new scale shape.
emscripten::val reshape_mean =
model_builder.GetBuilder().call<emscripten::val>("reshape", mean, emscripten::val::array(new_scale_shape));
emscripten::val reshape_variance =
model_builder.GetBuilder().call<emscripten::val>("reshape", variance, emscripten::val::array(new_scale_shape));

std::vector<uint32_t> axes = {0};
for (uint32_t i = 2; i < rank; i++) {
axes.push_back(i);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}

options.set("axes", emscripten::val::array(axes));
options.set("mean", reshape_mean);
options.set("variance", reshape_variance);
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);
axis = HandleNegativeAxis(axis, rank);
std::vector<uint32_t> axes(rank - SafeInt<uint32_t>(axis));
std::iota(axes.begin(), axes.end(), axis);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) {
std::iota(axes.begin(), axes.end(), axis - 1);
fdwr marked this conversation as resolved.
Show resolved Hide resolved
} else {
std::iota(axes.begin(), axes.end(), axis);
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else if (op_type == "InstanceNormalization") {
std::vector<uint32_t> axes;
for (uint32_t i = 2; i < rank; i++) {
axes.emplace_back(i);
// WebNN spec only supports 4D input for instanceNormalization.
// Supports 3D input by prepending 1 size dimension.
// For models with dimensions greater than 4, they will be reshaped into 4D.
constexpr size_t webnn_shape_rank = 4;
if (input_shape.size() != webnn_shape_rank) {
std::vector<uint32_t> new_shape;
fdwr marked this conversation as resolved.
Show resolved Hide resolved
new_shape.reserve(std::max(input_shape.size(), webnn_shape_rank));
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });

size_t insertion_offset = (model_builder.GetPreferredLayout() == DataLayout::NHWC) ? 2 : 3;
ptrdiff_t excess_rank = new_shape.size() - webnn_shape_rank;
auto insertion_point = new_shape.begin() + insertion_offset;
if (input_shape.size() < webnn_shape_rank) {
// Pad the shape with extra 1's to satisfy WebNN v1's rank requirements.
new_shape.insert(insertion_point, -excess_rank, 1);
} else {
// Fold the extra range to fit within WebNN v1's rank requirements.
uint32_t sum = std::accumulate(
insertion_point, insertion_point + excess_rank + 1, 1, std::multiplies<uint32_t>());
new_shape.erase(insertion_point, insertion_point + excess_rank);
*insertion_point = sum;
}
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
}

if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
// Reshape back to the original output shape for 3D input.
if (input_shape.size() != 4) {
std::vector<uint32_t> output_shape;
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(output_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", output, emscripten::val::array(output_shape));
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
} else if (op_type == "GroupNormalization") {
ORT_RETURN_IF_NOT(helper.HasAttr("num_groups"), "GroupNormalization num_group must be provided.");
int32_t group_count = helper.Get("num_groups", -1);
std::vector<uint32_t> orig_shape, new_shape;
std::transform(input_shape.cbegin(), input_shape.cend(),
std::back_inserter(orig_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
// Add N and Group.
ORT_RETURN_IF_NOT(rank >= 2, "Input for GroupNormalization cannot be a scalar or 1D");
new_shape.emplace_back(SafeInt<uint32_t>(input_shape[0]));
new_shape.emplace_back(SafeInt<uint32_t>(group_count));

ORT_RETURN_IF_NOT(group_count > 0 && input_shape[1] % group_count == 0,
"GroupNormalization num_group must be divisible by group.");
new_shape.emplace_back(SafeInt<uint32_t>(std::reduce(input_shape.begin() + 2, input_shape.end(),
input_shape[1] / group_count, std::multiplies<int64_t>())));
// Input will be reshaped to (N, group count, channels per group x D1 x D2 ... Dn) and recovered after normalization.
options.set("axes", emscripten::val::array(std::vector<uint32_t>{2}));
output = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", output, options);
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(orig_shape));
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
Expand Down Expand Up @@ -214,7 +182,6 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat
constexpr static std::string_view op_types[] =
{
"BatchNormalization",
"GroupNormalization",
fdwr marked this conversation as resolved.
Show resolved Hide resolved
"InstanceNormalization",
"LayerNormalization",
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {

{ // Normalization
CreateNormalizationOpBuilder("BatchNormalization", op_registrations);
CreateNormalizationOpBuilder("GroupNormalization", op_registrations);
CreateNormalizationOpBuilder("InstanceNormalization", op_registrations);
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
}
Expand Down
Loading