diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 6b4bd24e986dd..47d838270211f 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -322,36 +322,35 @@ const computeAttentionProbs = // x holds the N and y holds the M let headIdx = workgroup_id.z; - let m = workgroup_id.y * TILE_SIZE + local_id.y; - let n = workgroup_id.x * TILE_SIZE + local_id.x; + let m = workgroup_id.y * TILE_SIZE; + let n = workgroup_id.x * TILE_SIZE; + let lm = m + local_id.y; + let ln = n + local_id.x; let qOffset = ${parameters.sequenceLength * vectorizedHeadSize} * headIdx + m * K; let kOffset = ${parameters.kvSequenceLength * vectorizedHeadSize} * headIdx + n * K; var value = ${fillVector(dataType, components)}; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { - if (m < M && w + local_id.x < K) { - tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + w + local_id.x]; + if (m + local_id.y < M && w + local_id.x < K) { + tileQ[TILE_SIZE * local_id.y + local_id.x] = q[qOffset + local_id.y * K + w + local_id.x]; } - if (n < N && w + local_id.x < K) { - tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + w + local_id.x]; + if (n + local_id.y < N && w + local_id.x < K) { + tileK[TILE_SIZE * local_id.y + local_id.x] = key[kOffset + local_id.y * K + w + local_id.x]; } workgroupBarrier(); - for (var k: u32 = 0u; k