forked from Dao-AILab/flash-attention
-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrated Rotary Positional Embeddings (RoPEs) into flash_attn_kvcache #83
Open
alexkranias-amd
wants to merge
26
commits into
main_perf
Choose a base branch
from
alexkranias/rotary-embedding
base: main_perf
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
* flash_attn_func works Compress This is a combination of 12 commits. add scripts save add our kernel import our kernel round trip use bshd layout figure out segfault fix show backward failure with prints save backward work run forward only test smallest config on everything add test fix remove pre commit install triton skip dropout pin d 32 factor d just run power of 2 remove timeout run serially clean up clean up 2 * Varlen works This is a combination of 6 commits. save some tests passing enable more enable everything move around alibi works * keep interface and kernel seperate * clean up
* Compress kvcache work This is a combination of 11 commits. kvcache work This is a combination of 4 commits. kvcache is not supported save save decode save clean up merge save cases save save save save key mask on triton side fix q size issue test combos save * fix causal. use cache_seqlens * clean and test what works * some configs work on new_kv but fails on 1,8 * cache overwrite correct * new_kv works more or less * test local * work on paged kv attention * prefill paged attention * fix has_batch_idx and skip local and rotatary emb * save * save * save * save * handle new_kv when paged kv cache * all except has_batch_idx works * major options are green * test all * add tests * save * clean up * minor clean up * simplest config * save debug true * save * refactor slightly * save work * need key masking * force hip * use is_hip * save * fix cache_seq_len issue * work on new_kv * pass new_kv data * save * benchmark fwd only * disable debug * pandas pdf * save * set methods * record number of heads * use configs * flexiable dim, n-heads, headofdim * better benchmarking * basic inplace update working * works upto 64 * new_kv supported! * test case for has_batch_idx * has_batch_idx works! * save * save * save * save ref * fix mqa and gqa by duplicating * GQA and MQA working by kernel modifications * fix new_kv with gqa * cache index * deal with nans on fwd_splitk * save * causal working on basic case * causal works! * alibi works! * clean up * clean prefill changes * remove bwd stuff * limit decode test to test_op_fwd * add ref * use bfloat
Fixes after rebase rebase fixes deal with kvcache failure new run for branch cancel-in-progress fix varlen_fwd bug
* Clean Clean This is a combination of 4 commits. clean 1 clean 2 clean more match main typo fix * use is_hip() * clean up more * skip odd d only * fix bug * skip randomly * use Flag * update readme * remove quantization * remove bwd * minor * print * remove verbose print * qunatize zero's out the d stride
…n(torch.autograd.Function)
- added a pyskip for an odd case of using mha_type:"gqa" - changed batch_size=1 and nheads=1#
micmelesse
force-pushed
the
main_perf
branch
2 times, most recently
from
October 28, 2024 19:31
5d03d58
to
730d260
Compare
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Motivation
Original Paper: RoFormer: Enhanced Transformer with Rotary Position Embedding
Rotary Positional Embeddings (RoPEs) are a common positional embedding type used in many transformer models today.
RoPEs work by applying a unique rotation transformation to the vectors that represent each token within our q and k tensors based on each token's respective position in the sequence$$m$$ .
To compute attention, we must first compute$$\text{matmul(}Q \text{,} ~ K^T \text{)}$$ . This effectively is taking the dot product between the vector embeddings of tokens in $$Q$$ and $$K^T$$ . Given two tokens at positions $$i$$ and $$j$$ , the closer $$i$$ and $$j$$ are to each other, then their vector embeddings will end up getting rotated roughly the same amount, and the dot product between these two token embedding vectors will be largely unchanged. However, the further away these tokens are from each other, the more the transformation applied to these two vector embeddings diverges, which causes the dot product to decay. As the dot product decays, so does the attention weighting applied between the two tokens, and likewise this effectively leads the model to learning that for a single token the tokens near it should be paid more attention to than the tokens much further away.
A more detailed explanation
Fundamentally RoPEs work by dividing the embedding space of our q and k vectors (the$$\text{head}$$ _ $$\text{dim}$$ ) into many chunks of two. Each 2-dimensional chunk can be thought of as a vector subcomponent of q and k projected on a 2-dimensional plane that exists within the higher dimensional space of the q and k embedding. RoPE "rotates" the planar chunks of our q and k vectors uniquely based on the index of the token in the sequence. Each "chunk" is rotated some unique amount $$\theta_{m, d/2}$$ based on the index of the token in the sequence $$m$$ , and the dimension $$d$$ of the subcomponents of q and k being rotated.
Implementation
RoPE is applied to Q and K at every attention layer. For developing a kernel there are two options:
Since Tri Dao already had a functional seperate RoPE kernel. I implemented approach 1 first.
Seperate RoPE and FlashAttention Kernels
We import
from flash_attn.layers.rotary import apply_rotary_emb
Within
class _attention(torch.autograd.Function)
before callingsplitk_flash_attn
we rotateq
andinput_metadata.k_new
by making a call to this methodapply_rotary_emb
which makes a call to a Triton kernel.Fused RoPE into FlashAttention
TODO
More Notes
Can be found at the following issue: https://github.com/ROCm/triton-internal/issues/33