You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm just curious whether your code here is using flash or not when mask is not None. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used, half() would not have been needed when mask is not None.
Thank you!
++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not None. If we force PyTorch to use flash, it spits out an error like below.
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory efficient kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:367.)
y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Memory Efficient attention has been runtime disabled. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:437.)
y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Flash attention kernel not used because: (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp:369.)
y = torch.nn.functional.scaled_dot_product_attention(
/tmp/ipykernel_467687/3943656874.py:12: UserWarning: Both fused kernels do not support non-null attn_mask. (Triggered internally at /opt/conda/conda-bld/pytorch_1702400410390/work/aten/src/ATen/native/transformers/sdp_utils_cpp.h:261.)
y = torch.nn.functional.scaled_dot_product_attention(
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[34], line 12
10 kv_mask = (torch.rand(B, S) > 0.1).to(device)
11 x = [x.half() forxin [q, k, v]]
---> 12 y = attn(*x, q_mask, kv_mask)
File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
1516 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1517 else:
-> 1518 return self._call_impl(*args, **kwargs)
File ~/miniconda3/envs/torch212/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
1522 # If we don't have any hooks, we want to skip the rest of the logic in
1523 # this function, and just call forward.
1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1525 or _global_backward_pre_hooks or _global_backward_hooks
1526 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527 return forward_call(*args, **kwargs)
1529 try:
1530 result = None
Cell In[32], line 12, in TorchNativeAttention.forward(self, q, k, v, q_mask, kv_mask)
10 attn_mask = None
11 with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
---> 12 y = torch.nn.functional.scaled_dot_product_attention(
13 q, k, v, attn_mask=attn_mask, dropout_p=self.attn_dropout, is_causal=False
14 )
16 return y if attn_mask is None elsey.nan_to_num()
RuntimeError: No available kernel. Aborting execution.
Hey @HJoonKwon! Damn, very good find, thank you! I guess this does matter in compiled forward, where we are padding inputs to static dimensions. We'd need to run the benchmarks, but maybe avoiding the call to half() could improve throughput then.
On the topic of FlashAttention, you link to FlashAttention and not FlashAttention2 here
Isn't the second version used? If not, why? Seems quite much faster
Thanks for your great work!
I'm just curious whether your code here is using flash or not when mask is not
None
. My guess is it's using memory efficient attention instead since PyTorch flash attention kernel does not support attention mask. In addition, if memory efficient was used,half()
would not have been needed when mask is notNone
.Thank you!
++ I did some experiments. Even if sdp_flash is enabled, it is not executed when mask is not
None
. If we force PyTorch to use flash, it spits out an error like below.while memory efficient kernel does not
The text was updated successfully, but these errors were encountered: