From 15a162824d0c5d8aa7a3d14ab6e9bb07e5732fb6 Mon Sep 17 00:00:00 2001 From: fxmarty <9808326+fxmarty@users.noreply.github.com> Date: Thu, 7 Dec 2023 15:51:13 +0100 Subject: [PATCH] Fix GPTQ compatibility with AutoGPTQ (#1574) * fix config saving * add test --- optimum/gptq/constants.py | 2 +- optimum/gptq/quantizer.py | 32 ++++++++++++++++++-------------- setup.py | 2 +- tests/gptq/test_quantization.py | 24 ++++++++++++++++++++++++ 4 files changed, 44 insertions(+), 16 deletions(-) diff --git a/optimum/gptq/constants.py b/optimum/gptq/constants.py index 70c2526651..2d3e51da7a 100644 --- a/optimum/gptq/constants.py +++ b/optimum/gptq/constants.py @@ -20,4 +20,4 @@ "model.layers", ] -GPTQ_CONFIG = "quantization_config.json" +GPTQ_CONFIG = "quantize_config.json" diff --git a/optimum/gptq/quantizer.py b/optimum/gptq/quantizer.py index 280f5ed14f..1a3d4b8702 100644 --- a/optimum/gptq/quantizer.py +++ b/optimum/gptq/quantizer.py @@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import copy import json import os from enum import Enum @@ -35,7 +34,6 @@ if is_accelerate_available(): from accelerate import ( - Accelerator, cpu_offload_with_hook, load_checkpoint_and_dispatch, ) @@ -146,6 +144,17 @@ def __init__( self.quant_method = QuantizationMethod.GPTQ self.cache_block_outputs = cache_block_outputs + self.serialization_keys = [ + "bits", + "dataset", + "group_size", + "damp_percent", + "desc_act", + "sym", + "true_sequential", + "quant_method", + ] + if self.bits not in [2, 3, 4, 8]: raise ValueError("only support quantize to [2,3,4,8] bits.") if self.group_size != -1 and self.group_size <= 0: @@ -169,7 +178,10 @@ def to_dict(self): """ Returns the args in dict format. """ - return copy.deepcopy(self.__dict__) + gptq_dict = {} + for key in self.serialization_keys: + gptq_dict[key] = getattr(self, key) + return gptq_dict @classmethod def from_dict(cls, config_dict: Dict[str, Any]): @@ -600,7 +612,7 @@ def pack_model( logger.info("Model packed.") - def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = False): + def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True): """ Save model state dict and configs @@ -618,20 +630,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa which will be bigger than `max_shard_size`. - safe_serialization (`bool`, defaults to `False`): + safe_serialization (`bool`, defaults to `True`): Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`). """ - - if not is_accelerate_available(): - raise RuntimeError( - "You need to install accelerate in order to save a quantized model. You can do it with `pip install accelerate`" - ) - os.makedirs(save_dir, exist_ok=True) - # save model and config - accelerator = Accelerator() - accelerator.save_model(model, save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) + model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f: json.dump(self.to_dict(), f, indent=2) diff --git a/setup.py b/setup.py index ec9ae3a31f..4c8e88d9fb 100644 --- a/setup.py +++ b/setup.py @@ -38,7 +38,7 @@ "invisible-watermark", ] -QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"] +QUALITY_REQUIRE = ["black~=23.1", "ruff==0.1.5"] BENCHMARK_REQUIRE = ["optuna", "tqdm", "scikit-learn", "seqeval", "torchvision", "evaluate>=0.2.0"] diff --git a/tests/gptq/test_quantization.py b/tests/gptq/test_quantization.py index 7f50a57496..a24b3e683c 100644 --- a/tests/gptq/test_quantization.py +++ b/tests/gptq/test_quantization.py @@ -23,9 +23,14 @@ from optimum.gptq import GPTQQuantizer, load_quantized_model from optimum.gptq.data import get_dataset +from optimum.utils.import_utils import is_auto_gptq_available from optimum.utils.testing_utils import require_accelerate, require_auto_gptq, require_torch_gpu +if is_auto_gptq_available(): + from auto_gptq import AutoGPTQForCausalLM + + @slow @require_auto_gptq @require_torch_gpu @@ -125,7 +130,9 @@ def check_inference_correctness(self, model): def test_generate_quality(self): self.check_inference_correctness(self.quantized_model) + @require_torch_gpu @require_accelerate + @slow def test_serialization(self): """ Test the serialization of the model and the loading of the quantized weights @@ -148,6 +155,11 @@ def test_serialization(self): exllama_config=self.exllama_config, ) self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old") + + with torch.device("cuda"): + _ = AutoModelForCausalLM.from_pretrained(tmpdirname) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) + self.check_inference_correctness(quantized_model_from_saved) @@ -177,6 +189,7 @@ def test_serialization(self): # act_order don't work with qlinear_cuda kernel pass + @require_torch_gpu def test_exllama_serialization(self): """ Test the serialization of the model and the loading of the quantized weights with exllama kernel @@ -195,6 +208,11 @@ def test_exllama_serialization(self): empty_model, save_folder=tmpdirname, device_map={"": 0}, exllama_config={"version": 1} ) self.check_quantized_layers_type(quantized_model_from_saved, "exllama") + + with torch.device("cuda"): + _ = AutoModelForCausalLM.from_pretrained(tmpdirname) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) + self.check_inference_correctness(quantized_model_from_saved) def test_exllama_max_input_length(self): @@ -245,6 +263,7 @@ def test_serialization(self): # don't need to test pass + @require_torch_gpu def test_exllama_serialization(self): """ Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel @@ -265,6 +284,11 @@ def test_exllama_serialization(self): device_map={"": 0}, ) self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2") + + with torch.device("cuda"): + _ = AutoModelForCausalLM.from_pretrained(tmpdirname) + _ = AutoGPTQForCausalLM.from_quantized(tmpdirname) + self.check_inference_correctness(quantized_model_from_saved)