Skip to content

Commit

Permalink
Attention WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 10, 2023
1 parent cfa7b2e commit b1d4e86
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ const computeAttentionProbs =

const dataType = tensorTypeToWsglStorageType(q.dataType);

const components = 1; //getMaxComponents(parameters.headSize);
const components = getMaxComponents(parameters.headSize);
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const output = outputVariable('output', q.dataType, probsShape);
Expand All @@ -304,13 +304,13 @@ const computeAttentionProbs =
const getShaderSource = (shaderHelper: ShaderHelper) => `
const M: u32 = ${M}u;
const N: u32 = ${N}u;
const K: u32 = ${K}u;
const K: u32 = ${K / components}u;
const alpha = ${dataType}(${alpha});
const beta: ${dataType} = 1.0;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileQ: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${qInput.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
${shaderHelper.declareVariables(qInput, kInput, output)}
Expand All @@ -328,18 +328,18 @@ const computeAttentionProbs =
let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K;
let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K;
var value: ${dataType} = 0;
var value = ${fillVector(dataType, components)};
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
if (m < M && w + local_id.x < K) {
tileQ[TILE_SIZE * local_id.x + local_id.y] = q[qOffset + w + local_id.x];
tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + w + local_id.x];
}
if (n < N && w + local_id.x < K) {
tileK[TILE_SIZE * local_id.x + local_id.y] = key[kOffset + w + local_id.x];
tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + w + local_id.x];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE; k++) {
value += tileQ[TILE_SIZE * local_id.x + k] * tileK[TILE_SIZE * k + local_id.y];
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
}
workgroupBarrier();
Expand All @@ -351,7 +351,7 @@ const computeAttentionProbs =
let headOffset = headIdx * M * N;
if (m < M && n < N) {
let outputIdx = headOffset + m * N + n;
output[outputIdx] = value;
output[outputIdx] = ${sumVector('value', components)};
}
}`;

Expand Down

0 comments on commit b1d4e86

Please sign in to comment.