Skip to content

Commit

Permalink
Qaulcomm AI Engine Direct - Fix quantization annotation for per chann…
Browse files Browse the repository at this point in the history
…el quant (#7026)

summary:
- Fix the 8a8w config in custom annotation
- Enable to set act observer and symmetic argument for per channel quant
- Remove unuse custom annotation in llama.py
  • Loading branch information
shewu-quic authored Nov 22, 2024
1 parent abd739e commit a7ed425
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 20 deletions.
1 change: 1 addition & 0 deletions backends/qualcomm/_passes/layout_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class LayoutTransform(ExportPass):
exir_ops.edge.aten.prelu.default,
exir_ops.edge.aten.relu.default,
exir_ops.edge.aten._softmax.default, # TODO: Need to find a new solution to do "axis_order" to transform axis.
exir_ops.edge.aten.sigmoid.default,
exir_ops.edge.aten.sqrt.default,
exir_ops.edge.aten.sub.Tensor,
exir_ops.edge.aten.sum.dim_IntList,
Expand Down
1 change: 1 addition & 0 deletions backends/qualcomm/partition/common_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
exir_ops.edge.aten.full.default,
exir_ops.edge.aten.slice_scatter.default,
exir_ops.edge.aten.copy.default,
exir_ops.edge.quantized_decomposed.embedding_4bit.dtype,
]

to_be_implemented_operator = [
Expand Down
2 changes: 1 addition & 1 deletion backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from torch.fx import Node


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None:
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
"""
This function is specific for matmul op 16a8w.
"""
Expand Down
29 changes: 22 additions & 7 deletions backends/qualcomm/quantizer/qconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,7 @@ def get_ptq_per_channel_quant_config(
act_dtype=torch.uint8,
weight_dtype=torch.int8,
act_observer=MovingAverageMinMaxObserver,
act_symmetric: bool = False,
) -> QuantizationConfig:
extra_args: Dict[str, Any] = {"eps": 2**-12}

Expand All @@ -241,13 +242,27 @@ def get_ptq_per_channel_quant_config(
), f"weight_dtype, {weight_dtype} is not one of supported types, {supported_weight_dtypes}"

# torch do not support uint16 quantization, use int32 to bypass
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)
if act_symmetric:
# If zero_point is 128, htp can do optimizations.
# If we keep quant_min and quant_max none, observer will default use 128 as zero_point.
# If we provide uint8 quant_min/max, it will use 127 as zero_point, which is undesired.
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
qscheme=torch.per_tensor_symmetric,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)
else:
# PyTorch will remove redundant observers based on attributes such as:
# dtype, quant_min, quant_max, ch_axis, etc.
# Providing values like quant_min and quant_max can help observers compare
# and further reduce the number of observers.
act_quantization_spec = QuantizationSpec(
dtype=torch.int32 if act_dtype == torch.uint16 else act_dtype,
quant_min=torch.iinfo(act_dtype).min,
quant_max=torch.iinfo(act_dtype).max,
qscheme=torch.per_tensor_affine,
observer_or_fake_quant_ctr=act_observer.with_args(**extra_args),
)

weight_quantization_spec = QuantizationSpec(
dtype=torch.int8 if weight_dtype == "int4" else weight_dtype,
Expand Down
40 changes: 32 additions & 8 deletions backends/qualcomm/quantizer/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from enum import IntEnum, unique
from functools import partial
from typing import Callable, Optional, Sequence, Set

import torch
Expand Down Expand Up @@ -67,28 +68,44 @@ class QuantDtype(IntEnum):
# PTQ
(QuantDtype.use_16a16w, False): (
get_16a16w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, torch.int16),
partial(
get_ptq_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype=torch.int16,
),
),
(QuantDtype.use_16a8w, False): (
get_16a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, torch.int8),
partial(
get_ptq_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype=torch.int8,
),
),
(QuantDtype.use_16a4w, False): (
get_16a4w_qnn_ptq_config,
get_ptq_per_channel_quant_config(torch.uint16, "int4"),
partial(
get_ptq_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype="int4",
),
),
(QuantDtype.use_8a8w, False): (
get_8a8w_qnn_ptq_config,
get_ptq_per_channel_quant_config(),
partial(get_ptq_per_channel_quant_config),
),
# QAT,
(QuantDtype.use_16a4w, True): (
get_16a4w_qnn_qat_config,
get_qat_per_channel_quant_config(torch.uint16, "int4"),
partial(
get_qat_per_channel_quant_config,
act_dtype=torch.uint16,
weight_dtype="int4",
),
),
(QuantDtype.use_8a8w, True): (
get_8a8w_qnn_qat_config,
get_qat_per_channel_quant_config(),
partial(get_qat_per_channel_quant_config),
),
}

Expand Down Expand Up @@ -176,11 +193,18 @@ def set_quant_config(
f"the quant config, (quant_dtype: {quant_dtype}, is_qat: {is_qat}) is not support"
)

quant_config_fuc, self.per_channel_quant_config = quant_config_dict[
quant_config_fuc, per_channel_quant_config_fuc = quant_config_dict[
(quant_dtype, is_qat)
]
self.quant_config = (
quant_config_fuc(act_observer) if act_observer else quant_config_fuc()
quant_config_fuc(act_observer=act_observer)
if act_observer
else quant_config_fuc()
)
self.per_channel_quant_config = (
per_channel_quant_config_fuc(act_observer=act_observer)
if act_observer
else per_channel_quant_config_fuc()
)

def set_per_channel_conv_quant(self, enable: bool) -> None:
Expand Down
5 changes: 1 addition & 4 deletions examples/qualcomm/oss_scripts/llama3_2/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,7 @@ def compile(args):
start_quantize_ts = time.time()
single_llama.quantize(
quant_dtype,
custom_annotations=(
custom_annotate_llama_last_conv_16a8w,
annotate_matmul_16a8w,
),
custom_annotations=(annotate_matmul_16a8w,),
)
end_quantize_ts = time.time()
logging.info(f"Time for quantizing: {end_quantize_ts - start_quantize_ts}")
Expand Down

0 comments on commit a7ed425

Please sign in to comment.