From 21b3cbc3af50aa4f77e1e477451d6b0cbc2b180d Mon Sep 17 00:00:00 2001 From: Satya Kumar Jandhyala Date: Thu, 25 Apr 2024 13:33:46 -0700 Subject: [PATCH] [WIP][JS/WebGPU] Inputs Key and Value could be 4-dims. (#20470) ### Description The Key and Value inputs could be 4-dims ### Motivation and Context --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 12 +- .../jsep/webgpu/ops/multihead-attentiion.ts | 2 +- .../test/data/ops/multihead-attention.jsonc | 236 +++++++++++++++++- 3 files changed, 239 insertions(+), 11 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index db9bb73e394c7..79b24e9c4d67a 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor })()}; workgroupBarrier(); - var max_value = -3.402823e+38f; + var max_value = f32(-3.402823e+38f); for (var i = 0u; i < ${WG}; i++) { max_value = max(thread_max[i], max_value); } - var sum_vector = ${f32Type}(${0}); + var sum_vector = ${f32Type}(0); for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) { sum_vector += exp(${f32Type}(x[offset + i]) - max_value); } @@ -378,7 +378,6 @@ const createAttentionProbsProgramInfo = {name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType} ]; return ` - const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; var tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>; @@ -426,16 +425,16 @@ const createAttentionProbsProgramInfo = throw new Error(`Unsupported components: ${components}`); } })()}; - output[outputIdx] = sum * uniforms.alpha; + ${(() => { if (relativePositionBiasInput) { return ` let batch = workgroup_id.z / uniforms.num_heads; let head = workgroup_id.z % uniforms.num_heads; var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x); - output[outputIdx] += ${relativePositionBiasInput.getByIndices('indices')};`; + output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`; } - return ''; + return 'output[outputIdx] = sum * uniforms.alpha;'; })()} } }`; @@ -512,7 +511,6 @@ const createVxAttentionScoreProgramInfo = // we need to transpose output from BNSH_v to BSND_v let batchIdx = workgroup_id.z / uniforms.num_heads; let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads; - let headOffset = (batchIdx * uniforms.M * uniforms.num_heads + currentBatchHeadNumber) * uniforms.N; if (m < uniforms.M && n < uniforms.N) { let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size + currentBatchHeadNumber * uniforms.N + n; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts index 7c91b97d13f4e..4b18a41ccbeb4 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts @@ -339,7 +339,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio if (kvBNSH) { return applyAttention( - context, Q, key, value, keyPaddingMask, undefined, undefined, undefined, relativePositionBias, params, + context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params, attributes); } if (!key || !value) { diff --git a/js/web/test/data/ops/multihead-attention.jsonc b/js/web/test/data/ops/multihead-attention.jsonc index 0bed30747bca9..2c5dd30df9b52 100644 --- a/js/web/test/data/ops/multihead-attention.jsonc +++ b/js/web/test/data/ops/multihead-attention.jsonc @@ -604,7 +604,7 @@ ] }, { - "name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue", + "name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey, pastValue, presentKey and presentValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 4, "type": "int" }], @@ -765,7 +765,83 @@ ] }, { - "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue", + "name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1.0], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2.0], + "dims": [1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3.0], + "dims": [1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": [10, 20], + "dims": [1, 1, 1, 2], + "type": "float32" + }, + // PastKey + { + "data": [4.0], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [5.0], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [3.0006706714630127], + "dims": [1, 1, 1], + "type": "float32" + }, + { + "data": [4, 2], + "dims": [1, 1, 2, 1], + "type": "float32" + }, + { + "data": [5, 3], + "dims": [1, 1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue", "operator": "MultiHeadAttention", "opset": { "domain": "com.microsoft", "version": 1 }, "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], @@ -803,7 +879,7 @@ }, // RelativePositionBias { - "data": [10, 20], + "data": [100, 200], "dims": [1, 1, 1, 2], "type": "float32" }, @@ -821,8 +897,162 @@ } ], "outputs": [ + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 4], + "type": "float32" + }, + // Present key + { + "data": [13, 14, 15, 16, 5, 6, 7, 8], + "dims": [1, 1, 2, 4], + "type": "float32" + }, + // Present value + { + "data": [17, 18, 19, 20, 9, 10, 11, 12], + "dims": [1, 1, 2, 4], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size one with pastKey and pastValue; kvBNSH (4-dim Key and Value, 3-dim Q)", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1.0], + "dims": [1, 1, 1], + "type": "float32" + }, + // K + { + "data": [2.0], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // V + { + "data": [3.0], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": [10, 20], + "dims": [1, 1, 1, 2], + "type": "float32" + }, + // PastKey + { + "data": [4.0], + "dims": [1, 1, 1, 1], + "type": "float32" + }, + // PastValue + { + "data": [5.0], + "dims": [1, 1, 1, 1], + "type": "float32" + } + ], + "outputs": [ + { + "data": [3.0006706714630127], + "dims": [1, 1, 1], + "type": "float32" + }, + { + "data": [4, 2], + "dims": [1, 1, 2, 1], + "type": "float32" + }, + { + "data": [5, 3], + "dims": [1, 1, 2, 1], + "type": "float32" + } + ] + } + ] + }, + { + "name": "MultiHeadAttention Basic, one head and head-size 4 with pastKey and pastValue; Key and Value 4-dims", + "operator": "MultiHeadAttention", + "opset": { "domain": "com.microsoft", "version": 1 }, + "attributes": [{ "name": "num_heads", "data": 1, "type": "int" }], + "cases": [ + { + "name": "T[0]", + "inputs": [ + // Q + { + "data": [1, 2, 3, 4], + "dims": [1, 1, 4], + "type": "float32" + }, + // K + { + "data": [5, 6, 7, 8], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // V + { + "data": [9, 10, 11, 12], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // Bias + { + "data": null, + "type": "float32" + }, + // Mask + { + "data": null, + "type": "int32" + }, + // RelativePositionBias + { + "data": [50, 100], + "dims": [1, 1, 1, 2], + "type": "float32" + }, + // PastKey + { + "data": [13, 14, 15, 16], + "dims": [1, 1, 1, 4], + "type": "float32" + }, + // PastValue { "data": [17, 18, 19, 20], + "dims": [1, 1, 1, 4], + "type": "float32" + } + ], + "outputs": [ + { + "data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234], "dims": [1, 1, 4], "type": "float32" },