Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

onnx export with dynamic shapes, fast attention #324

Merged
merged 16 commits into from
May 27, 2024
Merged

onnx export with dynamic shapes, fast attention #324

merged 16 commits into from
May 27, 2024

Conversation

jpata
Copy link
Owner

@jpata jpata commented May 25, 2024

image

Here's how the direct export of torch.nn.functional.scaled_dot_product_attention to an unfused ONNX model, with full matrix multiplications looks like:
image

Using the SDPA fused operation that will use flash attention on sufficiently new GPUs, where the MatMul->Softmax->MatMul part in the very end is rolled into an op SDPA that calls MultiHeadAttention:
image
image

Here are the timings, showing the benefit of the fused model:

timing/gpu_fp32_fused.txt:Nelem=2560 mean_time=6.99 ms stddev_time=2.89 ms mem_used=1678 MB
timing/gpu_fp32_fused.txt:Nelem=5120 mean_time=16.59 ms stddev_time=0.15 ms mem_used=1946 MB
timing/gpu_fp32_fused.txt:Nelem=10240 mean_time=53.13 ms stddev_time=0.23 ms mem_used=1946 MB

timing/gpu_fp32_unfused.txt:Nelem=2560 mean_time=39.31 ms stddev_time=1.73 ms mem_used=3817 MB
timing/gpu_fp32_unfused.txt:Nelem=5120 mean_time=130.18 ms stddev_time=6.52 ms mem_used=12407 MB
timing/gpu_fp32_unfused.txt:Nelem=10240 mean_time=465.09 ms stddev_time=25.82 ms mem_used=46766 MB

image
image

@jpata jpata changed the title enable onnx export via dynamo with dynamic shapes onnx export of quantized model with dynamic shapes May 25, 2024
@jpata jpata changed the title onnx export of quantized model with dynamic shapes onnx export with dynamic shapes, fast attention May 27, 2024
@jpata jpata linked an issue May 27, 2024 that may be closed by this pull request
3 tasks
@jpata jpata added the hard label May 27, 2024
@jpata jpata marked this pull request as ready for review May 27, 2024 15:27
@jpata jpata merged commit a7b00c1 into main May 27, 2024
5 checks passed
@jpata jpata added hard and removed hard labels May 27, 2024
@jpata jpata deleted the fix_onnx_export branch July 13, 2024 09:16
farakiko pushed a commit to farakiko/particleflow that referenced this pull request Aug 26, 2024
* enable onnx export via dynamo with dynamic shapes

* added standalone export script

* fp16 quantization sort of works also

* use sdpa

* MultiheadAttention op runs

* update timing study

* cleanup

* model closes

* update timing study

* onnx is factorized

* update onnx script

* revert main model code

* move to notebook
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Done
Development

Successfully merging this pull request may close these issues.

Integrate new pytorch attention model in CMSSW via ONNX
1 participant