From 6ab9a292c2fceff845ac30c39c9b5e3c042033f9 Mon Sep 17 00:00:00 2001 From: Marco Castelluccio Date: Tue, 22 Oct 2024 21:46:45 +0000 Subject: [PATCH] Bug 1924561 [wpt PR 48602] - WebNN: Implement `scatterElements` operator in DirectML backend, a=testonly MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Automatic update from web-platform-tests WebNN: Implement `scatterElements` operator in DirectML backend The `scatterElements` operator is proposed by WebML WG [1] for supporting popular transformer-based models. This CL adds the IDL and mojo definitions of scatterElements, and implements it in the DirectML backend by mapping to `DML_OPERATOR_SCATTER` [2]. This CL also adds the `scatterElements` validation and conformance tests into WPT. [1]: https://github.com/webmachinelearning/webnn/issues/375#issuecomment-2292466613 [2]: https://learn.microsoft.com/en-us/windows/win32/api/directml/ns-directml-dml_scatter_operator_desc Bug: 370536101,370538328 Change-Id: Ifb73bed5eb05cb919b106b4aaea5127ec099edb2 Cq-Include-Trybots: luci.chromium.try​:win11-blink-rel, mac14.arm64-blink-rel, mac14-blink-rel, mac15.arm64-blink-rel, mac15-blink-rel, linux-blink-rel Reviewed-on: https://chromium-review.googlesource.com/c/chromium/src/+/5921136 Reviewed-by: Alex Gough Reviewed-by: Weizhong Xia Auto-Submit: ningxin hu Commit-Queue: ningxin hu Commit-Queue: Weizhong Xia Reviewed-by: Rafael Cintron Reviewed-by: Austin Sullivan Cr-Commit-Position: refs/heads/main{#1368312} -- wpt-commits: 59751055ef506c581da667594a2da8dac0c599b3 wpt-pr: 48602 UltraBlame original commit: 442931e42d208661100d2921a200317ef2e5e110 --- .../scatterElements.https.any.js | 630 ++++++++ .../scatterElements.https.any.js | 1324 +++++++++++++++++ 2 files changed, 1954 insertions(+) create mode 100644 testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js create mode 100644 testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js diff --git a/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..9ece39471162 --- /dev/null +++ b/testing/web-platform/tests/webnn/conformance_tests/scatterElements.https.any.js @@ -0,0 +1,630 @@ +' +use +strict +' +; +const +getScatterElementsPrecisionTolerance += +( +) += +> +{ +return +{ +metricType +: +' +ULP +' +value +: +0 +} +; +} +; +const +scatterElementsTests += +[ +{ +' +name +' +: +' +Scatter +elements +along +axis +0 +' +' +graph +' +: +{ +' +inputs +' +: +{ +' +input +' +: +{ +' +data +' +: +[ +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +0 +. +0 +] +' +descriptor +' +: +{ +shape +: +[ +3 +3 +] +dataType +: +' +float32 +' +} +} +' +indices +' +: +{ +' +data +' +: +[ +1 +0 +2 +0 +2 +1 +] +' +descriptor +' +: +{ +shape +: +[ +2 +3 +] +dataType +: +' +int32 +' +} +} +' +updates +' +: +{ +' +data +' +: +[ +1 +. +0 +1 +. +1 +1 +. +2 +2 +. +0 +2 +. +1 +2 +. +2 +] +' +descriptor +' +: +{ +shape +: +[ +2 +3 +] +dataType +: +' +float32 +' +} +} +} +' +operators +' +: +[ +{ +' +name +' +: +' +scatterElements +' +' +arguments +' +: +[ +{ +' +input +' +: +' +input +' +} +{ +' +indices +' +: +' +indices +' +} +{ +' +updates +' +: +' +updates +' +} +{ +' +options +' +: +{ +' +axis +' +: +0 +} +} +] +' +outputs +' +: +' +output +' +} +] +' +expectedOutputs +' +: +{ +' +output +' +: +{ +' +data +' +: +[ +2 +. +0 +1 +. +1 +0 +. +0 +1 +. +0 +0 +. +0 +2 +. +2 +0 +. +0 +2 +. +1 +1 +. +2 +] +' +descriptor +' +: +{ +shape +: +[ +3 +3 +] +dataType +: +' +float32 +' +} +} +} +} +} +{ +' +name +' +: +' +Scatter +elements +along +axis +1 +' +' +graph +' +: +{ +' +inputs +' +: +{ +' +input +' +: +{ +' +data +' +: +[ +1 +. +0 +2 +. +0 +3 +. +0 +4 +. +0 +5 +. +0 +] +' +descriptor +' +: +{ +shape +: +[ +1 +5 +] +dataType +: +' +float32 +' +} +} +' +indices +' +: +{ +' +data +' +: +[ +1 +3 +] +' +descriptor +' +: +{ +shape +: +[ +1 +2 +] +dataType +: +' +int32 +' +} +} +' +updates +' +: +{ +' +data +' +: +[ +1 +. +1 +2 +. +1 +] +' +descriptor +' +: +{ +shape +: +[ +1 +2 +] +dataType +: +' +float32 +' +} +} +} +' +operators +' +: +[ +{ +' +name +' +: +' +scatterElements +' +' +arguments +' +: +[ +{ +' +input +' +: +' +input +' +} +{ +' +indices +' +: +' +indices +' +} +{ +' +updates +' +: +' +updates +' +} +{ +' +options +' +: +{ +' +axis +' +: +1 +} +} +] +' +outputs +' +: +' +output +' +} +] +' +expectedOutputs +' +: +{ +' +output +' +: +{ +' +data +' +: +[ +1 +. +0 +1 +. +1 +3 +. +0 +2 +. +1 +5 +. +0 +] +' +descriptor +' +: +{ +shape +: +[ +1 +5 +] +dataType +: +' +float32 +' +} +} +} +} +} +] +; +if +( +navigator +. +ml +) +{ +scatterElementsTests +. +forEach +( +( +test +) += +> +{ +webnn_conformance_test +( +buildGraphAndCompute +getScatterElementsPrecisionTolerance +test +) +; +} +) +; +} +else +{ +test +( +( +) += +> +assert_implements +( +navigator +. +ml +' +missing +navigator +. +ml +' +) +) +; +} diff --git a/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js new file mode 100644 index 000000000000..6d5336bd4694 --- /dev/null +++ b/testing/web-platform/tests/webnn/validation_tests/scatterElements.https.any.js @@ -0,0 +1,1324 @@ +' +use +strict +' +; +const +tests += +[ +{ +name +: +' +[ +scatterElements +] +Test +scatterElements +with +default +options +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +output +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Test +scatterElements +with +axis += +0 +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +axis +: +0 +output +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Test +scatterElements +with +axis += +1 +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +3 +2 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +2 +] +} +axis +: +1 +output +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Throw +if +axis +is +greater +than +input +rank +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +axis +: +2 +} +{ +name +: +' +[ +scatterElements +] +Throw +if +updates +tensor +data +type +is +not +the +same +as +input +data +type +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float16 +' +shape +: +[ +2 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Throw +if +input +indices +and +updates +are +scalar +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +] +} +} +{ +name +: +' +[ +scatterElements +] +Throw +if +indices +rank +is +not +the +same +as +input +rank +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Throw +if +indices +size +is +not +the +same +as +input +size +along +axis +1 +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +4 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +4 +] +} +axis +: +0 +} +{ +name +: +' +[ +scatterElements +] +Throw +if +indices +size +is +not +the +same +as +input +size +along +axis +0 +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +2 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +2 +] +} +axis +: +1 +} +{ +name +: +' +[ +scatterElements +] +Throw +if +indices +rank +is +not +the +same +as +updates +rank +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +3 +] +} +} +{ +name +: +' +[ +scatterElements +] +Throw +if +indices +shape +is +not +the +same +as +updates +shape +' +input +: +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +indices +: +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +updates +: +{ +dataType +: +' +float32 +' +shape +: +[ +2 +4 +] +} +} +] +; +tests +. +forEach +( +test += +> +promise_test +( +async +t += +> +{ +const +builder += +new +MLGraphBuilder +( +context +) +; +const +input += +builder +. +input +( +' +input +' +test +. +input +) +; +const +indices += +builder +. +input +( +' +indices +' +test +. +indices +) +; +const +updates += +builder +. +input +( +' +updates +' +test +. +updates +) +; +const +options += +{ +} +; +if +( +test +. +axis +) +{ +options +. +axis += +test +. +axis +; +} +if +( +test +. +output +) +{ +const +output += +builder +. +scatterElements +( +input +indices +updates +options +) +; +assert_equals +( +output +. +dataType +( +) +test +. +output +. +dataType +) +; +assert_array_equals +( +output +. +shape +( +) +test +. +output +. +shape +) +; +} +else +{ +const +label += +' +a_scatter_elements +' +options +. +label += +label +; +const +regexp += +new +RegExp +( +' +\ +\ +[ +' ++ +label ++ +' +\ +\ +] +' +) +; +assert_throws_with_label +( +( +) += +> +builder +. +scatterElements +( +input +indices +updates +options +) +regexp +) +; +} +} +test +. +name +) +) +; +multi_builder_test +( +async +( +t +builder +otherBuilder +) += +> +{ +const +input += +otherBuilder +. +input +( +' +input +' +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +) +; +const +indices += +builder +. +input +( +' +indices +' +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +) +; +const +updates += +builder +. +input +( +' +updates +' +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +) +; +assert_throws_js +( +TypeError +( +) += +> +builder +. +scatterElements +( +input +indices +updates +) +) +; +} +' +[ +scatterElements +] +Throw +if +input +is +from +another +builder +' +) +; +multi_builder_test +( +async +( +t +builder +otherBuilder +) += +> +{ +const +input += +builder +. +input +( +' +input +' +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +) +; +const +indices += +otherBuilder +. +input +( +' +indices +' +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +) +; +const +updates += +builder +. +input +( +' +updates +' +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +) +; +assert_throws_js +( +TypeError +( +) += +> +builder +. +scatterElements +( +input +indices +updates +) +) +; +} +' +[ +scatterElements +] +Throw +if +indices +is +from +another +builder +' +) +; +multi_builder_test +( +async +( +t +builder +otherBuilder +) += +> +{ +const +input += +builder +. +input +( +' +input +' +{ +dataType +: +' +float32 +' +shape +: +[ +3 +3 +] +} +) +; +const +indices += +builder +. +input +( +' +indices +' +{ +dataType +: +' +int32 +' +shape +: +[ +2 +3 +] +} +) +; +const +updates += +otherBuilder +. +input +( +' +updates +' +{ +dataType +: +' +float32 +' +shape +: +[ +2 +3 +] +} +) +; +assert_throws_js +( +TypeError +( +) += +> +builder +. +scatterElements +( +input +indices +updates +) +) +; +} +' +[ +scatterElements +] +Throw +if +updates +is +from +another +builder +' +) +;