Skip to content

Commit

Permalink
Support 3D input for instanceNormalization
Browse files Browse the repository at this point in the history
  • Loading branch information
zesongw committed Dec 26, 2023
1 parent d89cecc commit d0e0d3a
Showing 1 changed file with 26 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,34 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder
options.set("axes", emscripten::val::array(axes));
output = model_builder.GetBuilder().call<emscripten::val>("layerNormalization", input, options);
} else if (op_type == "InstanceNormalization") {
// WebNN spec only supports 4D input for instanceNormalization.
// Supports 3D input by prepanding 1 size dimension.
if (input_shape.size() == 3) {
std::vector<uint32_t> new_shape;
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(new_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
new_shape.insert(new_shape.begin() + 2, 1);
} else {
new_shape.emplace_back(1);
}
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input, emscripten::val::array(new_shape));
}

if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
options.set("layout", emscripten::val("nhwc"));
}
output = model_builder.GetBuilder().call<emscripten::val>("instanceNormalization", input, options);
// Reshape back to the original output shape for 3D input.
if (input_shape.size() == 3) {
std::vector<uint32_t> output_shape;
std::transform(input_shape.begin(), input_shape.end(),
std::back_inserter(output_shape),
[](int64_t dim) -> uint32_t { return SafeInt<uint32_t>(dim); });
output = model_builder.GetBuilder().call<emscripten::val>(
"reshape", output, emscripten::val::array(output_shape));
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type);
}
Expand Down Expand Up @@ -143,8 +167,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi
return false;
}
const auto rank = input_shape.size();
if (rank != 4) {
LOGS(logger, VERBOSE) << "InstanceNormalization only supports 4D input.";
if (rank > 4) {
LOGS(logger, VERBOSE) << "InstanceNormalization only supports up to 4D input.";
return false;
}
}
Expand Down

0 comments on commit d0e0d3a

Please sign in to comment.