Skip to content

Commit

Permalink
Attention WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 9, 2023
1 parent 6b9350c commit cfa7b2e
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,8 @@ const computeAttentionProbs =
const beta: ${dataType} = 1.0;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: mat4x4<${dataType}>;
var<workgroup> tileK: mat4x4<${dataType}>;
var<workgroup> tileQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.declareVariables(qInput, kInput, output)}
Expand All @@ -331,15 +331,15 @@ const computeAttentionProbs =
var value: ${dataType} = 0;
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
if (m < M && w + local_id.x < K) {
tileQ[local_id.x][local_id.y] = q[qOffset + w + local_id.x];
tileQ[TILE_SIZE * local_id.x + local_id.y] = q[qOffset + w + local_id.x];
}
if (n < N && w + local_id.x < K) {
tileK[local_id.x][local_id.y] = key[kOffset + w + local_id.x];
tileK[TILE_SIZE * local_id.x + local_id.y] = key[kOffset + w + local_id.x];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE; k++) {
value += tileQ[local_id.x][k] * tileK[k][local_id.y];
value += tileQ[TILE_SIZE * local_id.x + k] * tileK[TILE_SIZE * k + local_id.y];
}
workgroupBarrier();
Expand Down

0 comments on commit cfa7b2e

Please sign in to comment.