Skip to content

Commit

Permalink
[TorchFX][FBC] Constant linear layers support (#2866)
Browse files Browse the repository at this point in the history
### Changes

Constant linear layers support

### Reason for changes

To support swint_v2_s FBC

### Related tickets

#2766 

### Tests
Build post_training_quantization/444/ is finished successfully
Unit test `test_model_transformer.test_model_extraction` is presented
  • Loading branch information
daniil-lyakhov authored Aug 13, 2024
1 parent 27296b4 commit a7d575e
Show file tree
Hide file tree
Showing 13 changed files with 161 additions and 19 deletions.
2 changes: 1 addition & 1 deletion nncf/experimental/torch/fx/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def __init__(
:param priority: Transformation priority.
"""
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.transformation_fn = transformation_fn
self.priority = priority
21 changes: 19 additions & 2 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout

Expand Down Expand Up @@ -97,7 +98,23 @@ def _apply_model_extraction(
# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

extracted_model = splitted_gm.extracted
graph: torch.fx.Graph = extracted_model.graph
# Check extracted model has inputs.
# It is possible to have two constant inputs
# for the target layer, an placeholder is being
# placed to the input port.
target_node = get_graph_node_by_name(graph, node_name)
input_node = target_node.all_input_nodes[0]
if input_node.op != "placeholder":
with graph.inserting_before(target_node):
new_input_node = graph.create_node(
"placeholder", "placeholder_node", (), {}, name="placeholder_graph_node"
)
target_node.replace_input_with(input_node, new_input_node)
extracted_model.graph.eliminate_dead_code()
return extracted_model

@staticmethod
def _apply_transformation(
Expand All @@ -112,5 +129,5 @@ def _apply_transformation(
:return: Target model after all transformations were applied.
"""
for transformation in transformations:
transformation.tranformation_fn(model)
transformation.transformation_fn(model)
return model
12 changes: 6 additions & 6 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def qdq_insertion_tranformation_builder(
def qdq_insertion_transformation_builder(
quantizer: FakeQuantize, target_points: List[PTTargetPoint]
) -> TransformationFNType:
"""
Expand All @@ -92,7 +92,7 @@ def qdq_insertion_tranformation_builder(
inherited from the given quantizer to each given target point.
"""

def qdq_insertion_tranformation(model: torch.fx.GraphModule):
def qdq_insertion_transformation(model: torch.fx.GraphModule):
if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1:
raise RuntimeError(
"Insertion of shared qdq pair for the weights is not supported."
Expand All @@ -101,7 +101,7 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule):
for target_point in target_points:
insert_one_qdq(model, target_point, quantizer)

return qdq_insertion_tranformation
return qdq_insertion_transformation


def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
Expand Down Expand Up @@ -310,7 +310,7 @@ def _is_bn_node(node: torch.fx.Node):

def fuse_conv_bn(model: torch.fx.GraphModule) -> None:
"""
BatchNorm operations have 3 output ports, to make it easier for alorithms to work with
BatchNorm operations have 3 output ports, to make it easier for algorithms to work with
the target graph BatchNorm operations are being fused
:param model: Model to apply transformations to.
Expand Down Expand Up @@ -342,7 +342,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
:param model: Model to apply transformations to.
"""
# BatchNorm operations have 3 output ports,
# to make it easier for alorithms to work
# to make it easier for algorithms to work
# with the target graph BatchNorm operations
# are being fused
fuse_conv_bn(model)
Expand Down Expand Up @@ -484,7 +484,7 @@ def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[
Check which node should be merged by the given `is_target_node` predicate.
:param model: Target model.
:param is_target_node: Predicate to specify nodes which shoudld be merged with the bias
:param is_target_node: Predicate to specify nodes which should be merged with the bias
"""
add_node_targets = (torch.ops.aten.add_.Tensor,)
for n in model.graph.nodes:
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand Down Expand Up @@ -288,7 +288,7 @@ def create_quantizer_insertion_command(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
transformation = qdq_insertion_tranformation_builder(quantizer, [target_point])
transformation = qdq_insertion_transformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)

@staticmethod
Expand All @@ -308,7 +308,7 @@ def create_unified_scales_quantizers_insertion_commands(

transformations = []
for tp in target_points:
transformation = qdq_insertion_tranformation_builder(quantizer, [tp])
transformation = qdq_insertion_transformation_builder(quantizer, [tp])
transformations.append(FXApplyTransformationCommand(transformation))
return transformations

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D
"""
Serializes given transformation layout to a dict.
:param tranformation_layout: Given transformation layout.
:param transformation_layout: Given transformation layout.
:return: Serialized representation of given transformation layout as a dict.
"""
transformation_commands = []
Expand Down
6 changes: 6 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ torchvision/resnet18_backend_CUDA_TORCH:
metric_value: 0.69152
torchvision/resnet18_backend_FX_TORCH:
metric_value: 0.6946
torchvision/swin_v2_s_backend_FP32:
metric_value: 0.83712
torchvision/swin_v2_s_backend_OV:
metric_value: 0.83638
torchvision/swin_v2_s_backend_FX_TORCH:
metric_value: 0.82908
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down
11 changes: 11 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@
"backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX],
"batch_size": 128,
},
{
"reported_name": "torchvision/swin_v2_s",
"model_id": "swin_v2_s",
"pipeline_cls": ImageClassificationTorchvision,
"compression_params": {
"model_type": ModelType.TRANSFORMER,
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5),
},
"backends": [BackendType.FX_TORCH, BackendType.OV],
"batch_size": 1,
},
# Timm models
{
"reported_name": "timm/crossvit_9_240",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ def prepare_model(self) -> None:
model = model_cls(weights=self.model_weights)
model.eval()

self.static_input_size = [self.batch_size, 3, 224, 224]
default_input_size = [self.batch_size, 3, 224, 224]
self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size))
self.static_input_size = list(self.dummy_tensor.shape)

self.input_size = self.static_input_size.copy()
if self.batch_size > 1: # Dynamic batch_size shape export
self.input_size[0] = -1

self.dummy_tensor = torch.rand(self.static_input_size)

if self.backend == BackendType.FX_TORCH:
with torch.no_grad():
with disable_patching():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 placeholder_graph_node" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _conv_w" -> "1 add";
"1 add" -> "3 conv2d";
"2 placeholder_graph_node" -> "3 conv2d";
"3 conv2d" -> "4 output";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 arg0_1" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _conv_w" -> "1 add";
"1 add" -> "3 conv2d";
"2 arg0_1" -> "3 conv2d";
"3 conv2d" -> "4 output";
}
65 changes: 65 additions & 0 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Tuple

import pytest
import torch
from torch._export import capture_pre_autograd_graph

from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.torch import disable_patching
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from tests.torch.test_compressed_graph import check_graph
from tests.torch.test_models.synthetic import ConvolutionWithAllConstantInputsModel
from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel


@dataclass
class ModelExtractionTestCase:
model: torch.nn.Module
input_shape: Tuple[int, ...]
command: PTModelExtractionCommand
ref: None = None


EXTRACTED_GRAPHS_DIR_NAME = Path("fx") / "extracted"

MODEL_EXTRACTION_CASES = (
ModelExtractionTestCase(
ConvolutionWithNotTensorBiasModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"])
),
ModelExtractionTestCase(
ConvolutionWithAllConstantInputsModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"])
),
)


def idfn(value: Any):
if isinstance(value, ModelExtractionTestCase):
return value.model.__name__
return None


@pytest.mark.parametrize("test_case", MODEL_EXTRACTION_CASES, ids=idfn)
def test_model_extraction(test_case: ModelExtractionTestCase):
with torch.no_grad():
with disable_patching():
captured_model = capture_pre_autograd_graph(test_case.model(), (torch.ones(test_case.input_shape),))
layout = TransformationLayout()
layout.register(test_case.command)
extracted_model = FXModelTransformer(captured_model).transform(layout)
nncf_graph = GraphConverter.create_nncf_graph(extracted_model)
check_graph(nncf_graph, f"{test_case.model.__name__}.dot", str(EXTRACTED_GRAPHS_DIR_NAME))
5 changes: 2 additions & 3 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -39,8 +38,8 @@
from tests.torch import test_models
from tests.torch.test_compressed_graph import check_graph

FX_DIR_NAME = "fx"
FX_QUANTIZED_DIR_NAME = "fx/quantized"
FX_DIR_NAME = Path("fx")
FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized"


@dataclass
Expand Down
21 changes: 21 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,24 @@ def forward(self, x):
unbinded_processed[0] = self.conv4(y_unbinded[0])
y = torch.cat(unbinded_processed, axis=0)
return y


class ConvolutionWithNotTensorBiasModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1)))

def forward(self, x):
w = self._conv_w + 10
return nn.functional.conv2d(x, w)


class ConvolutionWithAllConstantInputsModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1)))
self._conv_i = nn.Parameter(torch.ones((1, 1, 1, 1)))

def forward(self, x):
w = self._conv_w + 10
return x + nn.functional.conv2d(self._conv_i, w)

0 comments on commit a7d575e

Please sign in to comment.