From 6ee91a9fc1074e3ff641b0621e9e02f8c6fe60a0 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Fri, 20 Sep 2024 12:51:43 +0800 Subject: [PATCH] Fix 3bit packing for auto-gptq format (#264) --- README.md | 37 ++++++----- auto_round/__main__.py | 21 +++--- auto_round/eval/evaluation.py | 10 +-- auto_round/utils.py | 5 +- examples/language-modeling/main.py | 100 ++++++++++++++--------------- requirements.txt | 1 - test/test_export.py | 4 ++ 7 files changed, 96 insertions(+), 82 deletions(-) diff --git a/README.md b/README.md index 1f24f900..c9637df2 100644 --- a/README.md +++ b/README.md @@ -11,13 +11,13 @@ AutoRound
AutoRound is an advanced quantization algorithm for low-bits LLM inference. It's tailored for a wide range -of models. Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 +of models. AutoRound adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200 steps, which competes impressively against recent methods without introducing any additional inference overhead and keeping low tuning cost. The below image presents an overview of AutoRound. Check out our paper on [arxiv](https://arxiv.org/pdf/2309.05516v4) for more details and visit [low_bit_open_llm_leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) for -more accuracy data across various models. +more accuracy data and recipes across various models.
@@ -48,9 +48,8 @@ pip install -vvv --no-build-isolation -e . pip install auto-round ``` - - ## Model Quantization + ### API Usage (Gaudi2/CPU/GPU) ```python @@ -130,8 +129,9 @@ autoround.save_quantized(output_dir, format='auto_round', inplace=True) ### Basic Usage (version > 0.3.0) -A user guide detailing the full list of supported arguments is provided by calling ```auto_round -h``` on the terminal. Alternatively, you can use ```auto-round``` instead of ```auto_round```. +A user guide detailing the full list of supported arguments is provided by calling ```auto_round -h``` on the terminal. +Alternatively, you can use ```auto-round``` instead of ```auto_round```. ```bash auto_round --model facebook/opt-125m \ @@ -141,6 +141,7 @@ auto_round --model facebook/opt-125m \ --disable_eval \ --output_dir ./tmp_autoround ``` + We provide two recipes for best accuracy and fast running speed with low memory. Details as below.
Other Recipes @@ -167,30 +168,35 @@ We provide two recipes for best accuracy and fast running speed with low memory. --batch_size 4 \ --disable_eval ``` +
#### Formats -**AutoRound format**:This format is well-suited for CPU and HPU devices, as well as mixed-precision inference. It +**AutoRound format**:This format is well-suited for CPU, HPU devices, 2 bits, as well as mixed-precision inference. [2,4] +bits are supported. It resolves the asymmetric quantization kernel issues found in the AutoGPTQ format and supports both LM-head quantization and mixed precision. However, it has not yet gained widespread community adoption. For CUDA support, you will need to install from the source. **AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the -community. It also benefits from the Marlin kernel, which can boost inference performance notably. However, the -asymmetric kernel has issues that can cause considerable accuracy drops, particularly at 2-bit quantization and small models. +community, [2,3,4,8] bits are supported, for 3 bits, pip install auto-gptq first before quantization. It also benefits +from the Marlin kernel, which can boost inference performance notably. However, **the +asymmetric kernel has issues** that can cause considerable accuracy drops, particularly at 2-bit quantization and small +models. Additionally, symmetric quantization tends to perform poorly at 2-bit precision. **AutoAWQ format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted -within the community. Asymmetric quantization typically improves accuracy but may reduce inference speed. It features -specialized layer fusion tailored for Llama models. However, it supports only 4-bit asymmetric quantization. +within the community, only 4-bits asymmetric quantization is supported. Asymmetric quantization typically improves +accuracy but may reduce inference speed. It features +specialized layer fusion tailored for Llama models. ## Model Inference - Please run the quantization code first ### AutoGPTQ/AutoAWQ format + ```python from transformers import AutoModelForCausalLM, AutoTokenizer @@ -203,10 +209,9 @@ inputs = tokenizer(text, return_tensors="pt").to(model.device) print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) ``` - ### AutoRound format -**CPU**: no extra operations +**CPU**: pip install intel-extension-for-transformers **HPU**: docker image with Gaudi Software Stack is recommended. More details can be found in [Gaudi Guide](https://docs.habana.ai/en/latest/). @@ -214,7 +219,6 @@ in [Gaudi Guide](https://docs.habana.ai/en/latest/). **CUDA**: git clone https://github.com/intel/auto-round.git && cd auto-round && pip install -vvv --no-build-isolation -e . - #### CPU/HPU/CUDA on 0.3.0+ ```python @@ -223,7 +227,7 @@ from auto_round import AutoRoundConfig device = "auto" ##cpu, hpu, cuda quantization_config = AutoRoundConfig( - backend=device + backend=device ) quantized_model_path = "./tmp_autoround" model = AutoModelForCausalLM.from_pretrained(quantized_model_path, @@ -248,6 +252,7 @@ text = "There is a girl who likes adventure," inputs = tokenizer(text, return_tensors="pt").to(model.device) print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0])) ``` +
Evaluation @@ -259,8 +264,8 @@ auto_round --model saved_quantized_model \ --task lambada_openai \ --eval_bs 1 ``` -
+ ## Support List diff --git a/auto_round/__main__.py b/auto_round/__main__.py index e065cead..3de56571 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -27,6 +27,7 @@ from auto_round import AutoRoundConfig from auto_round.eval.evaluation import simple_evaluate from auto_round.utils import detect_device, get_library_version, detect_device_count +from auto_round.utils import logger def setup_parser(): @@ -48,7 +49,7 @@ def setup_parser(): parser.add_argument("--batch_size", default=8, type=int, help="train batch size") - parser.add_argument("--eval_bs", default=1, type=int, + parser.add_argument("--eval_bs", default=None, type=int, help="eval batch size") parser.add_argument("--device", default="auto", type=str, @@ -164,7 +165,7 @@ def tune(args): model_name = args.model if model_name[-1] == "/": model_name = model_name[:-1] - print(model_name, flush=True) + logger.info(f"start to quantize {model_name}") device_str = detect_device(args.device) torch_dtype = "auto" @@ -231,8 +232,8 @@ def tune(args): if hasattr(tokenizer, "model_max_length"): if tokenizer.model_max_length < seqlen: - print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length", - flush=True) + logger.info( + f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length") seqlen = min(seqlen, tokenizer.model_max_length) args.seqlen = seqlen @@ -248,7 +249,7 @@ def tune(args): if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D): if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: layer_config[n] = {"bits": 32} - print( + logger.info( f"{n} will not be quantized due to its shape not being divisible by 32," " resulting in an exporting issue to autogptq") fp_layers_list = args.fp_layers_list.split(",") @@ -258,7 +259,7 @@ def tune(args): name = n.split('.')[-1] if n in fp_layers_list or name in fp_layers_list: layer_config[n] = {"bits": 32} - print( + logger.info( f"{n} will not be quantized.") lm_head_layer_name = "lm_head" for n, _ in model.named_modules(): @@ -271,8 +272,8 @@ def tune(args): for item in tied_keys: if lm_head_layer_name in item: ##TODO extend to encoder-decoder layer, seq classification model args.quant_lm_head = False - print( - f"warning, disable quant_lm_head as quantizing lm_head with tied weights has not been " + logger.warning( + f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " f"supported currently") break if args.quant_lm_head: @@ -316,7 +317,7 @@ def tune(args): tasks = tasks.split(',') if not args.disable_eval: - print(f"Using the latest {lm_eval_version}") + logger.info(f"Using lm-eval version {lm_eval_version}") model_args = f"pretrained={eval_folder}" model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}" user_model = None @@ -350,6 +351,8 @@ def eval(args): def run(): args = setup_parser() + if args.eval_bs is None: + args.eval_bs = "auto" if args.eval: eval(args) else: diff --git a/auto_round/eval/evaluation.py b/auto_round/eval/evaluation.py index 8e50d889..9d1dcdb2 100644 --- a/auto_round/eval/evaluation.py +++ b/auto_round/eval/evaluation.py @@ -19,17 +19,19 @@ import lm_eval from lm_eval import simple_evaluate as lm_simple_evaluate +import os + +os.environ["TOKENIZERS_PARALLELISM"] = "false" def simple_evaluate( model, model_args: Optional[Union[str, dict]] = None, - user_model = None, + user_model=None, batch_size: Optional[int] = None, max_batch_size: Optional[int] = None, device: Optional[str] = None, **kwargs): - try: from auto_round import AutoRoundConfig except: @@ -37,7 +39,7 @@ def simple_evaluate( if model_args is None: model_args = "" - + if isinstance(model_args, dict): lm = lm_eval.api.registry.get_model(model).create_from_arg_obj( model_args, @@ -66,5 +68,3 @@ def simple_evaluate( max_batch_size=max_batch_size, device=device, **kwargs) - - diff --git a/auto_round/utils.py b/auto_round/utils.py index d180980c..c0354318 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -476,6 +476,7 @@ def detect_device(device=None): Returns: str: The device to use for computations, formatted as a string. """ + def is_valid_digit(s): try: num = int(s) @@ -912,6 +913,8 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): class: The dynamically imported QuantLinear class configured according to the specified parameters. """ use_triton = True + if bits not in [2, 4, 8]: + use_triton = False disable_exllamav2 = True disable_exllamav1 = False disable_marlin = True @@ -966,4 +969,4 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False): use_qigen=use_qigen, use_marlin=not disable_marlin, ) - return QuantLinear \ No newline at end of file + return QuantLinear diff --git a/examples/language-modeling/main.py b/examples/language-modeling/main.py index 1b78d3d6..e40ed069 100644 --- a/examples/language-modeling/main.py +++ b/examples/language-modeling/main.py @@ -43,7 +43,7 @@ parser.add_argument("--train_bs", default=8, type=int, help="train batch size") - parser.add_argument("--eval_bs", default=4, type=int, + parser.add_argument("--eval_bs", default=None, type=int, help="eval batch size") parser.add_argument("--device", default="auto", type=str, @@ -390,6 +390,9 @@ print('does not support cpu, xpu model evaluation.') exit() ## does not support cpu,xpu model eval + if args.disable_eval: + exit() + from packaging.version import Version from auto_round.utils import get_library_version @@ -402,55 +405,52 @@ use_eval_legacy = False from eval_legacy import eval_model - use_qdq = False - if args.deployment_device and 'fake' in args.deployment_device: - use_qdq = True - if args.format and ('fake' in args.format or 'qdq' in args.format): - use_qdq = True - # evaluation - if not args.disable_eval: + if use_eval_legacy: + print("Using the legacy lm_eval(0.3.0)") + else: + print(f"Using the lm_eval version {lm_eval_version}") + + if isinstance(tasks, str): + tasks = tasks.split(',') + + if lm_eval_version < Version("0.4.2"): + if args.eval_bs is None: + args.eval_bs = 1 if use_eval_legacy: - print("Using the legacy lm_eval(0.3.0)") - else: - print(f"Using the latest {lm_eval_version}") - - if isinstance(tasks, str): - tasks = tasks.split(',') - - if use_qdq and lm_eval_version < Version("0.4.2"): - if use_eval_legacy: - if "mmlu" in tasks: - tmp_tasks = tasks - tasks = ["hendrycksTest-*" if x == "mmlu" else x for x in tmp_tasks] - if "truthfulqa_mc1" in tasks or "truthfulqa_mc2" in tasks: - tmp_tasks = tasks - tasks = ["truthfulqa_mc" if "truthfulqa_mc" in x else x for x in tmp_tasks] - seen = set() + if "mmlu" in tasks: + tmp_tasks = tasks + tasks = ["hendrycksTest-*" if x == "mmlu" else x for x in tmp_tasks] + if "truthfulqa_mc1" in tasks or "truthfulqa_mc2" in tasks: tmp_tasks = tasks - tasks = [x for x in tmp_tasks if not (x in seen or seen.add(x))] - - excel_name = f"{output_dir}_result.xlsx" - output_dir += "/" - print(excel_name, flush=True) - eval_model( - model_path=output_dir, tasks=tasks, dtype=dtype, limit=None, - eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage, - device=torch_device, excel_file=excel_name, - trust_remote_code=not args.disable_trust_remote_code) - - if lm_eval_version >= Version("0.4.2"): - from eval.evaluation import simple_evaluate - - model_args = f"pretrained={eval_folder}" - model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}" - user_model = None - if args.act_bits <= 8: - user_model = model.to(device_str) - - res = simple_evaluate(model="hf", model_args=model_args, - tasks=tasks, - batch_size=args.eval_bs, user_model=user_model) - from lm_eval.utils import make_table - - print(make_table(res)) + tasks = ["truthfulqa_mc" if "truthfulqa_mc" in x else x for x in tmp_tasks] + seen = set() + tmp_tasks = tasks + tasks = [x for x in tmp_tasks if not (x in seen or seen.add(x))] + + excel_name = f"{output_dir}_result.xlsx" + output_dir += "/" + print(excel_name, flush=True) + eval_model( + model_path=output_dir, tasks=tasks, dtype=dtype, limit=None, + eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage, + device=torch_device, excel_file=excel_name, + trust_remote_code=not args.disable_trust_remote_code) + + if lm_eval_version >= Version("0.4.2"): + if args.eval_bs is None: + args.eval_bs = "auto" + from eval.evaluation import simple_evaluate + + model_args = f"pretrained={eval_folder}" + model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}" + user_model = None + if args.act_bits <= 8: + user_model = model.to(device_str) + + res = simple_evaluate(model="hf", model_args=model_args, + tasks=tasks, + batch_size=args.eval_bs, user_model=user_model) + from lm_eval.utils import make_table + + print(make_table(res)) diff --git a/requirements.txt b/requirements.txt index 987d26b6..0cc1327b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,6 +8,5 @@ triton numpy < 2.0 threadpoolctl lm-eval==0.4.4 -intel-extension-for-transformers tqdm packaging \ No newline at end of file diff --git a/test/test_export.py b/test/test_export.py index aa23ff58..c7d576f5 100644 --- a/test/test_export.py +++ b/test/test_export.py @@ -111,6 +111,10 @@ def test_autoround_format(self): quantized_model_path = "./saved" autoround.save_quantized(output_dir=quantized_model_path, inplace=False, format="auto_round") + try: + import intel_extension_for_transformers + except: + return from auto_round.auto_quantizer import AutoHfQuantizer model = AutoModelForCausalLM.from_pretrained(quantized_model_path, device_map="auto")