Skip to content

Commit

Permalink
refine quant tool
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee committed Sep 26, 2023
1 parent 3a92c56 commit 43f73ce
Show file tree
Hide file tree
Showing 2 changed files with 234 additions and 108 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,111 +9,26 @@
from pathlib import Path
from typing import List, Tuple

import logging

import numpy as np
import numpy.typing as npt
import onnx
from onnx.onnx_pb import GraphProto, ModelProto, NodeProto, TensorProto

from .onnx_model import ONNXModel
from .quant_utils import attribute_to_kwarg, load_model_with_shape_infer
import multiprocessing

from joblib import Parallel, delayed, parallel_config

def __q4_block_size() -> int:
# happens to be 32 for now, but future quantization types
# may have bigger block size
return 32


def __q4_blob_size(block_size: int) -> int:
return block_size // 2


def __q4_buf_size(rows: int, cols: int) -> int:
block_size = __q4_block_size()
blob_size = __q4_blob_size(block_size)
k_blocks = (rows + block_size - 1) // block_size
return k_blocks * cols * blob_size


def int4_block_quant(fp32weight: npt.ArrayLike, symmetric: bool) -> np.ndarray:
"""4b quantize fp32 weight to a blob"""

if len(fp32weight.shape) != 2:
raise ValueError("Current int4 block quantization only supports 2D tensors!")
rows, cols = fp32weight.shape

block_size = __q4_block_size()
blob_size = __q4_blob_size(block_size)
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows
if pad_len > 0:
fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")

# block wise quantization, each block comes from a single column
packed = np.zeros((cols * k_blocks, blob_size), dtype="uint8")
scales = np.zeros((cols * k_blocks), dtype=fp32weight.dtype)
zero_point = np.zeros((cols * k_blocks), dtype="uint8")
fp32weight = np.transpose(fp32weight)
def _process_column(n):
ncol = fp32weight[n, :]
blks = np.split(ncol, k_blocks)
blob_idx = n * k_blocks
#print(f"start to process {n}")
for blk in blks:
packed_blob = packed[blob_idx]

if symmetric:
amax_idx = np.argmax(np.abs(blk))
bmax = blk[amax_idx]
scale = bmax / (-8)
zp = 8
else:
vmin = np.min(blk)
vmax = np.max(blk)
vmin = min(vmin, 0.0)
vmax = max(vmax, 0.0)
scale = (vmax - vmin) / ((1 << 4) - 1)
zero_point_fp = vmin
if scale != 0.0:
zero_point_fp = 0.0 - vmin / scale
zp = min(15, max(0, round(zero_point_fp)))

reciprocal_scale = 1.0 / scale if scale != 0 else 0.0
scales[blob_idx] = scale
zero_point[blob_idx] = zp
blob_idx += 1

blk_int = np.clip(np.rint(blk * reciprocal_scale + zp), 0, 15).astype("uint8")
for i in range(0, blob_size, 2):
packed_blob[i//2] = blk_int[i] | blk_int[i+1] << 4
#print(f"end to process {n}")

with parallel_config(backend='threading', n_jobs=-1):
Parallel()(delayed(_process_column)(n) for n in range(cols))
return (packed.reshape((cols, k_blocks, blob_size)),
scales.reshape((cols, k_blocks)),
zero_point.reshape((cols, k_blocks)))


class MatMulWeight4Quantizer:
"""Perform 4b quantization of constant MatMul weights"""
import concurrent.futures

##################
# quantization types, must be consistent with native code type
# MLAS_BLK_QUANT_TYPE defined in mlas_q4.h
logger = logging.getLogger(__name__)

# 32 number block, symmetric quantization, with one fp32 as scale, zero point is always 0
BlkQ4Sym = 0

# 32 number block, quantization, with one fp32 as scale, one uint8 zero point
BlkQ4Zp8 = 1
class MatMul4BitsQuantizer:
"""Perform 4b quantization of constant MatMul weights"""

def __init__(self, model: ModelProto, quant_type: int):
def __init__(self, model: ModelProto, block_size: int, is_symmetric: bool):
self.model = ONNXModel(model)
self.quant_type = quant_type
self.block_size = block_size
self.is_symmetric = is_symmetric

@staticmethod
def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto, GraphProto]:
Expand All @@ -124,14 +39,65 @@ def __get_initializer(name, graph_path: List[GraphProto]) -> Tuple[TensorProto,
return tensor, graph
return None, None

def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto], symmetric) -> NodeProto:
def int4_block_quant(self, fp32weight: npt.ArrayLike) -> np.ndarray:
"""4b quantize fp32 weight to a blob"""

if len(fp32weight.shape) != 2:
raise ValueError("Current int4 block quantization only supports 2D tensors!")
rows, cols = fp32weight.shape

block_size = self.block_size
blob_size = block_size // 2
k_blocks = (rows + block_size - 1) // block_size
padded_rows = k_blocks * block_size
pad_len = padded_rows - rows
if pad_len > 0:
fp32weight = np.pad(fp32weight, ((0, pad_len), (0, 0)), "constant")

# block wise quantization, each block comes from a single column
packed = np.zeros((cols, k_blocks, blob_size), dtype="uint8")
scales = np.zeros((cols, k_blocks), dtype=fp32weight.dtype)
zero_point = np.zeros((cols, k_blocks), dtype="uint8")
fp32weight = np.transpose(fp32weight).copy()
for n in range(cols):
for k_id in range(0, rows, block_size):
if self.is_symmetric:
amax_idx = np.argmax(np.abs(fp32weight[n, k_id:k_id+block_size]))
bmax = fp32weight[n, k_id + amax_idx]
scale = bmax / (-8)
zp = 8
else:
vmin = np.min(fp32weight[n, k_id:k_id+block_size])
vmax = np.max(fp32weight[n, k_id:k_id+block_size])
vmin = min(vmin, 0.0)
vmax = max(vmax, 0.0)
scale = (vmax - vmin) / ((1 << 4) - 1)
zero_point_fp = vmin
if scale != 0.0:
zero_point_fp = 0.0 - vmin / scale
zp = min(15, max(0, round(zero_point_fp)))

reciprocal_scale = 1.0 / scale if scale != 0 else 0.0
scales[n, k_id // block_size] = scale
zero_point[n, k_id // block_size] = zp

blk_int0 = np.clip(fp32weight[n, k_id:k_id+block_size:2] * reciprocal_scale + zp, 0, 15).astype("uint8")
blk_int1 = np.clip(fp32weight[n, k_id + 1:k_id+block_size:2] * reciprocal_scale + zp, 0, 15).astype("uint8")
packed[n, k_id // block_size] = np.bitwise_or(blk_int0, np.left_shift(blk_int1, 4))

return (packed.reshape((cols, k_blocks, blob_size)),
scales.reshape((cols, k_blocks)),
zero_point.reshape((cols, k_blocks)))

def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto]) -> NodeProto:
"""If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node"""

if node.op_type != "MatMul":
return node # only care about MatMul for now

logger.info(f"start to quantize {node.name} ...")
inputB = node.input[1] # noqa: N806
B, Bs_graph = MatMulWeight4Quantizer.__get_initializer(inputB, graph_stack) # noqa: N806
B, Bs_graph = MatMul4BitsQuantizer.__get_initializer(inputB, graph_stack) # noqa: N806
if B is None:
return node # only care about constant weight

Expand All @@ -140,7 +106,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto],
if len(B_array.shape) != 2:
return node # can only process 2-D matrix

packed, scales, zero_points = int4_block_quant(B_array, symmetric)
packed, scales, zero_points = self.int4_block_quant(B_array)
B_quant = onnx.numpy_helper.from_array(packed) # noqa: N806
B_quant.name = B.name + "_Q4"
Bs_graph.initializer.remove(B)
Expand All @@ -154,7 +120,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto],
Bs_graph.initializer.extend([B_quant, scales_tensor])

input_names = [node.input[0], B_quant.name, scales_tensor.name]
if not symmetric:
if not self.is_symmetric:
zp_tensor = onnx.numpy_helper.from_array(zero_points) # noqa: N806
zp_tensor.name = B.name + "_zero_points"
Bs_graph.initializer.extend([B_quant, zp_tensor])
Expand All @@ -165,7 +131,7 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto],
kwargs["K"] = rows
kwargs["N"] = cols
kwargs["bits"] = 4
kwargs["block_size"] = 32
kwargs["block_size"] = self.block_size

matmul_q4_node = onnx.helper.make_node(
"MatMulWithCompressWeight",
Expand All @@ -175,14 +141,16 @@ def _q4_matmul_node_weight(self, node: NodeProto, graph_stack: List[GraphProto],
domain="com.microsoft",
**kwargs,
)

logger.info(f"finish {node.name} ...")

return matmul_q4_node

def _process_subgraph(self, graph_stack: List[GraphProto], symmetric):
def _process_subgraph(self, graph_stack: List[GraphProto]):
new_nodes = []
graph = graph_stack[-1]

for node in graph.node:
print(node.name)
graph_attrs = [
attr
for attr in node.attribute
Expand All @@ -194,13 +162,13 @@ def _process_subgraph(self, graph_stack: List[GraphProto], symmetric):
if attr.type == onnx.AttributeProto.GRAPH:
# recursive call to take care of sub-graph
graph_stack.append(attr.g)
kv = {attr.name: self._process_subgraph(graph_stack, symmetric)}
elif attr.type == onnx.AttributeProto.GRAPH:
kv = {attr.name: self._process_subgraph(graph_stack)}
elif attr.type == onnx.AttributeProto.GRAPHS:
value = []
for subgraph in attr.graphs:
# recursive call to take care of sub-graph
graph_stack.append(subgraph)
value.extend([self._process_subgraph(graph_stack, symmetric)])
value.extend([self._process_subgraph(graph_stack)])
kv = {attr.name: value}
else:
kv = attribute_to_kwarg(attr)
Expand All @@ -209,14 +177,14 @@ def _process_subgraph(self, graph_stack: List[GraphProto], symmetric):
node.op_type, node.input, node.output, name=node.name, **kwargs
)

new_nodes.append(self._q4_matmul_node_weight(node, graph_stack, symmetric))
new_nodes.append(self._q4_matmul_node_weight(node, graph_stack))

graph.ClearField("node")
graph.node.extend(new_nodes)
graph_stack.pop()
return graph

def process(self, symmetric=True):
def process(self):
# use a stack to keep track of sub-graphs
graph_stack = [self.model.graph()]
opset_import = self.model.opset_import()
Expand All @@ -228,7 +196,7 @@ def process(self, symmetric=True):
if not has_ms_domain:
opset_import.extend([onnx.helper.make_opsetid("com.microsoft", 1)])

self._process_subgraph(graph_stack, symmetric)
self._process_subgraph(graph_stack)


def parse_args():
Expand All @@ -243,6 +211,8 @@ def parse_args():

parser.add_argument("--input_model", required=True, help="Path to the input model file")
parser.add_argument("--output_model", required=True, help="Path to the output model file")
parser.add_argument("--block_size", required=False, default=32)
parser.add_argument("--symmetric", required=False, default=True, help="Indicate whether to quantize the model symmetrically")
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
parser.set_defaults(use_external_data_format=False)

Expand All @@ -256,6 +226,6 @@ def parse_args():
output_model_path = args.output_model

model = load_model_with_shape_infer(Path(input_model_path))
quant = MatMulWeight4Quantizer(model, 0)
quant = MatMul4BitsQuantizer(model, args.block_size, args.symmetric)
quant.process()
quant.model.save_model_to_file(output_model_path, True)
Loading

0 comments on commit 43f73ce

Please sign in to comment.