Skip to content

Advanced Quantization Algorithm for LLMs. This is official implementation of "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs"

License

Notifications You must be signed in to change notification settings

attafosu/auto-round

 
 

Repository files navigation

AutoRound

Advanced Quantization Algorithm for LLMs

python version license

AutoRound is an advanced quantization algorithm for low-bits LLM inference. It's tailored for a wide range of models. Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead. The below image presents an overview of AutoRound. Check out our updated paper on arxiv

What's New

  • [2024/07] Important change: the default value of nsamples has been changed from 512 to 128 to reduce the memory usages, which may cause a slight accuracy drop in some scenarios
  • [2024/06] AutoRound format supports mixed bit-widths and group sizes for inference, resolving the significant performance drop issue with the asymmetric kernel
  • [2024/05] AutoRound supports lm-head quantization, saving 0.7G for LLaMA3-8B at W4G128.
  • [2024/05] AutoRound performs well in low_bit_open_llm_leaderboard

Prerequisites

  • Python 3.9 or higher

Installation

Build from Source

pip install -vvv --no-build-isolation -e .
or
pip install -r requirements.txt
python setup.py install

Install from pypi

pip install auto-round

Model quantization

Gaudi2/ CPU/ GPU

We found a significant accuracy discrepancy with the qdq model using the AutoGPTQ GPU backend with asymmetric quantization in some scenarios, especially at lower bits,like 2. Please save quantized model to AuoRound format to fix this issue.

from transformers import AutoModelForCausalLM, AutoTokenizer

model_name = "facebook/opt-125m"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

from auto_round import AutoRound

bits, group_size, sym = 4, 128, False
autoround = AutoRound(model, tokenizer, bits=bits, group_size=group_size, sym=sym)
autoround.quantize()
output_dir = "./tmp_autoround"
autoround.save_quantized(output_dir) ##save_quantized(output_dir,format="auto_round")
Detailed Hyperparameters
  • model: The PyTorch model to be quantized.

  • tokenizer: An optional tokenizer for processing input data. If none, a dataset must be provided.

  • bits (int): Number of bits for quantization (default is 4).

  • group_size (int): Size of the quantization group (default is 128).

  • sym (bool): Whether to use symmetric quantization (default is False).

  • enable_quanted_input (bool): Whether to use the output of the previous quantized block as the input for the current block for tuning (default is True).

  • enable_minmax_tuning (bool): Whether to enable weight min-max tuning (default is True).

  • iters (int): Number of tuning iterations (default is 200).

  • lr (float): The learning rate for rounding value (default is None, it will be set to 1.0/iters automatically).

  • minmax_lr (float): The learning rate for min-max tuning (default is None, it will be set to lr automatically).

  • nsamples (int): Number of samples for tuning (default is 128).

  • seqlen (int): Data length of the sequence for tuning (default is 2048).

  • batch_size (int): Batch size for training (default is 8).

  • scale_dtype (str): The data type of quantization scale to be used (default is "float16"), different kernels have different choices.

  • amp (bool): Whether to use automatic mixed precision (default is True).

  • nblocks (int): Packing several blocks as one for tuning together (default is 1).

  • gradient_accumulate_steps (int): Number of gradient accumulation steps (default is 1).

  • low_gpu_mem_usage (bool): Whether to save GPU memory at the cost of ~20% more tuning time (default is False).

  • dataset Union[str, list, tuple, torch.utils.data.DataLoader]: The dataset name for tuning (default is " NeelNanda/pile-10k"). Local json file and combination of datasets have been supported, e.g. " ./tmp.json,NeelNanda/pile-10k:train, mbpp:train+validation+test"

  • layer_config (dict): Configuration for weight quantization (default is an empty dictionary), mainly for mixed bits or mixed precision.

  • device: The device to be used for tuning. The default is set to 'auto', allowing for automatic detection.

Tips

1 Consider increasing 'iters' (e.g. 1000) to achieve better results, albeit with increased tuning time.

2 Consider increasing 'nsamples' (e.g. 512) to achieve better results, albeit with more memory(~20G).

3 Setting 'minmax_lr' to 2.0/iters has been observed to occasionally yield improved results.

Model inference

Please run the quantization code first.

CPU

##pip install intel-extension-for-transformers
from intel_extension_for_transformers.transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

GPU

##pip install auto-gptq
from transformers import AutoModelForCausalLM, AutoTokenizer
##from auto_round.auto_quantizer import AutoHfQuantizer ## uncomment it for models with auto_round format

quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

Intel Gaudi-2

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from auto_round.auto_quantizer import AutoHfQuantizer
import habana_frameworks.torch.core as htcore
import habana_frameworks.torch.hpu as hthpu
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path).to('hpu').to(torch.float32)
tokenizer = AutoTokenizer.from_pretrained(quantized_model_path)
text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

Support List

Model Supported
Intel/neural-chat-7b-v3-3 HF-int4-model, accuracy, recipe, example
Intel/neural-chat-7b-v3-1 HF-int4-model, accuracy, recipe, example
mistralai/Mistral-7B-v0.1 HF-int4-model-lmhead,HF-int4-model, accuracy, recipe, example
microsoft/phi-2 HF-int4-sym-model, accuracy, recipe, example
google/gemma-2b HF-int4-model, accuracy, recipe, example
tiiuae/falcon-7b HF-int4-model-G64, accuracy, recipe, example
mistralai/Mistral-7B-Instruct-v0.2 HF-int4-model (under review), accuracy, recipe, example
mistralai/Mixtral-8x7B-Instruct-v0.1 HF-int4-model (under review), accuracy, recipe, example
mistralai/Mixtral-8x7B-v0.1 HF-int4-model (under review), accuracy, recipe, example
meta-llama/Meta-Llama-3-8B-Instruct accuracy, recipe, example
google/gemma-7b accuracy, recipe, example
meta-llama/Llama-2-7b-chat-hf accuracy, recipe, example
Qwen/Qwen1.5-7B-Chat accuracy, sym recipe, asym recipe , example
baichuan-inc/Baichuan2-7B-Chat accuracy, recipe, example
01-ai/Yi-6B-Chat accuracy, recipe, example
facebook/opt-2.7b accuracy, recipe, example
bigscience/bloom-3b accuracy, recipe, example
EleutherAI/gpt-j-6b accuracy, recipe, example
Salesforce/codegen25-7b-multi example
huggyllama/llama-7b example
mosaicml/mpt-7b example
THUDM/chatglm3-6b example
MBZUAI/LaMini-GPT-124M example
EleutherAI/gpt-neo-125m example
databricks/dolly-v2-3b example
stabilityai/stablelm-base-alpha-3b example

Comparison with other methods

We provide a comprehensive analysis with other methods in our accuracy data section. In summary, our approach achieved superior performance compared to GPTQ, scoring 30/32, AWQ with 27/32, HQQ with 15/16, and OmniQuant with a perfect score of 16/16 across llamv1/llamav2/mistral-7b on W4G-1, W4G128, W3G128, and W2G128, based on the average accuracies of 11 zero-shot tasks.

Reference

If you find SignRound useful for your research, please cite our paper:

@article{cheng2023optimize,
  title={Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs},
  author={Cheng, Wenhua and Zhang, Weiwei and Shen, Haihao and Cai, Yiyang and He, Xin and Lv, Kaokao and Liu, Yi},
  journal={arXiv preprint arXiv:2309.05516},
  year={2023}
}

About

Advanced Quantization Algorithm for LLMs. This is official implementation of "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs"

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 81.2%
  • Cuda 14.4%
  • Shell 3.6%
  • Other 0.8%