-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[performance] Torch SDPA cuDNN backend vs FlashAttention v3 #41
Comments
I think |
BTW, as I mentioned here pytorch/pytorch#136169 (comment), the final function calls can be different based on different inputs since aten or torch.compile have different dispatches for different cases. So, if we change the input shapes or dtypes, the underlying function calls can be changed too. |
it looks like there is already a cudnn version. https://github.com/pytorch-labs/tritonbench/blob/main/tritonbench/operators/flash_attention/operator.py#L349 |
Thanks @FindHao for your detailed response. The first thing I tried was to use the existing attention implementations for the attention operator, but dealt with several issues (see #11) for making it work properly. Tomorrow I will apply
I can only see FlexAttention compiled version . Am I missing other compilations? As you mentioned, using Torch for compiling def compiled_sdpa():
return torch.compile(
scaled_dot_product_attention,
fullgraph=True,
backend="inductor",
mode="max-autotune",
)
flops = 4 * batch * seq_len**2 * num_heads * head_dim // (2 if is_causal else 1)
for _ in range(warmup_iter):
_ = compiled_sdpa()(query, key, value, is_causal=is_causal)
with proton.scope("torch_scaled_dot_product_attention_compiled", metrics={"flops": flops}):
flash_attention_compiled_op = compiled_sdpa()
attn_out_sdpa_compiled = flash_attention_compiled_op(query, key, value, is_causal=is_causal) Attention mathematical expression for FLOPS: The updated benchmark results showcase a strange big performance gap. Torch compiled Update: Incorrect Tensor shape layout for FA3, using SDPA layout without permutationI was incorrectly using the FA3 tensor layout, without permuting the original """
FA -> (batch, seqlen, nheads, headdim)
Torch sdpa expects -> (batch, nheads, seqlen, headdim)
ref:
torch.functional: https://github.com/pytorch/pytorch/blob/8231180147a096a703d8891756068c89365292e0/torch/nn/functional.py#L5617
"""
query = query.permute(0, 2, 1, 3) # B, H, S, D
key = key.permute(0, 2, 1, 3) # B, H, S, D
value = value.permute(0, 2, 1, 3) # B, H, S, D Sill, the reported TFLOP/s for the Inductor compiled version of |
Sorry, I misread the function name in the code. Yes, we don’t have torch compiled version.
Can you compare the outputs of different implementations? using torch.allclose etc. to compare all implementations with baseline's outputs. |
Output and output reference match, thus I assume the issue might be related to Proton when profiling compiled Torch. I will forward this issue to Triton repo directly. |
Can you share you execution time results for different implementations? Is the TFLOPS aligned with the execution time? |
@FindHao do you mean the Proton results that I shared in the previous image? There you can see how is working as expecting, reporting coherent values, expect for when we compile |
cc @fywkevin for Proton results when profiling compiled Torch kernels |
Target: Figure out if by default Torch
scaled_dot_product_attention
when executing in Hopper architecture, thusPLATFORM_SUPPORTS_CUDNN_ATTENTION
andSM90OrLater
evaluate toTrue
, executestorch.ops.aten._scaled_dot_product_cudnn_attention
.I have created a simple benchmarking script whereas the FLOPs counting is done via Proton. We compare Torch SDPA with
SDPBackend.CUDNN_ATTENTION
and explicit aten op call. Both approaches launch the very same kernel:cudnn_generated_fort_native_sdpa_sm90_knob_7_64x128x128_4x1x1_kernel0_0
and achieve almost equal performance.Without calling explicit aten op or setting the cuDNN backend, the performance of SDPA is significantly lower. Isn’t Torch supposed to auto-detect the optimal backend for SDPA based on the environment information (e.g. compute capability >= 90)?
I have also benchmark against FlashAttention v3 kernel, but the measurements from Proton are inconsistent with what is reported officially by the authors and the GPU HW specs for
BFLOAT16
. (see FA and Torch SDPA shapes layout)Q: Would be worth making this explicit when testing
flash_attention
operator in Tritonbench?Benchmarking Results
Environment
@FindHao @xuzhao9
The text was updated successfully, but these errors were encountered: