Skip to content

Commit

Permalink
[WIP][JS/WebGPU] Inputs Key and Value could be 4-dims. (microsoft#20470)
Browse files Browse the repository at this point in the history
### Description
The Key and Value inputs could be 4-dims


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
satyajandhyala authored Apr 25, 2024
1 parent 2c19db0 commit 21b3cbc
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 11 deletions.
12 changes: 5 additions & 7 deletions js/web/lib/wasm/jsep/webgpu/ops/attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -282,12 +282,12 @@ const createInPlaceSoftmaxProgramInfo = (_context: ComputeContext, input: Tensor
})()};
workgroupBarrier();
var max_value = -3.402823e+38f;
var max_value = f32(-3.402823e+38f);
for (var i = 0u; i < ${WG}; i++) {
max_value = max(thread_max[i], max_value);
}
var sum_vector = ${f32Type}(${0});
var sum_vector = ${f32Type}(0);
for (var i: u32 = 0; i < uniforms.elements_per_thread && i + local_offset < uniforms.d_comp; i++) {
sum_vector += exp(${f32Type}(x[offset + i]) - max_value);
}
Expand Down Expand Up @@ -378,7 +378,6 @@ const createAttentionProbsProgramInfo =
{name: 'num_heads', type: 'u32'}, {name: 'alpha', type: dataType as UniformDataElementType}
];
return `
const beta: ${dataType} = 1.0;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
Expand Down Expand Up @@ -426,16 +425,16 @@ const createAttentionProbsProgramInfo =
throw new Error(`Unsupported components: ${components}`);
}
})()};
output[outputIdx] = sum * uniforms.alpha;
${(() => {
if (relativePositionBiasInput) {
return `
let batch = workgroup_id.z / uniforms.num_heads;
let head = workgroup_id.z % uniforms.num_heads;
var indices = ${relativePositionBiasInput.type.indices}(batch, head, global_id.y, global_id.x);
output[outputIdx] += ${relativePositionBiasInput.getByIndices('indices')};`;
output[outputIdx] = sum * uniforms.alpha + ${relativePositionBiasInput.getByIndices('indices')};`;
}
return '';
return 'output[outputIdx] = sum * uniforms.alpha;';
})()}
}
}`;
Expand Down Expand Up @@ -512,7 +511,6 @@ const createVxAttentionScoreProgramInfo =
// we need to transpose output from BNSH_v to BSND_v
let batchIdx = workgroup_id.z / uniforms.num_heads;
let currentBatchHeadNumber = workgroup_id.z % uniforms.num_heads;
let headOffset = (batchIdx * uniforms.M * uniforms.num_heads + currentBatchHeadNumber) * uniforms.N;
if (m < uniforms.M && n < uniforms.N) {
let outputIdx = batchIdx * uniforms.M *uniforms.v_hidden_size + m * uniforms.v_hidden_size
+ currentBatchHeadNumber * uniforms.N + n;
Expand Down
2 changes: 1 addition & 1 deletion js/web/lib/wasm/jsep/webgpu/ops/multihead-attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ export const multiHeadAttention = (context: ComputeContext, attributes: Attentio

if (kvBNSH) {
return applyAttention(
context, Q, key, value, keyPaddingMask, undefined, undefined, undefined, relativePositionBias, params,
context, Q, key, value, keyPaddingMask, undefined, pastKey, pastValue, relativePositionBias, params,
attributes);
}
if (!key || !value) {
Expand Down
236 changes: 233 additions & 3 deletions js/web/test/data/ops/multihead-attention.jsonc
Original file line number Diff line number Diff line change
Expand Up @@ -604,7 +604,7 @@
]
},
{
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey and pastValue",
"name": "MultiHeadAttention Basic, 4 heads and head-size=1 with pastKey, pastValue, presentKey and presentValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 4, "type": "int" }],
Expand Down Expand Up @@ -765,7 +765,83 @@
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey and PastValue",
"name": "MultiHeadAttention Basic, one head and head-size one with RelativePositionBias, pastKey, pastValue, presentKey and presentValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1.0],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2.0],
"dims": [1, 1, 1],
"type": "float32"
},
// V
{
"data": [3.0],
"dims": [1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// PastKey
{
"data": [4.0],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [5.0],
"dims": [1, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3.0006706714630127],
"dims": [1, 1, 1],
"type": "float32"
},
{
"data": [4, 2],
"dims": [1, 1, 2, 1],
"type": "float32"
},
{
"data": [5, 3],
"dims": [1, 1, 2, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size=4 with RelativePositionBias, PastKey, PastValue, PresentKey and PresentValue",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
Expand Down Expand Up @@ -803,7 +879,7 @@
},
// RelativePositionBias
{
"data": [10, 20],
"data": [100, 200],
"dims": [1, 1, 1, 2],
"type": "float32"
},
Expand All @@ -821,8 +897,162 @@
}
],
"outputs": [
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 4],
"type": "float32"
},
// Present key
{
"data": [13, 14, 15, 16, 5, 6, 7, 8],
"dims": [1, 1, 2, 4],
"type": "float32"
},
// Present value
{
"data": [17, 18, 19, 20, 9, 10, 11, 12],
"dims": [1, 1, 2, 4],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size one with pastKey and pastValue; kvBNSH (4-dim Key and Value, 3-dim Q)",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1.0],
"dims": [1, 1, 1],
"type": "float32"
},
// K
{
"data": [2.0],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// V
{
"data": [3.0],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": [10, 20],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// PastKey
{
"data": [4.0],
"dims": [1, 1, 1, 1],
"type": "float32"
},
// PastValue
{
"data": [5.0],
"dims": [1, 1, 1, 1],
"type": "float32"
}
],
"outputs": [
{
"data": [3.0006706714630127],
"dims": [1, 1, 1],
"type": "float32"
},
{
"data": [4, 2],
"dims": [1, 1, 2, 1],
"type": "float32"
},
{
"data": [5, 3],
"dims": [1, 1, 2, 1],
"type": "float32"
}
]
}
]
},
{
"name": "MultiHeadAttention Basic, one head and head-size 4 with pastKey and pastValue; Key and Value 4-dims",
"operator": "MultiHeadAttention",
"opset": { "domain": "com.microsoft", "version": 1 },
"attributes": [{ "name": "num_heads", "data": 1, "type": "int" }],
"cases": [
{
"name": "T[0]",
"inputs": [
// Q
{
"data": [1, 2, 3, 4],
"dims": [1, 1, 4],
"type": "float32"
},
// K
{
"data": [5, 6, 7, 8],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// V
{
"data": [9, 10, 11, 12],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// Bias
{
"data": null,
"type": "float32"
},
// Mask
{
"data": null,
"type": "int32"
},
// RelativePositionBias
{
"data": [50, 100],
"dims": [1, 1, 1, 2],
"type": "float32"
},
// PastKey
{
"data": [13, 14, 15, 16],
"dims": [1, 1, 1, 4],
"type": "float32"
},
// PastValue
{
"data": [17, 18, 19, 20],
"dims": [1, 1, 1, 4],
"type": "float32"
}
],
"outputs": [
{
"data": [9.000362396240234, 10.00036334991455, 11.000362396240234, 12.000362396240234],
"dims": [1, 1, 4],
"type": "float32"
},
Expand Down

0 comments on commit 21b3cbc

Please sign in to comment.