Skip to content

Commit

Permalink
refine autoawq exporting code (#261)
Browse files Browse the repository at this point in the history
  • Loading branch information
wenhuach21 authored Sep 18, 2024
1 parent 7816eea commit 6539d50
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
2 changes: 1 addition & 1 deletion auto_round/export/export_to_autogptq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
import torch

from auto_round.utils import check_to_quantized, get_block_names, \
get_module, logger, get_layer_names_in_block, set_module
get_module, logger, set_module
import copy
import json
import os
Expand Down
83 changes: 56 additions & 27 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,46 @@
import torch
import torch.nn as nn
from auto_round.export.register import register_format
from auto_round.utils import convert_dtype_torch2str_hf, logger
from auto_round.utils import convert_dtype_torch2str_hf, logger, get_module, set_module
import copy
import json
from typing import Dict, List, Optional, Union
from .utils import WQLinear_GEMM, clear_memory, get_self_modules
from concurrent.futures import ThreadPoolExecutor
import threadpoolctl as tctl
from tqdm import tqdm


def pack_layer(name, model, layer_config, backend, pbar):
with tctl.threadpool_limits(limits=1):
pbar.set_description(f"packing {name}")
if name == "lm_head": ##dese not support lm-head
pbar.update(1)
return
config = layer_config[name]
if config["bits"] > 8:
pbar.update(1)
return
scale, zp = config["scale"], config["zp"]
scale = scale.t().contiguous()
zp = zp.t().contiguous()
config["zp"] = config["zp"].to(torch.float32)
bits = config["bits"]
group_size = config["group_size"]
linear_layer = get_module(model, name)
q_linear = WQLinear_GEMM.from_linear(
linear=linear_layer,
w_bit=bits,
group_size=group_size,
init_only=False,
scales=scale,
zeros=zp,
)
linear_layer.cpu()
q_linear.to("cpu")
set_module(model, name, q_linear)
clear_memory()
pbar.update(1)


@register_format("auto_awq")
Expand Down Expand Up @@ -67,36 +103,30 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs):
else:
compressed_model = copy.deepcopy(model.to("cpu"))

from .utils import WQLinear_GEMM, clear_memory, get_self_modules
names = list(layer_config.keys())

q_linear_module = WQLinear_GEMM
self_modules = get_self_modules(compressed_model)
layers = []
for i in range(len(self_modules)):
module = self_modules[i]
named_linears = get_named_linears(module)
for name, linear_layer in named_linears.items():
key = get_module_name(compressed_model, linear_layer)
logger.info(f"packing {name}")
layers.append(key)
config = layer_config[key]
if config["bits"] > 8:
modules_to_not_convert.append(name)
continue
config["zp"] = config["zp"].to(torch.float32)
scale, zp = config["scale"], config["zp"]
scale = scale.t().contiguous()
zp = zp.t().contiguous()
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=bits,
group_size=group_size,
init_only=False,
scales=scale,
zeros=zp,
)
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()

backend = None
with ThreadPoolExecutor(max_workers=2) as executor:
with tqdm(total=len(names), leave=True) as pbar:
def wrapper(name):
pack_layer(name, model, layer_config, backend, pbar)

for _ in executor.map(wrapper, names):
pass
if output_dir is None:
return model

quant_config = {}
quant_config["quant_method"] = "awq"
Expand All @@ -123,11 +153,11 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs):


def save_quantized(
model,
save_dir,
quant_config,
safetensors=True,
shard_size="5GB",
model,
save_dir,
quant_config,
safetensors=True,
shard_size="5GB",
):
save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir

Expand Down Expand Up @@ -220,4 +250,3 @@ def get_module_name(model, module_to_find):
if module is module_to_find:
return name
return None

0 comments on commit 6539d50

Please sign in to comment.