Skip to content

Commit

Permalink
fix 3bit and remove itrex
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 committed Sep 20, 2024
1 parent 8af56ba commit 3330c15
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 72 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ AutoRound
<div align="left">

AutoRound is an advanced quantization algorithm for low-bits LLM inference. It's tailored for a wide range
of models. Our method adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200
of models. AutoRound adopts sign gradient descent to fine-tune rounding values and minmax values of weights in just 200
steps,
which competes impressively against recent methods without introducing any additional inference overhead and keeping low
tuning cost. The below
image presents an overview of AutoRound. Check out our paper on [arxiv](https://arxiv.org/pdf/2309.05516v4) for more
details and visit [low_bit_open_llm_leaderboard](https://huggingface.co/spaces/Intel/low_bit_open_llm_leaderboard) for
more accuracy data across various models.
more accuracy data and recipes across various models.

<div align="center">

Expand Down Expand Up @@ -177,8 +177,8 @@ and mixed precision. However, it has not yet gained widespread community adoptio
install from the source.

**AutoGPTQ Format**: This format is well-suited for symmetric quantization on CUDA devices and is widely adopted by the
community. It also benefits from the Marlin kernel, which can boost inference performance notably. However, the
asymmetric kernel has issues that can cause considerable accuracy drops, particularly at 2-bit quantization and small models.
community. It also benefits from the Marlin kernel, which can boost inference performance notably. However, **the
asymmetric kernel has issues** that can cause considerable accuracy drops, particularly at 2-bit quantization and small models.
Additionally, symmetric quantization tends to perform poorly at 2-bit precision.

**AutoAWQ format**: This format is well-suited for asymmetric 4-bit quantization on CUDA devices and is widely adopted
Expand Down Expand Up @@ -206,7 +206,7 @@ print(tokenizer.decode(model.generate(**inputs, max_new_tokens=50)[0]))

### AutoRound format

**CPU**: no extra operations
**CPU**: pip install intel-extension-for-transformers

**HPU**: docker image with Gaudi Software Stack is recommended. More details can be found
in [Gaudi Guide](https://docs.habana.ai/en/latest/).
Expand Down
21 changes: 11 additions & 10 deletions auto_round/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from auto_round import AutoRoundConfig
from auto_round.eval.evaluation import simple_evaluate
from auto_round.utils import detect_device, get_library_version, detect_device_count

from auto_round.utils import logger

def setup_parser():
parser = argparse.ArgumentParser()
Expand All @@ -48,7 +48,7 @@ def setup_parser():
parser.add_argument("--batch_size", default=8, type=int,
help="train batch size")

parser.add_argument("--eval_bs", default=1, type=int,
parser.add_argument("--eval_bs", default=None, type=int,
help="eval batch size")

parser.add_argument("--device", default="auto", type=str,
Expand Down Expand Up @@ -164,7 +164,7 @@ def tune(args):
model_name = args.model
if model_name[-1] == "/":
model_name = model_name[:-1]
print(model_name, flush=True)
logger.info(f"start to quantize {model_name}")

device_str = detect_device(args.device)
torch_dtype = "auto"
Expand Down Expand Up @@ -231,8 +231,7 @@ def tune(args):

if hasattr(tokenizer, "model_max_length"):
if tokenizer.model_max_length < seqlen:
print(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length",
flush=True)
logger.info(f"change sequence length to {tokenizer.model_max_length} due to the limitation of model_max_length")
seqlen = min(seqlen, tokenizer.model_max_length)
args.seqlen = seqlen

Expand All @@ -248,7 +247,7 @@ def tune(args):
if isinstance(m, torch.nn.Linear) or isinstance(m, transformers.modeling_utils.Conv1D):
if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0:
layer_config[n] = {"bits": 32}
print(
logger.info(
f"{n} will not be quantized due to its shape not being divisible by 32,"
" resulting in an exporting issue to autogptq")
fp_layers_list = args.fp_layers_list.split(",")
Expand All @@ -258,7 +257,7 @@ def tune(args):
name = n.split('.')[-1]
if n in fp_layers_list or name in fp_layers_list:
layer_config[n] = {"bits": 32}
print(
logger.info(
f"{n} will not be quantized.")
lm_head_layer_name = "lm_head"
for n, _ in model.named_modules():
Expand All @@ -271,8 +270,8 @@ def tune(args):
for item in tied_keys:
if lm_head_layer_name in item: ##TODO extend to encoder-decoder layer, seq classification model
args.quant_lm_head = False
print(
f"warning, disable quant_lm_head as quantizing lm_head with tied weights has not been "
logger.warning(
f"reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been "
f"supported currently")
break
if args.quant_lm_head:
Expand Down Expand Up @@ -316,7 +315,7 @@ def tune(args):
tasks = tasks.split(',')

if not args.disable_eval:
print(f"Using the latest {lm_eval_version}")
logger.info(f"Using lm-eval version {lm_eval_version}")
model_args = f"pretrained={eval_folder}"
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
user_model = None
Expand Down Expand Up @@ -350,6 +349,8 @@ def eval(args):

def run():
args = setup_parser()
if args.eval_bs is None:
args.eval_bs = "auto"
if args.eval:
eval(args)
else:
Expand Down
10 changes: 5 additions & 5 deletions auto_round/eval/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,27 @@

import lm_eval
from lm_eval import simple_evaluate as lm_simple_evaluate
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
user_model = None,
user_model=None,
batch_size: Optional[int] = None,
max_batch_size: Optional[int] = None,
device: Optional[str] = None,
**kwargs):

try:
from auto_round import AutoRoundConfig
except:
from auto_round.auto_quantizer import AutoHfQuantizer

if model_args is None:
model_args = ""

if isinstance(model_args, dict):
lm = lm_eval.api.registry.get_model(model).create_from_arg_obj(
model_args,
Expand Down Expand Up @@ -66,5 +68,3 @@ def simple_evaluate(
max_batch_size=max_batch_size,
device=device,
**kwargs)


5 changes: 4 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,7 @@ def detect_device(device=None):
Returns:
str: The device to use for computations, formatted as a string.
"""

def is_valid_digit(s):
try:
num = int(s)
Expand Down Expand Up @@ -912,6 +913,8 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
class: The dynamically imported QuantLinear class configured according to the specified parameters.
"""
use_triton = True
if bits not in [2, 4, 8]:
use_triton = False
disable_exllamav2 = True
disable_exllamav1 = False
disable_marlin = True
Expand Down Expand Up @@ -966,4 +969,4 @@ def get_autogptq_packing_qlinear(backend, bits=4, group_size=128, sym=False):
use_qigen=use_qigen,
use_marlin=not disable_marlin,
)
return QuantLinear
return QuantLinear
100 changes: 50 additions & 50 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
parser.add_argument("--train_bs", default=8, type=int,
help="train batch size")

parser.add_argument("--eval_bs", default=4, type=int,
parser.add_argument("--eval_bs", default=None, type=int,
help="eval batch size")

parser.add_argument("--device", default="auto", type=str,
Expand Down Expand Up @@ -390,6 +390,9 @@
print('does not support cpu, xpu model evaluation.')
exit() ## does not support cpu,xpu model eval

if args.disable_eval:
exit()

from packaging.version import Version
from auto_round.utils import get_library_version

Expand All @@ -402,55 +405,52 @@
use_eval_legacy = False
from eval_legacy import eval_model

use_qdq = False
if args.deployment_device and 'fake' in args.deployment_device:
use_qdq = True
if args.format and ('fake' in args.format or 'qdq' in args.format):
use_qdq = True

# evaluation
if not args.disable_eval:
if use_eval_legacy:
print("Using the legacy lm_eval(0.3.0)")
else:
print(f"Using the lm_eval version {lm_eval_version}")

if isinstance(tasks, str):
tasks = tasks.split(',')

if lm_eval_version < Version("0.4.2"):
if args.eval_bs is None:
args.eval_bs = 1
if use_eval_legacy:
print("Using the legacy lm_eval(0.3.0)")
else:
print(f"Using the latest {lm_eval_version}")

if isinstance(tasks, str):
tasks = tasks.split(',')

if use_qdq and lm_eval_version < Version("0.4.2"):
if use_eval_legacy:
if "mmlu" in tasks:
tmp_tasks = tasks
tasks = ["hendrycksTest-*" if x == "mmlu" else x for x in tmp_tasks]
if "truthfulqa_mc1" in tasks or "truthfulqa_mc2" in tasks:
tmp_tasks = tasks
tasks = ["truthfulqa_mc" if "truthfulqa_mc" in x else x for x in tmp_tasks]
seen = set()
if "mmlu" in tasks:
tmp_tasks = tasks
tasks = ["hendrycksTest-*" if x == "mmlu" else x for x in tmp_tasks]
if "truthfulqa_mc1" in tasks or "truthfulqa_mc2" in tasks:
tmp_tasks = tasks
tasks = [x for x in tmp_tasks if not (x in seen or seen.add(x))]

excel_name = f"{output_dir}_result.xlsx"
output_dir += "/"
print(excel_name, flush=True)
eval_model(
model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage,
device=torch_device, excel_file=excel_name,
trust_remote_code=not args.disable_trust_remote_code)

if lm_eval_version >= Version("0.4.2"):
from eval.evaluation import simple_evaluate

model_args = f"pretrained={eval_folder}"
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
user_model = None
if args.act_bits <= 8:
user_model = model.to(device_str)

res = simple_evaluate(model="hf", model_args=model_args,
tasks=tasks,
batch_size=args.eval_bs, user_model=user_model)
from lm_eval.utils import make_table

print(make_table(res))
tasks = ["truthfulqa_mc" if "truthfulqa_mc" in x else x for x in tmp_tasks]
seen = set()
tmp_tasks = tasks
tasks = [x for x in tmp_tasks if not (x in seen or seen.add(x))]

excel_name = f"{output_dir}_result.xlsx"
output_dir += "/"
print(excel_name, flush=True)
eval_model(
model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=args.low_gpu_mem_usage,
device=torch_device, excel_file=excel_name,
trust_remote_code=not args.disable_trust_remote_code)

if lm_eval_version >= Version("0.4.2"):
if args.eval_bs is None:
args.eval_bs = "auto"
from eval.evaluation import simple_evaluate

model_args = f"pretrained={eval_folder}"
model_args = model_args + f",trust_remote_code={not args.disable_trust_remote_code}"
user_model = None
if args.act_bits <= 8:
user_model = model.to(device_str)

res = simple_evaluate(model="hf", model_args=model_args,
tasks=tasks,
batch_size=args.eval_bs, user_model=user_model)
from lm_eval.utils import make_table

print(make_table(res))
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,5 @@ triton
numpy < 2.0
threadpoolctl
lm-eval==0.4.4
intel-extension-for-transformers
tqdm
packaging

0 comments on commit 3330c15

Please sign in to comment.