Skip to content

Commit

Permalink
[js/web] JSEP Attention & MultiHeadAttention (#17742)
Browse files Browse the repository at this point in the history
### Description
This is a narrow implementation of Attention/MultiHeadAttention as it
does not support:
a. inputs 5-7 for MHA
b. packed QKV/KV
c. past/present
d. attention mask

But it works well for StableDiffusion and can be extended later. It
reduces VRAM usage as it combines many ops into few
I've updated demo here https://islamov.ai/stable-diffusion-webgpu/ it
takes ~13sec for 1 image with 20 steps on RTX3090Ti and about 25s on M1
Pro
VRAM usage is about 8gb if you don't use img2img

Going to focus on SDXL now

---------

Co-authored-by: Guenther Schmuelling <[email protected]>
Co-authored-by: Yulong Wang <[email protected]>
  • Loading branch information
3 people authored Nov 17, 2023
1 parent a5537f2 commit fac3e33
Show file tree
Hide file tree
Showing 13 changed files with 1,866 additions and 0 deletions.
2 changes: 2 additions & 0 deletions js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ Do not modify directly.*
| Asinh | ai.onnx(9+) | |
| Atan | ai.onnx(7+) | |
| Atanh | ai.onnx(9+) | |
| Attention | com.microsoft(1+) | need implementing mask and past/present |
| AveragePool | ai.onnx(7-9,10,11+); com.ms.internal.nhwc(7-9,10,11+) | need perf optimization; need implementing activation |
| BiasAdd | com.microsoft(1+) | |
| BiasSplitGelu | com.microsoft(1+) | |
Expand Down Expand Up @@ -61,6 +62,7 @@ Do not modify directly.*
| MemcpyFromHost | ai.onnx(1+) | |
| MemcpyToHost | ai.onnx(1+) | |
| Mul | ai.onnx(7-12,13,14+) | |
| MultiHeadAttention | com.microsoft(1+) | need implementing mask and past/present |
| Neg | ai.onnx(6-12,13+) | |
| Not | ai.onnx(1+) | |
| Pad | ai.onnx(2-10,11-12,13-17,18,19+) | |
Expand Down
4 changes: 4 additions & 0 deletions js/web/lib/wasm/jsep/webgpu/op-resolve-rules.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

import {argMax, argMin, parseArgMinMaxAttributes} from './ops/argminmax';
import {attention, parseAttentionAttributes} from './ops/attention';
import {biasAdd} from './ops/bias-add';
import {biasSplitGelu} from './ops/bias-split-gelu';
import * as binaryOps from './ops/binary-op';
Expand All @@ -16,6 +17,7 @@ import {gemm, parseGemmAttributes} from './ops/gemm';
import {instanceNorm, parseInstanceNormAttributes} from './ops/instance-norm';
import {layerNorm, parseLayerNormAttributes} from './ops/layer-norm';
import {matMul} from './ops/matmul';
import {multiHeadAttention, parseMultiHeadAttentionAttributes} from './ops/multi-head-attentiion';
import {pad, parsePadAttributes} from './ops/pad';
import * as pool from './ops/pool';
import {range} from './ops/range';
Expand Down Expand Up @@ -46,6 +48,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
['Asinh', [unaryOps.asinh]],
['Atan', [unaryOps.atan]],
['Atanh', [unaryOps.atanh]],
['Attention', [attention, parseAttentionAttributes]],
// TODO: support new attributes for AveragePool-10
['AveragePool', [pool.averagePool, pool.parseAveragePoolAttributes]],
['BiasAdd', [biasAdd]],
Expand Down Expand Up @@ -86,6 +89,7 @@ export const WEBGPU_OP_RESOLVE_RULES: Map<string, OperatorImplementation> = new
// TODO: support new attributes for MaxPool-8 and MaxPool-10
['MaxPool', [pool.maxPool, pool.parseMaxPoolAttributes]],
['Mul', [binaryOps.mul]],
['MultiHeadAttention', [multiHeadAttention, parseMultiHeadAttentionAttributes]],
['Neg', [unaryOps.neg]],
['Not', [unaryOps.not]],
['Pad', [pad, parsePadAttributes]],
Expand Down
Loading

0 comments on commit fac3e33

Please sign in to comment.