diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py index 3b8d256a71d14..a08755d825ef3 100644 --- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py +++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py @@ -11,7 +11,7 @@ import importlib import logging import os -from typing import List, Optional +from typing import Optional import numpy as np import numpy.typing as npt @@ -155,7 +155,7 @@ def optimize_weights( tensor: np.ndarray, scale: np.ndarray, zero: np.ndarray, - min_max: List[int], + min_max: list[int], axis: int = 0, opt_params: Optional[dict] = None, verbose=False, @@ -184,10 +184,10 @@ def shrink_op(x, beta, p=lp_norm): best_error = 1e4 for i in range(iters): - w_q= np.round(w_f * scale + zero).clip(min_max[0], min_max[1]) - w_r = (w_q- zero) / scale + w_q = np.round(w_f * scale + zero).clip(min_max[0], min_max[1]) + w_r = (w_q - zero) / scale w_e = shrink_op(w_f - w_r, beta) - zero = np.mean(w_q- (w_f - w_e) * scale, axis=axis, keepdims=True) + zero = np.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdims=True) beta *= kappa current_error = float(np.abs(w_f - w_r).mean()) @@ -263,8 +263,8 @@ def quantize_internal( # Quantize # Necessary for fake quantization backprop - w_q= np.round(weight * scale + zero).clip(min_max[0], min_max[1]) - w_q= w_q.reshape(shape).astype(np.uint32) + w_q = np.round(weight * scale + zero).clip(min_max[0], min_max[1]) + w_q = w_q.reshape(shape).astype(np.uint32) scale = np.reciprocal(scale) if axis == 1: