Llama: Merge query/key/value projection layers #498
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR makes an ~7% optimization of the inference throughput (measured on a single A100-80GB) by merging the query/key/value projections into a single large matrix multiplication. This reduces the overhead of launching several matmul kernels, which turns out to be substantial for single-sequence single-token inference steps. Also, this code adds a
--throughput dry_run
option to estimate throughput without starting a server.Sample results from running experiments with and without the optimization (the command in each case is
CUDA_VISIBLE_DEVICES=0 python -m petals.cli.run_server petals-team/StableBeluga2 --throughput dry_run
):Current code (branch https://github.com/bigscience-workshop/petals/tree/no_qkv_merge):
Code from this PR: