diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 0949d65174b41..1a31253905694 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = ( if (input.dims.length === 3) { reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = ( biasOffset!, ); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1],