Skip to content

Commit

Permalink
Attention WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
dakenf committed Sep 28, 2023
1 parent e797f53 commit ebec851
Showing 1 changed file with 120 additions and 99 deletions.
219 changes: 120 additions & 99 deletions js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts
Original file line number Diff line number Diff line change
Expand Up @@ -211,43 +211,71 @@ export const parseAttentionAttributes = (attributes: AttentionAttrs): AttentionA
createAttributeWithCacheKey({...attributes});

export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView, N: number, D: number) => {
const components = getMaxComponents(D);
const inputHelper = outputVariable('x', input.dataType, input.dims, components);

let threadMaxValue = 'threadMaxVector';
if (components === 2) {
threadMaxValue = 'max(threadMaxVector.x, threadMaxVector.y)';
} else if (components === 4) {
threadMaxValue = 'max(max(threadMaxVector.x, threadMaxVector.y), max(threadMaxVector.z, threadMaxVector.w))';
}
const dataType = tensorTypeToWsglStorageType(input.dataType);
let WG = 64;
const dComp = D / components;
if (dComp < WG) {
WG = 1;
} else if (dComp / 8 < 64) {
WG = Math.ceil(dComp / 8);
}
const elementsPerWG = Math.ceil(D / components / WG);

// 6.2.4 in wgsl spec
const threadMaxMinValue = dataType === 'f32' ? '-3.402823e+38f' : '-65504.0h';
const getShaderSource = (shaderHelper: ShaderHelper) => `
const dInv: ${dataType} = 1 / ${D};
@group(0) @binding(0) var<storage, read_write> x: array<${dataType}>;
@group(0) @binding(1) var<storage, read_write> x2: array<${dataType}>;
const dComp = ${D / components};
var<workgroup> wgMax: array<${dataType}, ${WG}>;
var<workgroup> wgSum: array<${dataType}, ${WG}>;
${shaderHelper.declareVariables(inputHelper)}
@compute @workgroup_size(${WG}, 1, 1)
fn main(@builtin(workgroup_id) workgroup_id : vec3<u32>,
@builtin(local_invocation_index) local_index : u32) {
let localOffset = local_index * ${elementsPerWG};
let offset: u32 = workgroup_id.x * dComp + localOffset;
var threadMaxVector = ${fillVector(dataType, components, threadMaxMinValue)};
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
threadMaxVector = max(x[offset + i], threadMaxVector);
}
wgMax[local_index] = ${threadMaxValue};
workgroupBarrier();
${shaderHelper.mainStart()}
if (global_idx >= ${N}) {
return;
var maxValue = ${threadMaxMinValue};
for (var i = 0u; i < ${WG}; i++) {
maxValue = max(wgMax[i], maxValue);
}
let offset: u32 = global_idx * ${D};
var threadMax = -3.402823e+38f; // 6.2.4 in wgsl spec
for (var i: u32 = 0; i < ${D}; i++) {
threadMax = max(f32(x[offset + i]), threadMax);
var sumVector = ${fillVector(dataType, components, '0')};
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
sumVector += exp(x[offset + i] - maxValue);
}
wgSum[local_index] = ${sumVector('sumVector', components)};
workgroupBarrier();
var sum: f32 = 0.0;
for (var i: u32 = 0; i < ${D}; i++) {
let val: f32 = exp(f32(x[offset + i]) - threadMax);
// x[offset + i] = ${dataType}(val);
sum += val;
var sum: ${dataType} = 0;
for (var i = 0u; i < ${WG}; i++) {
sum += wgSum[i];
}
// for (var i: u32 = 0; i < ${D}; i++) {
// sum += x[offset + i];
// }
if (sum == 0) {
for (var i: u32 = 0; i < ${D}; i++) {
x[offset + i] = dInv;
x2[offset + i] = dInv;
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
x[offset + i] = ${fillVector(dataType, components, 'dInv')};
}
} else {
for (var i: u32 = 0; i < ${D}; i++) {
x[offset + i] = ${dataType}(exp(f32(x[offset + i]) - threadMax) / sum);
x2[offset + i] = x[offset + i];
for (var i: u32 = 0; i < ${elementsPerWG} && i + localOffset < dComp; i++) {
x[offset + i] = exp(x[offset + i] - maxValue) / sum;
}
}
}`;
Expand All @@ -257,13 +285,11 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView
name: 'computeAttentionProbsSoftmax',
cacheHint: '0',
inputTypes: [GpuDataType.default],
outputs: [
{dims: input.dims, dataType: DataType.float, gpuDataType: GpuDataType.default}
],
outputs: [],
getShaderSource,
dispatchGroup: () => ({x: Math.ceil(N / 64)})
dispatchGroup: () => ({x: N})
},
{inputs: [input], outputs: [-1]});
{inputs: [input], outputs: []});
};

const computeAttentionProbs =
Expand Down Expand Up @@ -372,11 +398,10 @@ const computeAttentionProbs =
};

const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => {
const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize];
const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize];

const components = getMaxComponents(params.totalSequenceLength);
const probsHelper = inputVariable('probs', probs.dataType, probs.dims, components);
const vHelper = inputVariable('v', v.dataType, v.dims, components);
const probsHelper = inputVariable('probs', probs.dataType, probs.dims);
const vHelper = inputVariable('v', v.dataType, v.dims);
const output = outputVariable('output', probs.dataType, outputShape);

const dataType = tensorTypeToWsglStorageType(probs.dataType);
Expand All @@ -385,15 +410,16 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v:
const dispatch = {
x: Math.ceil(params.vHeadSize / TILE_SIZE),
y: Math.ceil(params.sequenceLength / TILE_SIZE),
z: params.batchSize * params.numHeads,
z: params.batchSize * params.numHeads
};

const getShaderSource = (shaderHelper: ShaderHelper) => `
const M: u32 = ${params.sequenceLength}u;
const N: u32 = ${params.vHeadSize}u;
const K: u32 = ${params.totalSequenceLength / components}u;
const K: u32 = ${params.totalSequenceLength}u;
const numHeads: u32 = ${params.numHeads}u;
const TILE_SIZE = ${TILE_SIZE}u;
var<workgroup> tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
var<workgroup> tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>;
Expand All @@ -406,35 +432,31 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v:
workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index;
let headIdx = workgroup_id.z;
let m = workgroup_id.y * TILE_SIZE;
let n = workgroup_id.x * TILE_SIZE;
let lm = m + local_id.y;
let ln = n + local_id.x;
let m = workgroup_id.y * TILE_SIZE + local_id.y;
let n = workgroup_id.x * TILE_SIZE + local_id.x;
let offsetA = headIdx * (M * K) + m * K;
let offsetB = headIdx * (N * K) + n * K;
let offsetB = headIdx * (N * K) + n;
var value = ${fillVector(dataType, components)};
var value = ${dataType}(0);
for (var w: u32 = 0u; w < K; w += TILE_SIZE) {
if (m + local_id.y < M && w + local_id.x < K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + local_id.y * K + w + local_id.x];
if (m < M && w + local_id.x < K) {
tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x];
}
if (n + local_id.y < N && w + local_id.x < K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + local_id.y * K + w + local_id.x];
if (n < N && w + local_id.y < K) {
tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * local_id.x + k];
value += tileQ[TILE_SIZE * local_id.y + k] * tileK[TILE_SIZE * k + local_id.x];
}
workgroupBarrier();
}
let batchIdx = workgroup_id.z / ${params.numHeads};
let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads};
let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize};
if (lm < M && ln < N) {
let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + lm * ${params.vHiddenSize} + currentBatchHeadNumber * ${params.vHeadSize} + ln;
output[outputIdx] = ${sumVector('value', components)};
let headOffset = headIdx * M * N;
if (m < M && n < N) {
let outputIdx = headOffset + m * N + n;
output[outputIdx] = value;
}
}`;

Expand All @@ -447,7 +469,7 @@ let batchIdx = workgroup_id.z / ${params.numHeads};
getShaderSource,
dispatchGroup: () => (dispatch)
},
{inputs: [probs, v], outputs: [0]})[0];
{inputs: [probs, v], outputs: [-1]})[0];
};

export const applyAttention =
Expand All @@ -456,39 +478,41 @@ export const applyAttention =
relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => {
const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes);

computeVxAttentionScore(context, probs, v, parameters);
// const attentionResult = computeVxAttentionScore(context, probs, v, parameters);

// const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize];
// const input = inputVariable('input', q.dataType, attentionResult.dims);
// const output = outputVariable('output', q.dataType, outputShape);
// const outputSize = parameters.batchSize * parameters.sequenceLength * parameters.vHeadSize * parameters.numHeads;
// const getShaderSource = (shaderHelper: ShaderHelper) => `
// ${shaderHelper.declareVariables(input, output)}
//
// ${shaderHelper.mainStart()}
// ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
// let h = global_idx % ${parameters.vHeadSize};
// let n = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength};
// let s = (global_idx / (${parameters.vHeadSize * parameters.numHeads})) % ${parameters.sequenceLength};
// let b = global_idx / (${parameters.vHeadSize * parameters.sequenceLength * parameters.numHeads});
//
// var inputOffset = b * ${parameters.numHeads * parameters.sequenceLength * parameters.vHeadSize} + n * ${parameters.sequenceLength * parameters.vHeadSize} + s * ${parameters.vHeadSize} + h;
// var outputOffset = b * ${parameters.sequenceLength * parameters.vHiddenSize} + s * ${parameters.vHiddenSize} + n * ${parameters.vHeadSize} + h;
//
// output[outputOffset] = input[inputOffset];
// }`;
//
// context.compute(
// {
// name: 'AttentionTranspose',
// cacheHint: JSON.stringify(parameters),
// inputTypes: [GpuDataType.default],
// outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}],
// getShaderSource,
// dispatchGroup: () => ({ x: Math.ceil(outputSize / 64) }),
// },
// {inputs: [attentionResult], outputs: [0]});
// computeVxAttentionScore(context, probs, v, parameters);
const attentionResult = computeVxAttentionScore(context, probs, v, parameters);

const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize];
const input = inputVariable('input', q.dataType, attentionResult.dims);
const output = outputVariable('output', q.dataType, outputShape);
const outputSize = parameters.batchSize * parameters.sequenceLength * parameters.vHeadSize * parameters.numHeads;
const getShaderSource = (shaderHelper: ShaderHelper) => `
${shaderHelper.declareVariables(input, output)}
${shaderHelper.mainStart()}
${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)}
let h = global_idx % ${parameters.vHeadSize};
let s = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength};
let n = (global_idx / (${parameters.vHeadSize} * ${parameters.sequenceLength})) % ${parameters.numHeads};
let b = global_idx / (${parameters.vHeadSize} * ${parameters.sequenceLength} * ${parameters.numHeads});
// Calculate the offsets for the input and output tensors
var inputOffset = b * ${parameters.numHeads} * ${parameters.sequenceLength} * ${parameters.vHeadSize} + n * ${parameters.sequenceLength} * ${parameters.vHeadSize} + s * ${parameters.vHeadSize} + h;
var outputOffset = b * ${parameters.sequenceLength} * ${parameters.vHiddenSize} + s * ${parameters.vHiddenSize} + h + n * ${parameters.vHeadSize};
// Copy the value from the input tensor to the output tensor
output[outputOffset] = input[inputOffset];
}`;

context.compute(
{
name: 'AttentionTranspose',
cacheHint: JSON.stringify(parameters),
inputTypes: [GpuDataType.default],
outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}],
getShaderSource,
dispatchGroup: () => ({ x: Math.ceil(outputSize / 64) }),
},
{inputs: [attentionResult], outputs: [0]});
};

const prepare = (context: ComputeContext, parameters: AttentionParameters, attributes: AttentionAttrs) => {
Expand Down Expand Up @@ -558,19 +582,18 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri
tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x];
}
if (n < N && w + local_id.y < K) {
tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + n + (w + local_id.y) * ldb];
}
if (n < N && w + local_id.y < K) {
tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + n + (w + local_id.y) * ldb];
}
if (n < N && w + local_id.y < K) {
tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + n + (w + local_id.y) * ldb];
let offset = n + (w + local_id.y) * ldb;
tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + offset];
tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + offset];
tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + offset];
}
workgroupBarrier();
for (var k: u32 = 0u; k<TILE_SIZE && w+k < K; k++) {
valueQ += tileInput[TILE_SIZE * local_id.y + k] * tileWeightQ[TILE_SIZE * k + local_id.x];
valueK += tileInput[TILE_SIZE * local_id.y + k] * tileWeightK[TILE_SIZE * k + local_id.x];
valueV += tileInput[TILE_SIZE * local_id.y + k] * tileWeightV[TILE_SIZE * k + local_id.x];
let inputTileOffset = TILE_SIZE * local_id.y + k;
let weightTileOffset = TILE_SIZE * k + local_id.x;
valueQ += tileInput[inputTileOffset] * tileWeightQ[weightTileOffset];
valueK += tileInput[inputTileOffset] * tileWeightK[weightTileOffset];
valueV += tileInput[inputTileOffset] * tileWeightV[weightTileOffset];
}
workgroupBarrier();
Expand All @@ -586,9 +609,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri
let outputIdx = offset + m * N + n;
outputQ[outputIdx] = valueQ;
outputK[outputIdx] = valueK;
// transpose V to use vec4 optimizations in compute score
let outputIdxV = offset + n * M + m;
outputV[outputIdxV] = valueV;
outputV[outputIdx] = valueV;
}
}`;

Expand Down

0 comments on commit ebec851

Please sign in to comment.