From 1b42e1266a229e46f11f09b941dd44aeddf37f4a Mon Sep 17 00:00:00 2001 From: Wanming Lin Date: Fri, 19 Jul 2024 11:06:53 +0800 Subject: [PATCH] [WebNN EP] Update argMax/argMin to adapt to latest spec WebNN spec recently changes the definition of argMax/argMin: - Remove selectLastIndex option, let backends decide the selected the index. - Move axes option to axis input --- js/web/docs/webnn-operators.md | 4 ++-- .../builders/impl/argmax_min_op_builder.cc | 23 +++---------------- 2 files changed, 5 insertions(+), 22 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 8d077846fa6a4..75652899b5e5e 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -13,8 +13,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim |:------:|:------:|:------:|:-:|:-:|:------| | Abs | ai.onnx(7-12, 13+) | abs | ✓ | ✓ | | | Add | ai.onnx(7-12, 13, 14+) | add | ✓ | ✓ | | -| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | -| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | WebNN CPU backend only supports 'select_last_index' value is 0 | +| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax | ✓ | ✓ | | +| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin | ✓ | ✓ | | | AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 | | BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization | ✓ | ✓ | Only supports 'training_mode' value is 0, one output | | Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast | ✓ | ✓ | WebNN CPU backend doesn't support casting to uint64 data type | diff --git a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc index 1330a3e354871..1ae63a644a287 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/argmax_min_op_builder.cc @@ -40,28 +40,20 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, NodeAttrHelper helper(node); int64_t axis = helper.Get("axis", 0); const auto keep_dims = helper.Get("keepdims", 1); - const auto select_last_index = helper.Get("select_last_index", 0); axis = HandleNegativeAxis(axis, input_rank); - emscripten::val axes = emscripten::val::array(); - axes.call("push", static_cast(axis)); emscripten::val options = emscripten::val::object(); - options.set("axes", axes); options.set("keepDimensions", keep_dims == 1); - options.set("selectLastIndex", select_last_index == 1); - // TODO: use WebNN's opSupportLimits API to check the backend's supported output data types. - // If the backend doesn't support int64 output, we should use default int32 output data type - // then do a type casting (int32 -> int64) for the output. Refer to the CoreML EP for how to - // support int64 output. + // TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API. options.set("outputDataType", "int64"); emscripten::val output = emscripten::val::object(); const auto& op_type = node.OpType(); if (op_type == "ArgMax") { - output = model_builder.GetBuilder().call("argMax", input, options); + output = model_builder.GetBuilder().call("argMax", input, narrow(axis), options); } else if (op_type == "ArgMin") { - output = model_builder.GetBuilder().call("argMin", input, options); + output = model_builder.GetBuilder().call("argMin", input, narrow(axis), options); } else { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ArgMaxMinOpBuilder, unknown op: ", op_type); } @@ -81,15 +73,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia if (!GetShape(*input_defs[0], input_shape, logger)) return false; - // WebNN CPU backend only supports select_last_index = 0. - if (device_type == WebnnDeviceType::CPU) { - NodeAttrHelper helper(node); - const auto select_last_index = helper.Get("select_last_index", 0); - if (select_last_index) { - LOGS(logger, VERBOSE) << "ArgMax/ArgMin with select_last_index = 1 is not supported on WebNN CPU backend."; - return false; - } - } return true; }