From 9add286c4a1ac9cc4142e10d19b5d510f0ff9635 Mon Sep 17 00:00:00 2001 From: Nikolay Date: Tue, 1 Oct 2024 20:47:23 +0200 Subject: [PATCH] Support for 3D activations in data-aware weight compression --- .../weight_compression/activation_stats.py | 5 +- .../weight_compression/algorithm.py | 6 +- .../algorithms/weight_compression/awq.py | 3 +- .../algorithms/weight_compression/gptq.py | 6 +- .../weight_compression/lora_correction.py | 3 +- .../weight_compression/openvino_backend.py | 3 +- .../weight_compression/scale_estimation.py | 3 +- .../weight_compression/torch_backend.py | 2 +- .../weight_compression/torch_fx_backend.py | 2 +- .../weight_compression/weight_lowering.py | 6 +- nncf/quantization/quantize_model.py | 30 ++-- .../native/quantization/test_quantize_api.py | 2 +- .../quantization/test_weights_compression.py | 156 ++++++++++++++---- tests/torch/fx/test_compress_weights.py | 5 +- tests/torch/ptq/test_weights_compression.py | 5 +- 15 files changed, 174 insertions(+), 63 deletions(-) diff --git a/nncf/quantization/algorithms/weight_compression/activation_stats.py b/nncf/quantization/algorithms/weight_compression/activation_stats.py index 359887e7769..e172b9696ab 100644 --- a/nncf/quantization/algorithms/weight_compression/activation_stats.py +++ b/nncf/quantization/algorithms/weight_compression/activation_stats.py @@ -28,7 +28,10 @@ def process_stats(stats: List[Tensor], subset_size: int) -> Tuple[Tensor, Tensor X - average channel magnitude across tokens in the sequence [HiddenDim, SampleSize] :rtype: Tuple[TTensor, TTensor] """ - X = fns.stack([fns.mean(stat, axis=0) for stat in stats]) # [Batch, HiddenDim] + # Transposed input activations are not supported, the hidden dimension is always the last dimension. + # Need to reduce over all axis, except hidden dimension. + reduction_axis = tuple(range(stats[0].ndim)[:-1]) + X = fns.stack([fns.mean(stat, axis=reduction_axis) for stat in stats]) # [Batch, HiddenDim] X_full = fns.transpose(X) # [HiddenDim, Batch] # prevent high memory and time consumption diff --git a/nncf/quantization/algorithms/weight_compression/algorithm.py b/nncf/quantization/algorithms/weight_compression/algorithm.py index 24fc509c85e..d326f5b2227 100644 --- a/nncf/quantization/algorithms/weight_compression/algorithm.py +++ b/nncf/quantization/algorithms/weight_compression/algorithm.py @@ -369,7 +369,7 @@ def apply( self._set_weight_compression_config(ratio_defining_params, model, graph, activations) nncf_logger.info(self._get_bitwidth_distribution_str(all_weight_params, ratio_defining_params)) - if self._awq and activations is not None and self._mode != CompressWeightsMode.E2M1: + if self._awq and activations is not None: awq_params = self._advanced_parameters.awq_params awq_algo = AWQ( model, @@ -399,7 +399,7 @@ def apply( backend_entity=self._backend_entity, ) else: - if self._scale_estimation and activations is not None and self._mode != CompressWeightsMode.E2M1: + if self._scale_estimation and activations is not None: scale_estimation_params = self._advanced_parameters.scale_estimation_params scale_algo = ScaleEstimation( model, @@ -549,6 +549,8 @@ def _get_activations( matmul_metatypes = self._backend_entity.matmul_metatypes filtered_nodes = filter(lambda node: node.metatype in matmul_metatypes, nodes_to_compress) for node in filtered_nodes: + if node.layer_attributes.input_attributes["transpose"]: + raise nncf.UnsupportedModelError("Transposed input is not supported") act_node, output_port_id = self._get_activation_node_and_port(node, graph) act_node_name = act_node.node_name if act_node_name in all_act_nodes: diff --git a/nncf/quantization/algorithms/weight_compression/awq.py b/nncf/quantization/algorithms/weight_compression/awq.py index 6271bd6c255..9361bd39bd1 100644 --- a/nncf/quantization/algorithms/weight_compression/awq.py +++ b/nncf/quantization/algorithms/weight_compression/awq.py @@ -13,6 +13,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Optional, TypeVar +import nncf from nncf import Dataset from nncf import nncf_logger from nncf.common.factory import ModelTransformerFactory @@ -115,7 +116,7 @@ def _set_backend_entity(self, model: TModel) -> None: self._backend_entity = OVAWQAlgoAlgoBackend(model, self.name_to_node_mapping) self._patterns = self._backend_entity.get_awq_patterns() else: - raise RuntimeError( + raise nncf.UnsupportedBackendError( "Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value) ) diff --git a/nncf/quantization/algorithms/weight_compression/gptq.py b/nncf/quantization/algorithms/weight_compression/gptq.py index bd6518c86ad..539b61e81c8 100644 --- a/nncf/quantization/algorithms/weight_compression/gptq.py +++ b/nncf/quantization/algorithms/weight_compression/gptq.py @@ -169,9 +169,9 @@ def _calculate_hessian(self, node: NNCFNode, inputs: List[Tensor]) -> Tensor: nsamples = 0 if node.metatype in self._backend_entity.convolution_metatypes: - raise RuntimeError("Convolution metatypes are not supported") + raise nncf.UnsupportedModelError("Convolution metatypes are not supported") if node.layer_attributes.input_attributes["transpose"]: - raise RuntimeError("Transpose is not supported") + raise nncf.UnsupportedModelError("Transposed input is not supported") hessian = fns.zeros( (inputs[0].shape[-1], inputs[0].shape[-1]), backend=inputs[0].backend, dtype=TensorDataType.float32 @@ -264,7 +264,7 @@ def _quantize_weights( scales.append(scale) else: if self._scale_estimation and block_compression_config.num_bits == 4: - activations = [inp.squeeze()[:, (i1 + i) : (i1 + i + group_size)] for inp in inputs] + activations = [inp.squeeze()[..., (i1 + i) : (i1 + i + group_size)] for inp in inputs] scale, zero_point = ScaleEstimation.calculate_quantization_params( self._backend_entity, activations, diff --git a/nncf/quantization/algorithms/weight_compression/lora_correction.py b/nncf/quantization/algorithms/weight_compression/lora_correction.py index 8907cb64a2c..a44cde169bd 100644 --- a/nncf/quantization/algorithms/weight_compression/lora_correction.py +++ b/nncf/quantization/algorithms/weight_compression/lora_correction.py @@ -14,6 +14,7 @@ import matplotlib.pyplot as plt import pandas as pd +import nncf from nncf.common.logging import nncf_logger from nncf.common.utils.debug import DEBUG_LOG_DIR from nncf.common.utils.debug import is_debug @@ -176,7 +177,7 @@ def calculate_low_rank_matrices( indexes = do_nf4_quantization(compressed_weight.tensor, compressed_weight.scale, is_normalized_weight=True) fq_weights = do_nf4_dequantization(indexes, compressed_weight.scale, reduction_axis) else: - raise ValueError( + raise nncf.InternalError( f"{mode.value} mode is invalid for Lora Correction algorithm. Supported modes: INT4_SYM, INT4_ASYM, NF4" ) # fq_w + residual = w => residual = w - fq_w diff --git a/nncf/quantization/algorithms/weight_compression/openvino_backend.py b/nncf/quantization/algorithms/weight_compression/openvino_backend.py index 2b9bc3712b8..33da8033a59 100644 --- a/nncf/quantization/algorithms/weight_compression/openvino_backend.py +++ b/nncf/quantization/algorithms/weight_compression/openvino_backend.py @@ -13,6 +13,7 @@ import openvino as ov from openvino.runtime import opset13 as opset +import nncf from nncf.common.graph import NNCFGraph from nncf.common.graph import NNCFNode from nncf.common.graph.operator_metatypes import OperatorMetatype @@ -221,7 +222,7 @@ def _create_compression_subgraph( elif compression_config.mode == CompressWeightsMode.INT8_ASYM: compression_dtype = ov.Type.u8 else: - raise ValueError(f"{compression_config.mode.value} is not supported.") + raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") original_shape = weight.shape compressed_weight = compress_weight(weight, reduction_axes, compression_config, layer_scales, layer_zero_points) diff --git a/nncf/quantization/algorithms/weight_compression/scale_estimation.py b/nncf/quantization/algorithms/weight_compression/scale_estimation.py index f9ad77375a2..6cf065ba0bd 100644 --- a/nncf/quantization/algorithms/weight_compression/scale_estimation.py +++ b/nncf/quantization/algorithms/weight_compression/scale_estimation.py @@ -12,6 +12,7 @@ from copy import deepcopy from typing import Any, Dict, List, Optional, Tuple, TypeVar +import nncf from nncf import Dataset from nncf.common.graph.graph import NNCFGraph from nncf.common.graph.graph import NNCFNode @@ -100,7 +101,7 @@ def _set_backend_entity(self, model: TModel) -> None: self._backend_entity = OVWeightCompressionAlgoBackend(model, self.name_to_node_mapping) else: - raise RuntimeError( + raise nncf.UnsupportedBackendError( "Cannot return backend-specific AWQ entity because {} is not supported!".format(model_backend.value) ) diff --git a/nncf/quantization/algorithms/weight_compression/torch_backend.py b/nncf/quantization/algorithms/weight_compression/torch_backend.py index f46d9727d63..42585e0f1e1 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_backend.py @@ -208,7 +208,7 @@ def transform_model( CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8, ]: - raise ValueError(f"{compression_config.mode.value} is not supported.") + raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) weight_name = weight_node.layer_attributes.name diff --git a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py index ca3e2d16331..fa6090fccf7 100644 --- a/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py +++ b/nncf/quantization/algorithms/weight_compression/torch_fx_backend.py @@ -172,7 +172,7 @@ def transform_model( CompressWeightsMode.INT8_SYM, CompressWeightsMode.INT8, ]: - raise ValueError(f"{compression_config.mode.value} is not supported.") + raise nncf.ParameterNotSupportedError(f"{compression_config.mode.value} is not supported.") weight_node = get_const_node(wc_params.node_with_weight, wc_params.weight_port_id, graph) weight_name = weight_node.node_name weight = self.get_weight(wc_params.node_with_weight, wc_params.weight_port_id, model, graph) diff --git a/nncf/quantization/algorithms/weight_compression/weight_lowering.py b/nncf/quantization/algorithms/weight_compression/weight_lowering.py index ef36a157040..342725c0237 100644 --- a/nncf/quantization/algorithms/weight_compression/weight_lowering.py +++ b/nncf/quantization/algorithms/weight_compression/weight_lowering.py @@ -82,12 +82,14 @@ def reshape_weight_for_grouped_quantization( if isinstance(reduction_axes, tuple) and len(reduction_axes) == 1: reduction_axes = reduction_axes[0] if not isinstance(reduction_axes, int): - raise NotImplementedError( + raise nncf.UnsupportedModelError( f"Group-wise quantization expects a single reduction axis, but given: {reduction_axes}." ) channel_size = weight.shape[reduction_axes] if channel_size % group_size != 0: - raise nncf.ValidationError(f"Channel size {channel_size} should be divisible by size of group {group_size}") + raise nncf.UnsupportedModelError( + f"Channel size {channel_size} should be divisible by size of group {group_size}" + ) num_groups_per_channel = channel_size // group_size shape = list(weight.shape) # [a1, r, a2] - "r" refers to number of channels along reduction axis diff --git a/nncf/quantization/quantize_model.py b/nncf/quantization/quantize_model.py index 1f633458e98..ecd0b0d7a9c 100644 --- a/nncf/quantization/quantize_model.py +++ b/nncf/quantization/quantize_model.py @@ -160,7 +160,7 @@ def quantize( :rtype: TModel """ if subset_size < 1: - raise ValueError("Subset size must be positive.") + raise nncf.ValidationError("Subset size must be positive.") advanced_parameters = _update_advanced_quantization_parameters(advanced_parameters, calibration_dataset) @@ -463,27 +463,27 @@ def compress_weights( from nncf.torch.quantization.quantize_model import compress_weights_impl as pt_compression_weights_impl if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: - raise AttributeError( + raise nncf.ParameterNotSupportedError( "Torch backend supports only INT8_ASYM, INT8_SYM modes for weight compression, " f"but given {mode.value} mode." ) if True in [awq, scale_estimation, gptq, lora_correction]: - raise AttributeError( + raise nncf.ParameterNotSupportedError( "Torch backend does not support 'awq', 'scale_estimation', 'gptq' and 'lora_correction' options. " "Set them to None." ) if is_wrapped_model(model): if not model.nncf.trace_parameters: - raise ValueError( + raise nncf.ValidationError( "Tracing capabilities with tracing parameters are required in the PyTorch model " "for nncf.compress_weights(). Please wrap the model using " "nncf.torch.wrap_model(model, example_input, trace_parameters=True) before calling " "nncf.compress_weights()." ) elif dataset is None: - raise AttributeError("Please provide a dataset of at least one element for PyTorch model tracing.") + raise nncf.ValidationError("Please provide a dataset of at least one element for PyTorch model tracing.") else: example_input = next(iter(dataset.get_inference_data())) model = wrap_model(model, example_input=example_input, trace_parameters=True) @@ -496,18 +496,18 @@ def compress_weights( ) if mode not in [CompressWeightsMode.INT8_ASYM, CompressWeightsMode.INT8_SYM]: - raise AttributeError( + raise nncf.ParameterNotSupportedError( "TorchFX backend supports only INT8_ASYM, INT8_SYM modes for weight compression, " f"but given {mode.value} mode." ) if any((awq, scale_estimation, gptq, lora_correction)): - raise AttributeError( + raise nncf.ParameterNotSupportedError( "TorchFX backend does not support 'awq', 'scale_estimation', 'gptq'," "and 'lora_correction' options. Set them to None." ) if dataset: - raise AttributeError( + raise nncf.ParameterNotSupportedError( "TorchFX only supports data-free weights compression," "Set the 'dataset' option to None" ) compression_weights_impl = fx_compression_weights_impl @@ -518,13 +518,13 @@ def compress_weights( if any((awq, scale_estimation, gptq, lora_correction)) and ( dataset is None or mode == CompressWeightsMode.E2M1 ): - raise AttributeError( + raise nncf.ParameterNotSupportedError( "Scale estimation, AWQ, GPTQ or Lora Correction algorithm is defined, " "but dataset is None or mode is E2M1." ) if gptq and lora_correction: - raise AttributeError( + raise nncf.ValidationError( "Simultaneous use of Lora correction and GPTQ algorithms is not supported. Select one of them." ) @@ -536,7 +536,7 @@ def compress_weights( if group_size is None: group_size = -1 if ratio != 1 or group_size != -1: - raise AttributeError( + raise nncf.ParameterNotSupportedError( "INT8 mode assumes per-channel quantization of all layers in 8 bit. " "Default values of `ratio` (1) and `group_size` (-1) parameters can not be overridden" ) @@ -551,7 +551,7 @@ def compress_weights( } unsupported_for_int8 = [name for name, value in options.items() if value is not None] if unsupported_for_int8: - raise AttributeError( + raise nncf.ParameterNotSupportedError( f"INT8 modes do not support {', '.join(unsupported_for_int8)} option(s). Set them to None." ) @@ -578,14 +578,14 @@ def compress_weights( else SensitivityMetric.MAX_ACTIVATION_VARIANCE ) if ratio != 1 and dataset is None and sensitivity_metric != SensitivityMetric.WEIGHT_QUANTIZATION_ERROR: - raise AttributeError( + raise nncf.ValidationError( f"Mixed precision selection based on the given sensitivity metric={sensitivity_metric.value} requires " "a dataset, but it's not provided." ) if ratio < 0 or ratio > 1: - raise ValueError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.") + raise nncf.ValidationError(f"The ratio should be between 0 and 1, but ratio={ratio} is specified.") if subset_size is None or subset_size <= 0: - raise ValueError(f"The subset_size value should be positive, but subset_size={subset_size} is given.") + raise nncf.ValidationError(f"The subset_size value should be positive, but subset_size={subset_size} is given.") if compression_weights_impl is None: raise nncf.UnsupportedBackendError(f"Unsupported type of backend: {backend}") diff --git a/tests/openvino/native/quantization/test_quantize_api.py b/tests/openvino/native/quantization/test_quantize_api.py index 4eeca520af8..b272553c985 100644 --- a/tests/openvino/native/quantization/test_quantize_api.py +++ b/tests/openvino/native/quantization/test_quantize_api.py @@ -32,6 +32,6 @@ def get_mock_model() -> Model: def test_non_positive_subset_size(): model_to_test = get_mock_model() - with pytest.raises(ValueError) as e: + with pytest.raises(nncf.ValidationError) as e: nncf.quantize(model_to_test, Dataset(MockDataset(INPUT_SHAPE)), subset_size=0) assert "Subset size must be positive." in e.info diff --git a/tests/openvino/native/quantization/test_weights_compression.py b/tests/openvino/native/quantization/test_weights_compression.py index 4ec30282bbd..5f610e438bf 100644 --- a/tests/openvino/native/quantization/test_weights_compression.py +++ b/tests/openvino/native/quantization/test_weights_compression.py @@ -20,15 +20,16 @@ from attr import dataclass from openvino.runtime import opset13 as opset +import nncf from nncf import CompressWeightsMode from nncf import SensitivityMetric from nncf.common.utils.debug import nncf_debug from nncf.data.dataset import Dataset -from nncf.errors import ValidationError from nncf.experimental.common.tensor_statistics.collectors import AggregatorBase from nncf.openvino.graph.node_utils import get_const_value from nncf.quantization import compress_weights from nncf.quantization.advanced_parameters import AdvancedCompressionParameters as CompressionParams +from nncf.quantization.advanced_parameters import AdvancedGPTQParameters as GPTQParams from nncf.quantization.advanced_parameters import AdvancedLoraCorrectionParameters as LoraParams from nncf.quantization.algorithms.weight_compression.config import WeightCompressionConfig from nncf.quantization.algorithms.weight_compression.config import WeightCompressionParameters @@ -77,23 +78,28 @@ class LMLinearModel(OVReferenceModel): - HIDDEN_DIM = 16 OUTPUT_DIM = 32 - INPUT_SHAPE = [24, HIDDEN_DIM] # [SeqLen, HiddenDim] + INPUT_SHAPE = [24, 16] # [SeqLen, HiddenDim] - def _create_ov_model(self, transpose_b: bool = True): - input_1 = opset.parameter(self.INPUT_SHAPE, name="Input") + def _create_ov_model(self, transpose_b: bool = True, transpose_a=False, input_shape=None): + self._input_shape = self.INPUT_SHAPE if input_shape is None else input_shape + hdim_axis = 0 if transpose_a else -1 + self._hidden_dim = self._input_shape[hdim_axis] + input_1 = opset.parameter(self._input_shape, name="Input") weight_shape = self.get_weight_shape(transpose_b) data = self._rng.random(weight_shape).astype(np.float32) - matmul = opset.matmul(input_1, data, transpose_a=False, transpose_b=transpose_b, name="MatMul") + matmul = opset.matmul(input_1, data, transpose_a=transpose_a, transpose_b=transpose_b, name="MatMul") result = opset.result(matmul, name="Result") result.get_output_tensor(0).set_names(set(["Result"])) model = ov.Model([result], [input_1]) return model - @classmethod - def get_weight_shape(cls, transpose_b: bool = True): - return [cls.OUTPUT_DIM, cls.HIDDEN_DIM] if transpose_b else [cls.HIDDEN_DIM, cls.OUTPUT_DIM] + @property + def hidden_dim(self): + return self._hidden_dim + + def get_weight_shape(self, transpose_b: bool = True): + return [self.OUTPUT_DIM, self.hidden_dim] if transpose_b else [self.hidden_dim, self.OUTPUT_DIM] def get_next_node(node): @@ -681,12 +687,12 @@ def test_calculate_scale_per_group(desc: CalculateScaleDesc): def test_raise_error_for_many_axes(): - with pytest.raises(RuntimeError): + with pytest.raises(nncf.UnsupportedModelError): reshape_weight_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0, 1), group_size=1) def test_raise_error_channel_size_is_not_divisible_by_group_size(): - with pytest.raises(ValidationError): + with pytest.raises(nncf.UnsupportedModelError): reshape_weight_for_grouped_quantization(WEIGHTS_2x4, reduction_axes=(0,), group_size=3) @@ -707,7 +713,7 @@ def test_raise_error_channel_size_is_not_divisible_by_group_size(): ), ) def test_raise_error_with_unsupported_params_for_int8(mode, params): - with pytest.raises(AttributeError): + with pytest.raises(nncf.ParameterNotSupportedError): compress_weights(ov.Model([], []), mode=mode, **params) @@ -717,15 +723,44 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): ({"dataset": "anything", "lora_correction": True, "gptq": True},), ) def test_raise_error_with_unsupported_params_for_int4(mode, params): - with pytest.raises(AttributeError): + with pytest.raises(nncf.ValidationError): compress_weights(ov.Model([], []), mode=mode, **params) +@pytest.mark.parametrize( + "algo", + ( + "lora_correction", + "awq", + "scale_estimation", + "gptq", + ), +) +def test_raise_error_with_unsupported_params_for_e2m1(algo): + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(ov.Model([], []), dataset="anything", mode=CompressWeightsMode.E2M1, **{algo: True}) + + +@pytest.mark.parametrize("mode", INT4_NF4_MODES) +@pytest.mark.parametrize( + "algo", + ( + "lora_correction", + "awq", + "scale_estimation", + "gptq", + ), +) +def test_raise_error_with_unsupported_params_for_empty_dataset(mode, algo): + with pytest.raises(nncf.ParameterNotSupportedError): + compress_weights(ov.Model([], []), dataset=None, mode=mode, **{algo: True}) + + @pytest.mark.parametrize("mode", INT4_NF4_MODES) @pytest.mark.parametrize("metric", DATA_BASED_SENSITIVITY_METRICS) def test_raise_error_with_data_metric_and_without_dataset(mode, metric): model = IntegerModel().ov_model - with pytest.raises(AttributeError): + with pytest.raises(nncf.ValidationError): compress_weights(model, mode=mode, sensitivity_metric=metric, group_size=-1, ratio=0.8) @@ -862,7 +897,7 @@ def test_default_subset_value(): def test_invalid_subset_size(subset_size): model = IdentityMatmul().ov_model dataset = Dataset([ACTIVATION]) - with pytest.raises(ValueError): + with pytest.raises(nncf.ValidationError): compress_weights(model, mode=CompressWeightsMode.INT4_ASYM, ratio=0.5, dataset=dataset, subset_size=subset_size) @@ -1059,11 +1094,12 @@ def get_shape_for_second_input(op_with_weights: ov.Node) -> List[int]: ) def test_lora_adapters_in_the_graph(params, transpose_b): advanced_parameters = CompressionParams() if params is None else CompressionParams(lora_correction_params=params) - model = LMLinearModel(transpose_b=transpose_b).ov_model - dataset = Dataset(np.ones(inp.shape) for inp in model.inputs) + model = LMLinearModel(transpose_b=transpose_b) + ov_model = model.ov_model + dataset = Dataset(np.ones(inp.shape) for inp in ov_model.inputs) compressed_model = compress_weights( - model, + ov_model, mode=CompressWeightsMode.INT4_SYM, ratio=1.0, group_size=8, @@ -1079,8 +1115,8 @@ def test_lora_adapters_in_the_graph(params, transpose_b): next_node = target_input.get_node() assert next_node.type_info.name == "MatMul" shape = get_shape_for_second_input(next_node) - if shape != LMLinearModel.get_weight_shape(transpose_b): - assert shape == [advanced_parameters.lora_correction_params.adapter_rank, LMLinearModel.HIDDEN_DIM] + if shape != model.get_weight_shape(transpose_b): + assert shape == [advanced_parameters.lora_correction_params.adapter_rank, model.hidden_dim] node = get_next_node(next_node) assert node.type_info.name == "MatMul" assert get_shape_for_second_input(node) == [ @@ -1107,9 +1143,9 @@ def test_lora_adapters_in_the_graph(params, transpose_b): def test_lora_adapters_reduce_noise(zero_seed, mode, apply_regularization, is_per_channel, mocker, tmp_path): mocker.patch("nncf.quantization.algorithms.weight_compression.lora_correction.DEBUG_LOG_DIR", str(tmp_path)) - model_cls = LMLinearModel - group_size = -1 if is_per_channel else model_cls.HIDDEN_DIM // 2 - model = model_cls().ov_model + model = LMLinearModel() + group_size = -1 if is_per_channel else model.hidden_dim // 2 + model = model.ov_model n_iters = 1 ie = ov.Core() input_data = [np.ones(inp.shape) for inp in model.inputs] @@ -1126,7 +1162,7 @@ def test_lora_adapters_reduce_noise(zero_seed, mode, apply_regularization, is_pe int4_out = next(iter(int4_out.values())) noise_before = np.mean(np.abs(fp32_out - int4_out)) - model = model_cls().ov_model + model = LMLinearModel().ov_model with nncf_debug(): int4_model = compress_weights( @@ -1204,8 +1240,9 @@ def test_compression_with_lora_for_different_dtypes(activation_dtype, weight_dty def test_compression_with_lora_with_subset_size(mocker): subset_size = 2 dataset_size = 4 - model = LMLinearModel().ov_model - input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size + model = LMLinearModel() + ov_model = model.ov_model + input_data = [np.ones(inp.shape) for inp in ov_model.inputs] * dataset_size dataset = Dataset(input_data) from nncf.quantization.algorithms.weight_compression import lora_correction @@ -1213,7 +1250,7 @@ def test_compression_with_lora_with_subset_size(mocker): get_stats_spy = mocker.spy(lora_correction, "process_stats") compress_weights( - model, + ov_model, mode=CompressWeightsMode.INT4_SYM, ratio=1.0, group_size=8, @@ -1229,8 +1266,8 @@ def test_compression_with_lora_with_subset_size(mocker): get_stats_spy.assert_called_once() s, X = get_stats_spy.spy_return - assert X.shape == (LMLinearModel.HIDDEN_DIM, subset_size) - assert s.shape == (LMLinearModel.HIDDEN_DIM,) + assert X.shape == (model.hidden_dim, subset_size) + assert s.shape == (model.hidden_dim,) def test_lora_with_mixed_precision(): @@ -1245,3 +1282,64 @@ def test_lora_with_mixed_precision(): op_name = op.get_friendly_name() if op.get_type_name() == "Constant" and ("/zero_point" in op_name or "/scale" in op_name): assert op.get_shape() == [sz, 1] + + +@pytest.mark.parametrize( + "kwargs", + [ + dict(scale_estimation=True), + dict(lora_correction=True), + dict( + gptq=True, + scale_estimation=True, + advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), + ), + ], +) +def test_compression_with_3D_activations(kwargs): + dataset_size = 4 + model = LMLinearModel(input_shape=[3, 5, 16]).ov_model + input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size + dataset = Dataset(input_data) + + compress_weights( + model, + mode=CompressWeightsMode.INT4_SYM, + ratio=1.0, + group_size=8, + subset_size=2, + dataset=dataset, + all_layers=True, + **kwargs, + ) + + +@pytest.mark.parametrize( + "kwargs", + [ + dict(scale_estimation=True), + dict(lora_correction=True), + dict( + gptq=True, + scale_estimation=True, + advanced_parameters=CompressionParams(gptq_params=GPTQParams(subset_size=2)), + ), + ], +) +def test_compression_with_transposed_activations(kwargs): + dataset_size = 4 + model = LMLinearModel(input_shape=[24, 16], transpose_a=True, transpose_b=False).ov_model + input_data = [np.ones(inp.shape) for inp in model.inputs] * dataset_size + dataset = Dataset(input_data) + + with pytest.raises(nncf.UnsupportedModelError): + compress_weights( + model, + mode=CompressWeightsMode.INT4_SYM, + ratio=1.0, + group_size=8, + subset_size=2, + dataset=dataset, + all_layers=True, + **kwargs, + ) diff --git a/tests/torch/fx/test_compress_weights.py b/tests/torch/fx/test_compress_weights.py index 1d5012d5d57..1c4e5dd73dc 100644 --- a/tests/torch/fx/test_compress_weights.py +++ b/tests/torch/fx/test_compress_weights.py @@ -15,6 +15,7 @@ import torch from torch._export import capture_pre_autograd_graph +import nncf from nncf import CompressWeightsMode from nncf.common.factory import NNCFGraphFactory from nncf.data.dataset import Dataset @@ -214,7 +215,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() exported_model = _capture_model(dummy_torch_model, dummy_input) - with pytest.raises(AttributeError): + with pytest.raises(nncf.ParameterNotSupportedError): compress_weights(exported_model, mode=mode, **params) @@ -223,7 +224,7 @@ def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() exported_model = _capture_model(dummy_torch_model, dummy_input) - with pytest.raises(AttributeError): + with pytest.raises(nncf.ParameterNotSupportedError): compress_weights(exported_model, mode=mode) diff --git a/tests/torch/ptq/test_weights_compression.py b/tests/torch/ptq/test_weights_compression.py index 5e4ca75e128..cdeb6fbdc99 100644 --- a/tests/torch/ptq/test_weights_compression.py +++ b/tests/torch/ptq/test_weights_compression.py @@ -13,6 +13,7 @@ import torch import torch.nn.functional as F +import nncf from nncf import CompressWeightsMode from nncf import SensitivityMetric from nncf.quantization import compress_weights @@ -220,7 +221,7 @@ def test_raise_error_with_unsupported_params_for_int8(mode, params): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) - with pytest.raises(AttributeError): + with pytest.raises(nncf.ParameterNotSupportedError): compress_weights(wrapped_model, mode=mode, **params) @@ -229,7 +230,7 @@ def test_raise_error_with_not_int8(mode): dummy_torch_model = EmptyModel() dummy_input = torch.Tensor() wrapped_model = wrap_model(dummy_torch_model, example_input=dummy_input, trace_parameters=True) - with pytest.raises(AttributeError): + with pytest.raises(nncf.ParameterNotSupportedError): compress_weights(wrapped_model, mode=mode)