From d0e0d3a905e7708fc5960e19e5e839eedd2536c3 Mon Sep 17 00:00:00 2001 From: zesongw Date: Tue, 26 Dec 2023 11:24:54 +0800 Subject: [PATCH] Support 3D input for instanceNormalization --- .../builders/impl/normalization_op_builder.cc | 28 +++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 9e0dc94d777d2..763c9011f2659 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -92,10 +92,34 @@ Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder options.set("axes", emscripten::val::array(axes)); output = model_builder.GetBuilder().call("layerNormalization", input, options); } else if (op_type == "InstanceNormalization") { + // WebNN spec only supports 4D input for instanceNormalization. + // Supports 3D input by prepanding 1 size dimension. + if (input_shape.size() == 3) { + std::vector new_shape; + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { + new_shape.insert(new_shape.begin() + 2, 1); + } else { + new_shape.emplace_back(1); + } + input = model_builder.GetBuilder().call("reshape", input, emscripten::val::array(new_shape)); + } + if (model_builder.GetPreferredLayout() == DataLayout::NHWC) { options.set("layout", emscripten::val("nhwc")); } output = model_builder.GetBuilder().call("instanceNormalization", input, options); + // Reshape back to the original output shape for 3D input. + if (input_shape.size() == 3) { + std::vector output_shape; + std::transform(input_shape.begin(), input_shape.end(), + std::back_inserter(output_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + output = model_builder.GetBuilder().call( + "reshape", output, emscripten::val::array(output_shape)); + } } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Unsupported normalization op: ", op_type); } @@ -143,8 +167,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return false; } const auto rank = input_shape.size(); - if (rank != 4) { - LOGS(logger, VERBOSE) << "InstanceNormalization only supports 4D input."; + if (rank > 4) { + LOGS(logger, VERBOSE) << "InstanceNormalization only supports up to 4D input."; return false; } }