diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index 3779ccf2ef6f2..ac9e61346049c 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -309,8 +309,8 @@ const computeAttentionProbs = const beta: ${dataType} = 1.0; const TILE_SIZE = ${TILE_SIZE}u; - var tileQ: mat4x4<${dataType}>; - var tileK: mat4x4<${dataType}>; + var tileQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; + var tileK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; ${shaderHelper.declareVariables(qInput, kInput, output)} @@ -331,15 +331,15 @@ const computeAttentionProbs = var value: ${dataType} = 0; for (var w: u32 = 0u; w < K; w += TILE_SIZE) { if (m < M && w + local_id.x < K) { - tileQ[local_id.x][local_id.y] = q[qOffset + w + local_id.x]; + tileQ[TILE_SIZE * local_id.x + local_id.y] = q[qOffset + w + local_id.x]; } if (n < N && w + local_id.x < K) { - tileK[local_id.x][local_id.y] = key[kOffset + w + local_id.x]; + tileK[TILE_SIZE * local_id.x + local_id.y] = key[kOffset + w + local_id.x]; } workgroupBarrier(); for (var k: u32 = 0u; k