Skip to content

Commit

Permalink
Add ruff fix for unused variables
Browse files Browse the repository at this point in the history
  • Loading branch information
anzr299 committed Jul 8, 2024
1 parent b56a3a4 commit 5c0d504
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 16 deletions.
1 change: 1 addition & 0 deletions nncf/torch/graph/operator_metatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,7 @@ class PTScaledDotProductAttentionMetatype(PTOperatorMetatype):
hw_config_names = [HWConfigOpName.SCALED_DOT_PRODUCT_ATTENTION]
target_input_ports = [0, 1]


@PT_OPERATOR_METATYPES.register()
class PTEmptyMetatype(PTOperatorMetatype):
name = "EmptyOP"
Expand Down
30 changes: 14 additions & 16 deletions tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,7 @@


from dataclasses import dataclass
from typing import Tuple

import os

import numpy as np
import openvino.torch # noqa
import pytest
import torch
Expand All @@ -29,12 +25,10 @@
from torch._export import capture_pre_autograd_graph

import nncf
from nncf.common.logging.track_progress import track
from nncf.common.graph.operator_metatypes import UnknownMetatype
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from tests.torch.fx.helpers import TinyImagenetDatasetManager

from nncf.torch.graph.operator_metatypes import PT_OPERATOR_METATYPES
from nncf.common.graph.operator_metatypes import UnknownMetatype
from tests.torch.fx.helpers import TinyImagenetDatasetManager

IMAGE_SIZE = 64
BATCH_SIZE = 128
Expand All @@ -54,12 +48,14 @@ class ModelCase:
MODELS = (
ModelCase(
"resnet18",
"https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth"
"https://storage.openvinotoolkit.org/repositories/nncf/openvino_notebook_ckpts/302_resnet18_fp32_v1.pth",
),
)


def get_model(model_id: str, checkpoint_url: str, device: torch.device, num_classes: int = 200, in_features: int = 512) -> torch.nn.Module:
def get_model(
model_id: str, checkpoint_url: str, device: torch.device, num_classes: int = 200, in_features: int = 512
) -> torch.nn.Module:
model = getattr(models, model_id)(weights=None)
# Update the last FC layer for Tiny ImageNet number of classes.
model.fc = nn.Linear(in_features=in_features, out_features=num_classes, bias=True)
Expand All @@ -73,14 +69,16 @@ def getNodeType(node: torch.fx.node) -> str:
if node.op == "call_function" and hasattr(node.target, "overloadpacket"):
node_type = str(node.target).split(".")[1]
return node_type
return ''
return ""


def isNodeMetatype(node_type: str) -> bool:
op_type = PT_OPERATOR_METATYPES.get_operator_metatype_by_op_name(node_type)
if(op_type is UnknownMetatype):
if op_type is UnknownMetatype:
return False
return True


def retrieve_nodes(model: torch.fx.GraphModule):
for node in model.graph.nodes:
yield node
Expand All @@ -92,7 +90,7 @@ def test_sanity(test_case: ModelCase, tiny_imagenet_dataset):
torch.manual_seed(42)
device = torch.device("cpu")
model = get_model(test_case.model_id, test_case.checkpoint_url, device)
_, val_dataloader, calibration_dataset = tiny_imagenet_dataset
_, _, calibration_dataset = tiny_imagenet_dataset

def transform_fn(data_item):
return data_item[0].to(device)
Expand All @@ -105,6 +103,6 @@ def transform_fn(data_item):
exported_model = capture_pre_autograd_graph(model, args=(ex_input,))
nodes = retrieve_nodes(exported_model)
for node in nodes:
node_type = getNodeType(node)
if(node_type):
assert isNodeMetatype(node_type)
node_type = getNodeType(node)
if node_type:
assert isNodeMetatype(node_type)

0 comments on commit 5c0d504

Please sign in to comment.