diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 744f6d3a04bc4..004176b2822c6 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -373,15 +373,16 @@ const computeAttentionProbs = }; const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { - const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize]; + const outputShape = [params.batchSize, params.sequenceLength, params.vHiddenSize]; - const probsHelper = inputVariable('probs', probs.dataType, probs.dims); - const vHelper = inputVariable('v', v.dataType, v.dims); + const components = getMaxComponents(params.totalSequenceLength); + const probsHelper = inputVariable('probs', probs.dataType, probs.dims, components); + const vHelper = inputVariable('v', v.dataType, v.dims, components); const output = outputVariable('output', probs.dataType, outputShape); const dataType = tensorTypeToWsglStorageType(probs.dataType); - const TILE_SIZE = 1; + const TILE_SIZE = 8; const dispatch = { x: Math.ceil(params.vHeadSize / TILE_SIZE), y: Math.ceil(params.totalSequenceLength / TILE_SIZE), @@ -391,10 +392,9 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${params.sequenceLength}u; const N: u32 = ${params.vHeadSize}u; - const K: u32 = ${params.totalSequenceLength}u; - const numHeads: u32 = ${params.numHeads}u; + const K: u32 = ${params.totalSequenceLength / components}u; const TILE_SIZE = ${TILE_SIZE}u; - + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; @@ -407,31 +407,33 @@ const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; let offsetA = headIdx * (M * K) + m * K; - let offsetB = headIdx * (N * K) + n; + let offsetB = headIdx * (N * K) + n * K; - var value = ${dataType}(0); + var value = ${fillVector(dataType, components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + w + local_id.x]; + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = probs[offsetA + local_id.y * K + w + local_id.x]; } - if (n < N && w + local_id.y < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + (w + local_id.y) * N]; + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = v[offsetB + local_id.y * K + w + local_id.x]; } workgroupBarrier(); for (var k: u32 = 0u; k (dispatch) }, - {inputs: [probs, v], outputs: [-1]})[0]; + {inputs: [probs, v], outputs: [0]})[0]; }; export const applyAttention = @@ -453,39 +455,43 @@ export const applyAttention = relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); - const attentionResult = computeVxAttentionScore(context, probs, v, parameters); - - const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; - const input = inputVariable('input', q.dataType, attentionResult.dims); - const output = outputVariable('output', q.dataType, outputShape); - const getShaderSource = (shaderHelper: ShaderHelper) => ` - ${shaderHelper.declareVariables(input, output)} - - ${shaderHelper.mainStart(parameters.numHeads * parameters.batchSize)} - let batchIndex = global_idx / ${parameters.numHeads}; - let headIndex = global_idx % ${parameters.numHeads}; - // let in = input[0]; - - var inputOffset = ${parameters.sequenceLength * parameters.vHeadSize} * global_idx; - var outputOffset = (batchIndex * ${parameters.sequenceLength * parameters.numHeads} + headIndex) * ${parameters.vHeadSize}; - for (var j = 0; j < ${parameters.sequenceLength}; j++) { - for (var i: u32 = 0; i < ${parameters.vHeadSize}; i++) { - output[outputOffset + i] = input[inputOffset + i]; - } - inputOffset += ${parameters.vHeadSize}; - outputOffset += ${parameters.vHiddenSize}; - } - }`; - - context.compute( - { - ...transposeProgramMetadata, - cacheHint: JSON.stringify(parameters), - outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], - getShaderSource, - dispatchGroup: () => ({ x: 1 }), - }, - {inputs: [attentionResult], outputs: [0]}); + computeVxAttentionScore(context, probs, v, parameters); + // const attentionResult = computeVxAttentionScore(context, probs, v, parameters); + // + // const outputShape = [parameters.batchSize, parameters.sequenceLength, parameters.vHiddenSize]; + // const input = inputVariable('input', q.dataType, attentionResult.dims); + // const output = outputVariable('output', q.dataType, outputShape); + // const getShaderSource = (shaderHelper: ShaderHelper) => ` + // ${shaderHelper.declareVariables(input, output)} + // + // ${shaderHelper.mainStart(parameters.numHeads * parameters.batchSize)} + // let headOffset = global_idx % ${parameters.vHeadSize}; + // let sequenceIndex = (global_idx / ${parameters.vHeadSize}) % ${parameters.sequenceLength}; + // let batchIndex = global_idx / ${parameters.numHeads}; + // let headIndex = global_idx % ${parameters.numHeads}; + // // let in = input[0]; + // + // var inputOffset = ${parameters.sequenceLength * parameters.vHeadSize} * global_idx; + // var outputOffset = (batchIndex * ${parameters.sequenceLength * parameters.numHeads} + headIndex) + // * ${parameters.vHeadSize}; + // for (var j = 0; j < ${parameters.sequenceLength}; j++) { + // for (var i: u32 = 0; i < ${parameters.vHeadSize}; i++) { + // output[outputOffset + i] = input[inputOffset + i]; + // } + // inputOffset += ${parameters.vHeadSize}; + // outputOffset += ${parameters.vHiddenSize}; + // } + // }`; + // + // context.compute( + // { + // ...transposeProgramMetadata, + // cacheHint: JSON.stringify(parameters), + // outputs: [{dims: outputShape, dataType: DataType.float, gpuDataType: GpuDataType.default}], + // getShaderSource, + // dispatchGroup: () => ({ x: 1 }), + // }, + // {inputs: [attentionResult], outputs: [0]}); }; const prepare = (context: ComputeContext, parameters: AttentionParameters, attributes: AttentionAttrs) => { @@ -583,7 +589,9 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri let outputIdx = offset + m * N + n; outputQ[outputIdx] = valueQ; outputK[outputIdx] = valueK; - outputV[outputIdx] = valueV; + // transpose V to use vec4 optimizations in compute score + let outputIdxV = offset + n * M + m; + outputV[outputIdxV] = valueV; } }`;