Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[performance] Torch SDPA cuDNN backend vs FlashAttention v3 #41

Open
antferdom opened this issue Nov 7, 2024 · 10 comments
Open

[performance] Torch SDPA cuDNN backend vs FlashAttention v3 #41

antferdom opened this issue Nov 7, 2024 · 10 comments

Comments

@antferdom
Copy link

antferdom commented Nov 7, 2024

Target: Figure out if by default Torch scaled_dot_product_attention when executing in Hopper architecture, thus PLATFORM_SUPPORTS_CUDNN_ATTENTION and SM90OrLater evaluate to True, executes torch.ops.aten._scaled_dot_product_cudnn_attention.

I have created a simple benchmarking script whereas the FLOPs counting is done via Proton. We compare Torch SDPA with SDPBackend.CUDNN_ATTENTION and explicit aten op call. Both approaches launch the very same kernel: cudnn_generated_fort_native_sdpa_sm90_knob_7_64x128x128_4x1x1_kernel0_0 and achieve almost equal performance.

Without calling explicit aten op or setting the cuDNN backend, the performance of SDPA is significantly lower. Isn’t Torch supposed to auto-detect the optimal backend for SDPA based on the environment information (e.g. compute capability >= 90)?

I have also benchmark against FlashAttention v3 kernel, but the measurements from Proton are inconsistent with what is reported officially by the authors and the GPU HW specs for BFLOAT16. (see FA and Torch SDPA shapes layout)

Q: Would be worth making this explicit when testing flash_attention operator in Tritonbench?

Benchmarking Results

Screenshot 2024-11-07 at 18 19 26

Environment

Collecting environment information...
PyTorch version: 2.6.0.dev20241107+cu124
Is debug build: False
CUDA used to build PyTorch: 12.4
ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.30.5
Libc version: glibc-2.35

Python version: 3.10.15 (main, Oct  3 2024, 07:27:34) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-117-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: 12.4.131
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA H100 80GB HBM3
Nvidia driver version: 550.90.12
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.9.7
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.9.7
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture:                         x86_64
CPU op-mode(s):                       32-bit, 64-bit
Address sizes:                        46 bits physical, 57 bits virtual
Byte Order:                           Little Endian
CPU(s):                               30
On-line CPU(s) list:                  0-29
Vendor ID:                            GenuineIntel
Model name:                           Intel(R) Xeon(R) Platinum 8462Y+
CPU family:                           6
Model:                                143
Thread(s) per core:                   1
Core(s) per socket:                   1
Socket(s):                            30
Stepping:                             8
BogoMIPS:                             5600.00
Flags:                                fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss syscall nx pdpe1gb rdtscp lm constant_tsc arch_perfmon rep_good nopl xtopology cpuid tsc_known_freq pni pclmulqdq vmx ssse3 fma cx16 pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_fault invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves avx_vnni avx512_bf16 wbnoinvd arat avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b fsrm md_clear serialize tsxldtrk avx512_fp16 arch_capabilities
Virtualization:                       VT-x
Hypervisor vendor:                    KVM
Virtualization type:                  full
L1d cache:                            960 KiB (30 instances)
L1i cache:                            960 KiB (30 instances)
L2 cache:                             120 MiB (30 instances)
L3 cache:                             480 MiB (30 instances)
NUMA node(s):                         1
NUMA node0 CPU(s):                    0-29
Vulnerability Gather data sampling:   Not affected
Vulnerability Itlb multihit:          Not affected
Vulnerability L1tf:                   Not affected
Vulnerability Mds:                    Not affected
Vulnerability Meltdown:               Not affected
Vulnerability Mmio stale data:        Unknown: No mitigations
Vulnerability Reg file data sampling: Not affected
Vulnerability Retbleed:               Not affected
Vulnerability Spec rstack overflow:   Not affected
Vulnerability Spec store bypass:      Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1:             Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:             Mitigation; Enhanced / Automatic IBRS; IBPB conditional; RSB filling; PBRSB-eIBRS SW sequence; BHI SW loop, KVM SW loop
Vulnerability Srbds:                  Not affected
Vulnerability Tsx async abort:        Mitigation; TSX disabled

Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] nvidia-cublas-cu12==12.4.5.8
[pip3] nvidia-cuda-cupti-cu12==12.4.127
[pip3] nvidia-cuda-nvrtc-cu12==12.4.127
[pip3] nvidia-cuda-runtime-cu12==12.4.127
[pip3] nvidia-cudnn-cu12==9.1.0.70
[pip3] nvidia-cudnn-frontend==1.7.0
[pip3] nvidia-cufft-cu12==11.2.1.3
[pip3] nvidia-curand-cu12==10.3.5.147
[pip3] nvidia-cusolver-cu12==11.6.1.9
[pip3] nvidia-cusparse-cu12==12.3.1.170
[pip3] nvidia-cusparselt-cu12==0.6.2
[pip3] nvidia-nccl-cu12==2.21.5
[pip3] nvidia-nvjitlink-cu12==12.4.127
[pip3] nvidia-nvtx-cu12==12.4.127
[pip3] pytorch-triton==3.1.0+cf34004b8a
[pip3] torch==2.6.0.dev20241107+cu124
[pip3] torchao==0.5.0+cu124
[pip3] transformer_engine_torch==1.11.0
[pip3] triton==3.0.0
[conda] numpy                     1.26.4                   pypi_0    pypi
[conda] nvidia-cublas-cu12        12.4.5.8                 pypi_0    pypi
[conda] nvidia-cuda-cupti-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-nvrtc-cu12    12.4.127                 pypi_0    pypi
[conda] nvidia-cuda-runtime-cu12  12.4.127                 pypi_0    pypi
[conda] nvidia-cudnn-cu12         9.1.0.70                 pypi_0    pypi
[conda] nvidia-cudnn-frontend     1.7.0                    pypi_0    pypi
[conda] nvidia-cufft-cu12         11.2.1.3                 pypi_0    pypi
[conda] nvidia-curand-cu12        10.3.5.147               pypi_0    pypi
[conda] nvidia-cusolver-cu12      11.6.1.9                 pypi_0    pypi
[conda] nvidia-cusparse-cu12      12.3.1.170               pypi_0    pypi
[conda] nvidia-cusparselt-cu12    0.6.2                    pypi_0    pypi
[conda] nvidia-nccl-cu12          2.21.5                   pypi_0    pypi
[conda] nvidia-nvjitlink-cu12     12.4.127                 pypi_0    pypi
[conda] nvidia-nvtx-cu12          12.4.127                 pypi_0    pypi
[conda] pytorch-triton            3.1.0+cf34004b8a          pypi_0    pypi
[conda] torch                     2.6.0.dev20241107+cu124          pypi_0    pypi
[conda] torchao                   0.5.0+cu124              pypi_0    pypi
[conda] transformer-engine-torch  1.11.0                   pypi_0    pypi
[conda] triton                    3.0.0                    pypi_0    pypi

@FindHao @xuzhao9

@FindHao
Copy link
Member

FindHao commented Nov 7, 2024

Without calling explicit aten op or setting the cuDNN backend, the performance of SDPA is significantly lower. Isn’t Torch supposed to auto-detect the optimal backend for SDPA based on the environment information (e.g. compute capability >= 90)?

I think torch.nn.functional.* doesn't have such capability because it directly calls the underlying implementations, but I'm not 100% sure. Aten API can do some dispatch irrc. Can you try use torch.compile to wrap your function and enable max autotune. e.g., torch.compile(your_function, mode="max-autotune") and do benchmarking again? Torch compiled flash_attention has been added to tritonbench.

@FindHao
Copy link
Member

FindHao commented Nov 7, 2024

BTW, as I mentioned here pytorch/pytorch#136169 (comment), the final function calls can be different based on different inputs since aten or torch.compile have different dispatches for different cases. So, if we change the input shapes or dtypes, the underlying function calls can be changed too.

@FindHao
Copy link
Member

FindHao commented Nov 7, 2024

@antferdom
Copy link
Author

antferdom commented Nov 7, 2024

Thanks @FindHao for your detailed response. The first thing I tried was to use the existing attention implementations for the attention operator, but dealt with several issues (see #11) for making it work properly. Tomorrow I will apply --only for iterating each implementation in isolation.

Torch compiled flash_attention has been added to tritonbench.

I can only see FlexAttention compiled version . Am I missing other compilations?

As you mentioned, using Torch for compiling sdpa:

def compiled_sdpa():
    return torch.compile(
        scaled_dot_product_attention,
        fullgraph=True,
        backend="inductor",
        mode="max-autotune",
    )

flops = 4 * batch * seq_len**2 * num_heads * head_dim // (2 if is_causal else 1)


for _ in range(warmup_iter):
    _ = compiled_sdpa()(query, key, value, is_causal=is_causal)
with proton.scope("torch_scaled_dot_product_attention_compiled", metrics={"flops": flops}):
    flash_attention_compiled_op = compiled_sdpa()
    attn_out_sdpa_compiled = flash_attention_compiled_op(query, key, value, is_causal=is_causal)

Attention mathematical expression for FLOPS:

The updated benchmark results showcase a strange big performance gap. Torch compiled scaled_dot_product_attention becomes even faster than flash_attn_func_hopper, what seems incorrect. See the following Proton profiling results (double check with ncu):

Screenshot 2024-11-07 at 23 50 21

Update: Incorrect Tensor shape layout for FA3, using SDPA layout without permutation

I was incorrectly using the FA3 tensor layout, without permuting the original $q, k, v$ tensors. @chengzeyi quickly figure it out, thanks!

"""
FA -> (batch, seqlen, nheads, headdim)
Torch sdpa expects -> (batch, nheads, seqlen, headdim)

ref: 
    torch.functional: https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/nn/functional.py#L5617
"""
query = query.permute(0, 2, 1, 3) # B, H, S, D
key = key.permute(0, 2, 1, 3) # B, H, S, D
value = value.permute(0, 2, 1, 3) # B, H, S, D

Sill, the reported TFLOP/s for the Inductor compiled version of sdpa seems incorrect.

@FindHao
Copy link
Member

FindHao commented Nov 8, 2024

I can only see FlexAttention compiled version . Am I missing other compilations?

Sorry, I misread the function name in the code. Yes, we don’t have torch compiled version.

Sill, the reported TFLOP/s for the Inductor compiled version of sdpa seems incorrect.

Can you compare the outputs of different implementations? using torch.allclose etc. to compare all implementations with baseline's outputs.
If the outputs are same, then inductor compiled is correct and there must be something wrong with the FLOPS computations. If not, this is a correctness bug for pytorch compiler.

@antferdom
Copy link
Author

Output and output reference match, thus I assume the issue might be related to Proton when profiling compiled Torch. I will forward this issue to Triton repo directly.

@FindHao
Copy link
Member

FindHao commented Nov 12, 2024

Can you share you execution time results for different implementations? Is the TFLOPS aligned with the execution time?

@antferdom
Copy link
Author

@FindHao do you mean the Proton results that I shared in the previous image? There you can see how is working as expecting, reporting coherent values, expect for when we compile sdpa op. That's why I thought is Proton's problem.

@xuzhao9
Copy link
Contributor

xuzhao9 commented Nov 14, 2024

cc @fywkevin for Proton results when profiling compiled Torch kernels

@antferdom
Copy link
Author

@xuzhao9 @fywkevin any update?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants