From 335f67cc4886a2d7747a50ed5e010ada0045f139 Mon Sep 17 00:00:00 2001 From: Qin Jiajia Date: Sat, 12 Oct 2024 16:12:59 +0800 Subject: [PATCH] [js/webgpu] Optimize maybeTransposeToBNSHAndAddBias With this optimization, 96 MultiHeadAttention|Transpose ops in phi3 disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs. --- js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts index 0949d65174b41..1a31253905694 100644 --- a/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts +++ b/js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = ( if (input.dims.length === 3) { reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]); } + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1], @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = ( biasOffset!, ); reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]); + if (numHeads === 1 || sequenceLength === 1) { + return reshapedInput; + } return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), { inputs: [reshapedInput], outputs: [-1],