diff --git a/nncf/experimental/torch/fx/commands.py b/nncf/experimental/torch/fx/commands.py index 831f177cac7..064e08a223a 100644 --- a/nncf/experimental/torch/fx/commands.py +++ b/nncf/experimental/torch/fx/commands.py @@ -33,5 +33,5 @@ def __init__( :param priority: Transformation priority. """ super().__init__(TransformationType.INSERT) - self.tranformation_fn = transformation_fn + self.transformation_fn = transformation_fn self.priority = priority diff --git a/nncf/experimental/torch/fx/model_transformer.py b/nncf/experimental/torch/fx/model_transformer.py index 4be8f306051..caf90dfac6d 100644 --- a/nncf/experimental/torch/fx/model_transformer.py +++ b/nncf/experimental/torch/fx/model_transformer.py @@ -18,6 +18,7 @@ from nncf.common.graph.model_transformer import ModelTransformer from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand +from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name from nncf.torch.graph.transformations.commands import PTModelExtractionCommand from nncf.torch.graph.transformations.layout import PTTransformationLayout @@ -97,7 +98,23 @@ def _apply_model_extraction( # TODO(dlyakhov): reduce memory consumption by # more optimal splitting implementation. splitted_gm = split_by_tags(model, tags) - return splitted_gm.extracted + + extracted_model = splitted_gm.extracted + graph: torch.fx.Graph = extracted_model.graph + # Check extracted model has inputs. + # It is possible to have two constant inputs + # for the target layer, an placeholder is being + # placed to the input port. + target_node = get_graph_node_by_name(graph, node_name) + input_node = target_node.all_input_nodes[0] + if input_node.op != "placeholder": + with graph.inserting_before(target_node): + new_input_node = graph.create_node( + "placeholder", "placeholder_node", (), {}, name="placeholder_graph_node" + ) + target_node.replace_input_with(input_node, new_input_node) + extracted_model.graph.eliminate_dead_code() + return extracted_model @staticmethod def _apply_transformation( @@ -112,5 +129,5 @@ def _apply_transformation( :return: Target model after all transformations were applied. """ for transformation in transformations: - transformation.tranformation_fn(model) + transformation.transformation_fn(model) return model diff --git a/nncf/experimental/torch/fx/transformations.py b/nncf/experimental/torch/fx/transformations.py index a6248a2a76c..6e10e8f25f4 100644 --- a/nncf/experimental/torch/fx/transformations.py +++ b/nncf/experimental/torch/fx/transformations.py @@ -79,7 +79,7 @@ def bias_update_transformation(model: torch.fx.GraphModule): return bias_update_transformation -def qdq_insertion_tranformation_builder( +def qdq_insertion_transformation_builder( quantizer: FakeQuantize, target_points: List[PTTargetPoint] ) -> TransformationFNType: """ @@ -92,7 +92,7 @@ def qdq_insertion_tranformation_builder( inherited from the given quantizer to each given target point. """ - def qdq_insertion_tranformation(model: torch.fx.GraphModule): + def qdq_insertion_transformation(model: torch.fx.GraphModule): if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1: raise RuntimeError( "Insertion of shared qdq pair for the weights is not supported." @@ -101,7 +101,7 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule): for target_point in target_points: insert_one_qdq(model, target_point, quantizer) - return qdq_insertion_tranformation + return qdq_insertion_transformation def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize): @@ -310,7 +310,7 @@ def _is_bn_node(node: torch.fx.Node): def fuse_conv_bn(model: torch.fx.GraphModule) -> None: """ - BatchNorm operations have 3 output ports, to make it easier for alorithms to work with + BatchNorm operations have 3 output ports, to make it easier for algorithms to work with the target graph BatchNorm operations are being fused :param model: Model to apply transformations to. @@ -342,7 +342,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None: :param model: Model to apply transformations to. """ # BatchNorm operations have 3 output ports, - # to make it easier for alorithms to work + # to make it easier for algorithms to work # with the target graph BatchNorm operations # are being fused fuse_conv_bn(model) @@ -484,7 +484,7 @@ def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[ Check which node should be merged by the given `is_target_node` predicate. :param model: Target model. - :param is_target_node: Predicate to specify nodes which shoudld be merged with the bias + :param is_target_node: Predicate to specify nodes which should be merged with the bias """ add_node_targets = (torch.ops.aten.add_.Tensor,) for n in model.graph.nodes: diff --git a/nncf/quantization/algorithms/min_max/torch_fx_backend.py b/nncf/quantization/algorithms/min_max/torch_fx_backend.py index c095836e674..f18c7fc385b 100644 --- a/nncf/quantization/algorithms/min_max/torch_fx_backend.py +++ b/nncf/quantization/algorithms/min_max/torch_fx_backend.py @@ -28,7 +28,7 @@ from nncf.experimental.common.tensor_statistics.collectors import TensorCollector from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand -from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder +from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder from nncf.parameters import ModelType from nncf.parameters import TargetDevice from nncf.quantization.advanced_parameters import StatisticsType @@ -288,7 +288,7 @@ def create_quantizer_insertion_command( quantizer = FXMinMaxAlgoBackend._create_quantizer( quantizer_config, scale_shape, parameters, target_point.target_type ) - transformation = qdq_insertion_tranformation_builder(quantizer, [target_point]) + transformation = qdq_insertion_transformation_builder(quantizer, [target_point]) return FXApplyTransformationCommand(transformation) @staticmethod @@ -308,7 +308,7 @@ def create_unified_scales_quantizers_insertion_commands( transformations = [] for tp in target_points: - transformation = qdq_insertion_tranformation_builder(quantizer, [tp]) + transformation = qdq_insertion_transformation_builder(quantizer, [tp]) transformations.append(FXApplyTransformationCommand(transformation)) return transformations diff --git a/nncf/torch/graph/transformations/serialization.py b/nncf/torch/graph/transformations/serialization.py index 282c59453eb..abb92379bfb 100644 --- a/nncf/torch/graph/transformations/serialization.py +++ b/nncf/torch/graph/transformations/serialization.py @@ -29,7 +29,7 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D """ Serializes given transformation layout to a dict. - :param tranformation_layout: Given transformation layout. + :param transformation_layout: Given transformation layout. :return: Serialized representation of given transformation layout as a dict. """ transformation_commands = [] diff --git a/tests/post_training/data/ptq_reference_data.yaml b/tests/post_training/data/ptq_reference_data.yaml index 490cb7e73da..17321b06036 100644 --- a/tests/post_training/data/ptq_reference_data.yaml +++ b/tests/post_training/data/ptq_reference_data.yaml @@ -34,6 +34,12 @@ torchvision/resnet18_backend_CUDA_TORCH: metric_value: 0.69152 torchvision/resnet18_backend_FX_TORCH: metric_value: 0.6946 +torchvision/swin_v2_s_backend_FP32: + metric_value: 0.83712 +torchvision/swin_v2_s_backend_OV: + metric_value: 0.83638 +torchvision/swin_v2_s_backend_FX_TORCH: + metric_value: 0.82908 timm/crossvit_9_240_backend_CUDA_TORCH: metric_value: 0.689 timm/crossvit_9_240_backend_FP32: diff --git a/tests/post_training/model_scope.py b/tests/post_training/model_scope.py index 7f78d70528c..1bec8d1d4ae 100644 --- a/tests/post_training/model_scope.py +++ b/tests/post_training/model_scope.py @@ -75,6 +75,17 @@ "backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX], "batch_size": 128, }, + { + "reported_name": "torchvision/swin_v2_s", + "model_id": "swin_v2_s", + "pipeline_cls": ImageClassificationTorchvision, + "compression_params": { + "model_type": ModelType.TRANSFORMER, + "advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5), + }, + "backends": [BackendType.FX_TORCH, BackendType.OV], + "batch_size": 1, + }, # Timm models { "reported_name": "timm/crossvit_9_240", diff --git a/tests/post_training/pipelines/image_classification_torchvision.py b/tests/post_training/pipelines/image_classification_torchvision.py index 91e586605cb..c42aa9ab1bb 100644 --- a/tests/post_training/pipelines/image_classification_torchvision.py +++ b/tests/post_training/pipelines/image_classification_torchvision.py @@ -43,13 +43,14 @@ def prepare_model(self) -> None: model = model_cls(weights=self.model_weights) model.eval() - self.static_input_size = [self.batch_size, 3, 224, 224] + default_input_size = [self.batch_size, 3, 224, 224] + self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size)) + self.static_input_size = list(self.dummy_tensor.shape) + self.input_size = self.static_input_size.copy() if self.batch_size > 1: # Dynamic batch_size shape export self.input_size[0] = -1 - self.dummy_tensor = torch.rand(self.static_input_size) - if self.backend == BackendType.FX_TORCH: with torch.no_grad(): with disable_patching(): diff --git a/tests/torch/data/reference_graphs/fx/extracted/ConstantModelExtractionModel.dot b/tests/torch/data/reference_graphs/fx/extracted/ConstantModelExtractionModel.dot new file mode 100644 index 00000000000..0cfbbf5840b --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/extracted/ConstantModelExtractionModel.dot @@ -0,0 +1,11 @@ +strict digraph { +"0 _conv_w" [id=0, type=get_attr]; +"1 add" [id=1, type=add]; +"2 placeholder_graph_node" [id=2, type=input]; +"3 conv2d" [id=3, type=conv2d]; +"4 output" [id=4, type=output]; +"0 _conv_w" -> "1 add"; +"1 add" -> "3 conv2d"; +"2 placeholder_graph_node" -> "3 conv2d"; +"3 conv2d" -> "4 output"; +} diff --git a/tests/torch/data/reference_graphs/fx/extracted/ModelExtractionModel.dot b/tests/torch/data/reference_graphs/fx/extracted/ModelExtractionModel.dot new file mode 100644 index 00000000000..47edd99d8c9 --- /dev/null +++ b/tests/torch/data/reference_graphs/fx/extracted/ModelExtractionModel.dot @@ -0,0 +1,11 @@ +strict digraph { +"0 _conv_w" [id=0, type=get_attr]; +"1 add" [id=1, type=add]; +"2 arg0_1" [id=2, type=input]; +"3 conv2d" [id=3, type=conv2d]; +"4 output" [id=4, type=output]; +"0 _conv_w" -> "1 add"; +"1 add" -> "3 conv2d"; +"2 arg0_1" -> "3 conv2d"; +"3 conv2d" -> "4 output"; +} diff --git a/tests/torch/fx/test_model_transformer.py b/tests/torch/fx/test_model_transformer.py new file mode 100644 index 00000000000..7edb8fa482d --- /dev/null +++ b/tests/torch/fx/test_model_transformer.py @@ -0,0 +1,65 @@ +# Copyright (c) 2024 Intel Corporation +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Tuple + +import pytest +import torch +from torch._export import capture_pre_autograd_graph + +from nncf.common.graph.transformations.layout import TransformationLayout +from nncf.experimental.torch.fx.model_transformer import FXModelTransformer +from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter +from nncf.torch import disable_patching +from nncf.torch.graph.transformations.commands import PTModelExtractionCommand +from tests.torch.test_compressed_graph import check_graph +from tests.torch.test_models.synthetic import ConvolutionWithAllConstantInputsModel +from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel + + +@dataclass +class ModelExtractionTestCase: + model: torch.nn.Module + input_shape: Tuple[int, ...] + command: PTModelExtractionCommand + ref: None = None + + +EXTRACTED_GRAPHS_DIR_NAME = Path("fx") / "extracted" + +MODEL_EXTRACTION_CASES = ( + ModelExtractionTestCase( + ConvolutionWithNotTensorBiasModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"]) + ), + ModelExtractionTestCase( + ConvolutionWithAllConstantInputsModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"]) + ), +) + + +def idfn(value: Any): + if isinstance(value, ModelExtractionTestCase): + return value.model.__name__ + return None + + +@pytest.mark.parametrize("test_case", MODEL_EXTRACTION_CASES, ids=idfn) +def test_model_extraction(test_case: ModelExtractionTestCase): + with torch.no_grad(): + with disable_patching(): + captured_model = capture_pre_autograd_graph(test_case.model(), (torch.ones(test_case.input_shape),)) + layout = TransformationLayout() + layout.register(test_case.command) + extracted_model = FXModelTransformer(captured_model).transform(layout) + nncf_graph = GraphConverter.create_nncf_graph(extracted_model) + check_graph(nncf_graph, f"{test_case.model.__name__}.dot", str(EXTRACTED_GRAPHS_DIR_NAME)) diff --git a/tests/torch/fx/test_models.py b/tests/torch/fx/test_models.py index 8e3b991329c..1b7b6bf0208 100644 --- a/tests/torch/fx/test_models.py +++ b/tests/torch/fx/test_models.py @@ -9,7 +9,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import json import os from dataclasses import dataclass @@ -39,8 +38,8 @@ from tests.torch import test_models from tests.torch.test_compressed_graph import check_graph -FX_DIR_NAME = "fx" -FX_QUANTIZED_DIR_NAME = "fx/quantized" +FX_DIR_NAME = Path("fx") +FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized" @dataclass diff --git a/tests/torch/test_models/synthetic.py b/tests/torch/test_models/synthetic.py index 24bddc340f0..06a54579e49 100644 --- a/tests/torch/test_models/synthetic.py +++ b/tests/torch/test_models/synthetic.py @@ -501,3 +501,24 @@ def forward(self, x): unbinded_processed[0] = self.conv4(y_unbinded[0]) y = torch.cat(unbinded_processed, axis=0) return y + + +class ConvolutionWithNotTensorBiasModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1))) + + def forward(self, x): + w = self._conv_w + 10 + return nn.functional.conv2d(x, w) + + +class ConvolutionWithAllConstantInputsModel(torch.nn.Module): + def __init__(self): + super().__init__() + self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1))) + self._conv_i = nn.Parameter(torch.ones((1, 1, 1, 1))) + + def forward(self, x): + w = self._conv_w + 10 + return x + nn.functional.conv2d(self._conv_i, w)