diff --git a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts index 8e5c0ab3d149c..e1f2a47301bfb 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/attention.ts @@ -317,7 +317,7 @@ export const computeInPlaceSoftmax = (context: ComputeContext, input: TensorView }; const computeAttentionProbs = - (context: ComputeContext, q: TensorView, key: TensorView, bias: TensorView|undefined, + (context: ComputeContext, q: TensorView, key: TensorView, _bias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const probsShape = [ parameters.batchSize, parameters.numHeads, parameters.sequenceLength, @@ -443,7 +443,7 @@ const computeVxAttentionScore = const K: u32 = ${params.totalSequenceLength}u; const numHeads: u32 = ${params.numHeads}u; const TILE_SIZE = ${TILE_SIZE}u; - + var tileQ: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; var tileK: array<${probsHelper.type.storage}, ${TILE_SIZE * TILE_SIZE}>; @@ -482,7 +482,7 @@ const computeVxAttentionScore = let currentBatchHeadNumber = workgroup_id.z % ${params.numHeads}; let headOffset = (batchIdx * M * ${params.numHeads} + currentBatchHeadNumber) * ${params.vHeadSize}; if (m < M && n < N) { - let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize} + let outputIdx = batchIdx * ${params.sequenceLength * params.vHiddenSize} + m * ${params.vHiddenSize} + currentBatchHeadNumber * ${params.vHeadSize} + n; output[outputIdx] = value; } @@ -502,8 +502,8 @@ const computeVxAttentionScore = }; export const applyAttention = - (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, maskIndex: TensorView|undefined, - past: TensorView|undefined, pastKey: TensorView|undefined, pastValue: TensorView|undefined, + (context: ComputeContext, q: TensorView, k: TensorView, v: TensorView, _maskIndex: TensorView|undefined, + _past: TensorView|undefined, _pastKey: TensorView|undefined, _pastValue: TensorView|undefined, relativePositionBias: TensorView|undefined, parameters: AttentionParameters, attributes: AttentionAttrs) => { const probs = computeAttentionProbs(context, q, k, relativePositionBias, parameters, attributes); @@ -538,7 +538,7 @@ const prepare = (context: ComputeContext, parameters: AttentionParameters) => { const numHeads: u32 = ${parameters.numHeads}; const ldb = ${parameters.hiddenSize + parameters.hiddenSize + parameters.vHiddenSize}u; const TILE_SIZE = ${TILE_SIZE}u; - + var tileInput: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightQ: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; var tileWeightK: array<${dataType}, ${TILE_SIZE * TILE_SIZE}>; diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts index 0a55ade2469eb..b7726a36bcaad 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multi-head-attentiion.ts @@ -280,7 +280,7 @@ const maybeTransposeToBNSHAndAddBias = reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } return context.compute( - createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm), + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {inputs: [reshapedInput], outputs: [-1]})[0]; } else { if (sequenceLength === 1) { @@ -290,7 +290,7 @@ const maybeTransposeToBNSHAndAddBias = addBiasTranspose(context, input, bias, batchSize, sequenceLength, numHeads * headSize, biasOffset!); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); return context.compute( - createTransposeProgramInfo(input.dataType, reshapedInput.dims.length, weightTransposeAttribute.perm), + createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {inputs: [reshapedInput], outputs: [-1]})[0]; } }