From cd60af03659e65fca05fbcf66e2444c1aaa036b8 Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 25 Oct 2024 00:35:53 +0800 Subject: [PATCH] [WebNN EP] Allow 0D input/output for Reshape and Expand (#22344) - Allows Expand input be a scalar - Allows Reshape input be a scalar - Allows Reshape to a scalar Fixed #22215 --------- Co-authored-by: Dwayne Robinson --- .../webnn/builders/impl/expand_op_builder.cc | 5 -- .../webnn/builders/impl/reshape_op_builder.cc | 52 +++++++++---------- 2 files changed, 24 insertions(+), 33 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc index c8cea833983b1..5e99551fe6e7d 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc @@ -95,11 +95,6 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return false; } - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Expand does not support empty input's shape."; - return false; - } - std::vector output_shape; if (!GetBidirectionalBroadcastShape(input_shape, new_shape, output_shape)) { LOGS(logger, VERBOSE) << "The input cannot expand to shape " << GetShapeString(new_shape); diff --git a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc index a7911683f0355..0a438e98ad737 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc @@ -44,21 +44,25 @@ Status ReshapeOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const auto& input_defs = node.InputDefs(); const auto& initializers(model_builder.GetInitializerTensors()); const auto& target_shape_tensor = *initializers.at(input_defs[1]->Name()); - const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() - ? reinterpret_cast(target_shape_tensor.raw_data().data()) - : target_shape_tensor.int64_data().data(); + const auto& target_shape_tensor_dims = target_shape_tensor.dims(); + std::vector new_shape; + // Do nothing if target shape is an empty shape, which means converting to a scalar. + if (!target_shape_tensor_dims.empty()) { + const int64_t* raw_target_shape = target_shape_tensor.int64_data().empty() + ? reinterpret_cast(target_shape_tensor.raw_data().data()) + : target_shape_tensor.int64_data().data(); + + const auto size = target_shape_tensor_dims[0]; + TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; + std::vector input_shape; + ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); + ReshapeHelper helper(TensorShape(input_shape), target_shape); + std::transform(target_shape.cbegin(), target_shape.cend(), + std::back_inserter(new_shape), + [](int64_t dim) -> uint32_t { return SafeInt(dim); }); + } - const auto size = target_shape_tensor.dims()[0]; - TensorShapeVector target_shape{raw_target_shape, raw_target_shape + size}; - std::vector input_shape; - ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get shape"); - ReshapeHelper helper(TensorShape(input_shape), target_shape); emscripten::val input = model_builder.GetOperand(input_defs[0]->Name()); - std::vector new_shape; - std::transform(target_shape.cbegin(), target_shape.cend(), - std::back_inserter(new_shape), - [](int64_t dim) -> uint32_t { return SafeInt(dim); }); - emscripten::val options = emscripten::val::object(); options.set("label", node.Name()); emscripten::val output = model_builder.GetBuilder().call("reshape", @@ -76,6 +80,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer const WebnnDeviceType /* device_type */, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + + std::vector input_shape; + if (!GetShape(*input_defs[0], input_shape, logger)) + return false; + const auto& perm_name = input_defs[1]->Name(); if (!Contains(initializers, perm_name)) { LOGS(logger, VERBOSE) << "New shape of reshape must be a constant initializer"; @@ -92,24 +101,11 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializer const int64_t* raw_new_shape = reinterpret_cast(unpacked_tensor.data()); const auto& perm_dims = perm_tensor.dims(); - if (perm_dims.empty() || perm_dims[0] == 0) { - LOGS(logger, VERBOSE) << "New shape of reshape cannot be empty"; - return false; - } - - std::vector input_shape; - if (!GetShape(*input_defs[0], input_shape, logger)) - return false; - - if (input_shape.empty()) { - LOGS(logger, VERBOSE) << "Reshape does not support empty input shape"; - return false; - } // WebNN reshape does not support 0 as dimension. NodeAttrHelper helper(node); - const bool allow_zero = helper.Get("allowzero ", 0) == 1; - if (allow_zero) { + const bool allow_zero = helper.Get("allowzero", 0) == 1; + if (allow_zero && !perm_dims.empty()) { for (int64_t i = 0; i < perm_dims[0]; i++) { if (raw_new_shape[i] == 0) { LOGS_DEFAULT(VERBOSE) << "Reshape doesn't support 0 reshape dimension when allowzero is enabled";