Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Attention projections (QKV, O) disaggregation #1436

Merged
merged 28 commits into from
Oct 9, 2024
Merged

Conversation

yingchen21
Copy link
Collaborator

@yingchen21 yingchen21 commented Jul 10, 2024

Description of changes:
This PR moves the qkv projection (and output projection) from the attention operator into a separate dense layer to support LORA on qkv projection (and output projection).

It also adds support for LLAMA 3, LLAMA 3.1 and LLAMA 3.2 models

Related Issues:

Linked Issues:

  • Issue #

Issues closed by this PR:

  • Closes #

This change is Reviewable

@yingchen21
Copy link
Collaborator Author

This PR has been implemented for IncMultiHeadSelfAttention, TreeIncMultiHeadSelfAttention and SpecIncMultiHeadSelfAttention. Its cuda implementation is tested under TP=2, TP=4, both with fusion and w/o fusion.
The backward pass is tested for earlier commit. Due to some issue I had with script testing peft the latest commit's backward pass is not tested yet

@yingchen21
Copy link
Collaborator Author

Rebased onto peft branch and tested forward pass

@yingchen21 yingchen21 force-pushed the attn-qkv-proj branch 2 times, most recently from 6c4349d to 4acab6c Compare August 7, 2024 23:36
@yingchen21 yingchen21 marked this pull request as ready for review August 28, 2024 17:41
Copy link
Collaborator

@goliaro goliaro left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great work! I left some comments. The code is working locally on my machine, so I think we only need to do a bit of cleanup, then we can merge. We also need to remove the unused functions from the attention files (.cu & .cpp) and the deprecated parameters (weight_ptr, bias_ptr). Once done, can you apply the disaggregation to the other models as well (opt, mpt, falcon, etc)?

inference/models/llama.cc Outdated Show resolved Hide resolved
inference/models/llama.cc Show resolved Hide resolved
@@ -171,6 +188,23 @@ void LLAMA::create_llama_model(FFModel &ff,
}
}

Tensor mha_input = mha;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we just reuse the same mha tensor for the input of the output projection?

src/ops/inc_multihead_self_attention.cc Outdated Show resolved Hide resolved
src/ops/inc_multihead_self_attention.cc Outdated Show resolved Hide resolved
src/ops/inc_multihead_self_attention.cu Outdated Show resolved Hide resolved
src/ops/kernels/linear_kernels.cu Outdated Show resolved Hide resolved
src/ops/linear.cc Outdated Show resolved Hide resolved
src/runtime/request_manager.cc Outdated Show resolved Hide resolved
src/runtime/request_manager.cc Outdated Show resolved Hide resolved
@yingchen21 yingchen21 changed the base branch from peft to inference September 11, 2024 17:24
@goliaro goliaro mentioned this pull request Sep 27, 2024
7 tasks
@goliaro goliaro changed the title Attn qkv proj Attention projections (QKV, O) disaggregation Oct 9, 2024
@goliaro goliaro merged commit 96628b3 into inference Oct 9, 2024
39 checks passed
@goliaro goliaro deleted the attn-qkv-proj branch November 4, 2024 19:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants