Skip to content

Commit

Permalink
[js/webgpu] Optimize maybeTransposeToBNSHAndAddBias
Browse files Browse the repository at this point in the history
With this optimization, 96 MultiHeadAttention|Transpose ops in phi3
disappear. Phi3 becomes 113 tokens from 107 tokens on my dGPUs.
  • Loading branch information
qjia7 committed Oct 12, 2024
1 parent 3321735 commit 335f67c
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/ops/multihead-attention.ts
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,9 @@ export const maybeTransposeToBNSHAndAddBias = (
if (input.dims.length === 3) {
reshapedInput = input.reshape([batchSize, sequenceLength, numHeads, headSize]);
}
if (numHeads === 1 || sequenceLength === 1) {
return reshapedInput;
}
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
Expand All @@ -356,6 +359,9 @@ export const maybeTransposeToBNSHAndAddBias = (
biasOffset!,
);
reshapedInput = reshapedInput.reshape([batchSize, sequenceLength, numHeads, headSize]);
if (numHeads === 1 || sequenceLength === 1) {
return reshapedInput;
}
return context.compute(createTransposeProgramInfo(reshapedInput, weightTransposeAttribute.perm), {
inputs: [reshapedInput],
outputs: [-1],
Expand Down

0 comments on commit 335f67c

Please sign in to comment.