Skip to content

Commit

Permalink
algin config
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <[email protected]>
  • Loading branch information
yiliu30 committed Jun 19, 2024
1 parent a545ec1 commit c7cae86
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 36 deletions.
17 changes: 12 additions & 5 deletions auto_round/autoround.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@
logger,
sampling_inputs,
to_device, get_layer_names_in_block,
ORIGIN_LINEAR,
)

from auto_round import config as ar_config
class AutoRound(object):
"""This is Signround+ which is an advanced version of Signround. For more information,
please refer to Cheng, Wenhua, et al. "Optimize weight rounding via signed gradient descent
Expand Down Expand Up @@ -92,6 +92,7 @@ class AutoRound(object):
data_type (str): The data type to be used (default is "int").
scale_dtype (str): The data type of quantization scale to be used (default is "float32"), different kernels
have different choices.
enable_teq (bool): Whether to enable weight TEQ(Trainable Equivalent Transformation) (default is False).
Returns:
The quantized model.
Expand Down Expand Up @@ -127,6 +128,7 @@ def __init__(
dynamic_max_gap: int = -1,
data_type: str = "int", ##only support int for now
scale_dtype: str = "fp16",
enable_teq: bool = False,
**kwargs,
):
self.quantized = False
Expand Down Expand Up @@ -208,10 +210,12 @@ def __init__(
self.serialization_dict["autoround_version"] = __version__
if "scale_dtype" in self.serialization_dict.keys():
self.serialization_dict["scale_dtype"] = str(self.serialization_dict["scale_dtype"])

if is_optimum_habana_available():
logger.info("Optimum Habana is available, import htcore explicitly.")
import habana_frameworks.torch.core as htcore # pylint: disable=E0401
import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401
self.enable_teq = enable_teq

def check_configs(self):
"""Checks if the configurations are valid.
Expand Down Expand Up @@ -285,6 +289,9 @@ def quantize(self):
unquantized_layers = []
for n, m in self.model.named_modules():
if isinstance(m, tuple(self.supported_types)):
# For teq, replace `Linear` with `MulLinear`, and remove the suffix of the layer name.
if self.enable_teq and n.endswith():
n = n.replace("." + ORIGIN_LINEAR, "")
if self.weight_config[n]["bits"] == 16:
unquantized_layers.append(n)
else:
Expand Down Expand Up @@ -829,7 +836,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
if q_input is not None:
input_ids = q_input
torch.cuda.empty_cache()
quantized_layer_names, unquantized_layer_names = wrapper_block(block, self.enable_minmax_tuning)
quantized_layer_names, unquantized_layer_names = wrapper_block(block, self.enable_minmax_tuning, self.enable_teq)

round_params = []
minmax_params = []
Expand All @@ -847,7 +854,7 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
else:
trainable_params.append({"params":round_params})

if ar_config.layer_equalization_transform:
if self.enable_teq:
import auto_round.scale as scale_utils
leq_params_lst = scale_utils.get_scale_param_from_block(block)
trainable_params.append({"params": leq_params_lst})
Expand Down Expand Up @@ -930,15 +937,15 @@ def quant_block(self, block, input_ids, input_others, q_input=None, device=torch
best_v = collect_round_v(block)
best_min_scale, best_max_scale = collect_minmax_scale(block)
best_leq_weight_scales = None
if ar_config.layer_equalization_transform:
if self.enable_teq:
best_leq_weight_scales = collect_weight_scale(block)
last_best_iter = i
logger.info(f"get better result at iter {i}, the loss is {total_loss}")
if self.not_use_best_mse and i == self.iters - 1:
best_v = collect_round_v(block)
best_min_scale, best_max_scale = collect_minmax_scale(block)
best_leq_weight_scales = None
if ar_config.layer_equalization_transform:
if self.enable_teq:
best_leq_weight_scales = collect_weight_scale(block)
logger.info(f"get better result at last iter {i}, the loss is {total_loss}")

Expand Down
24 changes: 14 additions & 10 deletions auto_round/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@
get_scale_shape,
set_module,
logger,
ORIGIN_LINEAR,
)
from typing import Optional, Dict
import auto_round.config as ar_config
def round_ste(x: torch.Tensor):
"""Straight-Through Estimator for rounding.
This function is adapted from omniquant.
Expand Down Expand Up @@ -193,7 +193,7 @@ def quant_weight(


class WrapperLinear(torch.nn.Module):
def __init__(self, orig_layer, enable_minmax_tuning=True):
def __init__(self, orig_layer, enable_minmax_tuning=True, enable_teq=False):
"""A wrapper module for linear layers that enables quantization and min-max tuning of weights.
Args:
Expand Down Expand Up @@ -234,12 +234,13 @@ def __init__(self, orig_layer, enable_minmax_tuning=True):
else:
self.min_scale = torch.tensor(1.0, device=self.orig_layer.weight.device, dtype=weight_dtype)
self.max_scale = torch.tensor(1.0, device=self.orig_layer.weight.device, dtype=weight_dtype)

if ar_config.layer_equalization_transform:

self.enable_teq = enable_teq
if self.enable_teq:
from auto_round import scale
self.weight_scale_calculator = scale.ScaleCalculator(self.orig_layer.weight.data, self.orig_layer.weight.device)

def unwrapper(self, v, min_scale, max_scale, leq_weight_scale):
def unwrapper(self, v, min_scale, max_scale, leq_weight_scale=None):
"""Unwrapper the layer to the original layer.
Args:
Expand All @@ -252,14 +253,17 @@ def unwrapper(self, v, min_scale, max_scale, leq_weight_scale):
- torch.nn.Module: The original linear layer with updated weights after quantization and dequantization.
"""

if ar_config.layer_equalization_transform:
if leq_weight_scale is not None:
assert self.enable_teq, f"enable_teq is False, but got leq_weight_scale {leq_weight_scale}"
logger.warning(f"Layer equalization transform is enabled for {self.orig_layer}")
# import pdb; pdb.set_trace()
assert leq_weight_scale is not None, "leq_weight_scale is required for layer equalization transform"
from .scale import replace_linear_with_smoothed_linear
from auto_round.scale import replace_linear_with_smoothed_linear
logger.debug(f"Replace {self.orig_layer} with `MulLinear`")
logger.debug(f"The range of original layer weight: {self.orig_layer.weight.min()} - {self.orig_layer.weight.max()}")
self.orig_layer = replace_linear_with_smoothed_linear(self.orig_layer, leq_weight_scale)
logger.debug(f"The range of new layer weight: {self.orig_layer.linear.weight.min()} - {self.orig_layer.linear.weight.max()}")
weight = getattr(self.orig_layer, ORIGIN_LINEAR)
logger.debug(f"The range of new layer weight: {weight.min()} - {weight.max()}")

min_scale.clamp_(0, 1.0)
max_scale.clamp_(0, 1.0)
Expand Down Expand Up @@ -430,7 +434,7 @@ def forward(self, x, **kwargs):
return hidden_states


def wrapper_block(block, enable_minmax_tuning):
def wrapper_block(block, enable_minmax_tuning, enable_teq=False):
"""Wraps the layers in the given block with a custom Wrapper module.
Args:
Expand All @@ -447,7 +451,7 @@ def wrapper_block(block, enable_minmax_tuning):
if not check_to_quantized(m):
unquantized_layers.append(n)
continue
new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning)
new_m = WrapperLinear(m, enable_minmax_tuning=enable_minmax_tuning, enable_teq=enable_teq)
set_module(block, n, new_m)
quantized_layers.append(n)

Expand Down
22 changes: 7 additions & 15 deletions auto_round/scale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# [x] insert scale calculator ar `WrapperLinear`
# [x] `init` insert `self.input_scale_calculator = ScaleCalculatorV(module.in_features, module.weight.device)`
# [x] add parameter of `self.input_scale_calculator` into optimizer
# [x] `forward` transform `input` and `weight`
# [x] at the `unwrapper` stage, replace the original `Linear` with `MulLinear`
# [x] use the best weight scale instead of the final weight scale
# [ ] save and export



import torch
from .utils import logger
from auto_round import utils


def get_scale_param_from_block(block: torch.nn.Module):
scale_params = []
Expand Down Expand Up @@ -58,10 +50,10 @@ def __init__(self, module, weight_scale=None):
if weight_scale is None:
weight_scale = torch.ones(module.in_features)
self.register_buffer("weight_scale", weight_scale)
logger.info(f"Original module weight shape: {module.weight.shape}.")
utils.logger.info(f"Original module weight shape: {module.weight.shape}.")
module.weight *= weight_scale.reshape(1, -1)
self.add_module("linear", module)
logger.info(f"MulLinear: {module} has been wrapped as `MulLinear`.")
self.add_module(utils.ORIGIN_LINEAR, module)
utils.logger.info(f"MulLinear: {module} has been wrapped as `MulLinear`.")

def forward(self, X):
updated_x = _transform_input(X, self.weight_scale)
Expand Down Expand Up @@ -89,8 +81,8 @@ def bias(self, bias):

def replace_linear_with_smoothed_linear(module, weight_scale):
from .scale import MulLinear
logger.info(f"Replace {module} with `MulLinear`.")
logger.info(f"weight_scale shape: {weight_scale.shape}, weight scale min: {weight_scale.min()}, weight scale max: {weight_scale.max()}")
utils.logger.info(f"Replace {module} with `MulLinear`.")
utils.logger.info(f"weight_scale shape: {weight_scale.shape}, weight scale min: {weight_scale.min()}, weight scale max: {weight_scale.max()}")
return MulLinear(module, weight_scale)


Expand Down
1 change: 1 addition & 0 deletions auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def warning_once(self, msg: str):
import importlib
import transformers

ORIGIN_LINEAR = "_ORIGIN_LINEAR"
class LazyImport(object):
"""Lazy import python module till use."""

Expand Down
4 changes: 3 additions & 1 deletion examples/language-modeling/eval_legacy/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def simple_evaluate(

def eval_model(model_path=None, tasks=["lambada_openai", "hellaswag", "winogrande", "piqa"],
eval_bs=32, use_accelerate=True, dtype=None, limit=None, trust_remote_code=True,
device="cuda:0", seed=0, nsamples=128, mark="paper", excel_file="tmp.xlsx"):
device="cuda:0", seed=0, nsamples=128, mark="paper", excel_file="tmp.xlsx",
model_tokenizer_pairs=None,
):
print("evaluation with official lm-eval", flush=True)
try:
import lm_eval
Expand Down
22 changes: 17 additions & 5 deletions examples/language-modeling/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@

parser.add_argument("--model_dtype", default=None, type=str,
help="force to convert the dtype, some backends supports fp16 dtype better")

parser.add_argument("--enable_teq", action='store_true',
help="whether to enable teqg")
args = parser.parse_args()
if args.low_gpu_mem_usage:
print(
Expand Down Expand Up @@ -308,7 +309,9 @@ def get_library_version(library_name):
low_gpu_mem_usage=not args.disable_low_gpu_mem_usage,
seed=args.seed, gradient_accumulate_steps=args.gradient_accumulate_steps,
scale_dtype=args.scale_dtype, weight_config=weight_config,
enable_minmax_tuning=not args.disable_minmax_tuning)
enable_minmax_tuning=not args.disable_minmax_tuning,
enable_teq=args.enable_teq,
)
model, _ = autoround.quantize()
model_name = args.model_name.rstrip("/")

Expand All @@ -333,13 +336,22 @@ def get_library_version(library_name):
model.save_pretrained(output_dir)
tokenizer.save_pretrained(output_dir)


if not args.disable_eval and "fake" in deployment_device and lm_eval_version != "0.4.2":
excel_name = f"{output_dir}_result.xlsx"
output_dir += "/"
print(excel_name, flush=True)
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name)
if args.enable_teq:
# If `enable_teq`, it introduce `MulLinear`, cann't save directly
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name,
model_tokenizer_pairs=(model.to("cuda"), tokenizer)
)
else:
eval_model(model_path=output_dir, tasks=tasks, dtype=dtype, limit=None,
eval_bs=args.eval_bs, use_accelerate=not args.disable_low_gpu_mem_usage,
device=torch_device, excel_file=excel_name)

if not args.disable_eval and lm_eval_version == "0.4.2":
if "round" in deployment_device:
Expand Down

0 comments on commit c7cae86

Please sign in to comment.