Skip to content

Commit

Permalink
[WebNN EP] Update argMax/argMin to adapt to latest spec
Browse files Browse the repository at this point in the history
WebNN spec recently changes the definition of argMax/argMin:
- Remove selectLastIndex option, let backends decide the selected the index.
- Move axes option to axis input
  • Loading branch information
Honry committed Jul 25, 2024
1 parent ae3ec2e commit 1b42e12
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 22 deletions.
4 changes: 2 additions & 2 deletions js/web/docs/webnn-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim
|:------:|:------:|:------:|:-:|:-:|:------|
| Abs | ai.onnx(7-12, 13+) | abs ||| |
| Add | ai.onnx(7-12, 13, 14+) | add ||| |
| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax ||| WebNN CPU backend only supports 'select_last_index' value is 0 |
| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin ||| WebNN CPU backend only supports 'select_last_index' value is 0 |
| ArgMax | ai.onnx(7-10, 11, 12, 13+) | argMax ||| |
| ArgMin | ai.onnx(7-10, 11, 12, 13+) | argMin ||| |
| AveragePool | ai.onnx(7-9, 10, 11, 12-18, 19+) | averagePool2d ||| Only supports 4-D input, 2-D 'kernel_shape', 'count_include_pad' value is 0 |
| BatchNormalization | ai.onnx(7-8, 9-13, 14, 15+) | batchNormalization ||| Only supports 'training_mode' value is 0, one output |
| Cast | ai.onnx(7-8, 9-12, 13-18, 19-20, 21+) | cast ||| WebNN CPU backend doesn't support casting to uint64 data type |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,28 +40,20 @@ Status ArgMaxMinOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
NodeAttrHelper helper(node);
int64_t axis = helper.Get("axis", 0);
const auto keep_dims = helper.Get("keepdims", 1);
const auto select_last_index = helper.Get("select_last_index", 0);

axis = HandleNegativeAxis(axis, input_rank);
emscripten::val axes = emscripten::val::array();
axes.call<void>("push", static_cast<uint32_t>(axis));

emscripten::val options = emscripten::val::object();
options.set("axes", axes);
options.set("keepDimensions", keep_dims == 1);
options.set("selectLastIndex", select_last_index == 1);
// TODO: use WebNN's opSupportLimits API to check the backend's supported output data types.
// If the backend doesn't support int64 output, we should use default int32 output data type
// then do a type casting (int32 -> int64) for the output. Refer to the CoreML EP for how to
// support int64 output.
// TODO(Honry): check whether int64 output data type is supported by WebNN opSupportLimits() API.
options.set("outputDataType", "int64");
emscripten::val output = emscripten::val::object();

const auto& op_type = node.OpType();
if (op_type == "ArgMax") {
output = model_builder.GetBuilder().call<emscripten::val>("argMax", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("argMax", input, narrow<uint32_t>(axis), options);
} else if (op_type == "ArgMin") {
output = model_builder.GetBuilder().call<emscripten::val>("argMin", input, options);
output = model_builder.GetBuilder().call<emscripten::val>("argMin", input, narrow<uint32_t>(axis), options);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "ArgMaxMinOpBuilder, unknown op: ", op_type);
}
Expand All @@ -81,15 +73,6 @@ bool ArgMaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia
if (!GetShape(*input_defs[0], input_shape, logger))
return false;

// WebNN CPU backend only supports select_last_index = 0.
if (device_type == WebnnDeviceType::CPU) {
NodeAttrHelper helper(node);
const auto select_last_index = helper.Get("select_last_index", 0);
if (select_last_index) {
LOGS(logger, VERBOSE) << "ArgMax/ArgMin with select_last_index = 1 is not supported on WebNN CPU backend.";
return false;
}
}
return true;
}

Expand Down

0 comments on commit 1b42e12

Please sign in to comment.