From 875de3bfa39dfabd47c52ec096dd094b0b6fa95b Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Tue, 3 Oct 2023 20:55:56 +0400 Subject: [PATCH] Fixed format, updated docs --- js/web/docs/webgpu-operators.md | 2 + js/web/script/generate-webgpu-operator-md.ts | 2 + js/web/test/data/ops/attention.jsonc | 168 ++++++++++--------- 3 files changed, 92 insertions(+), 80 deletions(-) diff --git a/js/web/docs/webgpu-operators.md b/js/web/docs/webgpu-operators.md index a87a894e3b3c5..d3ab9d4d6c5ff 100644 --- a/js/web/docs/webgpu-operators.md +++ b/js/web/docs/webgpu-operators.md @@ -20,6 +20,7 @@ Do not modify directly.* | Asinh | ai.onnx(9+) | | | Atan | ai.onnx(7+) | | | Atanh | ai.onnx(9+) | | +| Attention | com.microsoft(1+) | need implementing mask and past/present | | AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(11+) | need perf optimization; need implementing activation | | Cast | ai.onnx(6-8,9-12,13-18,19+) | | | Ceil | ai.onnx(6-12,13+) | | @@ -58,6 +59,7 @@ Do not modify directly.* | MemcpyFromHost | ai.onnx(1+) | | | MemcpyToHost | ai.onnx(1+) | | | Mul | ai.onnx(7-12,13,14+) | | +| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present | | Neg | ai.onnx(6-12,13+) | | | Not | ai.onnx(1+) | | | Pad | ai.onnx(2-10,11-12,13-17,18,19+) | | diff --git a/js/web/script/generate-webgpu-operator-md.ts b/js/web/script/generate-webgpu-operator-md.ts index 7408f17004f5e..eab8175a941bd 100644 --- a/js/web/script/generate-webgpu-operator-md.ts +++ b/js/web/script/generate-webgpu-operator-md.ts @@ -16,6 +16,8 @@ const COMMENTS: Record = { 'Reshape': 'no GPU kernel', 'Shape': 'no GPU kernel; an ORT warning is generated - need to fix', 'Resize': 'CoordinateTransformMode align_corners is not supported with downsampling', + 'Attention': 'need implementing mask and past/present', + 'MultiHeadAttention': 'need implementing mask and past/present', }; /* eslint-disable max-len */ diff --git a/js/web/test/data/ops/attention.jsonc b/js/web/test/data/ops/attention.jsonc index 2e09c173e4cf2..bd4483027cc25 100644 --- a/js/web/test/data/ops/attention.jsonc +++ b/js/web/test/data/ops/attention.jsonc @@ -466,84 +466,92 @@ ] }, { - "name": "Attention Basic 1 head, batch 3", - "operator": "Attention", - "opset": { "domain": "com.microsoft", "version": 1 }, - "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], - "cases": [ - { - "name": "T[0]", - "inputs": [ - { - "data": [ - 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, - 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, - -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, - -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, - 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, - -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, - -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, - 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, - 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987 - ], - "dims": [3, 3, 10], - "type": "float32" - }, - { - "data": [ - 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, - 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, - -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, - 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, - -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, - -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, - 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, - 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, - 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, - 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, - 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, - -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, - -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, - 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, - 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, - -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, - -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, - -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, - -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, - -1.8803634643554688, 2.1661579608917236 - ], - "dims": [10, 15], - "type": "float32" - }, - { - "data": [ - -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, - -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, - -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 - ], - "dims": [15], - "type": "float32" - } - ], - "outputs": [ - { - "data": [ - -8.011263847351074,-5.7822418212890625,6.016238689422607,0.26747000217437744,-6.992536544799805,-8.011263847351074,-5.7822418212890625,6.016238689422607,0.26747000217437744,-6.992536544799805,-8.011263847351074,-5.7822418212890625,6.016238689422607,0.26747000217437744,-6.992536544799805,1.3541864156723022,-7.813620090484619,-6.758509635925293,7.597365856170654,-13.926229476928711,-1.322464108467102,-7.297357559204102,-0.05962071940302849,6.347561836242676,-5.869992256164551,-1.3616288900375366,-7.28973388671875,0.0386197566986084,6.329389572143555,-5.751791954040527,-2.400698661804199,-14.538958549499512,-7.898950576782227,12.744998931884766,-11.1966552734375,-2.400698661804199,-14.538958549499512,-7.898950576782227,12.744998931884766,-11.1966552734375,1.021930456161499,-2.373898983001709,3.8501391410827637,-0.6108309626579285,-9.256340980529785 - ], - "dims": [3, 3, 5], - "type": "float32" - } - ] - } - ] - } + "name": "Attention Basic 1 head, batch 3", + "operator": "Attention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + { + "data": [ + 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, + -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 0.3367, 0.1288, 0.2345, 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, + 0.5349, 0.8094, 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.3367, 0.1288, 0.2345, + 0.2303, -1.1229, -0.1863, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987 + ], + "dims": [3, 3, 10], + "type": "float32" + }, + { + "data": [ + 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, 3.6643, + 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, 0.9005484580993652, + -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, 2.1661579608917236, + 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, -1.486573576927185, + -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, -0.26380985975265503, + -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, -1.9054111242294312, + 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, 0.2303, 0.4617, 1.44, -2.22, + 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, 0.8701779842376709, + 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, -1.8803634643554688, + 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, -1.0069535970687866, + -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, -0.1792980432510376, + -0.26380985975265503, -0.25473490357398987, 2.2082, -0.638, 0.4617, 0.2674, 0.5349, 0.8094, 0.2345, + 0.2303, 0.4617, 1.44, -2.22, 3.6643, 0.8710794448852539, -1.9054111242294312, 0.9005484580993652, + 0.8701779842376709, 0.9005484580993652, -1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236, 1.7134947776794434, -1.5250005722045898, 1.6716957092285156, + -1.0069535970687866, -1.486573576927185, -1.328187108039856, -1.297916054725647, -0.8599594831466675, + -0.1792980432510376, -0.26380985975265503, -0.25473490357398987, 2.2082, 0.8710794448852539, + -1.9054111242294312, 0.9005484580993652, 1.9029953479766846, 0.8710794448852539, -1.9054111242294312, + -1.8803634643554688, 2.1661579608917236 + ], + "dims": [10, 15], + "type": "float32" + }, + { + "data": [ + -1.5670859813690186, -3.7310283184051514, -2.7460145950317383, 0.8121700286865234, -3.350031852722168, + -1.5735238790512085, -3.7310383319854736, 6.124307632446289, 0.7840213775634766, -0.7250789403915405, + -1.565433382987976, -3.731032371520996, -2.7436347007751465, 1.0472451448440552, -2.7828547954559326 + ], + "dims": [15], + "type": "float32" + } + ], + "outputs": [ + { + "data": [ + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + -8.011263847351074, -5.7822418212890625, 6.016238689422607, 0.26747000217437744, -6.992536544799805, + 1.3541864156723022, -7.813620090484619, -6.758509635925293, 7.597365856170654, -13.926229476928711, + -1.322464108467102, -7.297357559204102, -0.05962071940302849, 6.347561836242676, -5.869992256164551, + -1.3616288900375366, -7.28973388671875, 0.0386197566986084, 6.329389572143555, -5.751791954040527, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + -2.400698661804199, -14.538958549499512, -7.898950576782227, 12.744998931884766, -11.1966552734375, + 1.021930456161499, -2.373898983001709, 3.8501391410827637, -0.6108309626579285, -9.256340980529785 + ], + "dims": [3, 3, 5], + "type": "float32" + } + ] + } + ] + } ]