diff --git a/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..d0378c0a34b8 --- /dev/null +++ b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js @@ -0,0 +1,91 @@ + + + + + + + + +'use strict'; + +const getScatterElementsPrecisionTolerance = () => { + return {metricType: 'ULP', value: 0}; +}; + +const scatterElementsTests = [ + { + 'name': 'Scatter elements along axis 0', + 'graph': { + 'inputs': { + 'input': { + 'data': [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 0, 2, 0, 2, 1], + 'descriptor': {shape: [2, 3], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.0, 1.1, 1.2, 2.0, 2.1, 2.2], + 'descriptor': {shape: [2, 3], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 0}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [2.0, 1.1, 0.0, 1.0, 0.0, 2.2, 0.0, 2.1, 1.2], + 'descriptor': {shape: [3, 3], dataType: 'float32'} + } + } + } + }, + { + 'name': 'Scatter elements along axis 1', + 'graph': { + 'inputs': { + 'input': { + 'data': [1.0, 2.0, 3.0, 4.0, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + }, + 'indices': { + 'data': [1, 3], + 'descriptor': {shape: [1, 2], dataType: 'int32'}, + }, + 'updates': { + 'data': [1.1, 2.1], + 'descriptor': {shape: [1, 2], dataType: 'float32'} + } + }, + 'operators': [{ + 'name': 'scatterElements', + 'arguments': [ + {'input': 'input'}, {'indices': 'indices'}, {'updates': 'updates'}, + {'options': {'axis': 1}} + ], + 'outputs': 'output' + }], + 'expectedOutputs': { + 'output': { + 'data': [1.0, 1.1, 3.0, 2.1, 5.0], + 'descriptor': {shape: [1, 5], dataType: 'float32'} + } + } + } + } +]; + +if (navigator.ml) { + scatterElementsTests.forEach((test) => { + webnn_conformance_test( + buildGraphAndCompute, getScatterElementsPrecisionTolerance, test); + }); +} else { + test(() => assert_implements(navigator.ml, 'missing navigator.ml')); +} diff --git a/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..863ecad323b3 --- /dev/null +++ b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js @@ -0,0 +1,150 @@ + + + + + + + +'use strict'; + +const tests = [ + { + name: '[scatterElements] Test scatterElements with default options', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 0, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Test scatterElements with axis = 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [3, 2]}, + updates: {dataType: 'float32', shape: [3, 2]}, + axis: 1, + output: {dataType: 'float32', shape: [3, 3]} + }, + { + name: '[scatterElements] Throw if axis is greater than input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3]}, + axis: 2 + }, + { + name: + '[scatterElements] Throw if updates tensor data type is not the same as input data type', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float16', shape: [2, 3]}, + }, + { + name: '[scatterElements] Throw if input, indices and updates are scalar', + input: {dataType: 'float32', shape: []}, + indices: {dataType: 'int32', shape: []}, + updates: {dataType: 'float32', shape: []}, + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as input rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 1', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 4]}, + updates: {dataType: 'float32', shape: [2, 4]}, + axis: 0 + }, + { + name: + '[scatterElements] Throw if indices size is not the same as input size along axis 0', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 2]}, + updates: {dataType: 'float32', shape: [2, 2]}, + axis: 1 + }, + { + name: + '[scatterElements] Throw if indices rank is not the same as updates rank', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 3, 3]}, + }, + { + name: + '[scatterElements] Throw if indices shape is not the same as updates shape', + input: {dataType: 'float32', shape: [3, 3]}, + indices: {dataType: 'int32', shape: [2, 3]}, + updates: {dataType: 'float32', shape: [2, 4]}, + } +]; + +tests.forEach( + test => promise_test(async t => { + const builder = new MLGraphBuilder(context); + const input = builder.input('input', test.input); + const indices = builder.input('indices', test.indices); + const updates = builder.input('updates', test.updates); + + const options = {}; + if (test.axis) { + options.axis = test.axis; + } + + if (test.output) { + const output = + builder.scatterElements(input, indices, updates, options); + assert_equals(output.dataType(), test.output.dataType); + assert_array_equals(output.shape(), test.output.shape); + } else { + const label = 'a_scatter_elements' + options.label = label; + const regexp = new RegExp('\\[' + label + '\\]'); + assert_throws_with_label( + () => builder.scatterElements(input, indices, updates, options), + regexp); + } + }, test.name)); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = + otherBuilder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if input is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = + otherBuilder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + builder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if indices is from another builder'); + +multi_builder_test(async (t, builder, otherBuilder) => { + const input = builder.input('input', {dataType: 'float32', shape: [3, 3]}); + const indices = builder.input('indices', {dataType: 'int32', shape: [2, 3]}); + const updates = + otherBuilder.input('updates', {dataType: 'float32', shape: [2, 3]}); + + assert_throws_js( + TypeError, () => builder.scatterElements(input, indices, updates)); +}, '[scatterElements] Throw if updates is from another builder');