From 90090b01af6a97ac2bf1aa4af9161659aa82d712 Mon Sep 17 00:00:00 2001 From: Marco Castelluccio Date: Tue, 22 Oct 2024 21:45:45 +0000 Subject: [PATCH] Bug 1924561 [wpt PR 48602] - WebNN: Implement `scatterElements` operator in DirectML backend, a=testonly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Automatic update from web-platform-tests WebNN: Implement `scatterElements` operator in DirectML backend The `scatterElements` operator is proposed by WebML WG [1] for supporting popular transformer-based models. This CL adds the IDL and mojo definitions of scatterElements, and implements it in the DirectML backend by mapping to `DML_OPERATOR_SCATTER` [2]. This CL also adds the `scatterElements` validation and conformance tests into WPT. [1]: https://github.com/webmachinelearning/webnn/issues/375#issuecomment-2292466613 [2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc Bug: 370536101,370538328 Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2 Cq-Include-Trybots: luci.chromium.try​:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136 Reviewed-by: Alex Gough Reviewed-by: Weizhong Xia Auto-Submit: ningxin hu Commit-Queue: ningxin hu Commit-Queue: Weizhong Xia Reviewed-by: Rafael Cintron Reviewed-by: Austin Sullivan Cr-Commit-Position: refs/heads/main{#1368312} -- wpt-commits: 59751055ef506c581da667594a2da8dac0c599b3 wpt-pr: 48602 UltraBlame original commit: 442931e42d208661100d2921a200317ef2e5e110 --- .../scatterElements.https.any.js | 91 +++++++++++ .../scatterElements.https.any.js | 150 ++++++++++++++++++ 2 files changed, 241 insertions(+) create mode 100644 testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js create mode 100644 testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js diff --git a/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..d0378c0a34b8 --- /dev/null +++ b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js @@ -0,0 +1,91 @@ + + + + + + + + +'use strict'; + +const getScatterElementsPrecisionTolerance = () => { + return {metricType: 'ULP', value: 0}; +}; + +const scatterElementsTests = [ + { + 'name': 'Scatter elements along axis 0', + 'graph': { + 'inputs': { + 'input': { + 'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 0, 2, 0, 2, 1], + 'descriptor': {shape: [2, 3], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], + 'descriptor': {shape: [2, 3], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 0}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + } + } + } + }, + { + 'name': 'Scatter elements along axis 1', + 'graph': { + 'inputs': { + 'input': { + 'data': [1.0, 2.0, 3.0, 4.0, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 3], + 'descriptor': {shape: [1, 2], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.1, 2.1], + 'descriptor': {shape: [1, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 1}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [1.0, 1.1, 3.0, 2.1, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + } + } + } + } +]; + +if (navigator.ml) { + scatterElementsTests.forEach((test) => { + webnn_conformance_test( + buildGraphAndCompute, getScatterElementsPrecisionTolerance, test); + }); +} else { + test(() => assert_implements(navigator.ml, 'missing navigator.ml')); +} diff --git a/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..863ecad323b3 --- /dev/null +++ b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js @@ -0,0 +1,150 @@ + + + + + + + +'use strict'; + +const tests = [ + { + name: '[scatterElements] Test scatterElements with default options', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 0, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [3, 2]}, + updates: {dataType: 'float32', shape: [3, 2]}, + axis: 1, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Throw if axis is greater than input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 2 + }, + { + name: + '[scatterElements] Throw if updates tensor data type is not the same as input data type', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float16', shape: [2, 3]}, + }, + { + name: '[scatterElements] Throw if input, indices and updates are scalar', + input: {dataType: 'float32', shape: []}, + indices: {dataType: 'int32', shape: []}, + updates: {dataType: 'float32', shape: []}, + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 4]}, + updates: {dataType: 'float32', shape: [2, 4]}, + axis: 0 + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 2]}, + updates: {dataType: 'float32', shape: [2, 2]}, + axis: 1 + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as updates rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices shape is not the same as updates shape', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 4]}, + } +]; + +tests.forEach( + test => promise_test(async t => { + const builder = new MLGraphBuilder(context); + const input = builder.input('input', test.input); + const indices = builder.input('indices', test.indices); + const updates = builder.input('updates', test.updates); + + const options = {}; + if (test.axis) { + options.axis = test.axis; + } + + if (test.output) { + const output = + builder.scatterElements(input, indices, updates, options); + assert_equals(output.dataType(), test.output.dataType); + assert_array_equals(output.shape(), test.output.shape); + } else { + const label = 'a_scatter_elements' + options.label = label; + const regexp = new RegExp('\\[' + label + '\\]'); + assert_throws_with_label( + () => builder.scatterElements(input, indices, updates, options), + regexp); + } + }, test.name)); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = + otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if input is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = + otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if indices is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if updates is from another builder');