From 8ac657d83a4fe73c8c880aaa4e0db3cc987217d7 Mon Sep 17 00:00:00 2001 From: Arthur Islamov Date: Fri, 29 Sep 2023 18:59:47 +0400 Subject: [PATCH] More cleanup --- js/web/lib/wasm/jsep/webgpu/ops/attention.ts | 137 ++++++++++--------- 1 file changed, 71 insertions(+), 66 deletions(-) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 033d3c265cb1e..65ca25ea09f04 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -75,6 +75,40 @@ export interface AttentionAttrs { } const validateAttentionInputs = (inputs: readonly TensorView[], attributes: AttentionAttrs): AttentionParameters => { + // Abbreviation and Meanings: + // B: batch_size + // S: sequence_length (input sequence length of query) + // P: past_sequence_length (past sequence length of key or value) + // L: kv_sequence_length (input sequence length of key or value) + // M: max_sequence_length + // T: total_sequence_length = past_sequence_length + kv_sequence_length + // N: num_heads + // H: head size for Q and K, aka q_head_size or k_head_size or qk_head_size + // H_v: v_head_size + // D_i: input hidden size + // D: hidden size for Q and K (D = N * H), aka q_hidden_size or k_hidden_size or qk_hidden_size + // D_v: v_hidden_size = num_heads * v_head_size + + // When past state is used, Q, K and V should have same hidden size (unless we split it into past_key and past_value). + + // Input shapes: + // input (Q/K/V) : (B, S, D_i) + // weights (Q/K/V) : (D_i, D + D + D_v) + // bias (Q/K/V) : (D + D + D_v) + // mask_index : see below + // past (K/V) : (2, B, N, P, H) or NULL + // relative_position_bias : (B, N, S, T) or NULL + + // For mask_index, the following shapes are supported: + // NULL, (B, 1), (1, 1) + // (B), (2 * B), (3 * B + 2) + // (B, T) + // (B, S, T) + // (B, 1, M, M) + // + // When a model is pruned (like some attention heads are removed in Q/K/V), input_hidden_size could be larger + // than hidden dimension of Q, K and V. + const input = inputs[0]; const weights = inputs[1]; const bias = inputs[2]; @@ -396,7 +430,7 @@ const computeAttentionProbs = }; const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { - const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize]; + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; const probsHelper = inputVariable('probs', probs.dataType, probs.dims); const vHelper = inputVariable('v', v.dataType, v.dims); @@ -429,33 +463,37 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; - - let offsetA = headIdx * (M * K) + m * K; - let offsetB = headIdx * (N * K) + n; - - var value = ${dataType}(0); - for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; - } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; - } - workgroupBarrier(); - for (var k: u32 = 0u; k (dispatch) }, - {inputs: [probs, v], outputs: [-1]})[0]; + {inputs: [probs, v], outputs: [0]})[0]; }; export const applyAttention = @@ -476,43 +514,10 @@ export const applyAttention = relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); - // computeVxAttentionScore(context, probs, v, parameters); - const attentionResult = computeVxAttentionScore(context, probs, v, parameters); - - const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; - const input = inputVariable('input', q.dataType, attentionResult.dims); - const output = outputVariable('output', q.dataType, outputShape); - const outputSize = parameters.batchSize * parameters.sequenceLength * parameters.vHeadSize * parameters.numHeads; - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(outputSize)} - let h = global_idx % ${parameters.vHeadSize}; - let s = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength}; - let n = (global_idx / (${parameters.vHeadSize} * ${parameters.sequenceLength})) % ${parameters.numHeads}; - let b = global_idx / (${parameters.vHeadSize} * ${parameters.sequenceLength} * ${parameters.numHeads}); - - let inputOffset = b * ${parameters.numHeads} * ${parameters.sequenceLength} * ${parameters.vHeadSize} - + n * ${parameters.sequenceLength} * ${parameters.vHeadSize} + s * ${parameters.vHeadSize} + h; - let outputOffset = b * ${parameters.sequenceLength} * ${parameters.vHiddenSize} - + s * ${parameters.vHiddenSize} + h + n * ${parameters.vHeadSize}; - output[outputOffset] = input[inputOffset]; - }`; - - context.compute( - { - name: 'AttentionTranspose', - cacheHint: JSON.stringify(parameters), - inputTypes: [GpuDataType.default], - outputs: [{dims: outputShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({ x: Math.ceil(outputSize / 64) }), - }, - {inputs: [attentionResult], outputs: [0]}); + computeVxAttentionScore(context, probs, v, parameters); }; -const prepare = (context: ComputeContext, parameters: AttentionParameters, attributes: AttentionAttrs) => { +const prepare = (context: ComputeContext, parameters: AttentionParameters) => { const outputShape = [ parameters.batchSize, parameters.numHeads, @@ -533,7 +538,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri z: parameters.batchSize * parameters.numHeads }; - const getShaderSource = (shaderHelper: ShaderHelper) => ` + const getShaderSource = () => ` const M: u32 = ${M}u; const K: u32 = ${K}u; const N: u32 = ${N}u; @@ -630,7 +635,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri export const attention = (context: ComputeContext, attributes: AttentionAttrs): void => { const params = validateAttentionInputs(context.inputs, attributes); - const [q, k, v] = prepare(context, params, attributes); + const [q, k, v] = prepare(context, params); return applyAttention( context, q, k, v, context.inputs[4], undefined, undefined, undefined, context.inputs[5], params, attributes);