Skip to content

Commit

Permalink
Support QDQ format for weight-only quantization (#35)
Browse files Browse the repository at this point in the history
## Type of Change

feature

## Description

Support QDQ format for weight-only quantization

It requires:
- onnxruntime >= 1.19.0
- opset_version of model >=21
- quantized bits in [4, 8]

## Expected Behavior & Potential Risk

the expected behavior that triggered by this PR 

## How has this PR been tested?

how to reproduce the test (including hardware information)

## Dependency Change?

any library dependency introduced or removed

---------

Signed-off-by: Mengni Wang <[email protected]>
Signed-off-by: Wang, Mengni <[email protected]>
  • Loading branch information
mengniwang95 authored Sep 23, 2024
1 parent 71c2484 commit 05bb58a
Show file tree
Hide file tree
Showing 27 changed files with 933 additions and 697 deletions.
27 changes: 27 additions & 0 deletions examples/.config/model_params_onnxrt.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,15 @@
"batch_size": 1,
"algorithm": "RTN"
},
"llama-2-7b-rtn-with-past-qdq": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21",
"main_script": "main.py",
"batch_size": 1,
"algorithm": "RTN"
},
"llama-2-7b-awq": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
Expand All @@ -36,6 +45,15 @@
"batch_size": 1,
"algorithm": "AWQ"
},
"llama-2-7b-awq-with-past-qdq": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21",
"main_script": "main.py",
"batch_size": 1,
"algorithm": "AWQ"
},
"llama-2-7b-gptq": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
Expand All @@ -54,6 +72,15 @@
"batch_size": 1,
"algorithm": "GPTQ"
},
"llama-2-7b-gptq-with-past-qdq": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
"dataset_location": "",
"input_model": "/tf_dataset2/models/onnx/Llama-2-7b-hf-with-past-opset-21",
"main_script": "main.py",
"batch_size": 1,
"algorithm": "GPTQ"
},
"llama-2-7b-woq_tune": {
"model_name": "meta-llama/Llama-2-7b-hf",
"model_src_dir": "nlp/huggingface_model/text_generation/quantization/weight_only",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,21 @@ python prepare_model.py --input_model="meta-llama/Llama-2-7b-hf" \

Set `algorithm=WOQ_TUNE` to tune weight-only quantization algorithm or specify algorithm to `RTN` or `GPTQ` or `AWQ`.

`quant_format=QDQ` works only when:
- onnxruntime >= 1.19.0
- opset version of the model >= 21
- quantized bits is in [4, 8]

otherwise it will execute QOperator automatically.

```bash
bash run_quant.sh --input_model=/path/to/model \ # folder path of onnx model
--output_model=/path/to/model_tune \ # folder path to save onnx model
--batch_size=batch_size # optional \
--dataset=NeelNanda/pile-10k \
--tokenizer=meta-llama/Llama-2-7b-hf \ # model name or folder path containing all relevant files for model's tokenizer
--algorithm=WOQ_TUNE # support WOQ_TUNE, RTN, AWQ, GPTQ
--algorithm=WOQ_TUNE # support WOQ_TUNE, RTN, AWQ, GPTQ \
--quant_format=QDQ # support QOperator and QDQ
```

## 2. Benchmark
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from torch.utils import data

from onnx_neural_compressor import data_reader
from onnx_neural_compressor.quantization import config, matmul_nbits_quantizer, tuning
from onnx_neural_compressor.quantization import QuantFormat, config, matmul_nbits_quantizer, tuning

logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.WARN
Expand Down Expand Up @@ -74,7 +74,8 @@
parser.add_argument(
"--tasks",
nargs="+",
default=[
default=["lambada_openai"],
choices=[
"winogrande",
"copa",
"piqa",
Expand Down Expand Up @@ -105,6 +106,7 @@
default=[],
help="nodes that will not be quantized. Doesn't take effect when 'algorithm' is 'WOQ_TUNE'",
)
parser.add_argument("--quant_format", type=str, default="QDQ", choices=["QOperator", "QDQ"])
args = parser.parse_args()

if args.tune and not os.path.exists(args.output_model):
Expand Down Expand Up @@ -347,8 +349,11 @@ def rewind(self):

nodes_to_exclude = ["/lm_head/MatMul"] if not args.quantize_lm_head else []
nodes_to_exclude = list(set(args.nodes_to_exclude + nodes_to_exclude))
quant_format = QuantFormat.QOperator if args.quant_format == "QOperator" else QuantFormat.QDQ
if args.algorithm.upper() == "RTN":
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(layer_wise_quant=args.layer_wise)
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig(
layer_wise_quant=args.layer_wise, quant_format=quant_format
)
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
model_path,
n_bits=4,
Expand All @@ -363,7 +368,9 @@ def rewind(self):
elif args.algorithm.upper() == "AWQ":
calibration_data_reader = AWQDataloader(model_path, pad_max=args.pad_max, batch_size=1)
algo_config = matmul_nbits_quantizer.AWQWeightOnlyQuantConfig(
calibration_data_reader=calibration_data_reader, enable_mse_search=False
calibration_data_reader=calibration_data_reader,
enable_mse_search=False,
quant_format=quant_format,
)
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
model_path,
Expand All @@ -379,7 +386,9 @@ def rewind(self):
elif args.algorithm.upper() == "GPTQ":
calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1)
algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig(
calibration_data_reader=calibration_data_reader, layer_wise_quant=args.layer_wise
calibration_data_reader=calibration_data_reader,
layer_wise_quant=args.layer_wise,
quant_format=quant_format,
)
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(
model_path,
Expand All @@ -395,7 +404,9 @@ def rewind(self):
elif args.algorithm.upper() == "WOQ_TUNE":
calibration_data_reader = GPTQDataloader(model_path, seqlen=args.seqlen, batch_size=1)
# set tolerable_loss to 0.5% for test, default is 1%
custom_tune_config = tuning.TuningConfig(config_set=config.get_woq_tuning_config(), tolerable_loss=0.005)
custom_tune_config = tuning.TuningConfig(
config_set=config.get_woq_tuning_config(quant_format=quant_format), tolerable_loss=0.005
)
best_model = tuning.autotune(
model_input=model_path,
tune_config=custom_tune_config,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ function init_params {
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
input_model=$(echo "$var" |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
batch_size=$(echo "$var" |cut -f2 -d=)
;;
--tokenizer=*)
tokenizer=$(echo $var |cut -f2 -d=)
tokenizer=$(echo "$var" |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
mode=$(echo "$var" |cut -f2 -d=)
;;
--intra_op_num_threads=*)
intra_op_num_threads=$(echo $var |cut -f2 -d=)
intra_op_num_threads=$(echo "$var" |cut -f2 -d=)
;;
esac
done
Expand All @@ -42,19 +42,27 @@ function run_benchmark {
input_model=$(dirname "$input_model")
fi

extra_cmd=""

if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then
extra_cmd="--trust_remote_code True"
extra_cmd=$extra_cmd"--trust_remote_code True "
fi

if [ "${batch_size}" ]; then
extra_cmd=$extra_cmd"--batch_size ${batch_size} "
fi
if [ "${tokenizer}" ]; then
extra_cmd=$extra_cmd"--tokenizer ${tokenizer} "
fi
if [ "${tasks}" ]; then
extra_cmd=$extra_cmd"--tasks ${tasks} "
fi
if [ "${intra_op_num_threads}" ]; then
extra_cmd=$extra_cmd"--intra_op_num_threads ${intra_op_num_threads} "
fi

eval "python main.py \
--model_path ${input_model} \
--batch_size=${batch_size-1} \
--tokenizer=${tokenizer-meta-llama/Llama-2-7b-hf} \
--tasks=${tasks-lambada_openai} \
--mode=${mode} \
--intra_op_num_threads=${intra_op_num_threads-24} \
--benchmark \
${extra_cmd}"
extra_cmd=$extra_cmd"--benchmark"
eval "python main.py --model_path ${input_model} --mode ${mode} ${extra_cmd}"

}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,25 @@ function init_params {
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
input_model=$(echo "$var" |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
output_model=$(echo "$var" |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
batch_size=$(echo "$var" |cut -f2 -d=)
;;
--dataset=*)
dataset=$(echo $var |cut -f2 -d=)
dataset=$(echo "$var" |cut -f2 -d=)
;;
--tokenizer=*)
tokenizer=$(echo $var |cut -f2 -d=)
tokenizer=$(echo "$var" |cut -f2 -d=)
;;
--algorithm=*)
algorithm=$(echo $var |cut -f2 -d=)
algorithm=$(echo "$var" |cut -f2 -d=)
;;
--quant_format=*)
quant_format=$(echo "$var" |cut -f2 -d=)
;;
esac
done
Expand Down Expand Up @@ -56,30 +59,42 @@ function run_tuning {
echo "Created directory $output_model"
fi

extra_cmd=""

if [[ "${tokenizer}" =~ "Phi-3-mini" ]]; then
nodes_to_exclude="/model/layers.*/self_attn/qkv_proj/MatMul /model/layers.*/mlp/down_proj/MatMul"
extra_cmd="--nodes_to_exclude ${nodes_to_exclude} --trust_remote_code True"
extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} --trust_remote_code True "
fi
if [[ "${tokenizer}" =~ "Llama-3-8B" ]]; then
nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul"
extra_cmd="--nodes_to_exclude ${nodes_to_exclude}"
extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} "
fi
if [[ "${tokenizer}" =~ "Qwen2-7B" ]]; then
nodes_to_exclude="/model/layers.*/mlp/down_proj/MatMul /model/layers.*/mlp/up_proj/MatMul"
extra_cmd="--nodes_to_exclude ${nodes_to_exclude}"
extra_cmd=$extra_cmd"--nodes_to_exclude ${nodes_to_exclude} "
fi

if [ "${tokenizer}" ]; then
extra_cmd=$extra_cmd"--tokenizer ${tokenizer} "
fi
if [ "${batch_size}" ]; then
extra_cmd=$extra_cmd"--batch_size ${batch_size} "
fi
if [ "${dataset}" ]; then
extra_cmd=$extra_cmd"--dataset ${dataset} "
fi
if [ "${algorithm}" ]; then
extra_cmd=$extra_cmd"--algorithm ${algorithm} "
fi
if [ "${tasks}" ]; then
extra_cmd=$extra_cmd"--tasks ${tasks} "
fi
if [ "${quant_format}" ]; then
extra_cmd=$extra_cmd"--quant_format ${quant_format} "
fi

eval "python main.py \
--model_path ${input_model} \
--tokenizer ${tokenizer-meta-llama/Llama-2-7b-hf} \
--output_model ${output_model} \
--batch_size ${batch_size-1} \
--dataset ${dataset-NeelNanda/pile-10k} \
--algorithm ${algorithm-WOQ_TUNE} \
--tasks ${tasks-lambada_openai} \
--layer_wise \
--tune \
${extra_cmd}"
extra_cmd=$extra_cmd"--layer_wise --tune"
eval "python main.py --model_path ${input_model} --output_model ${output_model} ${extra_cmd}"
}

main "$@"
7 changes: 5 additions & 2 deletions onnx_neural_compressor/algorithms/layer_wise/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import os
import pathlib
import tempfile

import onnx
import onnxruntime as ort
Expand Down Expand Up @@ -60,7 +61,7 @@ def layer_wise_quant(
model = onnx_model.ONNXModel(model, ignore_warning=True, load_external_data=False)

origin_model = copy.deepcopy(model)

tmp_file = tempfile.TemporaryDirectory()
providers = kwargs.get("providers", ["CPUExecutionProvider"])

# get and check split nodes
Expand Down Expand Up @@ -97,7 +98,7 @@ def layer_wise_quant(

# split model with given split node
split_model_part_1, split_model_part_2 = split_model.split_model_with_node(
split_node.name, model.model_path, save_both_split_models
split_node.name, model.model_path, save_both_split_models, save_path=tmp_file.name
)

if not save_both_split_models:
Expand Down Expand Up @@ -201,6 +202,8 @@ def layer_wise_quant(
onnx.external_data_helper.load_external_data_for_model(
quantized_model_merged.model, os.path.dirname(quantized_model_merged.model_path)
)

tmp_file.cleanup()
return quantized_model_merged


Expand Down
Loading

0 comments on commit 05bb58a

Please sign in to comment.