Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TorchFX][FBC] Constant linear layers support #2866

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion nncf/experimental/torch/fx/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@ def __init__(
:param priority: Transformation priority.
"""
super().__init__(TransformationType.INSERT)
self.tranformation_fn = transformation_fn
self.transformation_fn = transformation_fn
self.priority = priority
21 changes: 19 additions & 2 deletions nncf/experimental/torch/fx/model_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from nncf.common.graph.model_transformer import ModelTransformer
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.node_utils import get_graph_node_by_name
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from nncf.torch.graph.transformations.layout import PTTransformationLayout

Expand Down Expand Up @@ -97,7 +98,23 @@ def _apply_model_extraction(
# TODO(dlyakhov): reduce memory consumption by
# more optimal splitting implementation.
splitted_gm = split_by_tags(model, tags)
return splitted_gm.extracted

extracted_model = splitted_gm.extracted
daniil-lyakhov marked this conversation as resolved.
Show resolved Hide resolved
graph: torch.fx.Graph = extracted_model.graph
# Check extracted model has inputs.
# It is possible to have two constant inputs
# for the target layer, an placeholder is being
# placed to the input port.
target_node = get_graph_node_by_name(graph, node_name)
input_node = target_node.all_input_nodes[0]
AlexanderDokuchaev marked this conversation as resolved.
Show resolved Hide resolved
if input_node.op != "placeholder":
with graph.inserting_before(target_node):
new_input_node = graph.create_node(
"placeholder", "placeholder_node", (), {}, name="placeholder_graph_node"
)
target_node.replace_input_with(input_node, new_input_node)
extracted_model.graph.eliminate_dead_code()
return extracted_model

@staticmethod
def _apply_transformation(
Expand All @@ -112,5 +129,5 @@ def _apply_transformation(
:return: Target model after all transformations were applied.
"""
for transformation in transformations:
transformation.tranformation_fn(model)
transformation.transformation_fn(model)
return model
12 changes: 6 additions & 6 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def bias_update_transformation(model: torch.fx.GraphModule):
return bias_update_transformation


def qdq_insertion_tranformation_builder(
def qdq_insertion_transformation_builder(
quantizer: FakeQuantize, target_points: List[PTTargetPoint]
) -> TransformationFNType:
"""
Expand All @@ -92,7 +92,7 @@ def qdq_insertion_tranformation_builder(
inherited from the given quantizer to each given target point.
"""

def qdq_insertion_tranformation(model: torch.fx.GraphModule):
def qdq_insertion_transformation(model: torch.fx.GraphModule):
if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1:
raise RuntimeError(
"Insertion of shared qdq pair for the weights is not supported."
Expand All @@ -101,7 +101,7 @@ def qdq_insertion_tranformation(model: torch.fx.GraphModule):
for target_point in target_points:
insert_one_qdq(model, target_point, quantizer)

return qdq_insertion_tranformation
return qdq_insertion_transformation


def insert_one_qdq(model: torch.fx.GraphModule, target_point: PTTargetPoint, quantizer: FakeQuantize):
Expand Down Expand Up @@ -310,7 +310,7 @@ def _is_bn_node(node: torch.fx.Node):

def fuse_conv_bn(model: torch.fx.GraphModule) -> None:
"""
BatchNorm operations have 3 output ports, to make it easier for alorithms to work with
BatchNorm operations have 3 output ports, to make it easier for algorithms to work with
the target graph BatchNorm operations are being fused

:param model: Model to apply transformations to.
Expand Down Expand Up @@ -342,7 +342,7 @@ def apply_quantization_transformations(model: torch.fx.GraphModule) -> None:
:param model: Model to apply transformations to.
"""
# BatchNorm operations have 3 output ports,
# to make it easier for alorithms to work
# to make it easier for algorithms to work
# with the target graph BatchNorm operations
# are being fused
fuse_conv_bn(model)
Expand Down Expand Up @@ -484,7 +484,7 @@ def _merge_node_and_bias(model: torch.fx.GraphModule, is_target_node: Callable[[
Check which node should be merged by the given `is_target_node` predicate.

:param model: Target model.
:param is_target_node: Predicate to specify nodes which shoudld be merged with the bias
:param is_target_node: Predicate to specify nodes which should be merged with the bias
"""
add_node_targets = (torch.ops.aten.add_.Tensor,)
for n in model.graph.nodes:
Expand Down
6 changes: 3 additions & 3 deletions nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from nncf.experimental.common.tensor_statistics.collectors import TensorCollector
from nncf.experimental.common.tensor_statistics.statistics import MinMaxTensorStatistic
from nncf.experimental.torch.fx.commands import FXApplyTransformationCommand
from nncf.experimental.torch.fx.transformations import qdq_insertion_tranformation_builder
from nncf.experimental.torch.fx.transformations import qdq_insertion_transformation_builder
from nncf.parameters import ModelType
from nncf.parameters import TargetDevice
from nncf.quantization.advanced_parameters import StatisticsType
Expand Down Expand Up @@ -288,7 +288,7 @@ def create_quantizer_insertion_command(
quantizer = FXMinMaxAlgoBackend._create_quantizer(
quantizer_config, scale_shape, parameters, target_point.target_type
)
transformation = qdq_insertion_tranformation_builder(quantizer, [target_point])
transformation = qdq_insertion_transformation_builder(quantizer, [target_point])
return FXApplyTransformationCommand(transformation)

@staticmethod
Expand All @@ -308,7 +308,7 @@ def create_unified_scales_quantizers_insertion_commands(

transformations = []
for tp in target_points:
transformation = qdq_insertion_tranformation_builder(quantizer, [tp])
transformation = qdq_insertion_transformation_builder(quantizer, [tp])
transformations.append(FXApplyTransformationCommand(transformation))
return transformations

Expand Down
2 changes: 1 addition & 1 deletion nncf/torch/graph/transformations/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def serialize_transformations(transformations_layout: TransformationLayout) -> D
"""
Serializes given transformation layout to a dict.

:param tranformation_layout: Given transformation layout.
:param transformation_layout: Given transformation layout.
:return: Serialized representation of given transformation layout as a dict.
"""
transformation_commands = []
Expand Down
6 changes: 6 additions & 0 deletions tests/post_training/data/ptq_reference_data.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,12 @@ torchvision/resnet18_backend_CUDA_TORCH:
metric_value: 0.69152
torchvision/resnet18_backend_FX_TORCH:
metric_value: 0.6946
torchvision/swin_v2_s_backend_FP32:
metric_value: 0.83712
torchvision/swin_v2_s_backend_OV:
metric_value: 0.83638
torchvision/swin_v2_s_backend_FX_TORCH:
metric_value: 0.82908
timm/crossvit_9_240_backend_CUDA_TORCH:
metric_value: 0.689
timm/crossvit_9_240_backend_FP32:
Expand Down
11 changes: 11 additions & 0 deletions tests/post_training/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,17 @@
"backends": [BackendType.FX_TORCH, BackendType.TORCH, BackendType.CUDA_TORCH, BackendType.OV, BackendType.ONNX],
"batch_size": 128,
},
{
"reported_name": "torchvision/swin_v2_s",
"model_id": "swin_v2_s",
"pipeline_cls": ImageClassificationTorchvision,
"compression_params": {
"model_type": ModelType.TRANSFORMER,
"advanced_parameters": AdvancedQuantizationParameters(smooth_quant_alpha=0.5),
},
"backends": [BackendType.FX_TORCH, BackendType.OV],
"batch_size": 1,
},
# Timm models
{
"reported_name": "timm/crossvit_9_240",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,14 @@ def prepare_model(self) -> None:
model = model_cls(weights=self.model_weights)
model.eval()

self.static_input_size = [self.batch_size, 3, 224, 224]
default_input_size = [self.batch_size, 3, 224, 224]
self.dummy_tensor = self.model_weights.transforms()(torch.rand(default_input_size))
self.static_input_size = list(self.dummy_tensor.shape)

self.input_size = self.static_input_size.copy()
if self.batch_size > 1: # Dynamic batch_size shape export
self.input_size[0] = -1

self.dummy_tensor = torch.rand(self.static_input_size)

if self.backend == BackendType.FX_TORCH:
with torch.no_grad():
with disable_patching():
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 placeholder_graph_node" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _conv_w" -> "1 add";
"1 add" -> "3 conv2d";
"2 placeholder_graph_node" -> "3 conv2d";
"3 conv2d" -> "4 output";
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
strict digraph {
"0 _conv_w" [id=0, type=get_attr];
"1 add" [id=1, type=add];
"2 arg0_1" [id=2, type=input];
"3 conv2d" [id=3, type=conv2d];
"4 output" [id=4, type=output];
"0 _conv_w" -> "1 add";
"1 add" -> "3 conv2d";
"2 arg0_1" -> "3 conv2d";
"3 conv2d" -> "4 output";
}
65 changes: 65 additions & 0 deletions tests/torch/fx/test_model_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Tuple

import pytest
import torch
from torch._export import capture_pre_autograd_graph

from nncf.common.graph.transformations.layout import TransformationLayout
from nncf.experimental.torch.fx.model_transformer import FXModelTransformer
from nncf.experimental.torch.fx.nncf_graph_builder import GraphConverter
from nncf.torch import disable_patching
from nncf.torch.graph.transformations.commands import PTModelExtractionCommand
from tests.torch.test_compressed_graph import check_graph
from tests.torch.test_models.synthetic import ConvolutionWithAllConstantInputsModel
from tests.torch.test_models.synthetic import ConvolutionWithNotTensorBiasModel


@dataclass
class ModelExtractionTestCase:
model: torch.nn.Module
input_shape: Tuple[int, ...]
command: PTModelExtractionCommand
ref: None = None


EXTRACTED_GRAPHS_DIR_NAME = Path("fx") / "extracted"

MODEL_EXTRACTION_CASES = (
ModelExtractionTestCase(
ConvolutionWithNotTensorBiasModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"])
),
ModelExtractionTestCase(
ConvolutionWithAllConstantInputsModel, (1, 1, 3, 3), PTModelExtractionCommand(["conv2d"], ["conv2d"])
),
)


def idfn(value: Any):
if isinstance(value, ModelExtractionTestCase):
return value.model.__name__
return None


@pytest.mark.parametrize("test_case", MODEL_EXTRACTION_CASES, ids=idfn)
def test_model_extraction(test_case: ModelExtractionTestCase):
with torch.no_grad():
with disable_patching():
captured_model = capture_pre_autograd_graph(test_case.model(), (torch.ones(test_case.input_shape),))
layout = TransformationLayout()
layout.register(test_case.command)
extracted_model = FXModelTransformer(captured_model).transform(layout)
nncf_graph = GraphConverter.create_nncf_graph(extracted_model)
check_graph(nncf_graph, f"{test_case.model.__name__}.dot", str(EXTRACTED_GRAPHS_DIR_NAME))
5 changes: 2 additions & 3 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
from dataclasses import dataclass
Expand Down Expand Up @@ -39,8 +38,8 @@
from tests.torch import test_models
from tests.torch.test_compressed_graph import check_graph

FX_DIR_NAME = "fx"
FX_QUANTIZED_DIR_NAME = "fx/quantized"
FX_DIR_NAME = Path("fx")
FX_QUANTIZED_DIR_NAME = Path("fx") / "quantized"


@dataclass
Expand Down
21 changes: 21 additions & 0 deletions tests/torch/test_models/synthetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,3 +501,24 @@ def forward(self, x):
unbinded_processed[0] = self.conv4(y_unbinded[0])
y = torch.cat(unbinded_processed, axis=0)
return y


class ConvolutionWithNotTensorBiasModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1)))

def forward(self, x):
w = self._conv_w + 10
return nn.functional.conv2d(x, w)


class ConvolutionWithAllConstantInputsModel(torch.nn.Module):
def __init__(self):
super().__init__()
self._conv_w = nn.Parameter(torch.ones((1, 1, 1, 1)))
self._conv_i = nn.Parameter(torch.ones((1, 1, 1, 1)))

def forward(self, x):
w = self._conv_w + 10
return x + nn.functional.conv2d(self._conv_i, w)
Loading