Skip to content

Commit

Permalink
fix mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Nov 25, 2024
1 parent 940f2b6 commit ac65b01
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 98 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/mypy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ jobs:
run: poetry install --with dev

- name: Mypy Check
run: poetry run mypy excore/ --install-types
run: yes y | poetry run mypy excore/ --install-types
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ jobs:
- name: Test with pytest
run: |
cd ./tests
poetry run python check.py
poetry run pytest --cov=../excore
poetry run python init.py
poetry run pytest test_config.py
poetry run pytest test_config.py
poetry run pytest test_config.py
Expand Down
11 changes: 5 additions & 6 deletions excore/_misc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import functools
from typing import Any, Callable
from typing import Any, Callable, Sequence

from tabulate import tabulate

Expand All @@ -21,14 +21,13 @@ def _cache(self) -> Any:


def _create_table(
header: str | list[str] | tuple[str, ...],
contents: list[str | tuple[str, ...] | list[str]],
split: bool = True,
header: str | list[str] | tuple[str, ...] | None,
contents: Sequence[str] | Sequence[Sequence[str]],
prefix: str = "\n",
**tabel_kwargs: Any,
) -> str:
if split:
contents = [(i,) for i in contents]
if isinstance(contents[0], str):
contents = [(i,) for i in contents] # type: ignore
if header is None:
header = ()
if not isinstance(header, (list, tuple)):
Expand Down
2 changes: 1 addition & 1 deletion excore/cli/_extention.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def generate_typehints(


def _quote(config: str, override: bool) -> None:
config_paths = []
config_paths: list[str] = []

def _get_path(path, paths):
if not os.path.isdir(path):
Expand Down
14 changes: 7 additions & 7 deletions excore/cli/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
from typing import Any

import astor
import astor # type: ignore
import typer
from typer import Argument as CArg

Expand Down Expand Up @@ -35,11 +35,11 @@ def _has_import_excore(node) -> bool:


def _build_ast(name: str) -> ast.Assign:
targets = [ast.Name(name.upper(), ast.Store)]
func = ast.Name("Registry", ast.Load)
targets = [ast.Name(name.upper(), ast.Store)] # type: ignore
func = ast.Name("Registry", ast.Load) # type: ignore
args = [ast.Constant(name)]
value = ast.Call(func, args, [])
return ast.Assign(targets, value)
value = ast.Call(func, args, []) # type: ignore
return ast.Assign(targets, value) # type: ignore


def _generate_registries(entry="__init__"):
Expand Down Expand Up @@ -81,13 +81,13 @@ def _detect_assign(node: ast.AST, definition: list) -> None:
and hasattr(node.value.func, "id")
and node.value.func.id == "Registry"
):
definition.append(node.value.args[0].value)
definition.append(node.value.args[0].value) # type: ignore


def _detect_registy_difinition() -> bool:
target_file = osp.join(workspace.src_dir, "__init__.py")
logger.info("Detect Registry definition in {}", target_file)
definition = []
definition: list[Any] = []
with open(target_file, encoding="UTF-8") as f:
source_code = ast.parse(f.read())
_detect_assign(source_code, definition)
Expand Down
46 changes: 30 additions & 16 deletions excore/config/_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
from inspect import Parameter, _empty, _ParameterKind, isclass
from types import ModuleType
from typing import Any, Callable, Dict, Sequence, Union, get_args, get_origin
from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Union, get_args, get_origin

import toml

Expand All @@ -20,13 +20,27 @@
if sys.version_info >= (3, 10, 0):
from types import NoneType, UnionType
else:
NoneType = type(None)
NoneType = type(None) # type: ignore

# just a placeholder
class UnionType:
class UnionType: # type: ignore
pass


if TYPE_CHECKING:
from typing import TypedDict

from typing_extensions import NotRequired

class Property(TypedDict):
properties: NotRequired[Property]
type: NotRequired[str]
items: NotRequired[dict]
value: NotRequired[str]
description: NotRequired[str]
required: NotRequired[list[str]]


TYPE_MAPPER: dict[type, str] = {
int: "number", # sometimes default value are not accurate
str: "string",
Expand Down Expand Up @@ -95,22 +109,22 @@ def _check(bases) -> bool:
return False


def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
props = {
def parse_registry(reg: Registry) -> tuple[Property, dict[str, list[str | int]]]:
props: Property = {
"type": "object",
"properties": {},
}
class_mapping = {}
class_mapping: dict[str, list[str | int]] = {}
for name, item_dir in reg.items():
func = _str_to_target(item_dir)
func = _str_to_target(item_dir) # type: ignore
if isinstance(func, ModuleType):
continue
class_mapping[name] = [inspect.getfile(func), inspect.getsourcelines(func)[1]]
doc_string = func.__doc__
is_hook = isclass(func) and issubclass(func, ConfigArgumentHook)
if isclass(func) and _check(func.__bases__):
func = func.__init__
params = inspect.signature(func).parameters
params = inspect.signature(func).parameters # type: ignore
param_props = {"type": "object", "properties": {}}
if doc_string:
# TODO: parse doc string to each parameters
Expand Down Expand Up @@ -140,7 +154,7 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
param_props["properties"] = items
if required:
param_props["required"] = required
props["properties"][name] = param_props
props["properties"][name] = param_props # type: ignore
return props, class_mapping


Expand All @@ -157,7 +171,7 @@ def _remove_optional(anno):
return anno


def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None:
def _parse_inner_types(prop: Property, inner_types: Sequence[type]) -> None:
first_type = inner_types[0]
is_all_the_same = True
for t in inner_types:
Expand All @@ -166,7 +180,7 @@ def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None:
prop["items"] = {"type": TYPE_MAPPER.get(first_type)}


def _parse_typehint(prop: dict, anno: type) -> str | None:
def _parse_typehint(prop: Property, anno: type) -> str | None:
potential_type = TYPE_MAPPER.get(anno)
if potential_type is not None:
return potential_type
Expand All @@ -187,16 +201,16 @@ def _parse_typehint(prop: dict, anno: type) -> str | None:
return potential_type or "string"


def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]:
prop = {}
def parse_single_param(param: Parameter) -> tuple[bool, Property]:
prop: Property = {}
anno = param.annotation
potential_type = None

anno = _remove_optional(anno)

# hardcore for torch.optim
if param.default.__class__.__name__ == "_RequiredParameter":
param._default = _empty
param._default = _empty # type: ignore

if isinstance(anno, str):
raise AnnotationsFutureError(
Expand All @@ -212,9 +226,9 @@ def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]:
if isinstance(param.default, (list, tuple)):
types = [type(t) for t in param.default]
_parse_inner_types(prop, types)
elif param._kind is _ParameterKind.VAR_POSITIONAL:
elif param.kind is _ParameterKind.VAR_POSITIONAL:
return False, {"type": "array"}
elif param._kind is _ParameterKind.VAR_KEYWORD:
elif param.kind is _ParameterKind.VAR_KEYWORD:
return False, {"type": "object"}
if anno is _empty and param.default is _empty:
potential_type = "number"
Expand Down
12 changes: 7 additions & 5 deletions excore/config/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
Copy and adapt from mmengine/config/config.py
"""

from __future__ import annotations

import copy
from argparse import Action, ArgumentParser, Namespace
from typing import Any, Sequence, Union
from typing import Any, Sequence

__all__ = ["DictAction"]

Expand Down Expand Up @@ -37,7 +39,7 @@ def __init__(
self._dict = {}

@staticmethod
def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
def _parse_int_float_bool(val: str) -> int | float | bool | Any:
"""parse int/float/bool value in the string."""
try:
return int(val)
Expand All @@ -54,7 +56,7 @@ def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
return val

@staticmethod
def _parse_iterable(val: str) -> Union[list, tuple, Any]:
def _parse_iterable(val: str) -> list | tuple | Any:
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
Expand Down Expand Up @@ -137,8 +139,8 @@ def __call__(
self,
parser: ArgumentParser,
namespace: Namespace,
values: Union[str, Sequence[Any], None],
option_string: str = None,
values: str | Sequence[Any] | None,
option_string: str | None = None,
):
"""Parse Variables in string and add them into argparser.
Expand Down
9 changes: 6 additions & 3 deletions excore/config/lazy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from copy import deepcopy
from typing import Any

from ..engine.hook import ConfigHookManager
from ..engine.hook import ConfigHookManager, Hook
from ..engine.logging import logger
from ..engine.registry import Registry
from .model import ConfigHookNode, InterNode, ModuleWrapper
Expand Down Expand Up @@ -45,9 +45,10 @@ def build_config_hooks(self) -> None:
hooks = []
if hook_cfgs:
_, base = Registry.find(list(hook_cfgs.keys())[0])
assert base is not None, hook_cfgs
reg = Registry.get_registry(base)
for name, params in hook_cfgs.items():
hook = ConfigHookNode.from_str(reg[name], params)()
hook: Hook = ConfigHookNode.from_str(reg[name], params)() # type: ignore
if hook:
hooks.append(hook)
else:
Expand All @@ -64,7 +65,9 @@ def __getattr__(self, __name: str) -> Any:
def build_all(self) -> tuple[ModuleWrapper, dict[str, Any]]:
if not self.__is_parsed__:
self.parse()
module_dict, isolated_dict = ModuleWrapper(), {}
module_dict = ModuleWrapper()
isolated_dict: dict[str, Any] = {}

self.hooks.call_hooks("pre_build", self, module_dict, isolated_dict)
for name in self.target_modules:
if name not in self._config:
Expand Down
26 changes: 14 additions & 12 deletions excore/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@

from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError
from .._misc import CacheOut
from ..engine.hook import ConfigArgumentHook
from ..engine.hook import ConfigArgumentHook, Hook
from ..engine.logging import logger
from ..engine.registry import Registry

if TYPE_CHECKING:
from types import FunctionType, ModuleType
from typing import Any, Literal, Self
from typing import Any, Literal

from typing_extensions import Self

NodeClassType = Type[Any]
NodeParams = dict[Any, Any]
Expand All @@ -28,11 +30,11 @@

__all__ = ["silent"]

REUSE_FLAG = "@"
INTER_FLAG = "!"
CLASS_FLAG = "$"
REFER_FLAG = "&"
OTHER_FLAG = ""
REUSE_FLAG: Literal["@"] = "@"
INTER_FLAG: Literal["!"] = "!"
CLASS_FLAG: Literal["$"] = "$"
REFER_FLAG: Literal["&"] = "&"
OTHER_FLAG: Literal[""] = ""

LOG_BUILD_MESSAGE = True
DO_NOT_CALL_KEY = "__no_call__"
Expand Down Expand Up @@ -60,7 +62,7 @@ def _is_special(k: str) -> tuple[str, SpecialFlag]:
pattern = re.compile(r"^([@!$&])(.*)$")
match = pattern.match(k)
if match:
return match.group(2), match.group(1)
return match.group(2), match.group(1) # type: ignore
return k, ""


Expand Down Expand Up @@ -118,7 +120,7 @@ def _instantiate(self, params: NodeParams) -> NodeInstance:
)
return module

def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance:
def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance: # type: ignore
if self._no_call:
return self
params = self._get_params(**params)
Expand Down Expand Up @@ -171,7 +173,7 @@ class InterNode(ModuleNode):


class ConfigHookNode(ModuleNode):
def __call__(self, **params: NodeParams) -> NodeInstance | ConfigHookSkipFlag:
def __call__(self, **params: NodeParams) -> NodeInstance | ConfigHookSkipFlag | Hook:
if issubclass(self.cls, ConfigArgumentHook):
return None
params = self._get_params(**params)
Expand All @@ -182,14 +184,14 @@ class ReusedNode(InterNode):
priority: int = 3

@CacheOut()
def __call__(self, **params: NodeParams) -> NodeInstance | NoCallSkipFlag:
def __call__(self, **params: NodeParams) -> NodeInstance | NoCallSkipFlag: # type: ignore
return super().__call__(**params)


class ClassNode(InterNode):
priority: int = 1

def __call__(self) -> NodeClassType | FunctionType:
def __call__(self) -> NodeClassType | FunctionType: # type: ignore
return self.cls


Expand Down
Loading

0 comments on commit ac65b01

Please sign in to comment.