forked from kdashg/gecko-cinn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Bug 1924561 [wpt PR 48602] - WebNN: Implement
scatterElements
opera…
…tor in DirectML backend, a=testonly Automatic update from web-platform-tests WebNN: Implement `scatterElements` operator in DirectML backend The `scatterElements` operator is proposed by WebML WG [1] for supporting popular transformer-based models. This CL adds the IDL and mojo definitions of scatterElements, and implements it in the DirectML backend by mapping to `DML_OPERATOR_SCATTER` [2]. This CL also adds the `scatterElements` validation and conformance tests into WPT. [1]: webmachinelearning/webnn#375 (comment) [2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc Bug: 370536101,370538328 Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2 Cq-Include-Trybots: luci.chromium.try:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136 Reviewed-by: Alex Gough <[email protected]> Reviewed-by: Weizhong Xia <[email protected]> Auto-Submit: ningxin hu <[email protected]> Commit-Queue: ningxin hu <[email protected]> Commit-Queue: Weizhong Xia <[email protected]> Reviewed-by: Rafael Cintron <[email protected]> Reviewed-by: Austin Sullivan <[email protected]> Cr-Commit-Position: refs/heads/main@{#1368312} -- wpt-commits: 59751055ef506c581da667594a2da8dac0c599b3 wpt-pr: 48602
- Loading branch information
1 parent
6c58d0f
commit 587ee0b
Showing
2 changed files
with
241 additions
and
0 deletions.
There are no files selected for viewing
91 changes: 91 additions & 0 deletions
91
testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
// META: title=test WebNN API scatterElements operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils.js | ||
// META: timeout=long | ||
|
||
'use strict'; | ||
|
||
const getScatterElementsPrecisionTolerance = () => { | ||
return {metricType: 'ULP', value: 0}; | ||
}; | ||
|
||
const scatterElementsTests = [ | ||
{ | ||
'name': 'Scatter elements along axis 0', | ||
'graph': { | ||
'inputs': { | ||
'input': { | ||
'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], | ||
'descriptor': {shape: [3, 3], dataType: 'float32'} | ||
}, | ||
'indices': { | ||
'data': [1, 0, 2, 0, 2, 1], | ||
'descriptor': {shape: [2, 3], dataType: 'int32'}, | ||
}, | ||
'updates': { | ||
'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], | ||
'descriptor': {shape: [2, 3], dataType: 'float32'} | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'scatterElements', | ||
'arguments': [ | ||
{'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, | ||
{'options': {'axis': 0}} | ||
], | ||
'outputs': 'output' | ||
}], | ||
'expectedOutputs': { | ||
'output': { | ||
'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], | ||
'descriptor': {shape: [3, 3], dataType: 'float32'} | ||
} | ||
} | ||
} | ||
}, | ||
{ | ||
'name': 'Scatter elements along axis 1', | ||
'graph': { | ||
'inputs': { | ||
'input': { | ||
'data': [1.0, 2.0, 3.0, 4.0, 5.0], | ||
'descriptor': {shape: [1, 5], dataType: 'float32'} | ||
}, | ||
'indices': { | ||
'data': [1, 3], | ||
'descriptor': {shape: [1, 2], dataType: 'int32'}, | ||
}, | ||
'updates': { | ||
'data': [1.1, 2.1], | ||
'descriptor': {shape: [1, 2], dataType: 'float32'} | ||
} | ||
}, | ||
'operators': [{ | ||
'name': 'scatterElements', | ||
'arguments': [ | ||
{'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, | ||
{'options': {'axis': 1}} | ||
], | ||
'outputs': 'output' | ||
}], | ||
'expectedOutputs': { | ||
'output': { | ||
'data': [1.0, 1.1, 3.0, 2.1, 5.0], | ||
'descriptor': {shape: [1, 5], dataType: 'float32'} | ||
} | ||
} | ||
} | ||
} | ||
]; | ||
|
||
if (navigator.ml) { | ||
scatterElementsTests.forEach((test) => { | ||
webnn_conformance_test( | ||
buildGraphAndCompute, getScatterElementsPrecisionTolerance, test); | ||
}); | ||
} else { | ||
test(() => assert_implements(navigator.ml, 'missing navigator.ml')); | ||
} |
150 changes: 150 additions & 0 deletions
150
testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
// META: title=validation tests for WebNN API scatterElements operation | ||
// META: global=window,dedicatedworker | ||
// META: variant=?cpu | ||
// META: variant=?gpu | ||
// META: variant=?npu | ||
// META: script=../resources/utils_validation.js | ||
|
||
'use strict'; | ||
|
||
const tests = [ | ||
{ | ||
name: '[scatterElements] Test scatterElements with default options', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Test scatterElements with axis = 0', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
axis: 0, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Test scatterElements with axis = 1', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [3, 2]}, | ||
updates: {dataType: 'float32', shape: [3, 2]}, | ||
axis: 1, | ||
output: {dataType: 'float32', shape: [3, 3]} | ||
}, | ||
{ | ||
name: '[scatterElements] Throw if axis is greater than input rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3]}, | ||
axis: 2 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if updates tensor data type is not the same as input data type', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float16', shape: [2, 3]}, | ||
}, | ||
{ | ||
name: '[scatterElements] Throw if input, indices and updates are scalar', | ||
input: {dataType: 'float32', shape: []}, | ||
indices: {dataType: 'int32', shape: []}, | ||
updates: {dataType: 'float32', shape: []}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices rank is not the same as input rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3, 3]}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices size is not the same as input size along axis 1', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 4]}, | ||
updates: {dataType: 'float32', shape: [2, 4]}, | ||
axis: 0 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices size is not the same as input size along axis 0', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 2]}, | ||
updates: {dataType: 'float32', shape: [2, 2]}, | ||
axis: 1 | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices rank is not the same as updates rank', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 3, 3]}, | ||
}, | ||
{ | ||
name: | ||
'[scatterElements] Throw if indices shape is not the same as updates shape', | ||
input: {dataType: 'float32', shape: [3, 3]}, | ||
indices: {dataType: 'int32', shape: [2, 3]}, | ||
updates: {dataType: 'float32', shape: [2, 4]}, | ||
} | ||
]; | ||
|
||
tests.forEach( | ||
test => promise_test(async t => { | ||
const builder = new MLGraphBuilder(context); | ||
const input = builder.input('input', test.input); | ||
const indices = builder.input('indices', test.indices); | ||
const updates = builder.input('updates', test.updates); | ||
|
||
const options = {}; | ||
if (test.axis) { | ||
options.axis = test.axis; | ||
} | ||
|
||
if (test.output) { | ||
const output = | ||
builder.scatterElements(input, indices, updates, options); | ||
assert_equals(output.dataType(), test.output.dataType); | ||
assert_array_equals(output.shape(), test.output.shape); | ||
} else { | ||
const label = 'a_scatter_elements' | ||
options.label = label; | ||
const regexp = new RegExp('\\[' + label + '\\]'); | ||
assert_throws_with_label( | ||
() => builder.scatterElements(input, indices, updates, options), | ||
regexp); | ||
} | ||
}, test.name)); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = | ||
otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
builder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if input is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = | ||
otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
builder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if indices is from another builder'); | ||
|
||
multi_builder_test(async (t, builder, otherBuilder) => { | ||
const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); | ||
const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); | ||
const updates = | ||
otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); | ||
|
||
assert_throws_js( | ||
TypeError, () => builder.scatterElements(input, indices, updates)); | ||
}, '[scatterElements] Throw if updates is from another builder'); |