Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

int4 not faster than fp16 and fp8 #2487

Open
4 tasks
ShuaiShao93 opened this issue Nov 22, 2024 · 10 comments
Open
4 tasks

int4 not faster than fp16 and fp8 #2487

ShuaiShao93 opened this issue Nov 22, 2024 · 10 comments
Labels
Performance Issue about performance number triaged Issue has been triaged by maintainers

Comments

@ShuaiShao93
Copy link

System Info

x86_64, Debian 11, L4 GPU

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

  1. Build and benchmark bf16
python3 TensorRT-LLM/examples/llama/convert_checkpoint.py --model_dir ./llama-3.1-8b --output_dir ./tllm_checkpoint_1gpu_bf16 --dtype bfloat16

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_bf16 --output_dir ./tmp/llama/8B/trt_engines/bf16/1-gpu  --gpt_attention_plugin auto  --gemm_plugin auto  --max_num_tokens 16384 --max_batch_size 8 --logits_dtype=float32

python3 TensorRT-LLM/examples/run.py --engine_dir=./tmp/llama/8B/trt_engines/bf16/1-gpu --max_output_len 1 --run_profiling --tokenizer_dir ./llama-3.1-8b --max_input_length=100000 --input_file batched_input.txt
  1. Build and benchmark fp8
python TensorRT-LLM/examples/quantization/quantize.py --model_dir ./llama-3.1-8b --dtype bfloat16 --qformat fp8 --kv_cache_dtype fp8 --output_dir ./tllm_checkpoint_1gpu_fp8 --calib_size 512

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_fp8 --output_dir ./tmp/llama/8B/trt_engines/fp8/1-gpu  --gpt_attention_plugin auto  --gemm_plugin auto  --max_num_tokens 16384 --max_batch_size 8 --logits_dtype=float32

python3 TensorRT-LLM/examples/run.py --engine_dir=./tmp/llama/8B/trt_engines/fp8/1-gpu --max_output_len 1 --run_profiling --tokenizer_dir ./llama-3.1-8b --max_input_length=100000 --input_file batched_input.txt
  1. Build and benchmark int4
python TensorRT-LLM/examples/quantization/quantize.py --model_dir ./llama-3.1-8b --dtype float16 --qformat int4_awq  --awq_block_size 128  --output_dir ./tllm_checkpoint_1gpu_int4_awq   --calib_size 32

trtllm-build --checkpoint_dir ./tllm_checkpoint_1gpu_int4_awq --output_dir ./tmp/llama/8B/trt_engines/int4_awq/1-gpu  --gpt_attention_plugin auto  --gemm_plugin auto  --max_num_tokens 16384 --max_batch_size 8 --logits_dtype=float32

python3 TensorRT-LLM/examples/run.py --engine_dir=./tmp/llama/8B/trt_engines/int4_awq/1-gpu --max_output_len 1 --run_profiling --tokenizer_dir ./llama-3.1-8b --max_input_length=100000 --input_file batched_input.txt

Expected behavior

int4 faster than fp8 faster than bf16

actual behavior

bf16

batch_size: 2, avg latency of 10 iterations: : 1.1307756423950195 sec

fp8

batch_size: 2, avg latency of 10 iterations: : 1.2754581451416016 sec

int4

batch_size: 2, avg latency of 10 iterations: : 1.2559791564941407 sec

additional notes

The batched_input.txt has 2 inputs of 2k tokens

@ShuaiShao93 ShuaiShao93 added the bug Something isn't working label Nov 22, 2024
@ShuaiShao93 ShuaiShao93 changed the title int4 2x slower than fp16 and fp8 int4 2x not faster than fp16 and fp8 Nov 22, 2024
@ShuaiShao93 ShuaiShao93 changed the title int4 2x not faster than fp16 and fp8 int4 not faster than fp16 and fp8 Nov 22, 2024
@hello-11
Copy link
Collaborator

@ShuaiShao93 Which version of TrtLLM do you use? Could you use the latest version?

@hello-11 hello-11 added Performance Issue about performance number and removed bug Something isn't working labels Nov 25, 2024
@ShuaiShao93
Copy link
Author

@ShuaiShao93 Which version of TrtLLM do you use? Could you use the latest version?

Yes it’s the latest 0.14.0

@aikitoria
Copy link

aikitoria commented Nov 27, 2024

Why do you only generate a single output token? With that, performance will be dominated by prefill compute, which is roughly the same for each quant mode, no? int4 might even be slower here due to the additional dequant work?

@ShuaiShao93
Copy link
Author

Why do you only generate a single output token?

We use llm as a judge, so we only need it to answer yes or no. I think this is a common use case today and it's probably worth dedicated optimization.

With that, performance will be dominated by prefill compute, which is roughly the same for each quant mode, no? int4 might even be slower here due to the additional dequant work?

Why is int4 not faster in prefill compute? Both IO and compute should still be faster than fp16 I guess?

@aikitoria
Copy link

aikitoria commented Nov 27, 2024

The int4 weights will first be dequantized to fp16 to run actual computations. Only fp8 is capable of avoiding that step, with specific config

@ShuaiShao93
Copy link
Author

The int4 weights will first be dequantized to fp16 to run actual computations.

Doesn't L4 GPU have int4 tensor cores? Why do we have to dequant first?

Only fp8 is capable of avoiding that step

fp8 is not faster than fp16 either

@aikitoria
Copy link

Why do we have to dequant first

Doing the math in int4 would not be accurate enough to produce useful results. They are trying something like it with fp4 on Blackwell but it seems to have questionable quality there also

@ShuaiShao93
Copy link
Author

ShuaiShao93 commented Nov 27, 2024

Why do we have to dequant first

Doing the math in int4 would not be accurate enough to produce useful results. They are trying something like it with fp4 on Blackwell but it seems to have questionable quality there also

I see, thanks for the explanation! Does this mean fp8 is likely to be the fastest option for now?

If so, why is it still slower than fp16 in the OP?

@hello-11 hello-11 added the triaged Issue has been triaged by maintainers label Dec 10, 2024
@Tracin
Copy link
Collaborator

Tracin commented Dec 10, 2024

@aikitoria is right. Weight-only quantization has benefit on generation stage. Prefill stage comparison is unfair. Please benchmark with gptManagerBenchmerk.

@ShuaiShao93
Copy link
Author

Used TensorRT-LLM/benchmarks/python/benchmark.py. fp8 is even the slowest.

bf16

$  python TensorRT-LLM/benchmarks/python/benchmark.py --dtype bfloat16 -m dec --engine_dir ./tmp/llama/8B/trt_engines/bf16/1-gpu/  --batch_size "1" --input_output_len "1000,1"
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
Allocated 1728.01 MiB for execution context memory.
/opt/conda/lib/python3.10/site-packages/torch/nested/__init__.py:226: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
[BENCHMARK] engine_dir 1-gpu world_size 1 num_heads 32 num_kv_heads 8 num_layers 32 hidden_size 4096 vocab_size 128256 precision bfloat16 batch_size 1 gpu_weights_percent 1.0 input_length 1000 output_length 1 gpu_peak_mem(gb) 17.16 build_time(s) None tokens_per_sec 3.51 percentile95(ms) 296.352 percentile99(ms) 298.563 latency(ms) 285.204 compute_cap sm89 quantization QuantMode.0 generation_time(ms) 0.047 total_generated_tokens 0.0 generation_tokens_per_second 0.0

fp8

$  python TensorRT-LLM/benchmarks/python/benchmark.py --dtype bfloat16 -m dec --engine_dir ./tmp/llama/8B/trt_engines/fp8/1-gpu/  --batch_size "1" --input_output_len "1000,1"
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
Allocated 3040.69 MiB for execution context memory.
/opt/conda/lib/python3.10/site-packages/torch/nested/__init__.py:226: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
[BENCHMARK] engine_dir 1-gpu world_size 1 num_heads 32 num_kv_heads 8 num_layers 32 hidden_size 4096 vocab_size 128256 precision bfloat16 batch_size 1 gpu_weights_percent 1.0 input_length 1000 output_length 1 gpu_peak_mem(gb) 11.885 build_time(s) None tokens_per_sec 3.26 percentile95(ms) 319.658 percentile99(ms) 321.311 latency(ms) 306.941 compute_cap sm89 quantization QuantMode.FP8_QDQ|FP8_KV_CACHE generation_time(ms) 0.048 total_generated_tokens 0.0 generation_tokens_per_second 0.0

int4

$  python TensorRT-LLM/benchmarks/python/benchmark.py --dtype bfloat16 -m dec --engine_dir ./tmp/llama/8B/trt_engines/int4_awq/1-gpu/  --batch_size "1" --input_output_len "1000,1"
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
Allocated 1601.76 MiB for execution context memory.
/opt/conda/lib/python3.10/site-packages/torch/nested/__init__.py:226: UserWarning: The PyTorch API of nested tensors is in prototype stage and will change in the near future. (Triggered internally at ../aten/src/ATen/NestedTensorImpl.cpp:178.)
  return _nested.nested_tensor(
[TensorRT-LLM] TensorRT-LLM version: 0.15.0
[BENCHMARK] engine_dir 1-gpu world_size 1 num_heads 32 num_kv_heads 8 num_layers 32 hidden_size 4096 vocab_size 128256 precision bfloat16 batch_size 1 gpu_weights_percent 1.0 input_length 1000 output_length 1 gpu_peak_mem(gb) 7.387 build_time(s) None tokens_per_sec 3.53 percentile95(ms) 287.507 percentile99(ms) 288.753 latency(ms) 283.471 compute_cap sm89 quantization QuantMode.PER_GROUP|INT4_WEIGHTS generation_time(ms) 0.047 total_generated_tokens 0.0 generation_tokens_per_second 0.0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Performance Issue about performance number triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

4 participants