Skip to content

Commit

Permalink
Attention WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 9, 2023
1 parent 6d828eb commit 6b9350c
Showing 1 changed file with 49 additions and 20 deletions.
69 changes: 49 additions & 20 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,10 @@ const computeAttentionProbs =
// TODO: handle mask

const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale;
const gemmSize = parameters.sequenceLength * parameters.totalSequenceLength;

const dataType = tensorTypeToWsglStorageType(q.dataType);

const components = getMaxComponents(parameters.headSize);
const components = 1; //getMaxComponents(parameters.headSize);
const qInput = inputVariable('q', q.dataType, q.dims, components);
const kInput = inputVariable('key', key.dataType, key.dims, components);
const output = outputVariable('output', q.dataType, probsShape);
Expand All @@ -293,37 +292,67 @@ const computeAttentionProbs =
const N = parameters.totalSequenceLength;
const K = vectorizedHeadSize;

const unitsOfWork = ShapeUtil.size(probsShape);
const TILE_SIZE = Math.min(8, vectorizedHeadSize);

const dispatch = {
x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE),
y: Math.ceil(parameters.sequenceLength / TILE_SIZE),
z: parameters.batchSize * parameters.numHeads
};

const inputs = [q, key];
const getShaderSource = (shaderHelper: ShaderHelper) => `
const M: u32 = ${M}u;
const N: u32 = ${N}u;
const K: u32 = ${K}u;
const gemmSize: u32 = ${gemmSize};
const alpha = ${dataType}(${alpha});
const beta: ${dataType} = 1.0;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: mat4x4<${dataType}>;
var<workgroup> tileK: mat4x4<${dataType}>;
${shaderHelper.declareVariables(qInput, kInput, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)}
let idxWoGemmSize = global_idx / gemmSize;
let gemmOffset = global_idx % gemmSize;
let m = gemmOffset / N;
let n = gemmOffset % N;
@compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1)
fn main(@builtin(global_invocation_id) global_id : vec3<u32>, @builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_id) local_id : vec3<u32>, @builtin(local_invocation_index) local_index : u32) {
let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u +
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
let inputOffset = ${parameters.sequenceLength * vectorizedHeadSize} * idxWoGemmSize + m * K;
let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * idxWoGemmSize + n * K;
// x holds the N and y holds the M
let headIdx = workgroup_id.z;
let m = workgroup_id.y * TILE_SIZE + local_id.y;
let n = workgroup_id.x * TILE_SIZE + local_id.x;
var value: ${qInput.type.storage} = ${fillVector(dataType, components)};
for (var k: u32 = 0u; k<${K}u; k++) {
// no trans a + trans b
value += q[k + inputOffset] * key[k + kOffset];
let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K;
let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K;
var value: ${dataType} = 0;
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
if (m < M && w + local_id.x < K) {
tileQ[local_id.x][local_id.y] = q[qOffset + w + local_id.x];
}
if (n < N && w + local_id.x < K) {
tileK[local_id.x][local_id.y] = key[kOffset + w + local_id.x];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE; k++) {
value += tileQ[local_id.x][k] * tileK[k][local_id.y];
}
workgroupBarrier();
}
//${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(ShapeUtil.size(probsShape))}
// output[global_idx] = f32(workgroup_id.z);
let headOffset = headIdx * M * N;
if (m < M && n < N) {
let outputIdx = headOffset + m * N + n;
output[outputIdx] = value;
}
let sum = ${sumVector('value', components)} * alpha;
// value += beta * output[global_id.x]; // no mask
output[global_idx] = sum;
}`;

const inputTypes = inputs.map(_ => GpuDataType.default);
Expand All @@ -335,7 +364,7 @@ const computeAttentionProbs =
inputTypes,
outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)})
dispatchGroup: () => (dispatch)
},
{inputs, outputs: [-1]})[0];

Expand Down

0 comments on commit 6b9350c

Please sign in to comment.