-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Attention projections (QKV, O) disaggregation #1436
Conversation
This PR has been implemented for |
59f209b
to
d140915
Compare
Rebased onto |
6c4349d
to
4acab6c
Compare
4207293
to
e75dbb6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
great work! I left some comments. The code is working locally on my machine, so I think we only need to do a bit of cleanup, then we can merge. We also need to remove the unused functions from the attention files (.cu & .cpp) and the deprecated parameters (weight_ptr, bias_ptr). Once done, can you apply the disaggregation to the other models as well (opt, mpt, falcon, etc)?
@@ -171,6 +188,23 @@ void LLAMA::create_llama_model(FFModel &ff, | |||
} | |||
} | |||
|
|||
Tensor mha_input = mha; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we just reuse the same mha
tensor for the input of the output projection?
7a8d200
to
104ba3c
Compare
commented out some alignment test, but should be equivalent to the oriinal test.
1bc1c1e
to
e0ee241
Compare
Description of changes:
This PR moves the qkv projection (and output projection) from the attention operator into a separate dense layer to support LORA on qkv projection (and output projection).
It also adds support for LLAMA 3, LLAMA 3.1 and LLAMA 3.2 models
Related Issues:
Linked Issues:
Issues closed by this PR:
This change is