diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index b902e7d13f9ec..3779ccf2ef6f2 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -279,11 +279,10 @@ const computeAttentionProbs = // TODO: handle mask const alpha = attributes.scale === 0 ? 1.0 / Math.sqrt(parameters.headSize) : attributes.scale; - const gemmSize = parameters.sequenceLength * parameters.totalSequenceLength; const dataType = tensorTypeToWsglStorageType(q.dataType); - const components = getMaxComponents(parameters.headSize); + const components = 1; //getMaxComponents(parameters.headSize); const qInput = inputVariable('q', q.dataType, q.dims, components); const kInput = inputVariable('key', key.dataType, key.dims, components); const output = outputVariable('output', q.dataType, probsShape); @@ -293,37 +292,67 @@ const computeAttentionProbs = const N = parameters.totalSequenceLength; const K = vectorizedHeadSize; - const unitsOfWork = ShapeUtil.size(probsShape); + const TILE_SIZE = Math.min(8, vectorizedHeadSize); + + const dispatch = { + x: Math.ceil(parameters.totalSequenceLength / TILE_SIZE), + y: Math.ceil(parameters.sequenceLength / TILE_SIZE), + z: parameters.batchSize * parameters.numHeads + }; const inputs = [q, key]; const getShaderSource = (shaderHelper: ShaderHelper) => ` const M: u32 = ${M}u; const N: u32 = ${N}u; const K: u32 = ${K}u; - const gemmSize: u32 = ${gemmSize}; const alpha = ${dataType}(${alpha}); const beta: ${dataType} = 1.0; + const TILE_SIZE = ${TILE_SIZE}u; + + var tileQ: mat4x4<${dataType}>; + var tileK: mat4x4<${dataType}>; ${shaderHelper.declareVariables(qInput, kInput, output)} - ${shaderHelper.mainStart()} - ${shaderHelper.guardAgainstOutOfBoundsWorkgroupSizes(unitsOfWork)} - let idxWoGemmSize = global_idx / gemmSize; - let gemmOffset = global_idx % gemmSize; - let m = gemmOffset / N; - let n = gemmOffset % N; + @compute @workgroup_size(${TILE_SIZE}, ${TILE_SIZE}, 1) + fn main(@builtin(global_invocation_id) global_id : vec3, @builtin(workgroup_id) workgroup_id : vec3, + @builtin(local_invocation_id) local_id : vec3, @builtin(local_invocation_index) local_index : u32) { + let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - let inputOffset = ${parameters.sequenceLength * vectorizedHeadSize} * idxWoGemmSize + m * K; - let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * idxWoGemmSize + n * K; + // x holds the N and y holds the M + let headIdx = workgroup_id.z; + let m = workgroup_id.y * TILE_SIZE + local_id.y; + let n = workgroup_id.x * TILE_SIZE + local_id.x; - var value: ${qInput.type.storage} = ${fillVector(dataType, components)}; - for (var k: u32 = 0u; k<${K}u; k++) { - // no trans a + trans b - value += q[k + inputOffset] * key[k + kOffset]; + let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; + let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; + + var value: ${dataType} = 0; + for (var w: u32 = 0u; w < K; w += TILE_SIZE) { + if (m < M && w + local_id.x < K) { + tileQ[local_id.x][local_id.y] = q[qOffset + w + local_id.x]; + } + if (n < N && w + local_id.x < K) { + tileK[local_id.x][local_id.y] = key[kOffset + w + local_id.x]; + } + workgroupBarrier(); + + for (var k: u32 = 0u; k GpuDataType.default); @@ -335,7 +364,7 @@ const computeAttentionProbs = inputTypes, outputs: [{dims: probsShape, dataType: q.dataType, gpuDataType: GpuDataType.default}], getShaderSource, - dispatchGroup: () => ({x: Math.ceil(unitsOfWork / 64 /* workgroup size */)}) + dispatchGroup: () => (dispatch) }, {inputs, outputs: [-1]})[0];