Skip to content

Commit

Permalink
[TorchFX] Model transformer tests
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Sep 20, 2024
1 parent 388fdca commit c33e6ca
Show file tree
Hide file tree
Showing 20 changed files with 872 additions and 23 deletions.
18 changes: 12 additions & 6 deletions nncf/experimental/torch/fx/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def module_insertion_transformation(model: torch.fx.GraphModule):
if target_point.target_type == TargetType.OPERATOR_POST_HOOK:
_set_new_node_meta(new_node, target_node, module_to_insert)
with graph.inserting_after(target_node):
for user in target_node.users:
for user in list(target_node.users):
if user is new_node:
continue
user.replace_input_with(target_node, new_node)
Expand Down Expand Up @@ -110,12 +110,13 @@ def leaf_module_insertion_transformation(model: torch.fx.GraphModule):
return leaf_module_insertion_transformation


def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def bias_update_transformation_builder(node: NNCFNode, value: torch.Tensor, input_port_id: int) -> TransformationFNType:
"""
Return transformation which updates constant of the given node with bias to the given value.
:param node: Node with bias which requires bias constant update.
:param value: New value to use as the bias constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node with bias to the given value.
"""

Expand All @@ -131,22 +132,27 @@ def bias_update_transformation(model: torch.fx.GraphModule):
raise nncf.InternalError(f"Node {graph_node.name} has {len(add_nodes)} outputs with adds, 1 expected")

bias_node = add_nodes[0]
constant_update_fn(model, bias_node, value, input_port_id=1)
constant_update_fn(model, bias_node, value, input_port_id=input_port_id)

return bias_update_transformation


def constant_update_transformation_builder(node: NNCFNode, value: torch.Tensor) -> TransformationFNType:
def constant_update_transformation_builder(
node: NNCFNode, value: torch.Tensor, input_port_id: int
) -> TransformationFNType:
"""
Return transformation which updates constant of the given node to the given value.
:param node: Node which requires bias constant update.
:param value: New value to use as the node constant.
:param input_port_id: Input port id to get constant node from.
:return: Transformation which updates constant of the given node to the given value.
"""

def constant_update_transformation(model: torch.fx.GraphModule):
constant_update_fn(model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=1)
constant_update_fn(
model, get_graph_node_by_name(model.graph, node.node_name), value, input_port_id=input_port_id
)

return constant_update_transformation

Expand Down Expand Up @@ -197,7 +203,7 @@ def qdq_insertion_transformation_builder(

def qdq_insertion_transformation(model: torch.fx.GraphModule):
if any(tp.target_type != TargetType.OPERATION_WITH_WEIGHTS for tp in target_points) and len(target_points) > 1:
raise RuntimeError(
raise nncf.InternalError(
"Insertion of shared qdq pair for the weights is not supported."
" Please use non shared qdq pairs for the weights quantization."
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))

@staticmethod
def model_extraction_command(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def target_point(target_type: TargetType, target_node_name: str, port_id: int) -
def create_bias_correction_command(
node: NNCFNode, bias_value: Tensor, nncf_graph: NNCFGraph
) -> FXApplyTransformationCommand:
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data))
return FXApplyTransformationCommand(bias_update_transformation_builder(node, bias_value.data, input_port_id=1))

@staticmethod
def model_extraction_command(
Expand Down
3 changes: 2 additions & 1 deletion nncf/quantization/algorithms/min_max/torch_fx_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Dict, List, Optional, Set, Tuple

import torch
from torch.quantization.fake_quantize import FakeQuantize

import nncf
import nncf.torch.graph.operator_metatypes as om
Expand Down Expand Up @@ -240,7 +241,7 @@ def _create_quantizer(
scale_shape: Tuple,
parameters: FakeQuantizeParameters,
target_type: TargetType,
) -> BaseQuantizer:
) -> FakeQuantize:
mode = quantizer_config.mode
quantizer_cls = QUANTIZATION_MODULES.get(mode)
narrow_range = target_type == TargetType.OPERATION_WITH_WEIGHTS and mode == QuantizationMode.SYMMETRIC
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,9 @@ def get_weight_value(node_with_weight: NNCFNode, model: torch.fx.GraphModule, nn

@staticmethod
def weight_update_command(node_with_weight: NNCFNode, weight_value: torch.Tensor) -> OVWeightUpdateCommand:
return FXApplyTransformationCommand(constant_update_transformation_builder(node_with_weight, weight_value.data))
return FXApplyTransformationCommand(
constant_update_transformation_builder(node_with_weight, weight_value.data, input_port_id=1)
)

@staticmethod
def scale_insertion_command(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1" [id=6, type=conv2d];
"7 add__updated_constant0" [id=7, type=get_attr];
"8 add_" [id=8, type=add_];
"9 _tensor_constant0_1" [id=9, type=get_attr];
"10 add__1" [id=10, type=add_];
"11 add" [id=11, type=add];
"12 _param_constant4" [id=12, type=get_attr];
"13 _param_constant5" [id=13, type=get_attr];
"14 conv2d_2" [id=14, type=conv2d];
"15 _tensor_constant0_2" [id=15, type=get_attr];
"16 add_1" [id=16, type=add];
"17 output" [id=17, type=output];
"0 arg0_1" -> "3 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"3 conv2d" -> "6 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "8 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "6 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1" -> "10 add__1" [label="(1, 3, 3, 3)", style=solid];
"7 add__updated_constant0" -> "8 add_" [label="(1,)", style=solid];
"8 add_" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0_1" -> "10 add__1" [label="(1,)", style=solid];
"10 add__1" -> "11 add" [label="(1, 3, 3, 3)", style=solid];
"11 add" -> "14 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"12 _param_constant4" -> "14 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"13 _param_constant5" -> "14 conv2d_2" [label="(3,)", style=solid];
"14 conv2d_2" -> "16 add_1" [label="(1, 3, 3, 3)", style=solid];
"15 _tensor_constant0_2" -> "16 add_1" [label="(1,)", style=solid];
"16 add_1" -> "17 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 TEST_MODULE_0" [id=3, type=call_module];
"4 TEST_MODULE_1" [id=4, type=call_module];
"5 conv2d" [id=5, type=conv2d];
"6 TEST_MODULE_3" [id=6, type=call_module];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 TEST_MODULE_2" [id=9, type=call_module];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "3 TEST_MODULE_0" [label="(1, 3, 3, 3)", style=solid];
"0 arg0_1" -> "5 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "4 TEST_MODULE_1" [label="(3, 3, 1, 1)", style=solid];
"1 _param_constant0" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"5 conv2d" -> "6 TEST_MODULE_3" [label="(1, 3, 3, 3)", style=solid];
"5 conv2d" -> "10 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"5 conv2d" -> "12 add_" [label="(1, 3, 3, 3)", style=solid];
"7 _param_constant2" -> "9 TEST_MODULE_2" [label="(3, 3, 1, 1)", style=solid];
"7 _param_constant2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 3, 3)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 TEST_MODULE_0" [id=3, type=call_module];
"4 TEST_MODULE_1" [id=4, type=call_module];
"5 conv2d" [id=5, type=conv2d];
"6 TEST_MODULE_3" [id=6, type=call_module];
"7 _param_constant2" [id=7, type=get_attr];
"8 _param_constant3" [id=8, type=get_attr];
"9 TEST_MODULE_2" [id=9, type=call_module];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "3 TEST_MODULE_0" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "4 TEST_MODULE_1" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "5 conv2d" [label="(3,)", style=solid];
"3 TEST_MODULE_0" -> "5 conv2d" [label="(1, 3, 3, 3)", style=solid];
"4 TEST_MODULE_1" -> "5 conv2d" [label="(3, 3, 1, 1)", style=solid];
"5 conv2d" -> "6 TEST_MODULE_3" [label="(1, 3, 3, 3)", style=solid];
"6 TEST_MODULE_3" -> "10 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"6 TEST_MODULE_3" -> "12 add_" [label="(1, 3, 3, 3)", style=solid];
"7 _param_constant2" -> "9 TEST_MODULE_2" [label="(3, 3, 1, 1)", style=solid];
"8 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"9 TEST_MODULE_2" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 3, 3)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant2" [id=1, type=get_attr];
"2 _param_constant3" [id=2, type=get_attr];
"3 conv2d_1" [id=3, type=conv2d];
"4 _tensor_constant0" [id=4, type=get_attr];
"5 add_" [id=5, type=add_];
"6 _tensor_constant0_1" [id=6, type=get_attr];
"7 add__1" [id=7, type=add_];
"8 add" [id=8, type=add];
"9 _param_constant4" [id=9, type=get_attr];
"10 _param_constant5" [id=10, type=get_attr];
"11 conv2d_2" [id=11, type=conv2d];
"12 _tensor_constant0_2" [id=12, type=get_attr];
"13 add_1" [id=13, type=add];
"14 output" [id=14, type=output];
"0 arg0_1" -> "3 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"0 arg0_1" -> "5 add_" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant2" -> "3 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant3" -> "3 conv2d_1" [label="(3,)", style=solid];
"3 conv2d_1" -> "7 add__1" [label="(1, 3, 3, 3)", style=solid];
"4 _tensor_constant0" -> "5 add_" [label="(1,)", style=solid];
"5 add_" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"6 _tensor_constant0_1" -> "7 add__1" [label="(1,)", style=solid];
"7 add__1" -> "8 add" [label="(1, 3, 3, 3)", style=solid];
"8 add" -> "11 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"9 _param_constant4" -> "11 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"10 _param_constant5" -> "11 conv2d_2" [label="(3,)", style=solid];
"11 conv2d_2" -> "13 add_1" [label="(1, 3, 3, 3)", style=solid];
"12 _tensor_constant0_2" -> "13 add_1" [label="(1,)", style=solid];
"13 add_1" -> "14 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 conv2d_1_scale_0" [id=6, type=get_attr];
"7 conv2d_1_zero_point_0" [id=7, type=get_attr];
"8 quantize_per_channel_default" [id=8, type=quantize_per_channel];
"9 dequantize_per_channel_default" [id=9, type=dequantize_per_channel];
"10 conv2d_1" [id=10, type=conv2d];
"11 _tensor_constant0" [id=11, type=get_attr];
"12 add_" [id=12, type=add_];
"13 _tensor_constant0_1" [id=13, type=get_attr];
"14 add__1" [id=14, type=add_];
"15 add" [id=15, type=add];
"16 _param_constant4" [id=16, type=get_attr];
"17 _param_constant5" [id=17, type=get_attr];
"18 conv2d_2" [id=18, type=conv2d];
"19 _tensor_constant0_2" [id=19, type=get_attr];
"20 add_1" [id=20, type=add];
"21 output" [id=21, type=output];
"0 arg0_1" -> "3 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"3 conv2d" -> "10 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "12 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "8 quantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "10 conv2d_1" [label="(3,)", style=solid];
"6 conv2d_1_scale_0" -> "8 quantize_per_channel_default" [label="(1,)", style=solid];
"6 conv2d_1_scale_0" -> "9 dequantize_per_channel_default" [label="(1,)", style=solid];
"7 conv2d_1_zero_point_0" -> "8 quantize_per_channel_default" [label="(1,)", style=solid];
"7 conv2d_1_zero_point_0" -> "9 dequantize_per_channel_default" [label="(1,)", style=solid];
"8 quantize_per_channel_default" -> "9 dequantize_per_channel_default" [label="(3, 3, 1, 1)", style=solid];
"9 dequantize_per_channel_default" -> "10 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"10 conv2d_1" -> "14 add__1" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0" -> "12 add_" [label="(1,)", style=solid];
"12 add_" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"13 _tensor_constant0_1" -> "14 add__1" [label="(1,)", style=solid];
"14 add__1" -> "15 add" [label="(1, 3, 3, 3)", style=solid];
"15 add" -> "18 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"16 _param_constant4" -> "18 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"17 _param_constant5" -> "18 conv2d_2" [label="(3,)", style=solid];
"18 conv2d_2" -> "20 add_1" [label="(1, 3, 3, 3)", style=solid];
"19 _tensor_constant0_2" -> "20 add_1" [label="(1,)", style=solid];
"20 add_1" -> "21 output" [label="(1, 3, 3, 3)", style=solid];
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
strict digraph {
"0 arg0_1" [id=0, type=input];
"1 _param_constant0" [id=1, type=get_attr];
"2 _param_constant1" [id=2, type=get_attr];
"3 conv2d" [id=3, type=conv2d];
"4 _param_constant2" [id=4, type=get_attr];
"5 _param_constant3" [id=5, type=get_attr];
"6 quantize_per_tensor_default" [id=6, type=quantize_per_tensor];
"7 dequantize_per_tensor_default" [id=7, type=dequantize_per_tensor];
"8 conv2d_1" [id=8, type=conv2d];
"9 _tensor_constant0" [id=9, type=get_attr];
"10 add_" [id=10, type=add_];
"11 _tensor_constant0_1" [id=11, type=get_attr];
"12 add__1" [id=12, type=add_];
"13 add" [id=13, type=add];
"14 _param_constant4" [id=14, type=get_attr];
"15 _param_constant5" [id=15, type=get_attr];
"16 conv2d_2" [id=16, type=conv2d];
"17 _tensor_constant0_2" [id=17, type=get_attr];
"18 add_1" [id=18, type=add];
"19 output" [id=19, type=output];
"0 arg0_1" -> "3 conv2d" [label="(1, 3, 3, 3)", style=solid];
"1 _param_constant0" -> "3 conv2d" [label="(3, 3, 1, 1)", style=solid];
"2 _param_constant1" -> "3 conv2d" [label="(3,)", style=solid];
"3 conv2d" -> "8 conv2d_1" [label="(1, 3, 3, 3)", style=solid];
"3 conv2d" -> "10 add_" [label="(1, 3, 3, 3)", style=solid];
"4 _param_constant2" -> "6 quantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"5 _param_constant3" -> "8 conv2d_1" [label="(3,)", style=solid];
"6 quantize_per_tensor_default" -> "7 dequantize_per_tensor_default" [label="(3, 3, 1, 1)", style=solid];
"7 dequantize_per_tensor_default" -> "8 conv2d_1" [label="(3, 3, 1, 1)", style=solid];
"8 conv2d_1" -> "12 add__1" [label="(1, 3, 3, 3)", style=solid];
"9 _tensor_constant0" -> "10 add_" [label="(1,)", style=solid];
"10 add_" -> "13 add" [label="(1, 3, 3, 3)", style=solid];
"11 _tensor_constant0_1" -> "12 add__1" [label="(1,)", style=solid];
"12 add__1" -> "13 add" [label="(1, 3, 3, 3)", style=solid];
"13 add" -> "16 conv2d_2" [label="(1, 3, 3, 3)", style=solid];
"14 _param_constant4" -> "16 conv2d_2" [label="(3, 3, 1, 1)", style=solid];
"15 _param_constant5" -> "16 conv2d_2" [label="(3,)", style=solid];
"16 conv2d_2" -> "18 add_1" [label="(1, 3, 3, 3)", style=solid];
"17 _tensor_constant0_2" -> "18 add_1" [label="(1,)", style=solid];
"18 add_1" -> "19 output" [label="(1, 3, 3, 3)", style=solid];
}
Loading

0 comments on commit c33e6ca

Please sign in to comment.