Skip to content

Commit

Permalink
Fix 3bit packing for auto-gptq format (#264)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Sep 20, 2024
1 parent 82322ac commit 6ee91a9
Show file tree
Hide file tree
Showing 7 changed files with 96 additions and 82 deletions.
37 changes: 21 additions & 16 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ AutoRound
<div align="left">

AutoRound is an advanced quantization algorithm for low-bits LLM inference. It's tailored for a wide range
of models. Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200
of models. AutoRound adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200
steps,
which competes impressively against recent methods without introducing any additional inference overhead and keeping low
tuning cost. The below
image presents an overview of AutoRound. Check out our paper on [arxiv](https://arxiv.org/pdf/2309.05516v4) for more
details and visit [low_bit_open_llm_leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) for
more accuracy data across various models.
more accuracy data and recipes across various models.

<div align="center">

Expand Down Expand Up @@ -48,9 +48,8 @@ pip install -vvv --no-build-isolation -e .
pip install auto-round
```



## Model Quantization

### API Usage (Gaudi2/CPU/GPU)

```python
Expand Down Expand Up @@ -130,8 +129,9 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True)
</details>

### Basic Usage (version > 0.3.0)
A user guide detailing the full list of supported arguments is provided by calling ```auto_round -h``` on the terminal. Alternatively, you can use ```auto-round``` instead of ```auto_round```.

A user guide detailing the full list of supported arguments is provided by calling ```auto_round -h``` on the terminal.
Alternatively, you can use ```auto-round``` instead of ```auto_round```.

```bash
auto_round --model facebook/opt-125m \
Expand All @@ -141,6 +141,7 @@ auto_round --model facebook/opt-125m \
--disable_eval \
--output_dir ./tmp_autoround
```

We provide two recipes for best accuracy and fast running speed with low memory. Details as below.
<details>
<summary>Other Recipes</summary>
Expand All @@ -167,30 +168,35 @@ We provide two recipes for best accuracy and fast running speed with low memory.
--batch_size 4 \
--disable_eval
```

</details>

#### Formats

**AutoRound format**:This format is well-suited for CPU and HPU devices, as well as mixed-precision inference. It
**AutoRound format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4]
bits are supported. It
resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization
and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to
install from the source.

**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the
community. It also benefits from the Marlin kernel, which can boost inference performance notably. However, the
asymmetric kernel has issues that can cause considerable accuracy drops, particularly at 2-bit quantization and small models.
community, [2,3,4,8] bits are supported, for 3 bits, pip install auto-gptq first before quantization. It also benefits
from the Marlin kernel, which can boost inference performance notably. However, **the
asymmetric kernel has issues** that can cause considerable accuracy drops, particularly at 2-bit quantization and small
models.
Additionally, symmetric quantization tends to perform poorly at 2-bit precision.

**AutoAWQ format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted
within the community. Asymmetric quantization typically improves accuracy but may reduce inference speed. It features
specialized layer fusion tailored for Llama models. However, it supports only 4-bit asymmetric quantization.
within the community, only 4-bits asymmetric quantization is supported. Asymmetric quantization typically improves
accuracy but may reduce inference speed. It features
specialized layer fusion tailored for Llama models.

## Model Inference


Please run the quantization code first

### AutoGPTQ/AutoAWQ format

```python
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand All @@ -203,18 +209,16 @@ inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
```


### AutoRound format

**CPU**: no extra operations
**CPU**: pip install intel-extension-for-transformers

**HPU**: docker image with Gaudi Software Stack is recommended. More details can be found
in [Gaudi Guide](https://docs.habana.ai/en/latest/).

**CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install -vvv --no-build-isolation
-e .


#### CPU/HPU/CUDA on 0.3.0+

```python
Expand All @@ -223,7 +227,7 @@ from auto_round import AutoRoundConfig
device = "auto" ##cpu, hpu, cuda
quantization_config = AutoRoundConfig(
backend=device
backend=device
)
quantized_model_path = "./tmp_autoround"
model = AutoModelForCausalLM.from_pretrained(quantized_model_path,
Expand All @@ -248,6 +252,7 @@ text = "There is a girl who likes adventure,"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))
```

<br>
<details>
<summary>Evaluation</summary>
Expand All @@ -259,8 +264,8 @@ auto_round --model saved_quantized_model \
--task lambada_openai \
--eval_bs 1
```
</details>

</details>

## Support List

Expand Down
21 changes: 12 additions & 9 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device, get_library_version, detect_device_count
from auto_round.utils import logger


def setup_parser():
Expand All @@ -48,7 +49,7 @@ def setup_parser():
parser.add_argument("--batch_size", default=8, type=int,
help="train batch size")

parser.add_argument("--eval_bs", default=1, type=int,
parser.add_argument("--eval_bs", default=None, type=int,
help="eval batch size")

parser.add_argument("--device", default="auto", type=str,
Expand Down Expand Up @@ -164,7 +165,7 @@ def tune(args):
model_name = args.model
if model_name[-1] == "/":
model_name = model_name[:-1]
print(model_name, flush=True)
logger.info(f"start to quantize {model_name}")

device_str = detect_device(args.device)
torch_dtype = "auto"
Expand Down Expand Up @@ -231,8 +232,8 @@ def tune(args):

if hasattr(tokenizer, "model_max_length"):
if tokenizer.model_max_length < seqlen:
print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length",
flush=True)
logger.info(
f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length")
seqlen = min(seqlen, tokenizer.model_max_length)
args.seqlen = seqlen

Expand All @@ -248,7 +249,7 @@ def tune(args):
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
layer_config[n] = {"bits": 32}
print(
logger.info(
f"{n} will not be quantized due to its shape not being divisible by 32,"
" resulting in an exporting issue to autogptq")
fp_layers_list = args.fp_layers_list.split(",")
Expand All @@ -258,7 +259,7 @@ def tune(args):
name = n.split('.')[-1]
if n in fp_layers_list or name in fp_layers_list:
layer_config[n] = {"bits": 32}
print(
logger.info(
f"{n} will not be quantized.")
lm_head_layer_name = "lm_head"
for n, _ in model.named_modules():
Expand All @@ -271,8 +272,8 @@ def tune(args):
for item in tied_keys:
if lm_head_layer_name in item: ##TODO extend to encoder-decoder layer, seq classification model
args.quant_lm_head = False
print(
f"warning, disable quant_lm_head as quantizing lm_head with tied weights has not been "
logger.warning(
f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been "
f"supported currently")
break
if args.quant_lm_head:
Expand Down Expand Up @@ -316,7 +317,7 @@ def tune(args):
tasks = tasks.split(',')

if not args.disable_eval:
print(f"Using the latest {lm_eval_version}")
logger.info(f"Using lm-eval version {lm_eval_version}")
model_args = f"pretrained={eval_folder}"
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
user_model = None
Expand Down Expand Up @@ -350,6 +351,8 @@ def eval(args):

def run():
args = setup_parser()
if args.eval_bs is None:
args.eval_bs = "auto"
if args.eval:
eval(args)
else:
Expand Down
10 changes: 5 additions & 5 deletions auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@

import lm_eval
from lm_eval import simple_evaluate as lm_simple_evaluate
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
user_model = None,
user_model=None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
**kwargs):

try:
from auto_round import AutoRoundConfig
except:
from auto_round.auto_quantizer import AutoHfQuantizer

if model_args is None:
model_args = ""

if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
Expand Down Expand Up @@ -66,5 +68,3 @@ def simple_evaluate(
max_batch_size=max_batch_size,
device=device,
**kwargs)


5 changes: 4 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def detect_device(device=None):
Returns:
str: The device to use for computations, formatted as a string.
"""

def is_valid_digit(s):
try:
num = int(s)
Expand Down Expand Up @@ -912,6 +913,8 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
class: The dynamically imported QuantLinear class configured according to the specified parameters.
"""
use_triton = True
if bits not in [2, 4, 8]:
use_triton = False
disable_exllamav2 = True
disable_exllamav1 = False
disable_marlin = True
Expand Down Expand Up @@ -966,4 +969,4 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
use_qigen=use_qigen,
use_marlin=not disable_marlin,
)
return QuantLinear
return QuantLinear
Loading

0 comments on commit 6ee91a9

Please sign in to comment.