From 64674c50de7a19b790498c98e1111c0aaba5676a Mon Sep 17 00:00:00 2001
From: Jing Fang <126209182+fajin-corp@users.noreply.github.com>
Date: Mon, 19 Aug 2024 10:25:36 -0700
Subject: [PATCH] Added a tool to quantize Gather to GatherBlockQuantized
 (#21697)

### Description
Added code in MatMul4BitsQuantizer to quantize Gather to
GatherBlockQuantized.

Only Gather with constant data is quantized.

Since quantized data is in int4, the quantized model will force upgrade
to onnx opset 21.

The implementation purely relies on numpy. If optimization is needed,
C++ kernels can be added later.

Only support default RTN algorithm since GatherBlockQuantized require
zero points to have the same type as quantized data.

### Motivation and Context
Support quantizing gather to int4 in Web scenario.
---
 .../quantization/matmul_4bits_quantizer.py    | 349 ++++++++++++++++--
 .../python/tools/quantization/quantize.py     |  17 +-
 .../quantization/test_op_matmul_4bits.py      | 125 ++++++-
 3 files changed, 438 insertions(+), 53 deletions(-)

diff --git a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
index c0cc4f038cd3b..975f82439c160 100644
--- a/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
+++ b/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py
@@ -29,8 +29,14 @@
 
 
 class WeightOnlyQuantConfig:
-    def __init__(self, algorithm, quant_format):
-        """This is the Base class for Weight Only Quant Configuration.
+    def __init__(
+        self,
+        algorithm: str,
+        quant_format: QuantFormat,
+        op_types_to_quantize: tuple[str, ...] | None = None,
+        quant_axes: tuple[tuple[str, int], ...] | None = None,
+    ):
+        """This is the Base class for Weight Only blockwise quantization Configuration.
 
         Args:
             algorithm:
@@ -38,9 +44,15 @@ def __init__(self, algorithm, quant_format):
             quant_format: QuantFormat{QOperator, QDQ}.
                 QOperator format quantizes the model with quantized operators directly.
                 QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
+            op_types_to_quantize (optional):
+                set of operator types to quantize. Default {MatMul}
+            quant_axes (dict[str, int], optional):
+                op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
         """
         self.algorithm = algorithm
         self.quant_format = quant_format
+        self.op_types_to_quantize = set(op_types_to_quantize) if op_types_to_quantize else {"MatMul"}
+        self.quant_axes = dict(quant_axes) if quant_axes else {"MatMul": 0, "Gather": 1}
 
 
 class RTNWeightOnlyQuantConfig(WeightOnlyQuantConfig):
@@ -48,6 +60,7 @@ def __init__(
         self,
         ratios=None,
         quant_format=QuantFormat.QOperator,
+        op_types_to_quantize: tuple[str, ...] | None = None,
     ):
         """
         This is a class for round-to-nearest (RTN) algorithm Weight Only Quant Configuration.
@@ -60,6 +73,8 @@ def __init__(
                 QOperator format quantizes the model with quantized operators directly.
                 QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
                 Defaults to QuantFormat.QOperator.
+            op_types_to_quantize (optional):
+                set of operator types to quantize.
         """
         assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
 
@@ -68,6 +83,7 @@ def __init__(
         super().__init__(
             algorithm="RTN",
             quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
         )
         self.ratios = ratios
 
@@ -75,13 +91,14 @@ def __init__(
 class GPTQWeightOnlyQuantConfig(WeightOnlyQuantConfig):
     def __init__(
         self,
-        calibration_data_reader: CalibrationDataReader,
+        calibration_data_reader: CalibrationDataReader | None = None,
         percdamp=0.01,
         block_size=128,
         actorder=False,
         mse=False,
         perchannel=True,
         quant_format=QuantFormat.QOperator,
+        op_types_to_quantize: tuple[str, ...] | None = None,
     ):
         """
         This is a class for GPTQ algorithm Weight Only Quant Configuration.
@@ -104,12 +121,15 @@ def __init__(
                 QOperator format quantizes the model with quantized operators directly.
                 QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
                 Defaults to QuantFormat.QOperator.
+            op_types_to_quantize (optional):
+                set of operator types to quantize.
         """
         assert quant_format == QuantFormat.QOperator, "GPTQ only supports QOperator format"
 
         super().__init__(
             algorithm="GPTQ",
             quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
         )
         self.calibration_data_reader = calibration_data_reader
         self.percdamp = percdamp
@@ -126,6 +146,8 @@ def __init__(
         bits=4,
         axis=1,
         quant_format=QuantFormat.QOperator,
+        op_types_to_quantize: tuple[str, ...] | None = None,
+        quant_axes: tuple[tuple[str, int], ...] | None = None,
     ):
         """
         This is a class for HQQ algorithm Weight Only Quant Configuration.
@@ -142,12 +164,18 @@ def __init__(
                 QOperator format quantizes the model with quantized operators directly.
                 QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
                 Defaults to QuantFormat.QOperator.
+            op_types_to_quantize (optional):
+                set of operator types to quantize.
+            quant_axes (dict[str, int], optional):
+                op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
         """
         assert quant_format == QuantFormat.QOperator, "HQQ only supports QOperator format"
 
         super().__init__(
             algorithm="HQQ",
             quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
+            quant_axes=quant_axes,
         )
         self.block_size = block_size
         self.bits = bits
@@ -161,6 +189,8 @@ def __init__(
         is_symmetric: bool = False,
         accuracy_level: int | None = None,
         quant_format=QuantFormat.QOperator,
+        op_types_to_quantize: tuple[str, ...] | None = None,
+        quant_axes: tuple[tuple[str, int], ...] | None = None,
     ):
         """
         This is a class for weight only affine quantization configuration.
@@ -178,8 +208,17 @@ def __init__(
                 QOperator format quantizes the model with quantized operators directly.
                 QDQ format quantize the model by inserting QuantizeLinear/DeQuantizeLinear on the tensor.
                 Defaults to QuantFormat.QOperator.
+            op_types_to_quantize (optional):
+                set of operator types to quantize.
+            quant_axes (dict[str, int], optional):
+                op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
         """
-        super().__init__(algorithm="DEFAULT", quant_format=quant_format)
+        super().__init__(
+            algorithm="DEFAULT",
+            quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
+            quant_axes=quant_axes,
+        )
         self.block_size = block_size
         self.is_symmetric = is_symmetric
         self.bits = 4
@@ -205,7 +244,7 @@ def optimize_weights(
         zero,
         min_max: list[int],
         axis: int = 0,
-        opt_params: dict = None,  # noqa: RUF013
+        opt_params: dict | None = None,
         verbose=False,
     ):
         import torch
@@ -223,14 +262,10 @@ def optimize_weights(
         scale = scale.to(dtype)
         zero = zero.to(dtype)
 
-        if lp_norm == 1:
-
-            def shrink_op(x, beta):
+        def shrink_op(x, beta, p=lp_norm):
+            if p == 1:
                 return torch.sign(x) * torch.nn.functional.relu(torch.abs(x) - 1.0 / beta)
-
-        else:
-
-            def shrink_op(x, beta, p=lp_norm):
+            else:
                 return torch.sign(x) * torch.nn.functional.relu(
                     torch.abs(x) - (1.0 / beta) * torch.pow(torch.abs(x) + 1e-8, p - 1)
                 )
@@ -335,11 +370,20 @@ def quantize_internal(
 
     def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
         """
-        If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.
-        If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul.
+        Target node:        QOperator node:            QDQ nodes:
+        MatMul              MatMulNBits                DeQuantizeLinear -> MatMul
+        Gather              GatherBlockQuantized       Gather, Gather, Gather (optional) -> DequantizeLinear
+        If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
+        return the new nodes.
+        If QOperator format, return the corresponding QOperator nodes.
+        If QDQ format, return the corresdponging QDQ nodes.
+        Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
+        not supported yet because Gather does not support int4 data.
         """
-        if node.op_type != "MatMul":
-            return [node]  # only care about MatMul for now
+        # With HQQ, zero points are in float. Current GatherBlockQuantized does not support float zero points.
+        if node.op_type == "Gather":
+            raise NotImplementedError("Gather quantization is not supported yet in HQQ")
+
         import torch
 
         logger.info(f"start to quantize {node.name} ...")
@@ -432,7 +476,7 @@ def __init__(self, config: DefaultWeightOnlyQuantConfig):
         self.config = config
 
     def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
-        """4b quantize fp32 weight to a blob"""
+        """4b quantize fp32 weight to int4 using C++ kernels."""
 
         if len(fp32weight.shape) != 2:
             raise ValueError("Current int4 block quantization only supports 2D tensors!")
@@ -465,16 +509,11 @@ def int4_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.nd
 
         return (packed, scales, zero_point)
 
-    def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
+    def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
         """
-        If the node is MatMul with fp32 const weight, quantize the weight with int4, and return the new node.
-        If QOperator format, return MatMulNbits. If QDQ format, return DeQuantizeLinear + MatMul.
+        Quantize weight B of MatMul node to int4.
+        Currently only support 2D constant matrix and axis 0 blockwise quantization.
         """
-
-        if node.op_type != "MatMul":
-            return [node]  # only care about MatMul for now
-
-        logger.info(f"start to quantize {node.name} ...")
         qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
         input_b = node.input[1]
         b_tensor, b_graph = get_initializer(input_b, graph_stack)
@@ -557,17 +596,206 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
             )
             output_nodes.extend([dq_node, matmul_node])
 
-        logger.info(f"complete quantization of {node.name} ...")
         return output_nodes
 
+    @staticmethod
+    def quant_slice_symmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
+        max_val = np.max(data, axis=1, keepdims=True)
+        min_val = np.min(data, axis=1, keepdims=True)
+        abs_max = np.where(np.abs(max_val) > np.abs(min_val), max_val, min_val)
+
+        scale = abs_max / -8.0  # if max == min, max may be clipped
+        quantized_slice = np.where(scale == 0, 0, data / scale).round().clip(-8, 7).astype(np.int8)
+
+        return quantized_slice, scale
+
+    @staticmethod
+    def quant_slice_asymmetric(data: np.ndarray) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
+        min_val = np.minimum(data.min(axis=1, keepdims=True), 0)
+        max_val = np.maximum(data.max(axis=1, keepdims=True), 0)
+
+        scale = (max_val - min_val) / 15.0
+        zero_point = np.where(scale == 0, 8, -min_val / scale).round().clip(0, 15).astype(np.uint8)
+        quantized_slice = np.where(scale == 0, 8, data / scale + zero_point).round().clip(0, 15).astype(np.uint8)
+
+        return quantized_slice, scale, zero_point
+
+    @staticmethod
+    def pack_int8_to_int4(data: np.ndarray) -> np.ndarray:
+        """Pack int8 data to int4 and store in uint8 ndarray."""
+        data_flat = data.reshape(-1)
+        if len(data_flat) % 2 != 0:
+            data_flat = np.append(data_flat, 0)
+        quant_data_int4 = (data_flat[::2] & 0xF) | ((data_flat[1::2] & 0xF) << 4)
+
+        return quant_data_int4.astype("uint8")
+
+    @staticmethod
+    def quantize_ndarray(
+        data: np.ndarray,
+        quantize_axis: int,
+        block_size: int,
+        is_symmetric: bool,
+    ) -> tuple[np.ndarray, np.ndarray, np.ndarray | None]:
+        """Quantize ndarray data to int4 using numpy, return (quantized data, scales, zero points)."""
+        # Get the shape of the matrix
+        m = 1  # dimension of the matrix before the quantize axis
+        k = data.shape[quantize_axis]  # dimension of the matrix along the quantize axis
+        n = 1  # dimension of the matrix after the quantize axis
+        for i, dim in enumerate(data.shape):
+            if i < quantize_axis:
+                m *= dim
+            elif i > quantize_axis:
+                n *= dim
+
+        k_blocks = (k + block_size - 1) // block_size
+        scales_shape = list(data.shape)
+        scales_shape[quantize_axis] = k_blocks
+
+        data_reshape = data.reshape((m, k, n))
+        scales = np.zeros((m, k_blocks, n), dtype=data.dtype)
+        if is_symmetric:
+            quant_data_int8 = np.zeros((m, k, n), dtype="int8")
+        else:
+            quant_data_int8 = np.zeros((m, k, n), dtype="uint8")
+            zero_point_int8 = np.zeros((m, k_blocks, n), dtype="uint8")
+
+        # slice and quantize
+        for i in range(0, k, block_size):
+            end_idx = min(i + block_size, k)
+            slice = data_reshape[:, i:end_idx, :]
+
+            if is_symmetric:
+                quantized_slice_int8, scale_slice = DefaultWeightOnlyQuantizer.quant_slice_symmetric(slice)
+            else:
+                quantized_slice_int8, scale_slice, zero_point_slice_int8 = (
+                    DefaultWeightOnlyQuantizer.quant_slice_asymmetric(slice)
+                )
+
+            quant_data_int8[:, i:end_idx, :] = quantized_slice_int8
+            j = i // block_size
+            scales[:, j : (j + 1), :] = scale_slice
+            if not is_symmetric:
+                zero_point_int8[:, j : (j + 1), :] = zero_point_slice_int8
+
+        # pack int8 to int4
+        quant_data_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(quant_data_int8)
+        zero_point_int4 = None
+        if not is_symmetric:
+            zero_point_int4 = DefaultWeightOnlyQuantizer.pack_int8_to_int4(zero_point_int8)
+        scales = scales.reshape(scales_shape)
+        return quant_data_int4, scales, zero_point_int4
+
+    def quantize_gather(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
+        """Quantize weight data of Gather node to int4."""
+        assert self.config.quant_format == QuantFormat.QOperator, "Gather only supports QOperator format currently."
+
+        qtype = TensorProto.INT4 if self.config.is_symmetric else TensorProto.UINT4
+        data_arg = node.input[0]
+        data_tensorproto, data_graphproto = get_initializer(data_arg, graph_stack)
+        if data_tensorproto is None:
+            logger.info("Gather doesn't have const weight. Skip quantization.")
+            return [node]  # only care about constant weight
+
+        data_ndarray = onnx.numpy_helper.to_array(data_tensorproto)
+        data_rank = len(data_ndarray.shape)
+        quantize_axis = self.config.quant_axes.get("Gather", 1)
+        block_size = self.config.block_size
+
+        assert quantize_axis < data_rank and quantize_axis >= -data_rank, "Invalid quantize axis for Gather node."
+        assert block_size >= 16 and ((block_size - 1) & block_size == 0), "Invalid block size for Gather node."
+
+        quantize_axis = (quantize_axis + data_rank) % data_rank
+        quantized_data, scales, zero_points = self.quantize_ndarray(
+            data_ndarray, quantize_axis, block_size, self.config.is_symmetric
+        )
+
+        for input in data_graphproto.input:
+            if input.name == data_arg:
+                data_graphproto.input.remove(input)
+                break
+
+        quantized_data_tensorproto = onnx.helper.make_tensor(
+            data_tensorproto.name + "_Q4", qtype, data_ndarray.shape, quantized_data.tobytes(), True
+        )
+        scales_tensorproto = onnx.numpy_helper.from_array(scales, data_tensorproto.name + "_scales")
+        input_names = [quantized_data_tensorproto.name, node.input[1], scales_tensorproto.name]
+        data_graphproto.initializer.extend([quantized_data_tensorproto, scales_tensorproto])
+        if not self.config.is_symmetric:
+            zp_tensorproto = onnx.helper.make_tensor(
+                data_tensorproto.name + "_zero_points", qtype, scales.shape, zero_points.tobytes(), True
+            )
+            input_names.append(zp_tensorproto.name)
+            data_graphproto.initializer.extend([zp_tensorproto])
+
+        try:
+            gather_axis = onnx.helper.get_node_attr_value(node, "axis")
+        except ValueError:
+            gather_axis = 0
+
+        kwargs = {
+            "gather_axis": gather_axis,
+            "quantize_axis": quantize_axis,
+            "block_size": block_size,
+        }
+
+        gather_q4_node = onnx.helper.make_node(
+            "GatherBlockQuantized",
+            inputs=input_names,
+            outputs=[node.output[0]],
+            name=node.name + "_Q4" if node.name else "",
+            domain="com.microsoft",
+            **kwargs,
+        )
+
+        return [gather_q4_node]
+
+    def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeProto]:
+        """
+        Target node:        QOperator node:            QDQ nodes:
+        MatMul              MatMulNBits                DeQuantizeLinear -> MatMul
+        Gather              GatherBlockQuantized       Gather, Gather, Gather (optional) -> DequantizeLinear
+        If the node is target node with fp32 or fp16 const weight, quantize the weight to int4 and
+        return the new nodes.
+        If QOperator format, return the corresponding QOperator nodes.
+        If QDQ format, return the corresdponging QDQ nodes.
+        Gather (quantized data) + Gather (scales) + Gather (optional, zero points) -> DequantizeLinear is
+        not supported yet because Gather does not support int4 data.
+        """
+        logger.info(f"start to quantize {node.name} ...")
+
+        if node.op_type == "MatMul":
+            results = self.quantize_matmul(node, graph_stack)
+        elif node.op_type == "Gather":
+            results = self.quantize_gather(node, graph_stack)
+        else:
+            logger.error(f"Unsupported operator {node.op_type} for weight only quantization. Skip quantization.")
+            results = [node]
+
+        logger.info(f"complete quantization of {node.name} ...")
 
+        return results
+
+
+# TODO(fajin): change class name
 class MatMul4BitsQuantizer:
     """
-    Perform 4b quantization of constant MatMul weights.
-    If algo_config.quant_format is QOperator, the quantized weight is stored in a MatMulNBits node, which relaces the
-    MatMul node.
-    If algo_config.quant_format is QDQ, the quantized weight is stored in a DeQuantizeLinear node. The MatMul node is
-    replaced by the DequantizeLinear + MatMul nodes.
+    Target node:        QOperator node:            QDQ nodes:
+    MatMul              MatMulNBits                DeQuantizeLinear -> MatMul
+    Gather              GatherBlockQuantized       Gather, Gather, Gather (optional) -> DequantizeLinear
+
+    Perform 4b quantization of constant weights for target nodes.
+    If algo_config.quant_format is QOperator:
+      - nodes are replaced by the corresponding QOperator nodes.
+      - quantized weights are stored in the contrib ops.
+    If algo_config.quant_format is QDQ:
+      - the quantized weight is stored in a standard onnx node. For MatMul, it is DequantizeLinear. For Gather,
+        it is the three Gathers, one for quantized data, one for scales and one for optional zero points.
+      - The nodes are replaced by the corresponding QDQ nodes.
+      - currently Gather is not supported in QDQ because Gather does not support int4 yet.
+    Note:
+      - for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
+        during runtime. Therefor it is not recommended.
     """
 
     def __init__(
@@ -577,7 +805,10 @@ def __init__(
         is_symmetric: bool = False,
         accuracy_level: int | None = None,
         nodes_to_exclude=None,
+        nodes_to_include: list[str] | None = None,
         quant_format=QuantFormat.QOperator,
+        op_types_to_quantize: tuple[str, ...] | None = None,
+        quant_axes: tuple[tuple[str, int], ...] | None = None,
         algo_config: WeightOnlyQuantConfig | None = None,
     ):
         if nodes_to_exclude is None:
@@ -588,6 +819,7 @@ def __init__(
         self.is_symmetric = is_symmetric
         self.accuracy_level = accuracy_level
         self.nodes_to_exclude = set(nodes_to_exclude)
+        self.nodes_to_include = set(nodes_to_include) if nodes_to_include else None
         self.node_quantizer = None
         if algo_config is None:
             algo_config = DefaultWeightOnlyQuantConfig(
@@ -595,6 +827,8 @@ def __init__(
                 is_symmetric=is_symmetric,
                 accuracy_level=accuracy_level,
                 quant_format=quant_format,
+                op_types_to_quantize=op_types_to_quantize,
+                quant_axes=quant_axes,
             )
         self.algo_config = algo_config
         if algo_config.algorithm == "HQQ":
@@ -636,10 +870,13 @@ def _process_subgraph(self, graph_stack: list[GraphProto]):
             if node.name in self.nodes_to_exclude:
                 logger.info(f"exclude to quantize {node.name} as specified by nodes_to_exclude...")
                 out_nodes = [node]
-            elif self.algo_config is not None and self.algo_config.algorithm == "HQQ":
+            elif (self.nodes_to_include and node.name in self.nodes_to_include) or (
+                node.op_type in self.algo_config.op_types_to_quantize
+            ):
                 out_nodes = self.node_quantizer.quantize(node, graph_stack)
             else:
-                out_nodes = self.node_quantizer.quantize(node, graph_stack)
+                logger.info(f"skip to quantize {node.name} ...")
+                out_nodes = [node]
             new_nodes.extend(out_nodes)
 
         graph.ClearField("node")
@@ -716,7 +953,8 @@ def process(self):
             # Update domain opset
             if self.algo_config.quant_format == QuantFormat.QOperator:
                 self.model.set_opset_import("com.microsoft", 1)
-            else:
+
+            if self.algo_config.quant_format == QuantFormat.QDQ or "Gather" in self.algo_config.op_types_to_quantize:
                 opset_import = self.model.opset_import()
                 for opset in opset_import:
                     if opset.domain in [None, "ai.onnx", ""] and opset.version < 21:
@@ -751,6 +989,12 @@ def ort_convert_str_to_bool(value):
     return value.lower() in ("true", "1")
 
 
+# Custom function to parse str:int pairs
+def parse_key_value_pair(s):
+    key, value = s.split(":")
+    return key, int(value)
+
+
 def parse_args():
     parser = argparse.ArgumentParser(
         description="""Blockwise int4 quantization for MatMul 2D weight matrices.
@@ -800,6 +1044,13 @@ def parse_args():
         default=[],
         help="Specify the nodes to be excluded from quantization with node names",
     )
+    parser.add_argument(
+        "--nodes_to_include",
+        nargs="+",
+        type=str,
+        required=False,
+        help="Specify the specific nodes to be included from quantization with node names",
+    )
     parser.add_argument(
         "--quant_format",
         default="QOperator",
@@ -809,6 +1060,23 @@ def parse_args():
         "QOperator format quantizes the model with quantized operators directly."
         "QDQ format quantize the model by inserting DeQuantizeLinear before the MatMul.",
     )
+    parser.add_argument(
+        "--op_types_to_quantize",
+        default="MatMul",
+        type=str,
+        nargs="+",
+        choices=["MatMul", "Gather"],
+        help="op_types_to_quantize {MatMul, Gather}. Operators to quantize. Default is MatMul.",
+    )
+    parser.add_argument(
+        "--quant_axes",
+        type=parse_key_value_pair,
+        nargs="+",
+        required=False,
+        help="Key-value pairs in op_type:axis_to_quantize separated by space."
+        "Specify the axis to quantize for an op. Default {MatMul:0, Gather:1}"
+        "Example: --quant_axes MatMul:0 Gather:1",
+    )
 
     return parser.parse_args()
 
@@ -821,6 +1089,8 @@ def parse_args():
     input_model_path = args.input_model
     output_model_path = args.output_model
     quant_format = QuantFormat[args.quant_format]
+    op_types_to_quantize = tuple(args.op_types_to_quantize) if args.op_types_to_quantize else None
+    quant_axes = tuple(args.quant_axes) if args.quant_axes else None
 
     if os.path.exists(output_model_path):
         logger.error(f"file {output_model_path} already exists")
@@ -832,18 +1102,22 @@ def parse_args():
 
     model = onnx.load(input_model_path)
     if args.quant_method == "hqq":
-        quant_config = HQQWeightOnlyQuantConfig(block_size=args.block_size, bits=args.bits)
+        quant_config = HQQWeightOnlyQuantConfig(
+            block_size=args.block_size, bits=args.bits, op_types_to_quantize=op_types_to_quantize, quant_axes=quant_axes
+        )
     elif args.quant_method == "default":
         quant_config = DefaultWeightOnlyQuantConfig(
             block_size=args.block_size,
             is_symmetric=args.symmetric,
             accuracy_level=args.accuracy_level,
             quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
+            quant_axes=quant_axes,
         )
     elif args.quant_method == "rtn":
-        quant_config = RTNWeightOnlyQuantConfig()
+        quant_config = RTNWeightOnlyQuantConfig(op_types_to_quantize=op_types_to_quantize)
     elif args.quant_method == "gptq":
-        quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size)
+        quant_config = GPTQWeightOnlyQuantConfig(block_size=args.block_size, op_types_to_quantize=op_types_to_quantize)
     else:
         raise ValueError(f"Unsupported quantization method: {args.quant_method}")
 
@@ -851,6 +1125,7 @@ def parse_args():
         model=model,
         accuracy_level=args.accuracy_level,
         nodes_to_exclude=args.nodes_to_exclude,
+        nodes_to_include=args.nodes_to_include,
         algo_config=quant_config,
     )
     quant.process()
diff --git a/onnxruntime/python/tools/quantization/quantize.py b/onnxruntime/python/tools/quantization/quantize.py
index 2340c995d3d5b..745344dc01fcb 100644
--- a/onnxruntime/python/tools/quantization/quantize.py
+++ b/onnxruntime/python/tools/quantization/quantize.py
@@ -699,9 +699,8 @@ def quantize(
     Args:
         model_input (str | Path | ModelProto): Path to the model or ModelProto to quantize.
         model_output (str | Path): Path to save the quantized model.
-        quant_config (QuantConfig): Quantization Configuration.
+        quant_config (QuantConfig | WeightOnlyQuantConfig): Quantization Configuration.
     """
-
     if isinstance(quant_config, StaticQuantConfig):
         quantize_static(
             model_input,
@@ -734,4 +733,16 @@ def quantize(
             extra_options=quant_config.extra_options,
         )
     else:
-        raise TypeError("Invalid quantization config type, it must be either StaticQuantConfig or DynamicQuantConfig.")
+        # training package doesn't has quantize_matmul_4bits, avoid global import
+        from .matmul_4bits_quantizer import MatMul4BitsQuantizer, WeightOnlyQuantConfig
+
+        if isinstance(quant_config, WeightOnlyQuantConfig):
+            model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input)
+            quant = MatMul4BitsQuantizer(model, algo_config=quant_config)
+            quant.process()
+            quant.model.save_model_to_file(model_output, True)
+        else:
+            raise TypeError(
+                "Invalid quantization config type, it must be either StaticQuantConfig, "
+                "DynamicQuantConfig, or WeightOnlyQuantConfig."
+            )
diff --git a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
index 0438d93227524..292dc50124c16 100644
--- a/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
+++ b/onnxruntime/test/python/quantization/test_op_matmul_4bits.py
@@ -51,12 +51,19 @@ def fill_int4_data(self, shape: Union[int, Tuple[int, ...]], symmetric: bool) ->
 
         return line.reshape(shape)
 
-    def input_feeds(self, n: int, name2shape: Dict[str, Union[int, Tuple[int, ...]]]) -> TestDataFeeds:
+    def input_feeds(
+        self,
+        n: int,
+        name2shape: Dict[str, Union[int, Tuple[int, ...]]],
+        low: int = -1,
+        high: int = 2,
+        dtype: type = np.float32,
+    ) -> TestDataFeeds:
         input_data_list = []
         for _i in range(n):
             inputs = {}
             for name, shape in name2shape.items():
-                inputs.update({name: np.random.randint(-1, 2, shape).astype(np.float32)})
+                inputs.update({name: np.random.randint(low, high, shape).astype(dtype)})
             input_data_list.extend([inputs])
         dr = TestDataFeeds(input_data_list)
         return dr
@@ -111,6 +118,65 @@ def make_matmul(
 
         onnx.save(model, output_model_path)
 
+    def construct_model_gather(
+        self,
+        output_model_path: str,
+        symmetric: bool,
+        tdata: TensorProto.DataType,
+        tind: TensorProto.DataType,
+        vocab_size: int = 545,
+        embedding_len: int = 228,
+    ) -> None:
+        #      (input)
+        #         |
+        #       Gather
+        #         |
+        #      (output)
+        indices_name = "input"
+        output_name = "output"
+        initializers = []
+
+        def make_gather(
+            indices_name, data_shape: Union[int, Tuple[int, ...]], data_name: str, output_name: str, node_name: str
+        ):
+            weight_data = self.fill_int4_data(data_shape, symmetric).astype(
+                np.float32 if tdata == TensorProto.FLOAT else np.float16
+            )
+            initializers.append(onnx.numpy_helper.from_array(weight_data, name=data_name))
+            kwargs = {"axis": 0}
+            return onnx.helper.make_node(
+                "Gather",
+                [data_name, indices_name],
+                [output_name],
+                node_name,
+                **kwargs,
+            )
+
+        gather_node = make_gather(
+            indices_name,
+            (vocab_size, embedding_len),
+            "linear1.weight",
+            output_name,
+            "Gather_0",
+        )
+
+        # make graph
+        input_tensor = helper.make_tensor_value_info(indices_name, tind, [-1, 1000])
+        output_tensor = helper.make_tensor_value_info(output_name, tdata, [-1, 1000, embedding_len])
+        graph_name = "gather_4bits_test"
+        graph = helper.make_graph(
+            [gather_node],
+            graph_name,
+            [input_tensor],
+            [output_tensor],
+            initializer=initializers,
+        )
+        # QDQ and gather requires op set >= 21. The tool should automatically update the opset.
+        model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 19)])
+        model.ir_version = 9  # use stable onnx ir version
+
+        onnx.save(model, output_model_path)
+
     def quant_test(
         self,
         model_fp32_path: str,
@@ -118,9 +184,13 @@ def quant_test(
         block_size: int,
         is_symmetric: bool,
         quant_format: quant_utils.QuantFormat = quant_utils.QuantFormat.QOperator,
+        op_types_to_quantize: Tuple[str, ...] = ("MatMul",),
+        quant_axes: Tuple[Tuple[str, int], ...] = (("MatMul", 0), ("Gather", 1)),
+        rtol: float = 0.01,
+        atol: float = 0.05,
     ):
         use_qdq = quant_format == quant_utils.QuantFormat.QDQ
-        name_prefix = "DQ_MatMul" if use_qdq else "MatMulNBits"
+        name_prefix = "QDQ" if use_qdq else "QOperator"
         model_int4_path = str(
             Path(self._tmp_model_dir.name).joinpath(f"{name_prefix}_{block_size}_{is_symmetric}.onnx").absolute()
         )
@@ -130,13 +200,20 @@ def quant_test(
 
         model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
         quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
-            block_size=block_size, is_symmetric=is_symmetric, quant_format=quant_format
+            block_size=block_size,
+            is_symmetric=is_symmetric,
+            quant_format=quant_format,
+            op_types_to_quantize=op_types_to_quantize,
+            quant_axes=quant_axes,
         )
         quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
         quant.process()
         quant.model.save_model_to_file(model_int4_path, False)
 
-        quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1}
+        if "Gather" in op_types_to_quantize:
+            quant_nodes = {"GatherBlockQuantized": 1}
+        else:
+            quant_nodes = {"DequantizeLinear": 1, "MatMul": 1} if use_qdq else {"MatMulNBits": 1}
         check_op_type_count(self, model_int4_path, **quant_nodes)
 
         if use_qdq:
@@ -163,7 +240,7 @@ def quant_test(
         data_reader.rewind()
 
         try:
-            check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next())
+            check_model_correctness(self, model_fp32_path, model_int4_path, data_reader.get_next(), rtol, atol)
         except Exception as exception:
             if "4b quantization not yet supported on this hardware platform!" in exception.args[0]:
                 # Currently we don't have int4 quantization support on all platforms, has to tolerate this exception
@@ -224,7 +301,7 @@ def test_quantize_matmul_int4_symmetric(self):
 
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=True)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test(model_fp32_path, data_reader, 32, True)
 
     @unittest.skipIf(
@@ -233,9 +310,31 @@ def test_quantize_matmul_int4_symmetric(self):
     def test_quantize_matmul_int4_offsets(self):
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=False)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test(model_fp32_path, data_reader, 32, False)
 
+    @unittest.skipIf(
+        find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
+    )
+    def test_quantize_gather_int4_symmetric(self):
+        np.random.seed(13)
+
+        model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_symmetric.onnx").absolute())
+        self.construct_model_gather(model_fp32_path, True, TensorProto.FLOAT, TensorProto.INT32)
+        data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int32)
+        # cover rounding error
+        self.quant_test(model_fp32_path, data_reader, 32, True, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5)
+
+    @unittest.skipIf(
+        find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
+    )
+    def test_quantize_gather_int4_offsets(self):
+        model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("gather_fp32_offset.onnx").absolute())
+        self.construct_model_gather(model_fp32_path, False, TensorProto.FLOAT16, TensorProto.INT64)
+        data_reader = self.input_feeds(1, {"input": (100, 1000)}, -545, 535, np.int64)
+        # cover rounding error
+        self.quant_test(model_fp32_path, data_reader, 32, False, op_types_to_quantize=("Gather",), rtol=0.2, atol=0.5)
+
     @unittest.skipIf(
         find_spec("onnxruntime.training"), "Skip because training package doesn't has quantize_matmul_4bits"
     )
@@ -244,7 +343,7 @@ def test_quantize_matmul_int4_symmetric_qdq(self):
 
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_symmetric.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=True)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test(model_fp32_path, data_reader, 32, True, quant_utils.QuantFormat.QDQ)
 
     @unittest.skipIf(
@@ -253,7 +352,7 @@ def test_quantize_matmul_int4_symmetric_qdq(self):
     def test_quantize_matmul_int4_offsets_qdq(self):
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=False)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test(model_fp32_path, data_reader, 32, False, quant_utils.QuantFormat.QDQ)
 
     @unittest.skipIf(
@@ -264,7 +363,7 @@ def test_quantize_matmul_int4_using_rtn_algo(self):
             self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=False)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test_with_algo("RTN", model_fp32_path, data_reader, 32, False)
 
     @unittest.skipIf(
@@ -275,7 +374,7 @@ def test_quantize_matmul_int4_using_gptq_algo(self):
             self.skipTest("skip test_smooth_quant since neural_compressor is not installed")
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=False)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test_with_algo("GPTQ", model_fp32_path, data_reader, 32, False)
 
     @unittest.skipIf(
@@ -286,7 +385,7 @@ def test_quantize_matmul_int4_using_hqq_algo(self):
             self.skipTest("skip test_hqq_quant since torch is not installed")
         model_fp32_path = str(Path(self._tmp_model_dir.name).joinpath("matmul_fp32_offset.onnx").absolute())
         self.construct_model_matmul(model_fp32_path, symmetric=False)
-        data_reader = self.input_feeds(1, {"input": [100, 52]})
+        data_reader = self.input_feeds(1, {"input": (100, 52)})
         self.quant_test_with_algo("HQQ", model_fp32_path, data_reader, 32, False)