Skip to content
This repository has been archived by the owner on Apr 17, 2023. It is now read-only.

Enable type checks #208

Draft
wants to merge 6 commits into
base: ote
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions .mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
[mypy]
python_version = 3.8
ignore_missing_imports = True
show_error_codes = True
check_untyped_defs = True
strict = True

[mypy-mmdet.core.*]
ignore_errors = True

[mypy-mmdet.datasets.*]
ignore_errors = True

[mypy-mmdet.integration.*]
ignore_errors = True

[mypy-mmdet.models.*]
ignore_errors = True

[mypy-mmdet.ops.*]
ignore_errors = True

[mypy-mmdet.parallel.*]
ignore_errors = True

[mypy-mmdet.utils.*]
ignore_errors = True

[mypy-mmdet]
ignore_errors = True

[mypy-mmdet.version]
ignore_errors = True

[mypy-mmdet.apis]
ignore_errors = True

[mypy-mmdet.apis.train]
ignore_errors = True

[mypy-mmdet.apis.test]
ignore_errors = True

[mypy-mmdet.apis.export]
ignore_errors = True

[mypy-mmdet.apis.inference]
ignore_errors = True

[mypy-mmdet.apis.fake_input]
ignore_errors = True
83 changes: 45 additions & 38 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,40 +1,47 @@
repos:
- repo: https://gitlab.com/pycqa/flake8.git
rev: 3.8.3
# - repo: https://gitlab.com/pycqa/flake8.git
# rev: 3.8.3
# hooks:
# - id: flake8
# - repo: https://github.com/asottile/seed-isort-config
# rev: v2.2.0
# hooks:
# - id: seed-isort-config
# - repo: https://github.com/timothycrosley/isort
# rev: 4.3.21
# hooks:
# - id: isort
# - repo: https://github.com/pre-commit/mirrors-yapf
# rev: v0.30.0
# hooks:
# - id: yapf
# - repo: https://github.com/pre-commit/pre-commit-hooks
# rev: v3.1.0
# hooks:
# - id: trailing-whitespace
# - id: check-yaml
# - id: end-of-file-fixer
# - id: requirements-txt-fixer
# - id: double-quote-string-fixer
# - id: check-merge-conflict
# - id: fix-encoding-pragma
# args: ["--remove"]
# - id: mixed-line-ending
# args: ["--fix=lf"]
# - repo: https://github.com/jumanjihouse/pre-commit-hooks
# rev: 2.1.4
# hooks:
# - id: markdownlint
# args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036"]
# - repo: https://github.com/myint/docformatter
# rev: v1.3.1
# hooks:
# - id: docformatter
# args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.812'
hooks:
- id: flake8
- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config
- repo: https://github.com/timothycrosley/isort
rev: 4.3.21
hooks:
- id: isort
- repo: https://github.com/pre-commit/mirrors-yapf
rev: v0.30.0
hooks:
- id: yapf
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.1.0
hooks:
- id: trailing-whitespace
- id: check-yaml
- id: end-of-file-fixer
- id: requirements-txt-fixer
- id: double-quote-string-fixer
- id: check-merge-conflict
- id: fix-encoding-pragma
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://github.com/jumanjihouse/pre-commit-hooks
rev: 2.1.4
hooks:
- id: markdownlint
args: ["-r", "~MD002,~MD013,~MD024,~MD029,~MD033,~MD034,~MD036"]
- repo: https://github.com/myint/docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- id: mypy
args: ["--config-file=.mypy.ini"]
additional_dependencies: [numpy==1.19.5, types-PyYAML, attrs==21.2.*, types-requests, types-Deprecated, types-docutils, types_futures, types-python-dateutil]
exclude: "ci|configs|data|demo|docker|docs|experiments|resources|tools|.dev_scripts|tests|setup.py"
32 changes: 16 additions & 16 deletions mmdet/apis/ote/apis/detection/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@
from .train_task import OTEDetectionTrainingTask

__all__ = [
config_from_string,
config_to_string,
generate_label_schema,
get_task_class,
load_template,
OpenVINODetectionTask,
OTEDetectionConfig,
OTEDetectionInferenceTask,
OTEDetectionNNCFTask,
OTEDetectionTrainingTask,
patch_config,
prepare_for_testing,
prepare_for_training,
save_config_to_file,
set_hyperparams,
]
'config_from_string',
'config_to_string',
'generate_label_schema',
'get_task_class',
'load_template',
'OpenVINODetectionTask',
'OTEDetectionConfig',
'OTEDetectionInferenceTask',
'OTEDetectionNNCFTask',
'OTEDetectionTrainingTask',
'patch_config',
'prepare_for_testing',
'prepare_for_training',
'save_config_to_file',
'set_hyperparams',
]
4 changes: 2 additions & 2 deletions mmdet/apis/ote/apis/detection/inference_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def hook(module, input, output):

@staticmethod
def _infer_detector(model: torch.nn.Module, config: Config, dataset: DatasetEntity, dump_features: bool = False,
eval: Optional[bool] = False, metric_name: Optional[str] = 'mAP') -> Tuple[List, float]:
eval: bool = False, metric_name: str = 'mAP') -> Tuple[List, float]:
model.eval()
test_config = prepare_for_testing(config, dataset)
mm_val_dataset = build_dataset(test_config.data.test)
Expand Down Expand Up @@ -256,7 +256,7 @@ def dummy_dump_features_hook(mod, inp, out):
result = eval_model(return_loss=False, rescale=True, **data)
eval_predictions.extend(result)

metric = None
metric = 0.0
if eval:
metric = mm_val_dataset.evaluate(eval_predictions, metric=metric_name)[metric_name]

Expand Down
6 changes: 3 additions & 3 deletions mmdet/apis/ote/apis/detection/nncf_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging
import os
from collections import defaultdict
from typing import Optional
from typing import DefaultDict, Optional

import torch
from ote_sdk.configuration import cfg_helper
Expand Down Expand Up @@ -189,7 +189,7 @@ def optimize(
else:
update_progress_callback = default_progress_callback
time_monitor = TrainingProgressCallback(update_progress_callback)
learning_curves = defaultdict(OTELoggerHook.Curve)
learning_curves: DefaultDict[str, OTELoggerHook.Curve] = defaultdict(OTELoggerHook.Curve)
training_config = prepare_for_training(config, train_dataset, val_dataset, time_monitor, learning_curves)
mm_train_dataset = build_dataset(training_config.data.train)

Expand Down Expand Up @@ -235,7 +235,7 @@ def save_model(self, output_model: ModelEntity):
hyperparams_str = ids_to_strings(cfg_helper.convert(hyperparams, dict, enum_to_str=True))
labels = {label.name: label.color.rgb_tuple for label in self._labels}
modelinfo = {
'compression_state': self._compression_ctrl.get_compression_state(),
'compression_state': self._compression_ctrl.get_compression_state(), # type: ignore # FIXME.
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AlexanderDokuchaev Please note.

'meta': {
'config': self._config,
'nncf_enable_compression': True,
Expand Down
8 changes: 5 additions & 3 deletions mmdet/apis/ote/apis/detection/openvino_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from compression.graph import load_model, save_model
from compression.graph.model_utils import compress_model_weights, get_nodes_by_type
from compression.pipeline.initializer import create_pipeline
from openvino.inference_engine import ExecutableNetwork
from ote_sdk.entities.annotation import Annotation, AnnotationSceneEntity, AnnotationSceneKind
from ote_sdk.entities.datasets import DatasetEntity
from ote_sdk.entities.inference_parameters import InferenceParameters, default_progress_callback
Expand Down Expand Up @@ -52,7 +53,7 @@
logger = logging.getLogger(__name__)


def get_output(net, outputs, name):
def get_output(net: ExecutableNetwork, outputs: Dict[str, np.ndarray], name: str) -> np.ndarray:
try:
key = net.get_ov_name_for_tensor(name)
assert key in outputs, f'"{key}" is not a valid output identifier'
Expand All @@ -63,7 +64,8 @@ def get_output(net, outputs, name):
return outputs[key]


def extract_detections(output, net, input_size):
def extract_detections(output: Dict[str, np.ndarray], net: ExecutableNetwork,
input_size: Tuple[int, int]) -> Dict[str, np.ndarray]:
if 'detection_out' in output:
detection_out = output['detection_out']
output['labels'] = detection_out[0, 0, :, 1].astype(np.int32)
Expand Down Expand Up @@ -114,7 +116,7 @@ def __init__(
self.confidence_threshold = confidence_threshold

@staticmethod
def resize_image(image: np.ndarray, size: Tuple[int], keep_aspect_ratio: bool = False) -> np.ndarray:
def resize_image(image: np.ndarray, size: Tuple[int, int], keep_aspect_ratio: bool = False) -> np.ndarray:
if not keep_aspect_ratio:
resized_frame = cv2.resize(image, size)
else:
Expand Down
6 changes: 3 additions & 3 deletions mmdet/apis/ote/apis/detection/train_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import os
from collections import defaultdict
from glob import glob
from typing import List, Optional
from typing import DefaultDict, Dict, List, Optional

import torch
from ote_sdk.configuration import cfg_helper
Expand All @@ -45,7 +45,7 @@

class OTEDetectionTrainingTask(OTEDetectionInferenceTask, ITrainingTask):

def _generate_training_metrics(self, learning_curves, map) -> Optional[List[MetricsGroup]]:
def _generate_training_metrics(self, learning_curves: Dict[str, OTELoggerHook.Curve], map: float) -> List[MetricsGroup]:
"""
Parses the mmdetection logs to get metrics from the latest training run

Expand Down Expand Up @@ -97,7 +97,7 @@ def train(self, dataset: DatasetEntity, output_model: ModelEntity, train_paramet
if train_parameters is not None:
update_progress_callback = train_parameters.update_progress
time_monitor = TrainingProgressCallback(update_progress_callback)
learning_curves = defaultdict(OTELoggerHook.Curve)
learning_curves: DefaultDict[str, OTELoggerHook.Curve] = defaultdict(OTELoggerHook.Curve)
training_config = prepare_for_training(config, train_dataset, val_dataset, time_monitor, learning_curves)
self._training_work_dir = training_config.work_dir
mm_train_dataset = build_dataset(training_config.data.train)
Expand Down
2 changes: 1 addition & 1 deletion mmdet/apis/ote/extension/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@

from .mmdataset import OTEDataset, get_annotation_mmdet_format

__all__ = [OTEDataset, get_annotation_mmdet_format]
__all__ = ['OTEDataset', 'get_annotation_mmdet_format']
6 changes: 3 additions & 3 deletions mmdet/apis/ote/extension/datasets/mmdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# and limitations under the License.

from copy import deepcopy
from typing import List
from typing import List, Sequence, Optional

import numpy as np
from ote_sdk.entities.dataset_item import DatasetItemEntity
Expand Down Expand Up @@ -88,7 +88,7 @@ class _DataInfoProxy:
forwards data access operations to ote_dataset and converts the dataset items to the view
convenient for mmdetection.
"""
def __init__(self, ote_dataset, classes):
def __init__(self, ote_dataset: DatasetEntity, classes: Optional[Sequence[str]]):
self.ote_dataset = ote_dataset
self.CLASSES = classes

Expand All @@ -113,7 +113,7 @@ def __getitem__(self, index):
return data_info

def __init__(self, ote_dataset: DatasetEntity, pipeline, classes=None, test_mode: bool = False):
self.ote_dataset = ote_dataset
self.ote_dataset: DatasetEntity = ote_dataset
self.test_mode = test_mode
self.CLASSES = self.get_classes(classes)

Expand Down
6 changes: 3 additions & 3 deletions mmdet/apis/ote/extension/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,6 @@
from .pipelines import LoadImageFromOTEDataset, LoadAnnotationFromOTEDataset
from .runner import EpochRunnerWithCancel

__all__ = [CancelTrainingHook, FixedMomentumUpdaterHook, LoadImageFromOTEDataset, EpochRunnerWithCancel,
LoadAnnotationFromOTEDataset, OTELoggerHook, OTEProgressHook, EarlyStoppingHook,
ReduceLROnPlateauLrUpdaterHook]
__all__ = ['CancelTrainingHook', 'FixedMomentumUpdaterHook', 'LoadImageFromOTEDataset', 'EpochRunnerWithCancel',
'LoadAnnotationFromOTEDataset', 'OTELoggerHook', 'OTEProgressHook', 'EarlyStoppingHook',
'ReduceLROnPlateauLrUpdaterHook']
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
_e2e_pytest_addoption = None
pass

import config
import e2e_config
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LeonidBeynenson Please note.


def pytest_addoption(parser):
if _e2e_pytest_addoption:
Expand Down
File renamed without changes.
10 changes: 4 additions & 6 deletions tests/test_ote_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from collections import namedtuple, OrderedDict
from copy import deepcopy
from pprint import pformat
from typing import List, Optional, Union
from typing import List, Optional, Tuple, Union

import pytest
import yaml
Expand All @@ -34,10 +34,8 @@
from ote_sdk.entities.model import (
ModelEntity,
ModelFormat,
ModelPrecision,
ModelStatus,
ModelOptimizationType,
OptimizationMethod,
)
from ote_sdk.entities.model_template import parse_model_template, TargetDevice
from ote_sdk.entities.optimization_parameters import OptimizationParameters
Expand All @@ -53,7 +51,7 @@

logger = logging.getLogger(__name__)

def DATASET_PARAMETERS_FIELDS():
def DATASET_PARAMETERS_FIELDS() -> Tuple[str, ...]:
return ('annotations_train',
'images_train_dir',
'annotations_val',
Expand All @@ -63,7 +61,7 @@ def DATASET_PARAMETERS_FIELDS():
)

ROOT_PATH_KEY = '_root_path'
DatasetParameters = namedtuple('DatasetParameters', DATASET_PARAMETERS_FIELDS())
DatasetParameters = namedtuple('DatasetParameters', DATASET_PARAMETERS_FIELDS()) # type: ignore
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@LeonidBeynenson Please note.


@pytest.fixture
def dataset_definitions_fx(request):
Expand Down Expand Up @@ -659,7 +657,7 @@ class OTETestStage:
time the stage is called the exception is re-raised.
"""
def __init__(self, action: BaseOTETestAction,
depends_stages: Optional[List['OTETestStage']]=None):
depends_stages: Optional[List['OTETestStage']] = None):
self.was_processed = False
self.stored_exception = None
self.action = action
Expand Down