Skip to content

Latest commit

 

History

History
224 lines (176 loc) · 9.5 KB

README.md

File metadata and controls

224 lines (176 loc) · 9.5 KB

Quantization-Aware Training (QAT)

Quantization-Aware Training (QAT) refers to applying fake quantization during the training or fine-tuning process, such that the final quantized model will exhibit higher accuracies and perplexities. Fake quantization refers to rounding the float values to quantized values without actually casting them to dtypes with lower bit-widths, in contrast to post-training quantization (PTQ), which does cast the quantized values to lower bit-width dtypes, e.g.:

# PTQ: x_q is quantized and cast to int8
# scale and zero point (zp) refer to parameters used to quantize x_float
# qmin and qmax refer to the range of quantized values
x_q = (x_float / scale + zp).round().clamp(qmin, qmax).cast(int8)

# QAT: x_fq is still in float
# Fake quantize simulates the numerics of quantize + dequantize
x_fq = (x_float / scale + zp).round().clamp(qmin, qmax)
x_fq = (x_fq - zp) * scale

QAT typically involves applying a transformation to your model before and after training. In torchao, these are represented as the prepare and convert steps: (1) prepare inserts fake quantize operations into linear layers, and (2) convert transforms the fake quantize operations to actual quantize and dequantize operations after training, thereby producing a quantized model (dequantize operations are typically fused with linear after lowering). Between these two steps, training can proceed exactly as before.

qat

torchao APIs

torchao currently supports two QAT APIs, one through the quantize_ API (recommended) and one through the Quantizer classes (legacy). The quantize_ API allows flexible configuration of quantization settings for both activations and weights, while the Quantizer classes each hardcode a specific quantization setting.

For example, running QAT on a single GPU:

import torch
from torchtune.models.llama3 import llama3

# Set up smaller version of llama3 to fit in a single GPU
def get_model():
    return llama3(
        vocab_size=4096,
        num_layers=16,
        num_heads=16,
        num_kv_heads=4,
        embed_dim=2048,
        max_seq_len=2048,
    ).cuda()

# Example training loop
def train_loop(m: torch.nn.Module):
    optimizer = torch.optim.SGD(m.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-5)
    loss_fn = torch.nn.CrossEntropyLoss()
    for i in range(10):
        example = torch.randint(0, 4096, (2, 16)).cuda()
        target = torch.randn((2, 16, 4096)).cuda()
        output = m(example)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

quantize_ API (recommended)

The recommended way to run QAT in torchao is through the quantize_ API:

  1. Prepare: specify how weights and/or activations are to be quantized through FakeQuantizeConfig and passing these to intx_quantization_aware_training
  2. Convert: quantize the model using the standard post-training quantization (PTQ) functions such as int8_dynamic_activation_int4_weight

For example:

from torchao.quantization import (
    quantize_,
    int8_dynamic_activation_int4_weight,
)
from torchao.quantization.qat import (
    FakeQuantizeConfig,
    from_intx_quantization_aware_training,
    intx_quantization_aware_training,
)
model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `FakeQuantizedLinear`
activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
weight_config = FakeQuantizeConfig(torch.int4, group_size=32)
quantize_(
    model,
    intx_quantization_aware_training(activation_config, weight_config),
)

# train
train_loop(model)

# convert: transform fake quantization ops into actual quantized ops
# swap `FakeQuantizedLinear` back to `torch.nn.Linear` and inserts
# quantized activation and weight tensor subclasses
quantize_(model, from_intx_quantization_aware_training())
quantize_(model, int8_dynamic_activation_int4_weight(group_size=32))

# inference or generate

To fake quantize embedding in addition to linear, you can additionally call the following with a filter function during the prepare step:

quantize_(
    m,
    intx_quantization_aware_training(weight_config=weight_config),
    filter_fn=lambda m, _: isinstance(m, torch.nn.Embedding),
)

Quantizer API (legacy)

Alternatively, torchao provides a few hardcoded quantization settings through the following Quantizers:

For example:

from torchao.quantization.qat import Int8DynActInt4WeightQATQuantizer
qat_quantizer = Int8DynActInt4WeightQATQuantizer(group_size=32)
model = get_model()

# prepare: insert fake quantization ops
# swaps `torch.nn.Linear` with `Int8DynActInt4WeightQATLinear`
model = qat_quantizer.prepare(model)

# train
train_loop(model)

# convert: transform fake quantization ops into actual quantized ops
# swaps `Int8DynActInt4WeightQATLinear` with `Int8DynActInt4WeightLinear`
model = qat_quantizer.convert(model)

# inference or generate

To use multiple Quantizers in the same model for different layer types, users can also leverage the ComposableQATQuantizer as follows:

from torchao.quantization.qat import (
    ComposableQATQuantizer,
    Int4WeightOnlyEmbeddingQATQuantizer,
    Int8DynActInt4WeightQATQuantizer,
)

quantizer = ComposableQATQuantizer([
    Int8DynActInt4WeightQATQuantizer(groupsize=group_size),
    Int4WeightOnlyEmbeddingQATQuantizer(group_size=group_size),
])

# prepare + train + convert as before
model = qat_quantizer.prepare(model)
train_loop(model)
model = qat_quantizer.convert(model)

torchtune integration

torchao QAT is integrated with torchtune to allow users to run quantized-aware fine-tuning as follows:

tune run --nproc_per_node 8 qat_distributed --config llama3/8B_qat_full

torchtune also supports a QAT + LoRA distributed training recipe that is 1.89x faster and uses 36.1% memory compared to vanilla QAT in our early experiments. You can read more about it here:

tune run --nnodes 1 --nproc_per_node 4 qat_lora_finetune_distributed --config llama3/8B_qat_lora

For more detail, please refer to this QAT tutorial.

Evaluation Results

Evaluation was performed on 6-8 A100 GPUs (80GB each) using the torchtune QAT integration described above. We fine-tune Llama3-8B on the C4 dataset (en subset) for 5000 steps using a group size of 256 for the weights. Note that extensive hyperparameter tuning may further improve these results.

Results for int8 per token dynamic activations + int4 per group weights, using a learning rate of 2e-5:

hellaswag
(acc)
hellaswag
(acc_norm)
wikitext
(word_perplexity)
wikitext
(byte_perplexity)
wikitext
(bits_per_byte)
No quantization 57.86% 76.60% 8.905 1.505 0.590
PTQ 51.74% 70.66% 11.878 1.588 0.668
QAT (quantized) 57.25% 76.51% 9.859 1.534 0.617
PTQ degradation -6.11% -5.94% +2.973 +0.083 +0.078
QAT degradation -0.61% -0.21% +0.947 +0.029 +0.027

Results for int4 per group weights, using a learning rate of 2e-6. For this quantization scheme, the quantized path uses the more efficient int4 tinygemm kernel.

hellaswag
(acc)
hellaswag
(acc_norm)
wikitext
(word_perplexity)
wikitext
(byte_perplexity)
wikitext
(bits_per_byte)
No quantization 57.16% 77.02% 8.858 1.504 0.589
PTQ 55.06% 74.24% 10.311 1.547 0.630
QAT (quantized) 55.86% 75.06% 10.134 1.542 0.625
PTQ degradation -2.10% -2.78% +1.453 +0.043 +0.041
QAT degradation -1.30% -1.96% +1.276 +0.038 +0.036

For more details, please refer to this blog post.