You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.
By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:
k - (bs,n_heads,max_seq_len,d_head)
v - (bs,n_heads,max_seq_len,d_head)
I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.
Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache:
I really love this project and the accompanying blogpost, so thanks! I've reimplemented some of the inference techniques to speed up an implementation of Whisper that I am using. I had a few questions the attention kernels, as they have been giving me some some performance related issues.
By adding print statements, I can see that during the attention calculation (not including prefilling) the shapes are essentially:
I understand that max_seq_len is there because of the static KV cache implementation. My understanding is that due to the attention mask, the F.scaled_dot_product_attention combined with torch.compile should be able to tell that it doesn't need to calculate the attention over the entire max_seq_len. In my case however, I've found that the max_seq_len value has a big effect on the inference speed, which suggests to me that the full attention (over the entire max_seq_len context) is being performed on every iteration. This is vastly reduced when using the following context manager, as is done in generate.py:
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True)
If I exclude this, I am seeing a 3x reduction in tok/s. Does this make sense? Or it is a sign I've implemented something wrong? Even with this context manager, I see a significant (50%+) increase in tok/s when I reduce the context length from 4096 to 2048 or 1024.
Thanks in advance. If it helps, here is my cuda graph friendly Whisper implementation using a static KV cache:
The text was updated successfully, but these errors were encountered: