diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts index d53d65448e768..8db353470e4c7 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attentiion.ts @@ -403,7 +403,7 @@ const computeAttentionProbs = const computeVxAttentionScore = (context: ComputeContext, probs: TensorView, v: TensorView, params: AttentionParameters) => { const outputShape = [params.batchSize, params.numHeads, params.sequenceLength, params.vHeadSize]; - const components = getMaxComponents(params.totalSequenceLength); + const components = 1; //getMaxComponents(params.totalSequenceLength); const probsHelper = inputVariable('probs', probs.dataType, probs.dims, components); const vHelper = inputVariable('v', v.dataType, v.dims, components); const output = outputVariable('output', probs.dataType, outputShape); @@ -545,12 +545,12 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri let global_idx = (workgroup_id.z * ${dispatch.x * dispatch.y}u + workgroup_id.y * ${dispatch.x}u + workgroup_id.x) * ${TILE_SIZE * TILE_SIZE}u + local_index; - let batchIndex = workgroup_id.z / ${parameters.batchSize}; - let headNumber = workgroup_id.z % ${parameters.batchSize}; + let batchIndex = workgroup_id.z / ${parameters.numHeads}; + let headNumber = workgroup_id.z % ${parameters.numHeads}; let m = workgroup_id.y * TILE_SIZE + local_id.y; let n = workgroup_id.x * TILE_SIZE + local_id.x; - let inputOffset = batchIndex * ${parameters.sequenceLength * parameters.inputHiddenSize} + m * K; + let inputOffset = batchIndex * (M * K) + m * K; let biasOffsetQ = headNumber * ${parameters.headSize}; let biasOffsetK = ${parameters.hiddenSize} + biasOffsetQ; let biasOffsetV = ${parameters.hiddenSize} + biasOffsetK; @@ -563,13 +563,13 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters, attri tileInput[TILE_SIZE * local_id.y + local_id.x] = input[inputOffset + w + local_id.x]; } if (n < N && w + local_id.y < K) { - tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + (w + local_id.y) * ldb]; + tileWeightQ[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetQ + n + (w + local_id.y) * ldb]; } if (n < N && w + local_id.y < K) { - tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + (w + local_id.y) * ldb]; + tileWeightK[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + n + (w + local_id.y) * ldb]; } if (n < N && w + local_id.y < K) { - tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetK + (w + local_id.y) * ldb]; + tileWeightV[TILE_SIZE * local_id.y + local_id.x] = weight[biasOffsetV + n + (w + local_id.y) * ldb]; } workgroupBarrier(); for (var k: u32 = 0u; k