From 6539d506e25920884984025233cadd5d34a7be93 Mon Sep 17 00:00:00 2001 From: wenhuach21 Date: Wed, 18 Sep 2024 15:04:00 +0800 Subject: [PATCH] refine autoawq exporting code (#261) --- .../export/export_to_autogptq/export.py | 2 +- auto_round/export/export_to_awq/export.py | 83 +++++++++++++------ 2 files changed, 57 insertions(+), 28 deletions(-) diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index f051304a..7f8a6800 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -37,7 +37,7 @@ import torch from auto_round.utils import check_to_quantized, get_block_names, \ - get_module, logger, get_layer_names_in_block, set_module + get_module, logger, set_module import copy import json import os diff --git a/auto_round/export/export_to_awq/export.py b/auto_round/export/export_to_awq/export.py index 00de4a5c..83f6824e 100644 --- a/auto_round/export/export_to_awq/export.py +++ b/auto_round/export/export_to_awq/export.py @@ -35,10 +35,46 @@ import torch import torch.nn as nn from auto_round.export.register import register_format -from auto_round.utils import convert_dtype_torch2str_hf, logger +from auto_round.utils import convert_dtype_torch2str_hf, logger, get_module, set_module import copy import json from typing import Dict, List, Optional, Union +from .utils import WQLinear_GEMM, clear_memory, get_self_modules +from concurrent.futures import ThreadPoolExecutor +import threadpoolctl as tctl +from tqdm import tqdm + + +def pack_layer(name, model, layer_config, backend, pbar): + with tctl.threadpool_limits(limits=1): + pbar.set_description(f"packing {name}") + if name == "lm_head": ##dese not support lm-head + pbar.update(1) + return + config = layer_config[name] + if config["bits"] > 8: + pbar.update(1) + return + scale, zp = config["scale"], config["zp"] + scale = scale.t().contiguous() + zp = zp.t().contiguous() + config["zp"] = config["zp"].to(torch.float32) + bits = config["bits"] + group_size = config["group_size"] + linear_layer = get_module(model, name) + q_linear = WQLinear_GEMM.from_linear( + linear=linear_layer, + w_bit=bits, + group_size=group_size, + init_only=False, + scales=scale, + zeros=zp, + ) + linear_layer.cpu() + q_linear.to("cpu") + set_module(model, name, q_linear) + clear_memory() + pbar.update(1) @register_format("auto_awq") @@ -67,36 +103,30 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs): else: compressed_model = copy.deepcopy(model.to("cpu")) - from .utils import WQLinear_GEMM, clear_memory, get_self_modules + names = list(layer_config.keys()) - q_linear_module = WQLinear_GEMM self_modules = get_self_modules(compressed_model) + layers = [] for i in range(len(self_modules)): module = self_modules[i] named_linears = get_named_linears(module) for name, linear_layer in named_linears.items(): key = get_module_name(compressed_model, linear_layer) - logger.info(f"packing {name}") + layers.append(key) config = layer_config[key] if config["bits"] > 8: modules_to_not_convert.append(name) - continue - config["zp"] = config["zp"].to(torch.float32) - scale, zp = config["scale"], config["zp"] - scale = scale.t().contiguous() - zp = zp.t().contiguous() - q_linear = q_linear_module.from_linear( - linear=linear_layer, - w_bit=bits, - group_size=group_size, - init_only=False, - scales=scale, - zeros=zp, - ) - linear_layer.cpu() - q_linear.to(next(module.parameters()).device) - set_op_by_name(module, name, q_linear) - clear_memory() + + backend = None + with ThreadPoolExecutor(max_workers=2) as executor: + with tqdm(total=len(names), leave=True) as pbar: + def wrapper(name): + pack_layer(name, model, layer_config, backend, pbar) + + for _ in executor.map(wrapper, names): + pass + if output_dir is None: + return model quant_config = {} quant_config["quant_method"] = "awq" @@ -123,11 +153,11 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs): def save_quantized( - model, - save_dir, - quant_config, - safetensors=True, - shard_size="5GB", + model, + save_dir, + quant_config, + safetensors=True, + shard_size="5GB", ): save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir @@ -220,4 +250,3 @@ def get_module_name(model, module_to_find): if module is module_to_find: return name return None -