Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update benchmark_mha.py to compare with PyTorch SDPA #21449

Merged
merged 6 commits into from
Jul 27, 2024

Conversation

tianleiwu
Copy link
Contributor

@tianleiwu tianleiwu commented Jul 23, 2024

Description

  • Update benchmark_mha.py to compare with PyTorch SDPA api.
  • Write results to csv file.
  • Use sdpa_kernel cuda provider option instead of environment variables for better control.
  • Add arguments (--use_gpu, --causal etc) to allow testing different senarios.
  • Update benchmark_mha.sh to add cpu benchmarks

For Q,K,V format, torch uses BNSH format, while ort uses BSNH format, so the result is not apple-to-apple. However, if the latency difference is large, that could be a warning.

Example GPU results

Example results on A100-SXM4-80GB with settings (use_gpu=TRUE, enable_cuda_graph=FALSE, causal=FALSE, past_sequence_length=0, intra_op_num_threads=0) in Azure Linux. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1.

format batch_size sequence_length num_heads head_size latency (s) tflops kernel
Q,KV 4 2048 32 128 0.0015 179.5 ort:flash
Q,KV 4 2048 32 128 0.0015 179.0 ort:default
Q,K,V 4 2048 32 128 0.0016 170.0 ort:default
Q,K,V 4 2048 32 128 0.0016 169.5 ort:flash
QKV 4 2048 32 128 0.0016 168.5 ort:default
QKV 4 2048 32 128 0.0016 167.4 ort:flash
Q,K,V 4 2048 32 128 0.0017 159.4 torch:default
Q,K,V 4 2048 32 128 0.0018 155.0 torch:flash
Q,KV 4 2048 32 128 0.0030 92.7 ort:efficient
Q,K,V 4 2048 32 128 0.0030 90.9 ort:efficient
QKV 4 2048 32 128 0.0031 89.9 ort:efficient
Q,K,V 4 2048 32 128 0.0031 89.0 torch:efficient
Q,K,V 4 2048 32 128 0.0054 51.3 torch:math
Q,KV 4 4096 32 128 0.0058 191.0 ort:default
Q,KV 4 4096 32 128 0.0058 190.6 ort:flash
Q,K,V 4 4096 32 128 0.0059 187.8 ort:default
Q,K,V 4 4096 32 128 0.0059 186.7 ort:flash
QKV 4 4096 32 128 0.0059 185.9 ort:flash
QKV 4 4096 32 128 0.0059 185.8 ort:default
Q,K,V 4 4096 32 128 0.0067 163.4 torch:default
Q,K,V 4 4096 32 128 0.0070 157.2 torch:flash
Q,KV 4 4096 32 128 0.0113 97.6 ort:efficient
Q,K,V 4 4096 32 128 0.0114 96.4 ort:efficient
QKV 4 4096 32 128 0.0114 96.2 ort:efficient
Q,K,V 4 4096 32 128 0.0127 86.3 torch:efficient
Q,KV 8 2048 32 128 0.0031 177.8 ort:flash
Q,KV 8 2048 32 128 0.0031 177.7 ort:default
Q,K,V 8 2048 32 128 0.0032 170.8 ort:default
Q,K,V 8 2048 32 128 0.0032 170.3 ort:flash
QKV 8 2048 32 128 0.0032 169.2 ort:default
QKV 8 2048 32 128 0.0033 169.0 ort:flash
Q,K,V 8 2048 32 128 0.0034 161.9 torch:default
Q,K,V 8 2048 32 128 0.0036 152.9 torch:flash
Q,KV 8 2048 32 128 0.0059 93.5 ort:efficient
Q,K,V 8 2048 32 128 0.0060 91.3 ort:efficient
QKV 8 2048 32 128 0.0060 91.0 ort:efficient
Q,K,V 8 2048 32 128 0.0064 86.0 torch:efficient
Q,KV 8 4096 32 128 0.0115 190.8 ort:flash
Q,KV 8 4096 32 128 0.0115 190.7 ort:default
Q,K,V 8 4096 32 128 0.0118 187.1 ort:default
Q,K,V 8 4096 32 128 0.0118 187.0 ort:flash
QKV 8 4096 32 128 0.0118 185.6 ort:default
QKV 8 4096 32 128 0.0118 185.6 ort:flash
Q,K,V 8 4096 32 128 0.0139 158.7 torch:default
Q,K,V 8 4096 32 128 0.0139 158.3 torch:flash
Q,KV 8 4096 32 128 0.0225 97.7 ort:efficient
Q,K,V 8 4096 32 128 0.0227 96.8 ort:efficient
QKV 8 4096 32 128 0.0228 96.3 ort:efficient
Q,K,V 8 4096 32 128 0.0260 84.5 torch:efficient

Example CPU results

Dell XPS 8960 with i9-13900 CPU (use_gpu=FALSE, causal=FALSE, past_sequence_length=0) in Windows. ORT: build from source with CUDA 12.5; PyTorch 2.3.1 for cuda 12.1.

format causal batch_size seq_len num_heads head_size threads latency (s) kernel
Q,K,V FALSE 1 128 32 128 8 0.0005 ort:flash
Q,K,V FALSE 1 128 32 128 0 0.0009 ort:flash
Q,K,V FALSE 1 128 32 128 0 0.0009 ort:math
Q,K,V FALSE 1 128 32 128 4 0.0009 ort:flash
Q,K,V FALSE 1 128 32 128 2 0.0014 ort:flash
Q,K,V FALSE 1 128 32 128 1 0.0025 ort:flash
Q,K,V FALSE 1 128 32 128 2 0.0045 torch:default
Q,K,V FALSE 1 128 32 128 24 0.0046 torch:default
Q,K,V FALSE 1 128 32 128 8 0.0046 torch:default
Q,K,V FALSE 1 128 32 128 4 0.0046 torch:default
Q,K,V FALSE 1 128 32 128 1 0.0047 torch:default
Q,K,V FALSE 1 256 32 128 0 0.0019 ort:flash
Q,K,V FALSE 1 256 32 128 8 0.0019 ort:flash
Q,K,V FALSE 1 256 32 128 0 0.0022 ort:math
Q,K,V FALSE 1 256 32 128 4 0.0030 ort:flash
Q,K,V FALSE 1 256 32 128 2 0.0047 ort:flash
Q,K,V FALSE 1 256 32 128 1 0.0086 ort:flash
Q,K,V FALSE 1 256 32 128 2 0.0161 torch:default
Q,K,V FALSE 1 256 32 128 4 0.0162 torch:default
Q,K,V FALSE 1 256 32 128 8 0.0162 torch:default
Q,K,V FALSE 1 256 32 128 24 0.0165 torch:default
Q,K,V FALSE 1 256 32 128 1 0.0166 torch:default
Q,K,V FALSE 1 512 32 128 8 0.0077 ort:flash
Q,K,V FALSE 1 512 32 128 0 0.0091 ort:flash
Q,K,V FALSE 1 512 32 128 0 0.0099 ort:math
Q,K,V FALSE 1 512 32 128 4 0.0103 ort:flash
Q,K,V FALSE 1 512 32 128 2 0.0177 ort:flash
Q,K,V FALSE 1 512 32 128 1 0.0328 ort:flash
Q,K,V FALSE 1 512 32 128 2 0.0624 torch:default
Q,K,V FALSE 1 512 32 128 4 0.0624 torch:default
Q,K,V FALSE 1 512 32 128 8 0.0625 torch:default
Q,K,V FALSE 1 512 32 128 24 0.0626 torch:default
Q,K,V FALSE 1 512 32 128 1 0.0640 torch:default
Q,K,V FALSE 1 1024 32 128 8 0.0286 ort:flash
Q,K,V FALSE 1 1024 32 128 0 0.0317 ort:flash
Q,K,V FALSE 1 1024 32 128 4 0.0367 ort:flash
Q,K,V FALSE 1 1024 32 128 0 0.0391 ort:math
Q,K,V FALSE 1 1024 32 128 2 0.0656 ort:flash
Q,K,V FALSE 1 1024 32 128 1 0.1235 ort:flash
Q,K,V FALSE 1 1024 32 128 24 0.2482 torch:default
Q,K,V FALSE 1 1024 32 128 2 0.2483 torch:default
Q,K,V FALSE 1 1024 32 128 4 0.2483 torch:default
Q,K,V FALSE 1 1024 32 128 8 0.2486 torch:default
Q,K,V FALSE 1 1024 32 128 1 0.2538 torch:default
Q,K,V FALSE 1 2048 32 128 0 0.1038 ort:flash
Q,K,V FALSE 1 2048 32 128 8 0.1050 ort:flash
Q,K,V FALSE 1 2048 32 128 0 0.1368 ort:math
Q,K,V FALSE 1 2048 32 128 4 0.1535 ort:flash
Q,K,V FALSE 1 2048 32 128 2 0.2461 ort:flash
Q,K,V FALSE 1 2048 32 128 1 0.4724 ort:flash
Q,K,V FALSE 1 2048 32 128 8 0.9835 torch:default
Q,K,V FALSE 1 2048 32 128 4 0.9841 torch:default
Q,K,V FALSE 1 2048 32 128 24 0.9841 torch:default
Q,K,V FALSE 1 2048 32 128 2 0.9873 torch:default
Q,K,V FALSE 1 2048 32 128 1 0.9985 torch:default

Motivation and Context

To compare with PyTorch SDPA on CPU and CUDA latency.

@tianleiwu tianleiwu marked this pull request as draft July 23, 2024 01:19
@tianleiwu tianleiwu marked this pull request as ready for review July 23, 2024 07:09
@tianleiwu tianleiwu merged commit 64819f6 into main Jul 27, 2024
96 of 103 checks passed
@tianleiwu tianleiwu deleted the tlwu/benchmark_torch_mha branch July 27, 2024 01:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants