Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
rk119 committed Oct 3, 2024
1 parent a26ca16 commit 76a53d7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
15 changes: 9 additions & 6 deletions tests/torch/fx/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.transformations.commands import TargetType
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.algorithms.min_max.torch_fx_backend import FXMinMaxAlgoBackend
Expand Down Expand Up @@ -61,19 +62,21 @@ def matmul_metatype(self):
return PTLinearMetatype

@staticmethod
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]):
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
pass
return None

@staticmethod
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]):
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
pass
return None

@staticmethod
def get_matmul_node_attrs(weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]):
def get_matmul_node_attrs(
weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]
) -> BaseLayerAttributes:
# This method isn't needed for Torch FX backend
pass
return None

def test_get_channel_axes_matmul_node_ov_onnx(self):
pytest.skip("Test is not applied for Torch FX backend.")
Expand Down
15 changes: 9 additions & 6 deletions tests/torch/ptq/test_min_max.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import pytest

from nncf.common.graph.graph import NNCFGraph
from nncf.common.graph.layer_attributes import BaseLayerAttributes
from nncf.common.graph.transformations.commands import TargetType
from nncf.quantization.algorithms.min_max.backend import MinMaxAlgoBackend
from nncf.quantization.algorithms.min_max.torch_backend import PTMinMaxAlgoBackend
Expand Down Expand Up @@ -60,19 +61,21 @@ def matmul_metatype(self):
return PTLinearMetatype

@staticmethod
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]):
def get_conv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
pass
return None

@staticmethod
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]):
def get_depthwiseconv_node_attrs(weight_port_id: int, weight_shape: Tuple[int]) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
pass
return None

@staticmethod
def get_matmul_node_attrs(weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]):
def get_matmul_node_attrs(
weight_port_id: int, transpose_weight: Tuple[int], weight_shape: Tuple[int]
) -> BaseLayerAttributes:
# This method isn't needed for Torch backend
pass
return None

def test_get_channel_axes_matmul_node_ov_onnx(self):
pytest.skip("Test is not applied for Torch backend.")
Expand Down

0 comments on commit 76a53d7

Please sign in to comment.