Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Jan 12, 2024
1 parent 595d0dc commit 0448cd5
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import importlib
import logging
import os
from typing import List, Optional
from typing import Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -155,7 +155,7 @@ def optimize_weights(
tensor: np.ndarray,
scale: np.ndarray,
zero: np.ndarray,
min_max: List[int],
min_max: list[int],
axis: int = 0,
opt_params: Optional[dict] = None,

Check warning

Code scanning / lintrunner

RUFF/UP007 Warning

Use X | Y for type annotations.
See https://docs.astral.sh/ruff/rules/non-pep604-annotation
verbose=False,
Expand Down Expand Up @@ -184,10 +184,10 @@ def shrink_op(x, beta, p=lp_norm):

best_error = 1e4
for i in range(iters):
w_q= np.round(w_f * scale + zero).clip(min_max[0], min_max[1])
w_r = (w_q- zero) / scale
w_q = np.round(w_f * scale + zero).clip(min_max[0], min_max[1])
w_r = (w_q - zero) / scale
w_e = shrink_op(w_f - w_r, beta)
zero = np.mean(w_q- (w_f - w_e) * scale, axis=axis, keepdims=True)
zero = np.mean(w_q - (w_f - w_e) * scale, axis=axis, keepdims=True)
beta *= kappa

current_error = float(np.abs(w_f - w_r).mean())
Expand Down Expand Up @@ -263,8 +263,8 @@ def quantize_internal(

# Quantize
# Necessary for fake quantization backprop
w_q= np.round(weight * scale + zero).clip(min_max[0], min_max[1])
w_q= w_q.reshape(shape).astype(np.uint32)
w_q = np.round(weight * scale + zero).clip(min_max[0], min_max[1])
w_q = w_q.reshape(shape).astype(np.uint32)

scale = np.reciprocal(scale)
if axis == 1:
Expand Down

0 comments on commit 0448cd5

Please sign in to comment.