Skip to content

Commit

Permalink
Support for 3D activations in data-aware weight compression
Browse files Browse the repository at this point in the history
  • Loading branch information
ljaljushkin committed Oct 1, 2024
1 parent 2c8b70c commit 9add286
Show file tree
Hide file tree
Showing 15 changed files with 174 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ def process_stats(stats: List[Tensor], subset_size: int) -> Tuple[Tensor, Tensor
X - average channel magnitude across tokens in the sequence [HiddenDim, SampleSize]
:rtype: Tuple[TTensor, TTensor]
"""
X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) # [Batch, HiddenDim]
# Transposed input activations are not supported, the hidden dimension is always the last dimension.
# Need to reduce over all axis, except hidden dimension.
reduction_axis = tuple(range(stats[0].ndim)[:-1])
X = fns.stack([fns.mean(stat, axis=reduction_axis) for stat in stats]) # [Batch, HiddenDim]
X_full = fns.transpose(X) # [HiddenDim, Batch]

# prevent high memory and time consumption
Expand Down
6 changes: 4 additions & 2 deletions nncf/quantization/algorithms/weight_compression/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,7 +369,7 @@ def apply(
self._set_weight_compression_config(ratio_defining_params, model, graph, activations)
nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params))

if self._awq and activations is not None and self._mode != CompressWeightsMode.E2M1:
if self._awq and activations is not None:
awq_params = self._advanced_parameters.awq_params
awq_algo = AWQ(
model,
Expand Down Expand Up @@ -399,7 +399,7 @@ def apply(
backend_entity=self._backend_entity,
)
else:
if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.E2M1:
if self._scale_estimation and activations is not None:
scale_estimation_params = self._advanced_parameters.scale_estimation_params
scale_algo = ScaleEstimation(
model,
Expand Down Expand Up @@ -549,6 +549,8 @@ def _get_activations(
matmul_metatypes = self._backend_entity.matmul_metatypes
filtered_nodes = filter(lambda node: node.metatype in matmul_metatypes, nodes_to_compress)
for node in filtered_nodes:
if node.layer_attributes.input_attributes["transpose"]:
raise nncf.UnsupportedModelError("Transposed input is not supported")
act_node, output_port_id = self._get_activation_node_and_port(node, graph)
act_node_name = act_node.node_name
if act_node_name in all_act_nodes:
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/weight_compression/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, TypeVar

import nncf
from nncf import Dataset
from nncf import nncf_logger
from nncf.common.factory import ModelTransformerFactory
Expand Down Expand Up @@ -115,7 +116,7 @@ def _set_backend_entity(self, model: TModel) -> None:
self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping)
self._patterns = self._backend_entity.get_awq_patterns()
else:
raise RuntimeError(
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value)
)

Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/weight_compression/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,9 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor:
nsamples = 0

if node.metatype in self._backend_entity.convolution_metatypes:
raise RuntimeError("Convolution metatypes are not supported")
raise nncf.UnsupportedModelError("Convolution metatypes are not supported")
if node.layer_attributes.input_attributes["transpose"]:
raise RuntimeError("Transpose is not supported")
raise nncf.UnsupportedModelError("Transposed input is not supported")

hessian = fns.zeros(
(inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32
Expand Down Expand Up @@ -264,7 +264,7 @@ def _quantize_weights(
scales.append(scale)
else:
if self._scale_estimation and block_compression_config.num_bits == 4:
activations = [inp.squeeze()[:, (i1 + i) : (i1 + i + group_size)] for inp in inputs]
activations = [inp.squeeze()[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs]
scale, zero_point = ScaleEstimation.calculate_quantization_params(
self._backend_entity,
activations,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt
import pandas as pd

import nncf
from nncf.common.logging import nncf_logger
from nncf.common.utils.debug import DEBUG_LOG_DIR
from nncf.common.utils.debug import is_debug
Expand Down Expand Up @@ -176,7 +177,7 @@ def calculate_low_rank_matrices(
indexes = do_nf4_quantization(compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True)
fq_weights = do_nf4_dequantization(indexes, compressed_weight.scale, reduction_axis)
else:
raise ValueError(
raise nncf.InternalError(
f"{mode.value} mode is invalid for Lora Correction algorithm. Supported modes: INT4_SYM, INT4_ASYM, NF4"
)
# fq_w + residual = w => residual = w - fq_w
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import openvino as ov
from openvino.runtime import opset13 as opset

import nncf
from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
from nncf.common.graph.operator_metatypes import OperatorMetatype
Expand Down Expand Up @@ -221,7 +222,7 @@ def _create_compression_subgraph(
elif compression_config.mode == CompressWeightsMode.INT8_ASYM:
compression_dtype = ov.Type.u8
else:
raise ValueError(f"{compression_config.mode.value} is not supported.")
raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.")

original_shape = weight.shape
compressed_weight = compress_weight(weight, reduction_axes, compression_config, layer_scales, layer_zero_points)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, TypeVar

import nncf
from nncf import Dataset
from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.graph import NNCFNode
Expand Down Expand Up @@ -100,7 +101,7 @@ def _set_backend_entity(self, model: TModel) -> None:

self._backend_entity = OVWeightCompressionAlgoBackend(model, self.name_to_node_mapping)
else:
raise RuntimeError(
raise nncf.UnsupportedBackendError(
"Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def transform_model(
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8,
]:
raise ValueError(f"{compression_config.mode.value} is not supported.")
raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.")

weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_name = weight_node.layer_attributes.name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def transform_model(
CompressWeightsMode.INT8_SYM,
CompressWeightsMode.INT8,
]:
raise ValueError(f"{compression_config.mode.value} is not supported.")
raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.")
weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph)
weight_name = weight_node.node_name
weight = self.get_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,14 @@ def reshape_weight_for_grouped_quantization(
if isinstance(reduction_axes, tuple) and len(reduction_axes) == 1:
reduction_axes = reduction_axes[0]
if not isinstance(reduction_axes, int):
raise NotImplementedError(
raise nncf.UnsupportedModelError(
f"Group-wise quantization expects a single reduction axis, but given: {reduction_axes}."
)
channel_size = weight.shape[reduction_axes]
if channel_size % group_size != 0:
raise nncf.ValidationError(f"Channel size {channel_size} should be divisible by size of group {group_size}")
raise nncf.UnsupportedModelError(
f"Channel size {channel_size} should be divisible by size of group {group_size}"
)

num_groups_per_channel = channel_size // group_size
shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis
Expand Down
30 changes: 15 additions & 15 deletions nncf/quantization/quantize_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def quantize(
:rtype: TModel
"""
if subset_size < 1:
raise ValueError("Subset size must be positive.")
raise nncf.ValidationError("Subset size must be positive.")

advanced_parameters = _update_advanced_quantization_parameters(advanced_parameters, calibration_dataset)

Expand Down Expand Up @@ -463,27 +463,27 @@ def compress_weights(
from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl

if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"Torch backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
f"but given {mode.value} mode."
)

if True in [awq, scale_estimation, gptq, lora_correction]:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"Torch backend does not support 'awq', 'scale_estimation', 'gptq' and 'lora_correction' options. "
"Set them to None."
)

if is_wrapped_model(model):
if not model.nncf.trace_parameters:
raise ValueError(
raise nncf.ValidationError(
"Tracing capabilities with tracing parameters are required in the PyTorch model "
"for nncf.compress_weights(). Please wrap the model using "
"nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling "
"nncf.compress_weights()."
)
elif dataset is None:
raise AttributeError("Please provide a dataset of at least one element for PyTorch model tracing.")
raise nncf.ValidationError("Please provide a dataset of at least one element for PyTorch model tracing.")
else:
example_input = next(iter(dataset.get_inference_data()))
model = wrap_model(model, example_input=example_input, trace_parameters=True)
Expand All @@ -496,18 +496,18 @@ def compress_weights(
)

if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, "
f"but given {mode.value} mode."
)

if any((awq, scale_estimation, gptq, lora_correction)):
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"TorchFX backend does not support 'awq', 'scale_estimation', 'gptq',"
"and 'lora_correction' options. Set them to None."
)
if dataset:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"TorchFX only supports data-free weights compression," "Set the 'dataset' option to None"
)
compression_weights_impl = fx_compression_weights_impl
Expand All @@ -518,13 +518,13 @@ def compress_weights(
if any((awq, scale_estimation, gptq, lora_correction)) and (
dataset is None or mode == CompressWeightsMode.E2M1
):
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, "
"but dataset is None or mode is E2M1."
)

if gptq and lora_correction:
raise AttributeError(
raise nncf.ValidationError(
"Simultaneous use of Lora correction and GPTQ algorithms is not supported. Select one of them."
)

Expand All @@ -536,7 +536,7 @@ def compress_weights(
if group_size is None:
group_size = -1
if ratio != 1 or group_size != -1:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
"INT8 mode assumes per-channel quantization of all layers in 8 bit. "
"Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden"
)
Expand All @@ -551,7 +551,7 @@ def compress_weights(
}
unsupported_for_int8 = [name for name, value in options.items() if value is not None]
if unsupported_for_int8:
raise AttributeError(
raise nncf.ParameterNotSupportedError(
f"INT8 modes do not support {', '.join(unsupported_for_int8)} option(s). Set them to None."
)

Expand All @@ -578,14 +578,14 @@ def compress_weights(
else SensitivityMetric.MAX_ACTIVATION_VARIANCE
)
if ratio != 1 and dataset is None and sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR:
raise AttributeError(
raise nncf.ValidationError(
f"Mixed precision selection based on the given sensitivity metric={sensitivity_metric.value} requires "
"a dataset, but it's not provided."
)
if ratio < 0 or ratio > 1:
raise ValueError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.")
raise nncf.ValidationError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.")
if subset_size is None or subset_size <= 0:
raise ValueError(f"The subset_size value should be positive, but subset_size={subset_size} is given.")
raise nncf.ValidationError(f"The subset_size value should be positive, but subset_size={subset_size} is given.")

if compression_weights_impl is None:
raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}")
Expand Down
2 changes: 1 addition & 1 deletion tests/openvino/native/quantization/test_quantize_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,6 @@ def get_mock_model() -> Model:
def test_non_positive_subset_size():
model_to_test = get_mock_model()

with pytest.raises(ValueError) as e:
with pytest.raises(nncf.ValidationError) as e:
nncf.quantize(model_to_test, Dataset(MockDataset(INPUT_SHAPE)), subset_size=0)
assert "Subset size must be positive." in e.info
Loading

0 comments on commit 9add286

Please sign in to comment.