Skip to content

Latest commit

 

History

History
822 lines (759 loc) · 28.4 KB

File metadata and controls

822 lines (759 loc) · 28.4 KB

Weights Compression

OpenVINO is the preferred backend to run Weights Compression with. PyTorch and Torch FX are also supported.

The algorithm description

The Weights Compression algorithm is aimed at compressing the weights of the models and can be used to optimize the model footprint and performance of large models where the size of weights is relatively larger than the size of activations, for example, Large Language Models (LLM). The algorithm compresses weights for Linear, Convolution and Embedding layers.

Supported modes

By default, weights are compressed asymmetrically to 8-bit integer data type - "INT8_ASYM" mode. OpenVINO backend also supports 4 modes of mixed precision weight quantization with a 4-bit data type as a primary precision - INT4_SYM, INT4_ASYM, NF4, E2M1. The primary precision in case of INT4_SYM mode is signed 4-bit integer and weights are quantized to it symmetrically without zero point. In case of INT4_ASYM mode - unsigned 4-bit integer and weight are quantized to it asymmetrically with a typical non-fixed zero point. In case of NF4 mode - nf4 data type without zero point. In case of E2M1 mode - e2m1 data type without zero point and has 8bit E8M0 scale. All 4-bit modes have a grouped quantization support, when small group of weights (e.g. 128) in the channel dimension share quantization parameters (scale). All embeddings, convolutions and last linear layers are always compressed to 8-bit integer data type. To quantize embeddings and last linear layers to 4-bit, use all_layers=True. Percent of the rest layers compressed to 4-bit can be configured by "ratio" parameter. E.g. ratio=0.9 means 90% of layers compressed to the corresponding 4-bit data type and the rest to 8-bit asymmetric integer data type.

User guide

  • Compress weights asymmetrically to 8-bit integer data type.
from nncf import compress_weights
compressed_model = compress_weights(model) # model is openvino.Model object
  • Compress weights symmetrically to 8-bit integer data type.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT8_SYM) # model is openvino.Model object
  • Compress weights symmetrically to 4-bit integer data type with group size = 128, except embeddings, convolutions and last linear layers - they are compressed asymmetrically to 8-bit integer data type.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM) # model is openvino.Model object
  • Generally, INT4_SYM mode is the fastest mixed-precision mode, but it may lead to a significant accuracy degradation or perplexity increase. Compressing weights asymmetrically (INT4_ASYM mode) is the way to increase accuracy, however in turns it slows down inference a bit. If the accuracy or perplexity is still not satisfying, there are 2 more hyper-parameters to tune: group_size and ratio. Please refer to the example how to automatically tune these parameters. Lower group size and less ratio of 4-bit layers usually improve accuracy at the sacrifice of inference speed. To disable grouped quantization and quantize weights per-channel, set group_size = -1. Below is the example how to compress weights of 90% of layers to 4-bit integer asymmetrically with the group size 64, and the rest of layers to 8-bit asymmetric integer data type. The same parametrization is applicable for INT4_SYM mode.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_ASYM, group_size=64, ratio=0.9) # model is openvino.Model object
  • Accuracy of the 4-bit compressed models can be improved by using data-aware mixed-precision algorithm. It is capable to find outliers in the input activations and assign different quantization precision to minimize accuracy degradation. Below is the example how to compress 80% of layers to 4-bit integer with a default data-aware mixed precision algorithm. It requires just one extra parameter - a NNCF wrapper of the dataset. Refer to the full example of data-aware weight compression for more details. If dataset is not specified, data-free mixed precision algorithm works based on weights only. Refer to the second table below for evaluation of data-free and data-aware method on the wikitext dataset. On the average the data-aware mixed-precision weight compression takes more time than the data-free one (~30% slower on Intel(R) Xeon(R) Gold 6430L), since it infers model on calibration dataset to find outliers in the input activations.
from nncf import compress_weights, CompressWeightsMode, Dataset
nncf_dataset = nncf.Dataset(data_source, transform_fn)
compressed_model = compress_weights(model, mode=CompressWeightsMode.INT4_SYM, ratio=0.8, dataset=nncf_dataset) # model is openvino.Model object
  • Accuracy of the 4-bit compressed models also can be improved by using AWQ, Scale Estimation, GPTQ or Lora Correction algorithms over data-based mixed-precision algorithm. These algorithms work by equalizing a subset of weights to minimize the difference between the original precision and the 4-bit precision. Unlike all others, the Lora Correction algorithm inserts an additional Linear layers for reducing quantization noise and further accuracy improvement. Inevitably, this approach introduces a memory and a runtime overheads, but they are negligible, since the inserted weight much smaller and can be quantized to 8-bit. The AWQ, Scale Estimation (SE) and Lora Correction (LC) algo can be used in any combination together: AWQ + SE, AWQ + LC, SE + LC, AWQ + SE + LC. The GPTQ algorithm can be combined with AWQ and Scale Estimation in any combination: AWQ + GPTQ, GPTQ + SE, AWQ + GPTQ + SE. Below are examples demonstrating how to enable the AWQ, Scale Estimation, GPTQ or Lora Correction algorithms:

    Prepare the calibration dataset for data-based algorithms:

from datasets import load_dataset
from functools import partial
from nncf import compress_weights, CompressWeightsMode, Dataset
from optimum.intel.openvino import OVModelForCausalLM
from transformers import AutoTokenizer

def transform_func(item, tokenizer, input_shapes):
    text = item['text']
    tokens = tokenizer(text)

    res = {'input_ids': np.expand_dims(np.array(tokens['input_ids']), 0),
           'attention_mask': np.expand_dims(np.array(tokens['attention_mask']), 0)}

    if 'position_ids' in input_shapes:
        position_ids = np.cumsum(res['attention_mask'], axis=1) - 1
        position_ids[res['attention_mask'] == 0] = 1
        res['position_ids'] = position_ids

    for name, shape in input_shapes.items():
        if name in res:
            continue
        res[name] = np.zeros(shape)

    return res

def get_input_shapes(model, batch_size = 1):
    inputs = {}

    for val in model.model.inputs:
        name = val.any_name
        shape = list(val.partial_shape.get_min_shape())
        shape[0] = batch_size
        inputs[name] = shape

    return inputs

# load your model and tokenizer
model = OVModelForCausalLM.from_pretrained(...)
tokenizer = AutoTokenizer.from_pretrained(...)

# prepare dataset for compression
dataset = load_dataset('wikitext', 'wikitext-2-v1', split='train')
dataset = dataset.filter(lambda example: len(example["text"]) > 80)
input_shapes = get_input_shapes(model)
nncf_dataset = Dataset(dataset, partial(transform_func, tokenizer=tokenizer,
                                                        input_shapes=input_shapes))
  • How to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and AWQ with Scale Estimation. It requires to set awq to True and scale_estimation to True additionally to data-based mixed-precision algorithm.
model.model = compress_weights(model.model,
                               mode=CompressWeightsMode.INT4_SYM,
                               ratio=0.8,
                               dataset=nncf_dataset,
                               awq=True,
                               scale_estimation=True)
  • How to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and GPTQ. It requires to set gptq to True additionally to data-based mixed-precision algorithm.
model.model = compress_weights(model.model,
                               mode=CompressWeightsMode.INT4_SYM,
                               ratio=0.8,
                               dataset=nncf_dataset,
                               gptq=True)
  • How to compress 80% of layers to 4-bit integer with a default data-based mixed precision algorithm and Lora Correction algorithm. It requires setting lora_correction to True additionally to data-based mixed-precision algorithm.
model.model = compress_weights(model.model,
                               mode=CompressWeightsMode.INT4_SYM,
                               ratio=0.8,
                               dataset=nncf_dataset,
                               lora_correction=True)
  • NF4 mode can be considered for improving accuracy, but currently models quantized to nf4 should not be faster models quantized to 8-bit asymmetric integer. Here's the example how to compress weights to nf4 data type with group size = 128. Different group_size and ratio are also supported.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.NF4)
  • E2M1 mode can be considered for improving accuracy, but currently models quantized to e2m1 should not be faster models quantized to 8-bit asymmetric integer. Here's the example how to compress weights to e2m1 data type with group size = 32 (recommended). Different group_size and ratio are also supported.
from nncf import compress_weights, CompressWeightsMode
compressed_model = compress_weights(model, mode=CompressWeightsMode.E2M1, group_size=32, all_layers=True)

Evaluation results

Data-free Mixed-Precision on Lambada OpenAI dataset

Here is the perplexity and model size before and after weight compression for different language models on the Lambada OpenAI dataset. g32 refers to the group size equals to 32, r60 - to the ratio equals to 0.6.

Model Mode Perplexity (↓) Perplexity
Increase (↓)
Model Size
(Gb)
databricks/dolly-v2-3b fp32 5.01 0 10.3
databricks/dolly-v2-3b int8_asym 5.07 0.05 2.6
databricks/dolly-v2-3b int4_asym_g32_r50 5.28 0.26 2.2
databricks/dolly-v2-3b nf4_g128_r60 5.19 0.18 1.9
facebook/opt-6.7b fp32 4.25 0 24.8
facebook/opt-6.7b int8_asym 4.27 0.01 6.2
facebook/opt-6.7b int4_asym_g64_r80 4.32 0.07 4.1
facebook/opt-6.7b nf4_g64 4.35 0.1 3.6
meta-llama/Llama-2-7b-chat-hf fp32 3.28 0 25.1
meta-llama/Llama-2-7b-chat-hf int8_asym 3.29 0.01 6.3
meta-llama/Llama-2-7b-chat-hf int4_asym_g128_r80 3.41 0.14 4.0
meta-llama/Llama-2-7b-chat-hf nf4_g128 3.41 0.13 3.5
togethercomputer/RedPajama-INCITE-7B-Instruct fp32 4.15 0 25.6
togethercomputer/RedPajama-INCITE-7B-Instruct int8_asym 4.17 0.02 6.4
togethercomputer/RedPajama-INCITE-7B-Instruct nf4_ov_g32_r60 4.28 0.13 5.1
togethercomputer/RedPajama-INCITE-7B-Instruct int4_asym_g128 4.17 0.02 3.6
meta-llama/Llama-2-13b-chat-hf fp32 2.92 0 48.5
meta-llama/Llama-2-13b-chat-hf int8_asym 2.91 0 12.1
meta-llama/Llama-2-13b-chat-hf int4_sym_g64_r80 2.98 0.06 8.0
meta-llama/Llama-2-13b-chat-hf nf4_g128 2.95 0.04 6.6

Data-aware Mixed-Precision and AWQ methods on Wikitext dataset

Here is the word perplexity with data-free and data-aware mixed-precision INT4-INT8 weight compression for different language models on the wikitext dataset. data suffix refers to the data-aware mixed-precision. data_awq suffix refers to the data-aware mixed-precision with modified AWQ algorithm. This modification applies only for patterns MatMul-Multiply-MatMul (for example MLP block in LLama).

Model Mode Word Perplexity (↓)
meta-llama/llama-7b-chat-hf fp16 11.57
int4_sym_g128_r80_data 11.87
int4_sym_g128_r80 11.92
int4_sym_g128_r100_data_awq 12.34
int4_sym_g128_r100 12.35
stabilityai_stablelm-3b-4e1t fp16 10.16
int4_sym_g64_r80_data 10.67
int4_sym_g64_r80 10.83
int4_sym_g64_r100_data_awq 10.89
int4_sym_g64_r100 11.07
stable-zephyr-3b-dpo int4_sym_g64_r80_data_awq 21.62
int4_sym_g64_r80_data 21.74
int4_sym_g64_r80 23.10
int4_sym_g64_r100_data_awq 21.76
int4_sym_g64_r100 23.19
HuggingFaceH4/zephyr-7b-beta fp16 9.82
int4_sym_g128_r80_data 10.13
int4_sym_g128 10.22

Scale Estimation and GPTQ methods on Lambada OpenAI dataset

Here is the perplexity and accuracy with data-free and data-aware mixed-precision INT4-INT8 weight compression for different language models on the lambada openai dataset. _scale suffix refers to the data-aware mixed-precision with Scale Estimation algorithm. _gptq suffix refers to the data-aware mixed-precision with GPTQ algorithm. _gptq_scale suffix refers to the use of GPTQ algorithm with the Scale estimation algorithm to calculate the quantization parameters. r100 means that embeddings and lm_head have INT8 precision and all other linear layers have INT4 precision.

Model Mode Acc (↑) Ppl (↓)
stabilityai_stablelm-2-zephyr-1_6b fp32 0.5925 6.3024
int4_sym_r100_gs64_gptq_scale 0.5795 7.1507
int4_sym_r100_gs64_gptq 0.5676 7.2391
int4_sym_r100_gs64_scale 0.5795 7.3245
int4_sym_r100_gs64 0.5465 8.649
stable-zephyr-3b-dpo fp32 0.6099 6.7151
int4_sym_r100_gs64_scale 0.595 7.037
int4_sym_r100_gs64_gptq_scale 0.5909 7.391
int4_sym_r100_gs64_gptq 0.567 8.6787
int4_sym_r100_gs64 0.5639 9.349
microsoft_Phi-3-mini-4k-instruct fp32 0.6839 4.1681
int4_sym_r100_gs128_gptq_scale 0.6757 4.5107
int4_sym_r100_gs128_scale 0.6736 4.4711
int4_sym_r100_gs128_gptq 0.6513 4.8365
int4_sym_r100_gs128 0.6342 5.3419
mistralai_Mistral-7B-v0.1 fp32 0.7592 3.1898
int4_sym_r100_gs128_scale 0.7479 3.3527
int4_sym_r100_gs128 0.7421 3.4932

Accuracy/Footprint trade-off

Below are the tables showing the accuracy/footprint trade-off for Qwen/Qwen2-7B and microsoft/Phi-3-mini-4k-instruct compressed with different options.

Compression ratio is defined as the ratio between the size of fp32 model and size of the compressed one. Accuracy metrics are measured on 4 tasks lambada openai, wikitext, winogrande, WWB. The average relative error in the tables below is the mean of relative errors for each of four tasks with respect to the metric value for fp32 model. All int4 models are compressed group-wise with group_size=128 and mode=CompressionMode.INT4_SYM and with calibration dataset based on 128 samples from wikitext-2-v1. Int8 model is compressed with mode=CompressionMode.INT8_ASYM. The following advanced parameters were used for AWQ, Scale Estimation and Lora Correction algorithms:

AdvancedCompressionParameters(
  awq_params=AdvancedAWQParameters(32, 0.05, 0.0, 1.0, 100),
  scale_estimation_params=AdvancedScaleEstimationParameters(32, 5, 10, -1.0),
  lora_correction_params=AdvancedLoraCorrectionParameters(adapter_rank=<LORA_RANK>)
)

The tables clearly shows the followings:

  • More layers in 8 bit does improve accuracy, but it increases the footprint a lot.
  • Scale Estimation, AWQ, GPTQ do improve accuracy of the baseline int4 model without footprint increase.
  • Lora correction algorithm improves the accuracy of int4 models further with a footprint much less compared to mixed-precision models with the same or worse accuracy.

Accuracy/footprint trade-off for Qwen/Qwen2-7B:

Mode %int4 %int8 lora
rank
average
relative
error
compression
rate
fp32 0% 0% 0.0% 1.0x
int8 0% 100% 7.9% 3.9x
int4 + awq + scale estimation + lora correction 100% 0% 256 16.5% 5.8x
int4 + awq + scale estimation 40% 60% 17.1% 4.7x
int4 + awq + scale estimation 60% 40% 17.1% 5.2x
int4 + awq + scale estimation + lora correction 100% 0% 32 17.4% 6.5x
int4 + awq + scale estimation + lora correction 100% 0% 8 17.5% 6.6x
int4 + awq + scale estimation 80% 20% 17.5% 5.8x
int4 + awq + scale estimation + lora correction 100% 0% 16 18.0% 6.6x
int4 + awq + scale estimation 100% 0% 18.4% 6.7x
int4 + awq + scale estimation + gptq 100% 0% 20.2% 6.7x
int4 100% 0% 21.4% 6.7x

Accuracy/footprint trade-off for microsoft/Phi-3-mini-4k-instruct:

Mode %int4 %int8 lora
rank
average
relative
error
compression
rate
fp32 0% 0% 0.0% 1.0x
int8 0% 100% 7.3% 4.0x
int4 + scale estimation 40% 60% 16.9% 4.9x
int4 + scale estimation 60% 40% 18.4% 5.5x
int4 + scale estimation + lora correction 100% 0% 256 18.7% 6.2x
int4 + scale estimation + lora correction 100% 0% 16 20.5% 7.3x
int4 + scale estimation + lora correction 100% 0% 32 20.6% 7.2x
int4 + scale estimation 80% 20% 21.3% 6.3x
int4 + scale estimation + gptq 100% 0% 21.7% 7.4x
int4 + scale estimation + lora correction 100% 0% 8 22.1% 7.3x
int4 + scale estimation 100% 0% 24.5% 7.4x
int4 100% 0% 25.3% 7.4x

Limitations

  • The algorithm is supported for OpenVINO, PyTorch and Torch FX models.
  • The compression applies in-place.
  • The compressed model is not trainable.
  • INT4_SYM, INT4_ASYM, NF4 and E2M1 modes, grouped quantization and mixed precision selection is available for OpenVINO backend only.
  • NF4, E2M1 support is experimental on GPU and NPU - models quantized to nf4/e2m1 should not be faster models quantized to 8-bit integer.

Additional resources

List of notebooks demonstrating OpenVINO conversion and inference together with NNCF weight compression for models from various domains: