Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: Merge query/key/value projection layers #498

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open

Conversation

mryab
Copy link
Member

@mryab mryab commented Sep 2, 2023

This PR makes an ~7% optimization of the inference throughput (measured on a single A100-80GB) by merging the query/key/value projections into a single large matrix multiplication. This reduces the overhead of launching several matmul kernels, which turns out to be substantial for single-sequence single-token inference steps. Also, this code adds a --throughput dry_run option to estimate throughput without starting a server.

Sample results from running experiments with and without the optimization (the command in each case is CUDA_VISIBLE_DEVICES=0 python -m petals.cli.run_server petals-team/StableBeluga2 --throughput dry_run):

Current code (branch https://github.com/bigscience-workshop/petals/tree/no_qkv_merge):

Sep 03 00:50:35.135 [INFO] Inference throughput: 532.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:50:47.722 [INFO] Forward pass throughput: 51749.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:52:07.524 [INFO] Inference throughput: 576.4 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:52:20.919 [INFO] Forward pass throughput: 36552.9 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:53:54.616 [INFO] Inference throughput: 512.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:54:14.464 [INFO] Forward pass throughput: 50242.5 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Code from this PR:

Sep 03 00:55:25.680 [INFO] Inference throughput: 564.7 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:55:38.648 [INFO] Forward pass throughput: 33023.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:56:45.526 [INFO] Inference throughput: 578.4 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:56:59.632 [INFO] Forward pass throughput: 54655.0 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

Sep 03 00:58:18.783 [INFO] Inference throughput: 593.1 tokens/sec per block (1 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)
Sep 03 00:58:33.015 [INFO] Forward pass throughput: 36200.4 tokens/sec per block (1024 tokens/batch, NVIDIA A100-SXM-80GB GPU, bfloat16, quantized to nf4)

@mryab mryab marked this pull request as ready for review September 2, 2023 22:06
@borzunov borzunov changed the title Merge query/key/value projection layers Llama: Merge query/key/value projection layers Sep 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant