Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WebNN EP] Update WebNN normalization ops #18817

Merged
merged 6 commits into from
Jan 9, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions onnxruntime/core/providers/webnn/builders/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"ArgMax", {"argMax", false}},
{"ArgMin", {"argMin", false}},
{"AveragePool", {"averagePool2d", true}},
{"BatchNormalization", {"meanVarianceNormalization", false}},
{"BatchNormalization", {"batchNormalization", false}},
{"Cast", {"cast", false}},
{"Ceil", {"ceil", true}},
{"Clip", {"clamp", true}},
Expand All @@ -162,12 +162,11 @@ static const InlinedHashMap<std::string, WebnnOpInfo> op_map = {
{"GlobalLpPool", {"l2Pool2d", false}},
{"Greater", {"greater", false}},
{"GreaterOrEqual", {"greaterOrEqual", false}},
{"GroupNormalization", {"meanVarianceNormalization", false}},
{"HardSigmoid", {"hardSigmoid", false}},
{"HardSwish", {"hardSwish", true}},
{"Identity", {"identity", false}},
{"InstanceNormalization", {"meanVarianceNormalization", false}},
{"LayerNormalization", {"meanVarianceNormalization", false}},
{"InstanceNormalization", {"instanceNormalization", false}},
{"LayerNormalization", {"layerNormalization", false}},
{"LeakyRelu", {"leakyRelu", true}},
{"Less", {"lesser", false}},
{"LessOrEqual", {"lesserOrEqual", false}},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@ class NormalizationOpBuilder : public BaseOpBuilder {
const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override;
};

// All normalization are based on layout NCHW.
// TODO: add support for NHWC.
Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
const Node& node,
const logging::Logger& logger) const {
Expand Down Expand Up @@ -61,49 +59,13 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(bias_shape == scale_shape, "The bias' shape should be equal to scale's shape.");
}

std::vector<uint32_t> new_scale_shape;
if (scale_size < rank) {
if (op_type == "BatchNormalization") {
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), rank - 2, 1);
} else if (op_type == "LayerNormalization") {
// Align right with leading ones.
scale_shape.insert(scale_shape.begin(), rank - scale_size, 1);
} else if (op_type == "InstanceNormalization") {
// Insert ones before and after the channel dimension.
scale_shape.insert(scale_shape.begin(), 1);
ORT_RETURN_IF(scale_size != 1 || rank < 2,
"The scale size should be 1 and rank should be at least 2 for InstanceNorm.");
scale_shape.insert(scale_shape.end(), rank - scale_size - 1, 1);
} else if (op_type == "GroupNormalization") {
// The input will be reshaped to 3D later. So just insert ones before the channel and after.
scale_shape.insert(scale_shape.begin(), 1);
scale_shape.insert(scale_shape.end(), 1);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
emscripten::val scale = model_builder.GetOperand(input_defs[1]->Name());
options.set("scale", scale);

std::transform(scale_shape.cbegin(), scale_shape.cend(),
std::back_inserter(new_scale_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
emscripten::val reshape_scale = model_builder.GetOperand(input_defs[1]->Name());
emscripten::val reshape_output_scale =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_scale, emscripten::val::array(new_scale_shape));
options.set("scale", reshape_output_scale);

if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val reshape_bias = model_builder.GetOperand(input_defs[2]->Name());
emscripten::val reshape_output_bias =
model_builder.GetBuilder().call<emscripten::val>("reshape", reshape_bias, emscripten::val::array(new_scale_shape));
options.set("bias", reshape_output_bias);
}
} else {
options.set("scale", model_builder.GetOperand(input_defs[1]->Name()));
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
options.set("bias", model_builder.GetOperand(input_defs[2]->Name()));
}
if (input_defs.size() >= 3 && !input_defs[2]->Name().empty()) {
// Bias input exists, and bias's shape is the same as scale's shape.
emscripten::val bias = model_builder.GetOperand(input_defs[2]->Name());
options.set("bias", bias);
}

NodeAttrHelper helper(node);
Expand All @@ -114,56 +76,26 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
ORT_RETURN_IF_NOT(input_defs.size() == 5, "BatchNormalization requires five inputs.");
emscripten::val mean = model_builder.GetOperand(input_defs[3]->Name());
emscripten::val variance = model_builder.GetOperand(input_defs[4]->Name());
// Enlarge 1-D mean and variance to new scale shape.
emscripten::val reshape_mean =
model_builder.GetBuilder().call<emscripten::val>("reshape", mean, emscripten::val::array(new_scale_shape));
emscripten::val reshape_variance =
model_builder.GetBuilder().call<emscripten::val>("reshape", variance, emscripten::val::array(new_scale_shape));

std::vector<uint32_t> axes = {0};
for (uint32_t i = 2; i < rank; i++) {
axes.push_back(i);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("axis", rank - 1);
}

options.set("axes", emscripten::val::array(axes));
options.set("mean", reshape_mean);
options.set("variance", reshape_variance);
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("batchNormalization", input, mean, variance, options);
} else if (op_type == "LayerNormalization") {
int64_t axis = helper.Get("axis", -1);
axis = HandleNegativeAxis(axis, rank);
std::vector<uint32_t> axes(rank - SafeInt<uint32_t>(axis));
std::iota(axes.begin(), axes.end(), axis);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC && axis > 1) {
std::iota(axes.begin(), axes.end(), axis - 1);
fdwr marked this conversation as resolved.
Show resolved Hide resolved
} else {
std::iota(axes.begin(), axes.end(), axis);
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else if (op_type == "InstanceNormalization") {
std::vector<uint32_t> axes;
for (uint32_t i = 2; i < rank; i++) {
axes.emplace_back(i);
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", input, options);
} else if (op_type == "GroupNormalization") {
ORT_RETURN_IF_NOT(helper.HasAttr("num_groups"), "GroupNormalization num_group must be provided.");
int32_t group_count = helper.Get("num_groups", -1);
std::vector<uint32_t> orig_shape, new_shape;
std::transform(input_shape.cbegin(), input_shape.cend(),
std::back_inserter(orig_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
// Add N and Group.
ORT_RETURN_IF_NOT(rank >= 2, "Input for GroupNormalization cannot be a scalar or 1D");
new_shape.emplace_back(SafeInt<uint32_t>(input_shape[0]));
new_shape.emplace_back(SafeInt<uint32_t>(group_count));

ORT_RETURN_IF_NOT(group_count > 0 && input_shape[1] % group_count == 0,
"GroupNormalization num_group must be divisible by group.");
new_shape.emplace_back(SafeInt<uint32_t>(std::reduce(input_shape.begin() + 2, input_shape.end(),
input_shape[1] / group_count, std::multiplies<int64_t>())));
// Input will be reshaped to (N, group count, channels per group x D1 x D2 ... Dn) and recovered after normalization.
options.set("axes", emscripten::val::array(std::vector<uint32_t>{2}));
output = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
output = model_builder.GetBuilder().call<emscripten::val>("meanVarianceNormalization", output, options);
output = model_builder.GetBuilder().call<emscripten::val>("reshape", output, emscripten::val::array(orig_shape));
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
Expand Down Expand Up @@ -204,6 +136,18 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
return false;
}

if (op_type == "InstanceNormalization") {
std::vector<int64_t> input_shape;
if (!GetShape(*input_defs[0], input_shape, logger)) {
LOGS(logger, VERBOSE) << "Cannot get input shape";
return false;
}
const auto rank = input_shape.size();
if (rank != 4) {
LOGS(logger, VERBOSE) << "InstanceNormalization only supports 4D input.";
return false;
}
}
return true;
}

Expand All @@ -214,7 +158,6 @@ void CreateNormalizationOpBuilder(const std::string& op_type, OpBuilderRegistrat
constexpr static std::string_view op_types[] =
{
"BatchNormalization",
"GroupNormalization",
fdwr marked this conversation as resolved.
Show resolved Hide resolved
"InstanceNormalization",
"LayerNormalization",
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ static OpBuilderRegistrations CreateOpBuilderRegistrations() {

{ // Normalization
CreateNormalizationOpBuilder("BatchNormalization", op_registrations);
CreateNormalizationOpBuilder("GroupNormalization", op_registrations);
CreateNormalizationOpBuilder("InstanceNormalization", op_registrations);
CreateNormalizationOpBuilder("LayerNormalization", op_registrations);
}
Expand Down
Loading