Skip to content

Commit

Permalink
Fix GPTQ compatibility with AutoGPTQ (#1574)
Browse files Browse the repository at this point in the history
* fix config saving

* add test
  • Loading branch information
fxmarty authored Dec 7, 2023
1 parent 74c7f22 commit 15a1628
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 16 deletions.
2 changes: 1 addition & 1 deletion optimum/gptq/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,4 @@
"model.layers",
]

GPTQ_CONFIG = "quantization_config.json"
GPTQ_CONFIG = "quantize_config.json"
32 changes: 18 additions & 14 deletions optimum/gptq/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import json
import os
from enum import Enum
Expand All @@ -35,7 +34,6 @@

if is_accelerate_available():
from accelerate import (
Accelerator,
cpu_offload_with_hook,
load_checkpoint_and_dispatch,
)
Expand Down Expand Up @@ -146,6 +144,17 @@ def __init__(
self.quant_method = QuantizationMethod.GPTQ
self.cache_block_outputs = cache_block_outputs

self.serialization_keys = [
"bits",
"dataset",
"group_size",
"damp_percent",
"desc_act",
"sym",
"true_sequential",
"quant_method",
]

if self.bits not in [2, 3, 4, 8]:
raise ValueError("only support quantize to [2,3,4,8] bits.")
if self.group_size != -1 and self.group_size <= 0:
Expand All @@ -169,7 +178,10 @@ def to_dict(self):
"""
Returns the args in dict format.
"""
return copy.deepcopy(self.__dict__)
gptq_dict = {}
for key in self.serialization_keys:
gptq_dict[key] = getattr(self, key)
return gptq_dict

@classmethod
def from_dict(cls, config_dict: Dict[str, Any]):
Expand Down Expand Up @@ -600,7 +612,7 @@ def pack_model(

logger.info("Model packed.")

def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = False):
def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", safe_serialization: bool = True):
"""
Save model state dict and configs
Expand All @@ -618,20 +630,12 @@ def save(self, model: nn.Module, save_dir: str, max_shard_size: str = "10GB", sa
which will be bigger than `max_shard_size`.
</Tip>
safe_serialization (`bool`, defaults to `False`):
safe_serialization (`bool`, defaults to `True`):
Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
"""

if not is_accelerate_available():
raise RuntimeError(
"You need to install accelerate in order to save a quantized model. You can do it with `pip install accelerate`"
)

os.makedirs(save_dir, exist_ok=True)
# save model and config
accelerator = Accelerator()
accelerator.save_model(model, save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
model.save_pretrained(save_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)
with open(os.path.join(save_dir, GPTQ_CONFIG), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
"invisible-watermark",
]

QUALITY_REQUIRE = ["black~=23.1", "ruff>=0.0.241,<=0.0.259"]
QUALITY_REQUIRE = ["black~=23.1", "ruff==0.1.5"]

BENCHMARK_REQUIRE = ["optuna", "tqdm", "scikit-learn", "seqeval", "torchvision", "evaluate>=0.2.0"]

Expand Down
24 changes: 24 additions & 0 deletions tests/gptq/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,14 @@

from optimum.gptq import GPTQQuantizer, load_quantized_model
from optimum.gptq.data import get_dataset
from optimum.utils.import_utils import is_auto_gptq_available
from optimum.utils.testing_utils import require_accelerate, require_auto_gptq, require_torch_gpu


if is_auto_gptq_available():
from auto_gptq import AutoGPTQForCausalLM


@slow
@require_auto_gptq
@require_torch_gpu
Expand Down Expand Up @@ -125,7 +130,9 @@ def check_inference_correctness(self, model):
def test_generate_quality(self):
self.check_inference_correctness(self.quantized_model)

@require_torch_gpu
@require_accelerate
@slow
def test_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights
Expand All @@ -148,6 +155,11 @@ def test_serialization(self):
exllama_config=self.exllama_config,
)
self.check_quantized_layers_type(quantized_model_from_saved, "cuda-old")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)


Expand Down Expand Up @@ -177,6 +189,7 @@ def test_serialization(self):
# act_order don't work with qlinear_cuda kernel
pass

@require_torch_gpu
def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllama kernel
Expand All @@ -195,6 +208,11 @@ def test_exllama_serialization(self):
empty_model, save_folder=tmpdirname, device_map={"": 0}, exllama_config={"version": 1}
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllama")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)

def test_exllama_max_input_length(self):
Expand Down Expand Up @@ -245,6 +263,7 @@ def test_serialization(self):
# don't need to test
pass

@require_torch_gpu
def test_exllama_serialization(self):
"""
Test the serialization of the model and the loading of the quantized weights with exllamav2 kernel
Expand All @@ -265,6 +284,11 @@ def test_exllama_serialization(self):
device_map={"": 0},
)
self.check_quantized_layers_type(quantized_model_from_saved, "exllamav2")

with torch.device("cuda"):
_ = AutoModelForCausalLM.from_pretrained(tmpdirname)
_ = AutoGPTQForCausalLM.from_quantized(tmpdirname)

self.check_inference_correctness(quantized_model_from_saved)


Expand Down

0 comments on commit 15a1628

Please sign in to comment.