diff --git a/.github/workflows/codestyle-check.yml b/.github/workflows/codestyle-check.yml index 2720b14..7187e75 100644 --- a/.github/workflows/codestyle-check.yml +++ b/.github/workflows/codestyle-check.yml @@ -9,7 +9,11 @@ on: jobs: code-style-check: runs-on: ubuntu-latest - name: CodeStyle Check (Python 3.8) + strategy: + matrix: + python-version: [3.8, 3.9, '3.10'] + + name: CodeStyle Check ${{ matrix.python-version }} steps: - name: Checkout uses: actions/checkout@v3 @@ -17,7 +21,7 @@ jobs: - name: Install python uses: actions/setup-python@v4 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Install dependencies run: | diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml new file mode 100644 index 0000000..3aa6e62 --- /dev/null +++ b/.github/workflows/mypy.yaml @@ -0,0 +1,35 @@ +name: MYPY Check + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9, '3.10'] + + steps: + - uses: actions/checkout@v2 + with: + fetch-depth: 1 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + + - name: Install Poetry + uses: snok/install-poetry@v1 + with: + version: 1.6.1 + + - name: Install Dependencies + run: poetry install --with dev + + - name: Mypy Check + run: poetry run mypy excore diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index a6b9a79..fc788a5 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -31,12 +31,17 @@ jobs: - name: Install Dependencies run: poetry install --with dev - - name: Test with pytest + - name: Initialize ExCore run: | cd ./tests export EXCORE_DEBUG=1 poetry run python init.py + + - name: Test with pytest + run: | + cd ./tests poetry run pytest --cov=../excore + poetry run python init.py poetry run pytest test_config.py poetry run pytest test_config.py poetry run pytest test_config.py diff --git a/excore/__init__.py b/excore/__init__.py index ff27640..07a9b44 100644 --- a/excore/__init__.py +++ b/excore/__init__.py @@ -3,7 +3,7 @@ from rich.traceback import install from . import config, plugins -from ._constants import __author__, __version__, _load_workspace_config, _workspace_cfg +from ._constants import __author__, __version__, workspace from .config.action import DictAction from .config.config import build_all, load from .config.parse import set_primary_fields @@ -43,7 +43,6 @@ install() init_logger() -_load_workspace_config() -set_primary_fields(_workspace_cfg) +set_primary_fields(workspace) _enable_excore_debug() -sys.path.append(_workspace_cfg["base_dir"]) +sys.path.append(workspace.base_dir) diff --git a/excore/_constants.py b/excore/_constants.py index d14333b..3e44f02 100644 --- a/excore/_constants.py +++ b/excore/_constants.py @@ -1,60 +1,80 @@ +from __future__ import annotations + import os import os.path as osp +from dataclasses import dataclass, field +from typing import Any + +import toml from .engine.logging import logger __author__ = "Asthestarsfalll" __version__ = "0.1.1beta" -_cache_base_dir = osp.expanduser("~/.cache/excore/") _workspace_config_file = "./.excore.toml" _registry_cache_file = "registry_cache.pkl" _json_schema_file = "excore_schema.json" _class_mapping_file = "class_mapping.json" -def _load_workspace_config(): - if osp.exists(_workspace_config_file): - _workspace_cfg.update(toml.load(_workspace_config_file)) - logger.ex("load `.excore.toml`") - else: - logger.warning("Please use `excore init` in your command line first") +@dataclass +class _WorkspaceConfig: + name: str = field(default="") + src_dir: str = field(default="") + base_dir: str = field(default="") + cache_base_dir: str = field(default=osp.expanduser("~/.cache/excore/")) + cache_dir: str = field(default="") + registry_cache_file: str = field(default="") + json_schema_file: str = field(default="") + class_mapping_file: str = field(default="") + registries: list[str] = field(default_factory=list) + primary_fields: list[str] = field(default_factory=list) + primary_to_registry: dict[str, str] = field(default_factory=dict) + json_schema_fields: dict[str, str | list[str]] = field(default_factory=dict) + props: dict[Any, Any] = field(default_factory=dict) + @property + def base_name(self): + return osp.split(self.cache_dir)[-1] -def _update_name(base_name): - name = base_name + def __post_init__(self) -> None: + if not osp.exists(_workspace_config_file): + self.base_dir = os.getcwd() + self.cache_dir = self._get_cache_dir() + self.registry_cache_file = osp.join(self.cache_dir, _registry_cache_file) + self.json_schema_file = osp.join(self.cache_dir, _json_schema_file) + self.class_mapping_file = osp.join(self.cache_dir, _class_mapping_file) + logger.warning("Please use `excore init` in your command line first") + else: + self.update(toml.load(_workspace_config_file)) - suffix = 1 - while osp.exists(osp.join(_cache_base_dir, name)): - name = f"{_base_name}_{suffix}" - suffix += 1 + def _get_cache_dir(self) -> str: + base_name = osp.basename(osp.normpath(os.getcwd())) + base_name = self._update_name(base_name) + return osp.join(self.cache_base_dir, base_name) - return name + def _update_name(self, base_name: str) -> str: + name = base_name + suffix = 1 + while osp.exists(osp.join(self.cache_base_dir, name)): + name = f"{base_name}_{suffix}" + suffix += 1 -if not osp.exists(_workspace_config_file): - _base_name = osp.basename(osp.normpath(os.getcwd())) - _base_name = _update_name(_base_name) -else: - import toml # pylint: disable=import-outside-toplevel + return name - cfg = toml.load(_workspace_config_file) - _base_name = cfg["name"] + def update(self, _cfg: dict[Any, Any]) -> None: + self.__dict__.update(_cfg) -_cache_dir = osp.join(_cache_base_dir, _base_name) + def dump(self, path: str) -> None: + with open(path, "w") as f: + cfg = self.__dict__ + cfg.pop("base_dir", None) + toml.dump(cfg, f) -# TODO: Use a data class to store this -_workspace_cfg = dict( - name="", - src_dir="", - base_dir=os.getcwd(), - registries=[], - primary_fields=[], - primary_to_registry={}, - json_schema_fields={}, - props={}, -) +workspace = _WorkspaceConfig() LOGO = r""" ▓█████ ▒██ ██▒ ▄████▄ ▒█████ ██▀███ ▓█████ ▓█ ▀ ▒▒ █ █ ▒░▒██▀ ▀█ ▒██▒ ██▒▓██ ▒ ██▒▓█ ▀ diff --git a/excore/_exceptions.py b/excore/_exceptions.py index a5f2101..7ce6cb8 100644 --- a/excore/_exceptions.py +++ b/excore/_exceptions.py @@ -56,3 +56,7 @@ class HookManagerBuildError(BaseException): class HookBuildError(BaseException): pass + + +class AnnotationsFutureError(Exception): + pass diff --git a/excore/_misc.py b/excore/_misc.py index 03812ac..0fb0eb5 100644 --- a/excore/_misc.py +++ b/excore/_misc.py @@ -1,14 +1,15 @@ +from __future__ import annotations + import functools -import threading -import time +from typing import Any, Callable, Sequence from tabulate import tabulate class CacheOut: - def __call__(self, func): + def __call__(self, func: Callable[..., Any]): @functools.wraps(func) - def _cache(self): + def _cache(self) -> Any: if not hasattr(self, "cached_elem"): cached_elem = func(self) if cached_elem != self: @@ -19,27 +20,14 @@ def _cache(self): return _cache -class FileLock: - def __init__(self, file_path, timeout=15): - self.file_path = file_path - self.timeout = timeout - self.lock = threading.Lock() - - def __enter__(self): - start_time = time.time() - while not self.lock.acquire(False): - if time.time() - start_time >= self.timeout: - raise TimeoutError("Failed to acquire lock on file") - time.sleep(0.1) - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.lock.release() - - -def _create_table(header, contents, split=True, prefix="\n", **tabel_kwargs): - if split: - contents = [(i,) for i in contents] +def _create_table( + header: str | list[str] | tuple[str, ...] | None, + contents: Sequence[str] | Sequence[Sequence[str]], + prefix: str = "\n", + **tabel_kwargs: Any, +) -> str: + if len(contents) > 0 and isinstance(contents[0], str): + contents = [(i,) for i in contents] # type: ignore if header is None: header = () if not isinstance(header, (list, tuple)): diff --git a/excore/cli/_cache.py b/excore/cli/_cache.py index 1516e60..743a95b 100644 --- a/excore/cli/_cache.py +++ b/excore/cli/_cache.py @@ -1,14 +1,17 @@ +from __future__ import annotations + import os import typer -from .._constants import _base_name, _cache_base_dir, _cache_dir +from excore import workspace + from .._misc import _create_table from ..engine.logging import logger from ._app import app -def _clear_cache(cache_dir): +def _clear_cache(cache_dir: str) -> None: if os.path.exists(cache_dir): import shutil # pylint: disable=import-outside-toplevel @@ -19,38 +22,41 @@ def _clear_cache(cache_dir): @app.command() -def clear_cache(): +def clear_cache() -> None: """ Remove the cache folder which belongs to current workspace. """ - if not typer.confirm(f"Are you sure you want to clear cache of {_base_name}?"): + if not typer.confirm( + f"Are you sure you want to clear cache of {workspace.name}?" + f" Cache dir is {workspace.cache_dir}." + ): return - - target = os.path.join(_cache_dir, _base_name) - _clear_cache(target) + _clear_cache(workspace.cache_dir) @app.command() -def clear_all_cache(): +def clear_all_cache() -> None: """ Remove the whole cache folder. """ + if not os.path.exists(workspace.cache_base_dir): + logger.warning("Cache dir {} does not exist", workspace.cache_base_dir) + return + print(_create_table("Cache Names", os.listdir(workspace.cache_base_dir))) if not typer.confirm("Are you sure you want to clear all cache?"): return - _clear_cache(_cache_base_dir) + _clear_cache(workspace.cache_base_dir) @app.command() -def cache_list(): +def cache_list() -> None: """ Show cache folders. """ - tabel = _create_table("NAMES", os.listdir(_cache_base_dir)) + tabel = _create_table("Cache Names", os.listdir(workspace.cache_base_dir)) logger.info(tabel) @app.command() -def cache_dir(): - # if not os.path.exists(_workspace_config_file): - # raise RuntimeError("Not in ExCore project") - print(_cache_dir) +def cache_dir() -> None: + print(workspace.cache_dir) diff --git a/excore/cli/_extention.py b/excore/cli/_extention.py index ac7253a..f6ae725 100644 --- a/excore/cli/_extention.py +++ b/excore/cli/_extention.py @@ -1,42 +1,44 @@ +from __future__ import annotations + import os import sys -from typing import Optional from typer import Argument as CArg from typer import Option as COp from typing_extensions import Annotated -from .._constants import _cache_base_dir, _workspace_cfg +from excore import workspace + from ..config._json_schema import _generate_json_schema_and_class_mapping, _generate_taplo_config from ..engine.logging import logger from ._app import app @app.command() -def config_extention(): +def config_extention() -> None: """ Generate json_schema for onfig completion and class_mapping for class navigation. """ - target_dir = os.path.join(_cache_base_dir, _workspace_cfg["name"]) - os.makedirs(target_dir, exist_ok=True) - _generate_taplo_config(target_dir) - if not _workspace_cfg["json_schema_fields"]: + _generate_taplo_config() + if not workspace.json_schema_fields: logger.warning("You should set json_schema_fields first") sys.exit(0) - _generate_json_schema_and_class_mapping(_workspace_cfg["json_schema_fields"]) + _generate_json_schema_and_class_mapping(workspace.json_schema_fields) -def _generate_typehints(entry: str, class_name: str, info_class_name: str, config: Optional[str]): - if not _workspace_cfg["primary_fields"]: +def _generate_typehints( + entry: str, class_name: str, info_class_name: str, config: str = "" +) -> None: + if not workspace.primary_fields: logger.critical("Please initialize the workspace first.") return - target_file = os.path.join(_workspace_cfg["src_dir"], entry + ".py") + target_file = os.path.join(workspace.src_dir, entry + ".py") logger.info(f"Generating module type hints in {target_file}.") with open(target_file, "w", encoding="UTF-8") as f: f.write(f"from typing import Union{', Any' if config else ''}\n") f.write("from excore.config.model import ModuleNode, ModuleWrapper\n\n") f.write(f"class {class_name}:\n") - for i in _workspace_cfg["primary_fields"]: + for i in workspace.primary_fields: f.write(f" {i}: Union[ModuleNode, ModuleWrapper]\n") logger.info(f"Generating isolated objects type hints in {target_file}.") if config: @@ -53,21 +55,19 @@ def _generate_typehints(entry: str, class_name: str, info_class_name: str, confi @app.command() def generate_typehints( - entry: Annotated[str, CArg(help="The file to generate.")] = "module_types", + entry: str = CArg(default="module_types", help="The file to generate."), class_name: Annotated[str, COp(help="The class name of type hints.")] = "TypedModules", info_class_name: Annotated[str, COp(help="The class name of run_info.")] = "RunInfo", - config: Annotated[ - Optional[str], COp(help="Used generate type hints for isolated objects.") - ] = None, -): + config: Annotated[str, COp(help="Used generate type hints for isolated objects.")] = "", +) -> None: """ Generate type hints for modules and isolated objects. """ _generate_typehints(entry, class_name, info_class_name, config) -def _quote(config: str, override: bool): - config_paths = [] +def _quote(config: str, override: bool) -> None: + config_paths: list[str] = [] def _get_path(path, paths): if not os.path.isdir(path): @@ -94,7 +94,7 @@ def _get_path(path, paths): def quote( config: Annotated[str, CArg(help="Target config file or folder.")], override: Annotated[bool, COp(help="Whether to override configs.")] = False, -): +) -> None: """ Quote all special keys in target config files. """ diff --git a/excore/cli/_registry.py b/excore/cli/_registry.py index 2c5f6e5..6719921 100644 --- a/excore/cli/_registry.py +++ b/excore/cli/_registry.py @@ -1,22 +1,24 @@ +from __future__ import annotations + import ast import importlib import os import os.path as osp import sys +from typing import Any -import astor +import astor # type: ignore import typer from typer import Argument as CArg -from typing_extensions import Annotated -from .._constants import _cache_base_dir, _workspace_cfg, _workspace_config_file +from .._constants import _workspace_config_file, workspace from .._misc import _create_table from ..engine.logging import logger from ..engine.registry import Registry from ._app import app -def _has_import_excore(node): +def _has_import_excore(node) -> bool: if isinstance(node, ast.Module): for child in node.body: f = _has_import_excore(child) @@ -33,18 +35,18 @@ def _has_import_excore(node): def _build_ast(name: str) -> ast.Assign: - targets = [ast.Name(name.upper(), ast.Store)] - func = ast.Name("Registry", ast.Load) + targets = [ast.Name(name.upper(), ast.Store)] # type: ignore + func = ast.Name("Registry", ast.Load) # type: ignore args = [ast.Constant(name)] - value = ast.Call(func, args, []) - return ast.Assign(targets, value) + value = ast.Call(func, args, []) # type: ignore + return ast.Assign(targets, value) # type: ignore def _generate_registries(entry="__init__"): - if not _workspace_cfg["primary_fields"]: + if not workspace.primary_fields: return logger.info("Generating Registry definition code.") - target_file = osp.join(_workspace_cfg["src_dir"], entry + ".py") + target_file = osp.join(workspace.src_dir, entry + ".py") if not osp.exists(target_file): with open(target_file, "w", encoding="UTF-8") as f: @@ -59,7 +61,7 @@ def _generate_registries(entry="__init__"): name = [ast.alias("Registry", None)] source_code.body.insert(0, ast.ImportFrom("excore", name, 0)) - for name in _get_registries(_workspace_cfg["registries"]): + for name in _get_registries(workspace.registries): if name.startswith("*"): name = name[1:] source_code.body.append(_build_ast(name)) @@ -69,7 +71,7 @@ def _generate_registries(entry="__init__"): logger.success("Generate Registry definition in {} according to `primary_fields`", target_file) -def _detect_assign(node, definition): +def _detect_assign(node: ast.AST, definition: list) -> None: if isinstance(node, ast.Module): for child in node.body: _detect_assign(child, definition) @@ -79,19 +81,19 @@ def _detect_assign(node, definition): and hasattr(node.value.func, "id") and node.value.func.id == "Registry" ): - definition.append(node.value.args[0].value) + definition.append(node.value.args[0].value) # type: ignore def _detect_registy_difinition() -> bool: - target_file = osp.join(_workspace_cfg["src_dir"], "__init__.py") + target_file = osp.join(workspace.src_dir, "__init__.py") logger.info("Detect Registry definition in {}", target_file) - definition = [] + definition: list[Any] = [] with open(target_file, encoding="UTF-8") as f: source_code = ast.parse(f.read()) _detect_assign(source_code, definition) if len(definition) > 0: logger.info("Find Registry definition: {}", definition) - _workspace_cfg["registries"] = definition + workspace.registries = definition return True return False @@ -104,7 +106,9 @@ def _format(reg_and_fields: str) -> str: return splits[0] + ": " + ", ".join(fields) -def _parse_registries(reg_and_fields): +def _parse_registries( + reg_and_fields: list[str], +) -> tuple[list[str], dict[Any, Any], dict[Any, Any]]: fields = [i[1:].split(":") for i in reg_and_fields if i.startswith("*")] targets = [] rev = {} @@ -121,16 +125,15 @@ def _parse_registries(reg_and_fields): rev[j] = i[0] targets.extend(tar) json_schema["isolated_fields"] = isolated_fields - return set(targets), rev, json_schema + return list(set(targets)), rev, json_schema -def _get_registries(reg_and_fields): +def _get_registries(reg_and_fields) -> list[str]: return [i.split(":")[0] for i in reg_and_fields] -def _update(is_init=True, entry="__init__"): - target_dir = osp.join(_cache_base_dir, _workspace_cfg["name"]) - os.makedirs(target_dir, exist_ok=True) +def _update(is_init: bool = True, entry: str = "__init__") -> None: + os.makedirs(workspace.cache_dir, exist_ok=True) logger.success("Generate `.taplo.toml`") if is_init: if not _detect_registy_difinition(): @@ -142,14 +145,14 @@ def _update(is_init=True, entry="__init__"): regs.append(_format(inp)) else: break - _workspace_cfg["registries"] = regs + workspace.registries = regs else: logger.imp("You can define fields later.") ( - _workspace_cfg["primary_fields"], - _workspace_cfg["primary_to_registry"], - _workspace_cfg["json_schema_fields"], - ) = _parse_registries(_workspace_cfg["registries"]) + workspace.primary_fields, + workspace.primary_to_registry, + workspace.json_schema_fields, + ) = _parse_registries(workspace.registries) _generate_registries(entry) else: logger.imp( @@ -158,20 +161,20 @@ def _update(is_init=True, entry="__init__"): ) else: ( - _workspace_cfg["primary_fields"], - _workspace_cfg["primary_to_registry"], - _workspace_cfg["json_schema_fields"], - ) = _parse_registries([_format(i) for i in _workspace_cfg["registries"]]) + workspace.primary_fields, + workspace.primary_to_registry, + workspace.json_schema_fields, + ) = _parse_registries([_format(i) for i in workspace.registries]) logger.success("Update primary_fields") -def _get_default_module_name(target_dir): +def _get_default_module_name(target_dir: str) -> str: assert os.path.isdir(target_dir) full_path = os.path.abspath(target_dir) return full_path.split(os.sep)[-1] -def _auto_register(target_dir, module_name): +def _auto_register(target_dir: str, module_name: str) -> None: for file_name in os.listdir(target_dir): full_path = os.path.join(target_dir, file_name) if os.path.isdir(full_path): @@ -190,44 +193,43 @@ def _auto_register(target_dir, module_name): @app.command() -def auto_register(): +def auto_register() -> None: """ Automatically import all modules in `src_dir` and register all modules, then dump to files. """ if not os.path.exists(_workspace_config_file): logger.critical("Please run `excore init` in your command line first!") sys.exit(0) - target_dir = osp.abspath(_workspace_cfg["src_dir"]) + target_dir = osp.abspath(workspace.src_dir) module_name = _get_default_module_name(target_dir) - sys.path.append(os.getcwd()) _auto_register(target_dir, module_name) Registry.dump() @app.command() -def primary_fields(): +def primary_fields() -> None: """ Show primary_fields. """ - table = _create_table("FIELDS", _parse_registries(_workspace_cfg["registries"])[0]) + table = _create_table("FIELDS", _parse_registries(workspace.registries)[0]) logger.info(table) @app.command() -def registries(): +def registries() -> None: """ Show registries. """ - table = _create_table("Registry", [_format(i) for i in _workspace_cfg["registries"]]) + table = _create_table("Registry", [_format(i) for i in workspace.registries]) logger.info(table) @app.command() def generate_registries( - entry: Annotated[ - str, CArg(help="Used for detect or generate Registry definition code") - ] = "__init__", -): + entry: str = CArg( + default="__init__", help="Used for detect or generate Registry definition code" + ), +) -> None: """ Generate registries definition code according to workspace config. """ diff --git a/excore/cli/_workspace.py b/excore/cli/_workspace.py index 21d465d..2be4b6a 100644 --- a/excore/cli/_workspace.py +++ b/excore/cli/_workspace.py @@ -1,34 +1,24 @@ import os import os.path as osp -import toml import typer from typer import Argument as CArg from typer import Option as COp from typing_extensions import Annotated -from .._constants import ( - LOGO, - _base_name, - _cache_base_dir, - _cache_dir, - _workspace_cfg, - _workspace_config_file, -) +from .._constants import LOGO, _workspace_config_file, workspace from ..engine.logging import logger from ._app import app from ._registry import _update -def _dump_workspace_config(): +def _dump_workspace_config() -> None: logger.info("Dump config to {}", _workspace_config_file) - _workspace_cfg.pop("base_dir", None) - with open(_workspace_config_file, "w", encoding="UTF-8") as f: - toml.dump(_workspace_cfg, f) + workspace.dump(_workspace_config_file) @app.command() -def update(): +def update() -> None: """ Update workspace config file. """ @@ -42,11 +32,11 @@ def init( entry: Annotated[ str, CArg(help="Used for detect or generate Registry definition code") ] = "__init__", -): +) -> None: """ Initialize workspace and generate a config file. """ - if osp.exists(_cache_dir) and not force: + if osp.exists(_workspace_config_file) and not force: logger.warning("excore.toml already existed!") return cwd = os.getcwd() @@ -56,17 +46,17 @@ def init( colors=True, ) logger.opt(colors=True).info(f"It will be generated in {cwd}\n") - logger.opt(colors=True).info(f"WorkSpace Name [{_base_name}]:") - name = typer.prompt("", default=_base_name, show_default=False, prompt_suffix="") - if not force and os.path.exists(os.path.join(_cache_base_dir, _base_name)): + logger.opt(colors=True).info(f"WorkSpace Name [{workspace.base_name}]:") + name = typer.prompt("", default=workspace.base_name, show_default=False, prompt_suffix="") + if not force and os.path.exists(workspace.cache_dir): logger.warning(f"name {name} already existed!") return logger.opt(colors=True).info("Source Code Directory(relative path):") src_dir = typer.prompt("", prompt_suffix="") - _workspace_cfg["name"] = name - _workspace_cfg["src_dir"] = src_dir + workspace.name = name + workspace.src_dir = src_dir _update(True, entry) _dump_workspace_config() diff --git a/excore/config/_json_schema.py b/excore/config/_json_schema.py index 5408b85..247535a 100644 --- a/excore/config/_json_schema.py +++ b/excore/config/_json_schema.py @@ -2,31 +2,45 @@ import inspect import json -import os -import os.path as osp import sys from inspect import Parameter, _empty, _ParameterKind, isclass from types import ModuleType -from typing import Any, Dict, Sequence, Union, get_args, get_origin +from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Union, get_args, get_origin import toml -from .._constants import _cache_dir, _class_mapping_file, _json_schema_file +from excore import workspace + +from .._exceptions import AnnotationsFutureError from ..engine.hook import ConfigArgumentHook from ..engine.logging import logger from ..engine.registry import Registry, load_registries from .model import _str_to_target if sys.version_info >= (3, 10, 0): - from types import NoneType, UnionType + from types import NoneType, UnionType # type: ignore else: - NoneType = type(None) + NoneType = type(None) # type: ignore # just a placeholder - class UnionType: + class UnionType: # type: ignore pass +if TYPE_CHECKING: + from typing import TypedDict + + from typing_extensions import NotRequired + + class Property(TypedDict): + properties: NotRequired[Property] + type: NotRequired[str] + items: NotRequired[dict] + value: NotRequired[str] + description: NotRequired[str] + required: NotRequired[list[str]] + + TYPE_MAPPER: dict[type, str] = { int: "number", # sometimes default value are not accurate str: "string", @@ -76,8 +90,8 @@ def _generate_json_schema_and_class_mapping( for name, v in props["properties"].items(): schema["properties"][name] = v json_str = json.dumps(schema, indent=2) - save_path = save_path or _json_schema_path() - class_mapping_save_path = class_mapping_save_path or _class_mapping_path() + save_path = save_path or workspace.json_schema_file + class_mapping_save_path = class_mapping_save_path or workspace.class_mapping_file with open(save_path, "w", encoding="UTF-8") as f: f.write(json_str) logger.success("json schema has been written to {}", save_path) @@ -95,14 +109,14 @@ def _check(bases) -> bool: return False -def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]: - props = { +def parse_registry(reg: Registry) -> tuple[Property, dict[str, list[str | int]]]: + props: Property = { "type": "object", "properties": {}, } - class_mapping = {} + class_mapping: dict[str, list[str | int]] = {} for name, item_dir in reg.items(): - func = _str_to_target(item_dir) + func = _str_to_target(item_dir) # type: ignore if isinstance(func, ModuleType): continue class_mapping[name] = [inspect.getfile(func), inspect.getsourcelines(func)[1]] @@ -110,7 +124,7 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]: is_hook = isclass(func) and issubclass(func, ConfigArgumentHook) if isclass(func) and _check(func.__bases__): func = func.__init__ - params = inspect.signature(func).parameters + params = inspect.signature(func).parameters # type: ignore param_props = {"type": "object", "properties": {}} if doc_string: # TODO: parse doc string to each parameters @@ -122,11 +136,16 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]: continue try: is_required, item = parse_single_param(param_obj) - except Exception: + except Exception as e: from rich.console import Console - logger.error(f"Skip {param_obj.name} of {name}") Console().print_exception() + if isinstance(e, AnnotationsFutureError): + logger.error( + f"Skip {name} due to mismatch of python version and annotations future." + ) + break + logger.error(f"Skip parameter {param_obj.name} of {name}") continue items[param_name] = item if is_required: @@ -135,14 +154,16 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]: param_props["properties"] = items if required: param_props["required"] = required - props["properties"][name] = param_props + props["properties"][name] = param_props # type: ignore return props, class_mapping def _remove_optional(anno): origin = get_origin(anno) + if origin is not Union: + return anno inner_types = get_args(anno) - if origin is not Union and len(inner_types) != 2: + if len(inner_types) != 2: return anno filter_types = [i for i in inner_types if i is not NoneType] if len(filter_types) == 1: @@ -150,7 +171,7 @@ def _remove_optional(anno): return anno -def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None: +def _parse_inner_types(prop: Property, inner_types: Sequence[type]) -> None: first_type = inner_types[0] is_all_the_same = True for t in inner_types: @@ -159,11 +180,13 @@ def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None: prop["items"] = {"type": TYPE_MAPPER.get(first_type)} -def _parse_typehint(prop: dict, anno: type) -> str | None: +def _parse_typehint(prop: Property, anno: type) -> str | None: potential_type = TYPE_MAPPER.get(anno) if potential_type is not None: return potential_type origin = get_origin(anno) + if anno is Callable: + return "string" inner_types = get_args(anno) if origin in (Sequence, list, tuple): potential_type = "array" @@ -178,8 +201,8 @@ def _parse_typehint(prop: dict, anno: type) -> str | None: return potential_type or "string" -def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]: - prop = {} +def parse_single_param(param: Parameter) -> tuple[bool, Property]: + prop: Property = {} anno = param.annotation potential_type = None @@ -187,24 +210,25 @@ def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]: # hardcore for torch.optim if param.default.__class__.__name__ == "_RequiredParameter": - param._default = _empty + param._default = _empty # type: ignore if isinstance(anno, str): - raise RuntimeError( + raise AnnotationsFutureError( "Use a higher version of python, e.g. 3.10, " "and remove `from __future__ import annotations`." ) elif anno is not _empty: potential_type = _parse_typehint(prop, anno) # determine type by default value - elif param.default is not _empty: - potential_type = TYPE_MAPPER[type(param.default)] + elif param.default is not _empty and param.default is not None: + # TODO: Allow user to add more type mapper. + potential_type = TYPE_MAPPER.get(type(param.default), "number") if isinstance(param.default, (list, tuple)): types = [type(t) for t in param.default] _parse_inner_types(prop, types) - elif param._kind is _ParameterKind.VAR_POSITIONAL: + elif param.kind is _ParameterKind.VAR_POSITIONAL: return False, {"type": "array"} - elif param._kind is _ParameterKind.VAR_KEYWORD: + elif param.kind is _ParameterKind.VAR_KEYWORD: return False, {"type": "object"} if anno is _empty and param.default is _empty: potential_type = "number" @@ -213,18 +237,10 @@ def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]: return param.default is _empty, prop -def _json_schema_path() -> str: - return os.path.join(_cache_dir, _json_schema_file) - - -def _class_mapping_path() -> str: - return os.path.join(_cache_dir, _class_mapping_file) - - -def _generate_taplo_config(path: str) -> None: +def _generate_taplo_config() -> None: cfg = dict( schema=dict( - path=osp.join(osp.expanduser(path), _json_schema_file), + path=workspace.json_schema_file, enabled=True, ), formatting=dict(align_entries=False), diff --git a/excore/config/action.py b/excore/config/action.py index b7c4943..4af9acc 100644 --- a/excore/config/action.py +++ b/excore/config/action.py @@ -2,9 +2,11 @@ Copy and adapt from mmengine/config/config.py """ +from __future__ import annotations + import copy from argparse import Action, ArgumentParser, Namespace -from typing import Any, Sequence, Union +from typing import Any, Sequence __all__ = ["DictAction"] @@ -37,7 +39,7 @@ def __init__( self._dict = {} @staticmethod - def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: + def _parse_int_float_bool(val: str) -> int | float | bool | Any: """parse int/float/bool value in the string.""" try: return int(val) @@ -54,7 +56,7 @@ def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]: return val @staticmethod - def _parse_iterable(val: str) -> Union[list, tuple, Any]: + def _parse_iterable(val: str) -> list | tuple | Any: """Parse iterable values in the string. All elements inside '()' or '[]' are treated as iterable values. @@ -137,8 +139,8 @@ def __call__( self, parser: ArgumentParser, namespace: Namespace, - values: Union[str, Sequence[Any], None], - option_string: str = None, + values: str | Sequence[Any] | None, + option_string: str | None = None, ): """Parse Variables in string and add them into argparser. diff --git a/excore/config/config.py b/excore/config/config.py index b68a1ea..f457dca 100644 --- a/excore/config/config.py +++ b/excore/config/config.py @@ -1,6 +1,8 @@ +from __future__ import annotations + import os import time -from typing import Dict, Optional, Tuple +from typing import Any import toml @@ -14,10 +16,6 @@ __all__ = ["load", "build_all", "load_config"] -# TODO: Improve error messages. high priority. -# TODO: Support multiple same modules parsing. - - BASE_CONFIG_KEY = "__base__" @@ -39,7 +37,7 @@ def load_config(filename: str, base_key: str = "__base__") -> ConfigDict: return base_cfg -def _merge_config(base_cfg, new_cfg): +def _merge_config(base_cfg: ConfigDict, new_cfg: dict) -> None: for k, v in new_cfg.items(): if k in base_cfg and isinstance(v, dict): _merge_config(base_cfg[k], v) @@ -50,8 +48,8 @@ def _merge_config(base_cfg, new_cfg): def load( filename: str, *, - dump_path: Optional[str] = None, - update_dict: Optional[Dict] = None, + dump_path: str | None = None, + update_dict: dict[str, Any] | None = None, base_key: str = BASE_CONFIG_KEY, parse_config: bool = True, ) -> LazyConfig: @@ -89,7 +87,7 @@ def load( return lazy_config -def build_all(cfg: LazyConfig) -> Tuple[ModuleWrapper, dict]: +def build_all(cfg: LazyConfig) -> tuple[ModuleWrapper, dict[str, Any]]: st = time.time() modules = cfg.build_all() logger.success("Modules building costs {:.4f}s!", time.time() - st) diff --git a/excore/config/lazy_config.py b/excore/config/lazy_config.py index 2c3b465..a38a6c2 100644 --- a/excore/config/lazy_config.py +++ b/excore/config/lazy_config.py @@ -1,8 +1,10 @@ +from __future__ import annotations + import time from copy import deepcopy -from typing import Any, Dict, Tuple +from typing import Any -from ..engine.hook import ConfigHookManager +from ..engine.hook import ConfigHookManager, Hook from ..engine.logging import logger from ..engine.registry import Registry from .model import ConfigHookNode, InterNode, ModuleWrapper @@ -10,7 +12,9 @@ class LazyConfig: - hook_key = "ConfigHook" + hook_key: str = "ConfigHook" + modules_dict: dict[str, ModuleWrapper] + isolated_dict: dict[str, Any] def __init__(self, config: ConfigDict) -> None: self.modules_dict, self.isolated_dict = {}, {} @@ -21,29 +25,30 @@ def __init__(self, config: ConfigDict) -> None: self._original_config = deepcopy(config) self.__is_parsed__ = False - def parse(self): + def parse(self) -> None: st = time.time() self.build_config_hooks() self._config.parse() logger.success("Config parsing cost {:.4f}s!", time.time() - st) self.__is_parsed__ = True - logger.ex(self._config) + logger.ex(str(self._config)) @property - def config(self): + def config(self) -> ConfigDict: return self._original_config - def update(self, cfg: "LazyConfig"): + def update(self, cfg: LazyConfig) -> None: self._config.update(cfg._config) - def build_config_hooks(self): + def build_config_hooks(self) -> None: hook_cfgs = self._config.pop(LazyConfig.hook_key, []) hooks = [] if hook_cfgs: _, base = Registry.find(list(hook_cfgs.keys())[0]) + assert base is not None, hook_cfgs reg = Registry.get_registry(base) for name, params in hook_cfgs.items(): - hook = ConfigHookNode.from_str(reg[name], params)() + hook: Hook = ConfigHookNode.from_str(reg[name], params)() # type: ignore if hook: hooks.append(hook) else: @@ -57,10 +62,12 @@ def __getattr__(self, __name: str) -> Any: return self._config[__name] raise AttributeError(__name) - def build_all(self) -> Tuple[ModuleWrapper, Dict]: + def build_all(self) -> tuple[ModuleWrapper, dict[str, Any]]: if not self.__is_parsed__: self.parse() - module_dict, isolated_dict = ModuleWrapper(), {} + module_dict = ModuleWrapper() + isolated_dict: dict[str, Any] = {} + self.hooks.call_hooks("pre_build", self, module_dict, isolated_dict) for name in self.target_modules: if name not in self._config: @@ -78,5 +85,5 @@ def build_all(self) -> Tuple[ModuleWrapper, Dict]: def dump(self, dump_path: str) -> None: self._original_config.dump(dump_path) - def __str__(self): + def __str__(self) -> str: return str(self._config) diff --git a/excore/config/model.py b/excore/config/model.py index bbcf2ff..aaa07c4 100644 --- a/excore/config/model.py +++ b/excore/config/model.py @@ -1,33 +1,51 @@ +from __future__ import annotations + import importlib import os import re from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import TYPE_CHECKING, Type, Union from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError from .._misc import CacheOut -from ..engine.hook import ConfigArgumentHook +from ..engine.hook import ConfigArgumentHook, Hook from ..engine.logging import logger from ..engine.registry import Registry +if TYPE_CHECKING: + from types import FunctionType, ModuleType + from typing import Any, Dict, Literal + + from typing_extensions import Self + + NodeClassType = Type[Any] + NodeParams = Dict[Any, Any] + NodeInstance = object + + NoCallSkipFlag = Self + ConfigHookSkipFlag = Type[None] + + SpecialFlag = Literal["@", "!", "$", "&", ""] + + __all__ = ["silent"] -REUSE_FLAG = "@" -INTER_FLAG = "!" -CLASS_FLAG = "$" -REFER_FLAG = "&" -OTHER_FLAG = "" +REUSE_FLAG: Literal["@"] = "@" +INTER_FLAG: Literal["!"] = "!" +CLASS_FLAG: Literal["$"] = "$" +REFER_FLAG: Literal["&"] = "&" +OTHER_FLAG: Literal[""] = "" LOG_BUILD_MESSAGE = True DO_NOT_CALL_KEY = "__no_call__" -def silent(): +def silent() -> None: global LOG_BUILD_MESSAGE # pylint: disable=global-statement LOG_BUILD_MESSAGE = False -def _is_special(k: str) -> Tuple[str, str]: +def _is_special(k: str) -> tuple[str, SpecialFlag]: """ Determine if the given string begin with target special character. `@` denotes reused module, which will only be built once and cached out. @@ -41,31 +59,27 @@ def _is_special(k: str) -> Tuple[str, str]: Returns: Tuple[str, str]: A tuple containing the modified string and the special character. """ - if k.startswith(REUSE_FLAG): - return k[1:], REUSE_FLAG - if k.startswith(INTER_FLAG): - return k[1:], INTER_FLAG - if k.startswith(CLASS_FLAG): - return k[1:], CLASS_FLAG - if k.startswith(REFER_FLAG): - return k[1:], REFER_FLAG + pattern = re.compile(r"^([@!$&])(.*)$") + match = pattern.match(k) + if match: + return match.group(2), match.group(1) # type: ignore return k, "" -def _str_to_target(module_name): - module_name = module_name.split(".") - if len(module_name) == 1: - return importlib.import_module(module_name[0]) - target_name = module_name.pop(-1) +def _str_to_target(module_name: str) -> ModuleType | NodeClassType | FunctionType: + module_names = module_name.split(".") + if len(module_names) == 1: + return importlib.import_module(module_names[0]) + target_name = module_names.pop(-1) try: - module = importlib.import_module(".".join(module_name)) + module = importlib.import_module(".".join(module_names)) except ModuleNotFoundError as exc: - raise StrToClassError(f"Cannot import such module: `{'.'.join(module_name)}`") from exc + raise StrToClassError(f"Cannot import such module: `{'.'.join(module_names)}`") from exc try: module = getattr(module, target_name) except AttributeError as exc: raise StrToClassError( - f"Cannot find such module `{target_name}` form `{'.'.join(module_name)}`" + f"Cannot find such module `{target_name}` form `{'.'.join(module_names)}`" ) from exc return module @@ -76,57 +90,57 @@ class ModuleNode(dict): _no_call: bool = field(default=False, repr=False) priority: int = field(default=0, repr=False) - def _get_params(self, **kwargs): - params = {} + def _get_params(self, **params: NodeParams) -> NodeParams: + return_params = {} for k, v in self.items(): - if isinstance(v, ModuleWrapper): + if isinstance(v, (ModuleWrapper, ModuleNode)): v = v() - params[k] = v - params.update(kwargs) - return params + return_params[k] = v + return_params.update(params) + return return_params @property - def name(self): + def name(self) -> str: return self.cls.__name__ - def add(self, **kwargs): - self.update(kwargs) + def add(self, **params: NodeParams) -> Self: + self.update(params) return self - def _instantiate(self, params): + def _instantiate(self, params: NodeParams) -> NodeInstance: try: module = self.cls(**params) except Exception as exc: raise ModuleBuildError( - f"Build Error with module {self.cls} and arguments {params}" + f"Instantiate Error with module {self.cls} and arguments {params}" ) from exc if LOG_BUILD_MESSAGE: logger.success( - f"Successfully build module: {self.cls.__name__}, with arguments {params}" + f"Successfully instantiate module: {self.cls.__name__}, with arguments {params}" ) return module - def __call__(self, **kwargs): + def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance: # type: ignore if self._no_call: return self - params = self._get_params(**kwargs) + params = self._get_params(**params) module = self._instantiate(params) return module - def __lshift__(self, kwargs): - if not isinstance(kwargs, dict): - raise TypeError(f"Expect type is dict, but got {type(kwargs)}") - self.update(kwargs) + def __lshift__(self, params: NodeParams) -> Self: + if not isinstance(params, dict): + raise TypeError(f"Expect type is dict, but got {type(params)}") + self.update(params) return self - def __rshift__(self, __other): + def __rshift__(self, __other: ModuleNode) -> Self: if not isinstance(__other, ModuleNode): raise TypeError(f"Expect type is `ModuleNode`, but got {type(__other)}") __other.update(self) return self @classmethod - def from_str(cls, str_target, params=None): + def from_str(cls, str_target: str, params: NodeParams | None = None) -> ModuleNode: node = cls(_str_to_target(str_target)) if params: node.update(params) @@ -135,7 +149,7 @@ def from_str(cls, str_target, params=None): return node @classmethod - def from_base_name(cls, base, name, params=None): + def from_base_name(cls, base: str, name: str, params: NodeParams | None = None) -> ModuleNode: try: cls_name = Registry.get_registry(base)[name] except KeyError as exc: @@ -145,7 +159,7 @@ def from_base_name(cls, base, name, params=None): return cls.from_str(cls_name, params) @classmethod - def from_node(cls, _other: "ModuleNode") -> "ModuleNode": + def from_node(cls, _other: ModuleNode) -> ModuleNode: if _other.__class__.__name__ == cls.__name__: return _other node = cls(_other.cls) << _other @@ -159,10 +173,10 @@ class InterNode(ModuleNode): class ConfigHookNode(ModuleNode): - def __call__(self, **kwargs): + def __call__(self, **params: NodeParams) -> NodeInstance | ConfigHookSkipFlag | Hook: if issubclass(self.cls, ConfigArgumentHook): return None - params = self._get_params(**kwargs) + params = self._get_params(**params) return self._instantiate(params) @@ -170,24 +184,24 @@ class ReusedNode(InterNode): priority: int = 3 @CacheOut() - def __call__(self, **kwargs): - return super().__call__(**kwargs) + def __call__(self, **params: NodeParams) -> NodeInstance | NoCallSkipFlag: # type: ignore + return super().__call__(**params) class ClassNode(InterNode): priority: int = 1 - def __call__(self): + def __call__(self) -> NodeClassType | FunctionType: # type: ignore return self.cls class ChainedInvocationWrapper(ConfigArgumentHook): - def __init__(self, node: ModuleNode, attrs: Sequence[str]) -> None: + def __init__(self, node: ModuleNode, attrs: list[str]) -> None: super().__init__(node) self.attrs = attrs - def hook(self, **kwargs): - target = self.node(**kwargs) + def hook(self, **params: NodeParams) -> Any: + target = self.node(**params) if isinstance(target, ModuleNode): raise ModuleBuildError(f"Do not support `{DO_NOT_CALL_KEY}`") if self.attrs: @@ -201,7 +215,7 @@ def hook(self, **kwargs): @dataclass class VariableReference: - def __init__(self, value: str): + def __init__(self, value: str) -> None: env_names = re.findall(r"\$\{([^}]+)\}", value) self.has_env = len(env_names) > 0 for env in env_names: @@ -210,16 +224,22 @@ def __init__(self, value: str): value = re.sub(r"\$\{" + re.escape(env) + r"\}", env_value, value) self.value = value - def __call__(self): + def __call__(self) -> str: return self.value +ConfigNode = Union[ModuleNode, ConfigArgumentHook] +NodeType = Type[ModuleNode] + + class ModuleWrapper(dict): def __init__( self, - modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None, - is_dict=False, - ): + modules: ( + dict[str, ConfigNode] | list[ConfigNode] | ConfigNode | VariableReference | None + ) = None, + is_dict: bool = False, + ) -> None: super().__init__() if modules is None: return @@ -241,14 +261,14 @@ def __init__( f"Expect modules to be `list`, `dict` or `ModuleNode`, but got {type(modules)}" ) - def _get_name(self, m): + def _get_name(self, m) -> Any: if hasattr(m, "name"): return m.name return m.__class__.__name__ - def __lshift__(self, kwargs): + def __lshift__(self, params: NodeParams) -> None: if len(self) == 1: - self[list(self.keys())[0]] << kwargs + self[list(self.keys())[0]] << params else: raise RuntimeError("Wrapped more than 1 ModuleNode, index first") @@ -274,10 +294,9 @@ def __repr__(self) -> str: return f"ModuleWrapper{list(self.values())}" -_dispatch_module_node = { +_dispatch_module_node: dict[SpecialFlag, NodeType] = { OTHER_FLAG: ModuleNode, REUSE_FLAG: ReusedNode, INTER_FLAG: InterNode, CLASS_FLAG: ClassNode, - REFER_FLAG: VariableReference, } diff --git a/excore/config/parse.py b/excore/config/parse.py index bb1951f..d09654e 100644 --- a/excore/config/parse.py +++ b/excore/config/parse.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import Any, Generator +from typing import TYPE_CHECKING from .._exceptions import CoreConfigParseError, ImplicitModuleParseError from .._misc import _create_table @@ -18,6 +18,13 @@ _is_special, ) +if TYPE_CHECKING: + from typing import Generator, Sequence + + from typing_extensions import Self + + from .model import ConfigNode, NodeType, SpecialFlag + def _check_implicit_module(module: ModuleNode) -> None: import inspect @@ -40,12 +47,12 @@ def _check_implicit_module(module: ModuleNode) -> None: ) -def _dict2node(module_type: str, base: str, _dict: dict): - ModuleType: type[ModuleNode] = _dispatch_module_node[module_type] +def _dict2node(module_type: SpecialFlag, base: str, _dict: dict): + ModuleType = _dispatch_module_node[module_type] return {name: ModuleType.from_base_name(base, name, v) for name, v in _dict.items()} -def _parse_param_name(name): +def _parse_param_name(name) -> tuple[str, list[str], list[str]]: names = name.split("@") attrs = names.pop(0).split(".") attrs = [i for i in attrs if i] @@ -53,25 +60,18 @@ def _parse_param_name(name): return attrs.pop(0), attrs, hooks -def _flatten_list(lis): +def _flatten_list( + lis: Sequence[ConfigNode | VariableReference | list[ConfigNode]], +) -> Sequence[ConfigNode | VariableReference]: new_lis = [] for i in lis: if isinstance(i, list): new_lis.extend(i) else: - new_lis.append(i) + new_lis.append(i) # type: ignore return new_lis -def _flatten_dict(dic): - new_dic = {} - for k, v in dic.items(): - if isinstance(v, list): - v = _flatten_list(v) - new_dic[k] = v - return new_dic - - class ConfigDict(dict): primary_fields: list primary_to_registry: dict[str, str] @@ -80,7 +80,7 @@ class ConfigDict(dict): scratchpads_fields: set[str] = set() current_field: str | None = None - def __new__(cls): + def __new__(cls) -> Self: if not hasattr(cls, "primary_fields"): raise RuntimeError("Call `set_primary_fields` before `load`") @@ -91,11 +91,13 @@ class ConfigDictImpl(ConfigDict): primary_to_registry = ConfigDict.primary_to_registry scratchpads_fields = ConfigDict.scratchpads_fields - inst = super().__new__(ConfigDictImpl) - return inst + inst = super().__new__(ConfigDictImpl) # type: ignore + return inst # type: ignore @classmethod - def set_primary_fields(cls, primary_fields, primary_to_registry): + def set_primary_fields( + cls, primary_fields: Sequence[str], primary_to_registry: dict[str, str] + ) -> None: """ Sets the `primary_fields` attribute to the specified list of module names, and `registered_fields` attributes based on the current state @@ -103,7 +105,7 @@ def set_primary_fields(cls, primary_fields, primary_to_registry): Note that `set_primary_fields` must be called before `config.load`. """ - cls.primary_fields = primary_fields + cls.primary_fields = list(primary_fields) cls.primary_to_registry = primary_to_registry def parse(self) -> None: @@ -177,12 +179,12 @@ def _clean(self) -> None: if name in self.registered_fields or isinstance(self[name], ModuleNode): self.pop(name) - def primary_keys(self) -> Generator[Any, Any, None]: + def primary_keys(self) -> Generator[str, None, None]: for name in self.primary_fields: if name in self: yield name - def non_primary_keys(self) -> Generator[Any, Any, None]: + def non_primary_keys(self) -> Generator[str, None, None]: keys = list(self.keys()) for k in keys: if k not in self.primary_fields: @@ -207,7 +209,7 @@ def _parse_primary_modules(self) -> None: self[name] = _dict2node(OTHER_FLAG, base, self.pop(name)) logger.ex(f"Set ModuleNode to self[{name}].") - def _parse_isolated_registered_module(self, name) -> None: + def _parse_isolated_registered_module(self, name: str) -> None: v = _dict2node(OTHER_FLAG, name, self.pop(name)) for i in v.values(): _, _ = Registry.find(i.name) @@ -215,7 +217,7 @@ def _parse_isolated_registered_module(self, name) -> None: raise CoreConfigParseError(f"Unregistered module `{i.name}`") self[name] = v - def _parse_implicit_module(self, name, module_type): + def _parse_implicit_module(self, name: str, module_type: NodeType) -> ModuleNode: _, base = Registry.find(name) if not base: raise CoreConfigParseError(f"Unregistered module `{name}`") @@ -227,7 +229,7 @@ def _parse_implicit_module(self, name, module_type): self[name] = node return node - def _parse_isolated_module(self, name) -> bool: + def _parse_isolated_module(self, name: str) -> bool: logger.ex(f"Not a registed field. Parse isolated module {name}.") _, base = Registry.find(name) if base: @@ -236,7 +238,7 @@ def _parse_isolated_module(self, name) -> bool: return True return False - def _parse_scratchpads(self, name) -> None: + def _parse_scratchpads(self, name: str) -> None: logger.ex(f"Not a registed node. Regrad as scratchpads {name}.") has_module = False modules = self[name] @@ -264,7 +266,7 @@ def _parse_isolated_obj(self) -> None: elif not self._parse_isolated_module(name): self._parse_scratchpads(name) - def _contain_module(self, name) -> bool: + def _contain_module(self, name: str) -> bool: is_contain = False for k in self.all_fields: if k not in self: @@ -282,7 +284,9 @@ def _contain_module(self, name) -> bool: self.current_field = k return is_contain - def _get_name_and_field(self, name, ori_name): + def _get_name_and_field( + self, name: str, ori_name: str + ) -> tuple[str, str | None] | tuple[list[str], str]: if not name.startswith("$"): return name, None names = name[1:].split("::") @@ -295,6 +299,10 @@ def _get_name_and_field(self, name, ori_name): raise CoreConfigParseError(f"Cannot find field `{base}` with `{ori_name}`") if len(spec_name) > 0: if spec_name[0] == "*": + logger.warning( + f"`The results of {names} " + "depend on their definition in config files when using `*`." + ) return [k.name for k in modules.values()], base if spec_name[0] not in modules: raise CoreConfigParseError( @@ -310,20 +318,22 @@ def _get_name_and_field(self, name, ori_name): ) return list(modules.keys())[0], base - def _apply_hooks(self, node, hooks, attrs): + def _apply_hooks(self, node: ModuleNode, hooks: list[str], attrs: list[str]) -> ConfigNode: if attrs: - node = ChainedInvocationWrapper(node, attrs) + node = ChainedInvocationWrapper(node, attrs) # type: ignore if not hooks: return node for hook in hooks: if hook not in self: - raise CoreConfigParseError(f"Unregistered hook {hook}") + raise CoreConfigParseError(f"Unregistered hook `{hook}`") node = self[hook](node=node) return node - def _convert_node(self, name, source, target_type) -> tuple[ModuleNode, type[ModuleNode]]: + def _convert_node( + self, name: str, source: ConfigDict, target_type: NodeType + ) -> tuple[ModuleNode, NodeType]: ori_type = source[name].__class__ - logger.ex(f"Original_type is {ori_type}, target_type is {target_type}.") + logger.ex(f"Original_type is `{ori_type}`, target_type is `{target_type}`.") node = source[name] if target_type.priority != ori_type.priority or target_type is ClassNode: node = target_type.from_node(source[name]) @@ -332,7 +342,15 @@ def _convert_node(self, name, source, target_type) -> tuple[ModuleNode, type[Mod source[name] = node return node, ori_type - def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks): + def _parse_single_param( + self, + name: str, + ori_name: str, + field: str | None, + target_type: NodeType, + attrs: list[str], + hooks: list[str], + ) -> ConfigNode: if name in self.all_fields: raise CoreConfigParseError( f"Conflict name: `{name}`, the class name cannot be same with field name" @@ -358,26 +376,28 @@ def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks): f"target_type is `{target_type}`, but got original_type `{ori_type}`. " f"Please considering using `scratchpads` to avoid conflicts." ) - node = self._apply_hooks(node, hooks, attrs) + node = self._apply_hooks(node, hooks, attrs) # type: ignore return node - def _parse_param(self, ori_name, module_type): + def _parse_param( + self, ori_name: str, module_type: SpecialFlag + ) -> ConfigNode | list[ConfigNode] | VariableReference: logger.ex(f"Parse with {ori_name}, {module_type}") if module_type == REFER_FLAG: return VariableReference(ori_name) target_type = _dispatch_module_node[module_type] name, attrs, hooks = _parse_param_name(ori_name) - name, field = self._get_name_and_field(name, ori_name) - logger.ex(f"Get name:{name}, field:{field}, attrs:{attrs}, hooks:{hooks}.") - if isinstance(name, list): + names, field = self._get_name_and_field(name, ori_name) + logger.ex(f"Get name:{names}, field:{field}, attrs:{attrs}, hooks:{hooks}.") + if isinstance(names, list): logger.ex(f"Detect output type list {ori_name}.") return [ self._parse_single_param(n, ori_name, field, target_type, attrs, hooks) - for n in name + for n in names ] - return self._parse_single_param(name, ori_name, field, target_type, attrs, hooks) + return self._parse_single_param(names, ori_name, field, target_type, attrs, hooks) - def _parse_module(self, node: ModuleNode): + def _parse_module(self, node: ModuleNode) -> None: logger.ex(f"Parse ModuleNode {node}.") for param_name in list(node.keys()): true_name, module_type = _is_special(param_name) @@ -396,7 +416,6 @@ def _parse_module(self, node: ModuleNode): elif isinstance(value, dict): logger.ex(f"{param_name}: Dict parameter {value}.") value = {k: self._parse_param(v, module_type) for k, v in value.items()} - value = _flatten_dict(value) is_dict = True else: raise CoreConfigParseError(f"Wrong type: {param_name, value}") @@ -410,9 +429,10 @@ def _parse_module(self, node: ModuleNode): else: node[true_name] = ref_name else: + # FIXME: Parsing inner VariableReference node[true_name] = ModuleWrapper(value, is_dict) - def _parse_inter_modules(self): + def _parse_inter_modules(self) -> None: for name in list(self.keys()): logger.ex(f"Parse inter module {name}") module = self[name] @@ -427,33 +447,32 @@ def _parse_inter_modules(self): elif isinstance(module, ModuleNode): self._parse_module(module) - def __str__(self): - _dict = {} + def __str__(self) -> str: + _dict: dict = {} for k, v in self.items(): self._flatten(_dict, k, v) return _create_table( None, [(k, v) for k, v in _dict.items()], - False, ) - def _flatten(self, _dict, k, v): + def _flatten(self, _dict: dict, k: str, v: dict) -> None: if isinstance(v, dict) and not isinstance(v, ModuleNode): for _k, _v in v.items(): _dict[".".join([k, _k])] = _v else: _dict[k] = v - def dump(self, path: str): + def dump(self, path: str) -> None: import toml with open(path, "w", encoding="UTF-8") as f: toml.dump(self, f) -def set_primary_fields(cfg): - primary_fields = cfg["primary_fields"] - primary_to_registry = cfg["primary_to_registry"] +def set_primary_fields(cfg) -> None: + primary_fields = cfg.primary_fields + primary_to_registry = cfg.primary_to_registry if hasattr(ConfigDict, "primary_fields"): logger.ex("`primary_fields` will be set to {}", primary_fields) if primary_fields: diff --git a/excore/engine/hook.py b/excore/engine/hook.py index b147b1c..241ef7c 100644 --- a/excore/engine/hook.py +++ b/excore/engine/hook.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from collections import defaultdict from typing import Any, Callable, Protocol, Sequence, final @@ -31,7 +33,7 @@ class Hook(Protocol): class MetaHookManager(type): - stages = None + stages: tuple[str, ...] = () """ A metaclass that is used to validate the `stages` attribute of a `HookManager` subclass. @@ -66,7 +68,7 @@ def __new__(cls, name, bases, attrs): """ inst = type.__new__(cls, name, bases, attrs) stages = inst.stages - if inst.__name__ != "HookManager" and stages is None: + if inst.__name__ != "HookManager" and not stages: raise HookManagerBuildError( f"The hook manager `{inst.__name__}` must have valid stages" ) @@ -75,7 +77,7 @@ def __new__(cls, name, bases, attrs): class HookManager(metaclass=MetaHookManager): - stages = tuple() + stages: tuple[str, ...] = tuple() """ Manages a set of hooks that can be triggered by events that occur during program execution. @@ -109,7 +111,7 @@ class HookManager(metaclass=MetaHookManager): for different applications. """ - def __init__(self, hooks: Sequence[Hook]): + def __init__(self, hooks: Sequence[Hook]) -> None: assert isinstance(hooks, Sequence) __error_msg = "The hook `{}` must have a valid `{}`, got {}" @@ -127,12 +129,12 @@ def __init__(self, hooks: Sequence[Hook]): __error_msg.format(h.__class__.__name__, "__CallInter__", h.__CallInter__) ) self.hooks = defaultdict(list) - self.calls = defaultdict(int) + self.calls: dict[str, int] = defaultdict(int) for h in hooks: self.hooks[h.__HookType__].append(h) @staticmethod - def check_life_span(hook: Hook): + def check_life_span(hook: Hook) -> bool: """ Checks whether a given `Hook` object has exceeded its maximum lifespan. @@ -142,7 +144,7 @@ def check_life_span(hook: Hook): hook.__LifeSpan__ -= 1 return hook.__LifeSpan__ <= 0 - def exist(self, stage): + def exist(self, stage) -> bool: """ Determines whether any hooks are registered for a given event stage. @@ -154,19 +156,19 @@ def exist(self, stage): """ return self.hooks[stage] != [] - def pre_call(self): + def pre_call(self) -> Any: """ Called before any hooks are executed during an event stage. """ return - def after_call(self): + def after_call(self) -> Any: """ Called after all hooks have been executed during an event stage. """ return - def __call__(self, stage, *inps): + def __call__(self, stage, *inps) -> None: """ Executes all hooks registered for a given event stage. @@ -174,7 +176,7 @@ def __call__(self, stage, *inps): stage (str): The name of the event stage to trigger. *inps: Input arguments to pass to the hook functions. """ - dead_hook_idx = [] + dead_hook_idx: list[int] = [] calls = self.calls[stage] for idx, hook in enumerate(self.hooks[stage]): if calls % hook.__CallInter__ == 0: @@ -185,7 +187,7 @@ def __call__(self, stage, *inps): self.hooks[stage].pop(idx) self.calls[stage] = calls + 1 - def call_hooks(self, stage, *inps): + def call_hooks(self, stage, *inps) -> None: """ Convenience method for calling all hooks at a given event stage. @@ -199,7 +201,7 @@ def call_hooks(self, stage, *inps): class ConfigHookManager(HookManager): - stages = ("pre_build", "every_build", "after_build") + stages: tuple[str, ...] = ("pre_build", "every_build", "after_build") """A subclass of HookManager that allows hooks to be registered and executed at specific points in the build process. @@ -215,17 +217,19 @@ def __init__( self, node: Callable, enabled: bool = True, - ): + ) -> None: self.node = node self.enabled = enabled + if not hasattr(node, "name"): + raise ValueError("The `node` must have name attribute.") self.name = node.name self._is_initialized = True - def hook(self, **kwargs): + def hook(self, **kwargs: Any) -> Any: raise NotImplementedError(f"`{self.__class__.__name__}` do not implement `hook` method.") @final - def __call__(self, **kwargs): + def __call__(self, **kwargs: Any) -> Any: if not getattr(self, "_is_initialized", False): raise CoreConfigSupportError( f"Call super().__init__() in class `{self.__class__.__name__}`" diff --git a/excore/engine/logging.py b/excore/engine/logging.py index aa415d3..980294f 100644 --- a/excore/engine/logging.py +++ b/excore/engine/logging.py @@ -1,11 +1,29 @@ +from __future__ import annotations + import os import sys +from typing import TYPE_CHECKING from loguru import logger as _logger +if TYPE_CHECKING: + from typing import Any, Callable, TextIO + + from loguru import FilterDict, FilterFunction, FormatFunction, Message, Record, Writable + from loguru._handler import Handler + from loguru._logger import Logger + + class PatchedLogger(Logger): + def ex(self, __message: str, *args: Any, **kwargs: Any) -> None: + pass + + def imp(self, __message: str, *args: Any, **kwargs: Any) -> None: + pass + + __all__ = ["logger", "add_logger", "remove_logger", "debug_only", "log_to_file_only"] -LOGGERS = {} +LOGGERS: dict[str, int] = {} FORMAT = ( "{time:YYYY-MM-DD HH:mm:ss} | " @@ -22,21 +40,21 @@ def _trace_patcher(log_record): log_record["function"] = "\b" -logger = _logger.patch(_trace_patcher) +logger: PatchedLogger = _logger.patch(_trace_patcher) # type: ignore def add_logger( - name, - sink, + name: str, + sink: TextIO | Writable | Callable[[Message], None] | Handler, *, - level=None, # pylint: disable=unused-argument - format=None, # pylint: disable=unused-argument - filter=None, # pylint: disable=unused-argument - colorize=None, # pylint: disable=unused-argument - serialize=None, # pylint: disable=unused-argument - backtrace=None, # pylint: disable=unused-argument - diagnose=None, # pylint: disable=unused-argument - enqueue=None, # pylint: disable=unused-argument + level: str | int | None = None, # pylint: disable=unused-argument + format: str | FormatFunction | None = None, # pylint: disable=unused-argument + filter: str | FilterFunction | FilterDict | None = None, # pylint: disable=unused-argument + colorize: bool | None = None, # pylint: disable=unused-argument + serialize: bool | None = None, # pylint: disable=unused-argument + backtrace: bool | None = None, # pylint: disable=unused-argument + diagnose: bool | None = None, # pylint: disable=unused-argument + enqueue: bool | None = None, # pylint: disable=unused-argument ) -> None: params = {k: v for k, v in locals().items() if v is not None} params.pop("sink") @@ -54,14 +72,14 @@ def remove_logger(name: str) -> None: logger.warning(f"Cannot find logger with name {name}") -def log_to_file_only(file_name: str, *args, **kwargs) -> None: +def log_to_file_only(file_name: str, *args: Any, **kwargs: Any) -> None: logger.remove(None) logger.add(file_name, *args, **kwargs) logger.success(f"Log to file {file_name} only") -def debug_only(*args, **kwargs) -> None: - def _debug_only(record): +def debug_only(*args: Any, **kwargs: Any) -> None: + def _debug_only(record: Record) -> bool: return record["level"].name == "DEBUG" filter = kwargs.pop("filter", None) @@ -72,27 +90,27 @@ def _debug_only(record): logger.debug("DEBUG ONLY!!!") -def _call_importance(__message: str, *args, **kwargs): +def _call_importance(__message: str, *args: Any, **kwargs: Any) -> None: logger.log("IMPORT", __message, *args, **kwargs) -def _excore_debug(__message: str, *args, **kwargs): +def _excore_debug(__message: str, *args: Any, **kwargs: Any) -> None: logger._log("EXCORE", False, logger._options, __message, args, kwargs) -def _enable_excore_debug(): +def _enable_excore_debug() -> None: if os.getenv("EXCORE_DEBUG"): logger.remove() logger.add(sys.stdout, format=FORMAT, level="EXCORE") logger.ex("Enabled excore debug") -def init_logger(): +def init_logger() -> None: logger.remove() logger.add(sys.stderr, format=FORMAT) logger.level("SUCCESS", color="") logger.level("WARNING", color="") logger.level("IMPORT", no=45, color="") logger.level("EXCORE", no=9, color="") - logger.imp = _call_importance - logger.ex = _excore_debug + logger.imp = _call_importance # type: ignore + logger.ex = _excore_debug # type: ignore diff --git a/excore/engine/registry.py b/excore/engine/registry.py index 00ccf33..a6e70c2 100644 --- a/excore/engine/registry.py +++ b/excore/engine/registry.py @@ -1,14 +1,18 @@ +from __future__ import annotations + import fnmatch import functools import inspect import os import re import sys -from types import ModuleType -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from types import FunctionType, ModuleType +from typing import Any, Callable, Literal, Sequence, Type, overload + +from filelock import FileLock -from .._constants import _cache_dir, _registry_cache_file, _workspace_config_file -from .._misc import FileLock, _create_table +from .._constants import _workspace_config_file, workspace +from .._misc import _create_table from .logging import logger _name_re = re.compile(r"^[A-Za-z0-9_]+$") @@ -16,19 +20,21 @@ __all__ = ["Registry"] +_ClassType = Type[Any] + # TODO: Maybe some methods need to be cleared. -def _is_pure_ascii(name: str): +def _is_pure_ascii(name: str) -> None: if not _name_re.match(name): raise ValueError( - f"""Unexpected name, only support ASCII letters, ASCII digits, - underscores, and dashes, but got {name}""" + "Unexpected name, only support ASCII letters, ASCII digits, " + f"underscores, and dashes, but got {name}." ) -def _is_function_or_class(module): +def _is_function_or_class(module: Any) -> bool: return inspect.isfunction(module) or inspect.isclass(module) @@ -36,7 +42,7 @@ def _default_filter_func(values: Sequence[Any]) -> bool: return all(v for v in values) -def _default_match_func(m, base_module): +def _default_match_func(m: str, base_module: ModuleType) -> bool: if not m.startswith("__"): m = getattr(base_module, m) if inspect.isfunction(m) or inspect.isclass(m): @@ -44,12 +50,12 @@ def _default_match_func(m, base_module): return False -def _get_module_name(m): +def _get_module_name(m: ModuleType | _ClassType | FunctionType) -> str: return getattr(m, "__qualname__", m.__name__) class RegistryMeta(type): - _registry_pool: Dict[str, "Registry"] = dict() + _registry_pool: dict[str, Registry] = {} """Metaclass that governs the creation of instances of its subclasses, which are `Registry` objects. @@ -65,7 +71,7 @@ class RegistryMeta(type): message is logged indicating that the extra arguments will be ignored. """ - def __call__(cls, name, **kwargs) -> "Registry": + def __call__(cls, name: str, **kwargs: Any) -> Registry: r"""Assert only call `__init__` once""" _is_pure_ascii(name) extra_field = kwargs.get("extra_field", None) @@ -85,70 +91,68 @@ def __call__(cls, name, **kwargs) -> "Registry": # Maybe someday we can get rid of Registry? -class Registry(dict, metaclass=RegistryMeta): - _globals: Optional["Registry"] = None - _registry_dir = "registry" +class Registry(dict, metaclass=RegistryMeta): # type: ignore + _globals: Registry | None = None # just a workaround for twice registry - _prevent_register = False + _prevent_register: bool = False + + extra_info: dict[str, str] + """A registry that stores functions and classes by name. Attributes: name (str): The name of the registry. - extra_field (Optional[Union[str, Sequence[str]]]): A field or fields that can be + extra_field (str|Sequence[str]|None): A field or fields that can be used to store additional information about each function or class in the registry. - extra_info (Dict[str, List[Any]]): A dictionary that maps each registered name + extra_info (dict[str, list[Any]]): A dictionary that maps each registered name to a list of extra values associated with that name (if any). - _globals (Optional[Registry]): A static variable that stores a global registry + _globals (Registry|None): A static variable that stores a global registry containing all functions and classes registered using Registry. """ - def __init__( - self, /, name: str, *, extra_field: Optional[Union[str, Sequence[str]]] = None - ) -> None: + def __init__(self, /, name: str, *, extra_field: str | Sequence[str] | None = None) -> None: super().__init__() self.name = name if extra_field: self.extra_field = [extra_field] if isinstance(extra_field, str) else extra_field - self.extra_info = dict() + self.extra_info = {} @classmethod - def dump(cls): - file_path = os.path.join(_cache_dir, cls._registry_dir, _registry_cache_file) - os.makedirs(os.path.join(_cache_dir, cls._registry_dir), exist_ok=True) + def dump(cls) -> None: + file_path = workspace.registry_cache_file import pickle # pylint: disable=import-outside-toplevel - with FileLock(file_path): # noqa: SIM117 - with open(file_path, "wb") as f: - pickle.dump(cls._registry_pool, f) + with FileLock(file_path + ".lock", timeout=5), open(file_path, "wb") as f: + pickle.dump(cls._registry_pool, f) + logger.success(f"Dump registry cache to {workspace.registry_cache_file}!") @classmethod - def load(cls): + def load(cls) -> None: if not os.path.exists(_workspace_config_file): logger.warning("Please run `excore init` in your command line first!") - sys.exit(0) - file_path = os.path.join(_cache_dir, cls._registry_dir, _registry_cache_file) + sys.exit(1) + file_path = workspace.registry_cache_file if not os.path.exists(file_path): # shall we need to be silent? Or raise error? logger.critical( "Registry cache file do not exist!" " Please run `excore auto-register in your command line first`" ) - sys.exit(0) + sys.exit(1) import pickle # pylint: disable=import-outside-toplevel - with FileLock(file_path): # noqa: SIM117 - with open(file_path, "rb") as f: - data = pickle.load(f) + with FileLock(file_path + ".lock"), open(file_path, "rb") as f: + data = pickle.load(f) cls._registry_pool.update(data) @classmethod - def lock_register(cls): + def lock_register(cls) -> None: cls._prevent_register = True @classmethod - def unlock_register(cls): + def unlock_register(cls) -> None: cls._prevent_register = False @classmethod @@ -161,7 +165,7 @@ def get_registry(cls, name: str, default: Any = None) -> Any: @classmethod @functools.lru_cache(32) - def find(cls, name: str) -> Any: + def find(cls, name: str) -> tuple[Any, str] | tuple[None, None]: """ Searches all registries for an element with the given name. If found, returns a tuple containing the element and the name of the registry where it @@ -173,7 +177,7 @@ def find(cls, name: str) -> Any: return (None, None) @classmethod - def make_global(cls): + def make_global(cls) -> Registry: """ Creates a global `Registry` instance that contains all elements from all other registries. If the global registry already exists, returns it instead @@ -187,7 +191,7 @@ def make_global(cls): cls._globals = reg return reg - def __setitem__(self, k, v) -> None: + def __setitem__(self, k: str, v: Any) -> None: _is_pure_ascii(k) super().__setitem__(k, v) @@ -195,18 +199,47 @@ def __repr__(self) -> str: return _create_table( ["NAEM", "DIR"], [(k, v) for k, v in self.items()], - False, ) __str__ = __repr__ + @overload def register_module( self, - module: Union[Callable, ModuleType], - force: bool = False, - _is_str: bool = False, + module: Callable[..., Any], + force: bool = ..., + _is_str: bool = ..., + **extra_info: Any, + ) -> Callable[..., Any]: + pass + + @overload + def register_module( + self, + module: ModuleType, + force: bool = ..., + _is_str: bool = ..., + **extra_info: Any, + ) -> ModuleType: + pass + + @overload + def register_module( + self, + module: str, + force: bool = ..., + _is_str: Literal[True] = ..., + **extra_info: Any, + ) -> str: + pass + + def register_module( + self, + module, + force=False, + _is_str=False, **extra_info, - ) -> Union[Callable, ModuleType]: + ): if Registry._prevent_register: logger.ex("Registry has been locked!!!") return module @@ -216,6 +249,7 @@ def register_module( name = _get_module_name(module) else: name = module.split(".")[-1] + if not force and name in self: raise ValueError(f"The name {name} exists") @@ -238,6 +272,8 @@ def register_module( ) else: target = module + + logger.ex(f"Register {name} with {target}.") self[name] = target # update to globals @@ -246,7 +282,7 @@ def register_module( return module - def register(self, force: bool = False, **extra_info) -> Callable: + def register(self, force: bool = False, **extra_info: Any) -> Callable[..., Any]: """ Decorator that registers a function or class with the current `Registry`. Any keyword arguments provided are added to the `extra_info` list for the @@ -257,8 +293,8 @@ def register(self, force: bool = False, **extra_info) -> Callable: def register_all( self, - modules: Sequence[Callable], - extra_info: Optional[Sequence[Dict[str, Any]]] = None, + modules: Sequence[Callable[..., Any]], + extra_info: Sequence[dict[str, Any]] | None = None, force: bool = False, _is_str: bool = False, ) -> None: @@ -274,14 +310,14 @@ def register_all( def merge( self, - others: Union["Registry", List["Registry"]], + others: Registry | Sequence[Registry], force: bool = False, ) -> None: """ Merge the contents of one or more other registries into the current one. If `force` is True, overwrites any existing elements with the same names. """ - if not isinstance(others, list): + if not isinstance(others, (list, tuple, Sequence)): others = [others] for other in others: if not isinstance(other, Registry): @@ -291,9 +327,9 @@ def merge( def filter( self, - filter_field: Union[Sequence[str], str], - filter_func: Callable = _default_filter_func, - ) -> List[str]: + filter_field: Sequence[str] | str, + filter_func: Callable[[Sequence[Any]], bool] = _default_filter_func, + ) -> list[str]: """ Returns a sorted list of all names in the registry for which the values of the given extra field(s) pass a filtering function. @@ -310,7 +346,12 @@ def filter( out = list(sorted(out)) return out - def match(self, base_module, match_func=_default_match_func, force=False): + def match( + self, + base_module: ModuleType, + match_func: Callable[[str, ModuleType], bool] = _default_match_func, + force: bool = False, + ) -> None: """ Registers all functions or classes from the given module that pass a matching function. If `match_func` is not provided, uses `_default_match_func`. @@ -328,11 +369,11 @@ def match(self, base_module, match_func=_default_match_func, force=False): def module_table( self, - filter: Optional[Union[Sequence[str], str]] = None, - select_info: Optional[Union[Sequence[str], str]] = None, - module_list: Optional[Sequence[str]] = None, - **table_kwargs, - ) -> Any: + filter: Sequence[str] | str | None = None, + select_info: Sequence[str] | str | None = None, + module_list: Sequence[str] | None = None, + **table_kwargs: Any, + ) -> str: """ Returns a table containing information about each registered function or class, filtered by name and/or extra info fields. `select_info` specifies @@ -347,16 +388,16 @@ def module_table( else: select_info = [] - all_modules = module_list if module_list else self.keys() + all_modules = module_list if module_list else list(self.keys()) if filter: - modules = set() + set_modules: set[str] = set() filters = [filter] if isinstance(filter, str) else filter for f in filters: include_models = fnmatch.filter(all_modules, f) if len(include_models): - modules = modules.union(include_models) + modules = list(set_modules.union(include_models)) else: - modules = all_modules + modules = all_modules # type: ignore modules = list(sorted(modules)) @@ -371,14 +412,13 @@ def module_table( table = _create_table( table_headers, [(i, *[self.extra_info[i][idx] for idx in select_idx]) for i in modules], - False, **table_kwargs, ) table = "\n" + table return table @classmethod - def registry_table(cls, **table_kwargs) -> Any: + def registry_table(cls, **table_kwargs) -> str: """ Returns a table containing the names of all available registries. """ @@ -386,15 +426,14 @@ def registry_table(cls, **table_kwargs) -> Any: table = _create_table( table_headers, list(sorted([[i] for i in cls._registry_pool])), - False, **table_kwargs, ) table = "\n" + table return table -def load_registries(): - if not os.path.exists(os.path.join(_cache_dir, Registry._registry_dir, _registry_cache_file)): +def load_registries() -> None: + if not os.path.exists(workspace.registry_cache_file): logger.warning("Please run `excore auto-register` in your command line first!") return Registry.load() @@ -406,4 +445,4 @@ def load_registries(): "No module has been registered, \ you may need to call `excore.registry.auto_register` first" ) - sys.exit(0) + sys.exit(1) diff --git a/excore/plugins/hub.py b/excore/plugins/hub.py index 20fe42b..2c7c9f3 100644 --- a/excore/plugins/hub.py +++ b/excore/plugins/hub.py @@ -21,7 +21,7 @@ import requests from tqdm import tqdm -from .._constants import __version__, _cache_dir +from .._constants import __version__, workspace from .._exceptions import ( GitCheckoutError, GitPullError, @@ -153,7 +153,7 @@ def fetch( kwargs = {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {} if commit is None: # shallow clone repo by branch/tag - p = subprocess.Popen( + p = subprocess.Popen( # type: ignore [ "git", "clone", @@ -169,14 +169,16 @@ def fetch( else: # clone repo and checkout to commit_id p = subprocess.Popen( # pylint: disable=consider-using-with - ["git", "clone", git_url, repo_dir], **kwargs + ["git", "clone", git_url, repo_dir], + **kwargs, # type: ignore ) cls._check_clone_pipe(p) with cd(repo_dir): logger.debug("git checkout to {}", commit) p = subprocess.Popen( # pylint: disable=consider-using-with - ["git", "checkout", commit], **kwargs + ["git", "checkout", commit], + **kwargs, # type: ignore ) _, err = p.communicate() if p.returncode: @@ -289,7 +291,7 @@ def _get_repo( raise InvalidProtocol( "Invalid protocol, the value should be one of {}.".format(", ".join(PROTOCOLS.keys())) ) - cache_dir = os.path.expanduser(os.path.join(_cache_dir, "hub")) + cache_dir = os.path.expanduser(os.path.join(workspace.cache_dir, "hub")) with cd(cache_dir): fetcher = PROTOCOLS[protocol] repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit) @@ -303,14 +305,14 @@ def _check_dependencies(module: types.ModuleType) -> None: dependencies = getattr(module, HUBDEPENDENCY) if not dependencies: return - missing_deps = [m for m in dependencies if importlib.util.find_spec(m)] + missing_deps = [m for m in dependencies if importlib.util.find_spec(m)] # type: ignore if len(missing_deps): raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps))) def load_module(name: str, path: str) -> types.ModuleType: - spec = importlib.util.spec_from_file_location(name, path) - module = importlib.util.module_from_spec(spec) + spec = importlib.util.spec_from_file_location(name, path) # type: ignore + module = importlib.util.module_from_spec(spec) # type: ignore spec.loader.exec_module(module) return module @@ -323,7 +325,7 @@ def _init_hub( commit: Optional[str] = None, protocol: str = DEFAULT_PROTOCOL, ): - cache_dir = os.path.expanduser(os.path.join(_cache_dir, "hub")) + cache_dir = os.path.expanduser(os.path.join(workspace.cache_dir, "hub")) os.makedirs(cache_dir, exist_ok=True) absolute_repo_dir = _get_repo( git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol @@ -412,7 +414,7 @@ def pretrained_model_func(pretrained=False, **kwargs): digest = sha256.hexdigest()[:6] filename = digest + "_" + filename - cached_file = os.path.join(_cache_dir, filename) + cached_file = os.path.join(workspace.cache_dir, filename) download_from_url(self.url, cached_file) self.load_func(cached_file, model) return model diff --git a/excore/plugins/path_manager.py b/excore/plugins/path_manager.py index d944a3e..8b9733f 100644 --- a/excore/plugins/path_manager.py +++ b/excore/plugins/path_manager.py @@ -1,8 +1,12 @@ +from __future__ import annotations + import shutil import time from dataclasses import is_dataclass from pathlib import Path -from typing import Callable, Dict, List, Optional, Protocol, Union +from typing import Callable, Protocol + +from typing_extensions import Self from ..engine.logging import logger @@ -10,9 +14,9 @@ class DataclassProtocol(Protocol): - __dataclass_fields__: Dict - __dataclass_params__: Dict - __post_init__: Optional[Callable] + __dataclass_fields__: dict + __dataclass_params__: dict + __post_init__: Callable | None class PathManager: @@ -59,9 +63,9 @@ def __init__( self, /, base_path: str, - sub_folders: Union[List[str], DataclassProtocol], + sub_folders: list[str] | DataclassProtocol, config_name: str, - instance_name: Optional[str] = None, + instance_name: str | None = None, *, remove_if_fail: bool = False, sub_folder_exist_ok: bool = False, @@ -76,9 +80,9 @@ def __init__( self.sub_exist_ok = sub_folder_exist_ok self.config_first = config_name_first self.return_str = return_str - self._info = {} + self._info: dict[str, Path] = {} - def _get_sub_folders(self, sub_folders): + def _get_sub_folders(self, sub_folders) -> None: if not isinstance(sub_folders, list): if not is_dataclass(sub_folders): raise TypeError("Only Support dataclass or list of str") @@ -101,7 +105,7 @@ def mkdir(self) -> None: sub.mkdir(parents=True, exist_ok=self.sub_exist_ok) self._info[str(f)] = sub - def get(self, name: str) -> Union[Path, str]: + def get(self, name: str) -> Path | str | None: """ Retrieve the path for a specific sub-folder by name. """ @@ -131,11 +135,11 @@ def remove_all(self) -> None: logger.info(f"Remove sub_folders {f}") shutil.rmtree(str(f)) - def __enter__(self): + def __enter__(self) -> Self: self.init() return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> bool: if exc_type is not None: self.final() return False diff --git a/pyproject.toml b/pyproject.toml index 04282ff..d524b84 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "excore" -version = "0.1.1beta" +version = "0.1.1beta1" description = "Build your own development toolkit efficiently." authors = ["Asthestarsfalll <1186454801@qq.com>"] license = "MIT" @@ -30,10 +30,10 @@ requests = "2.28.1" toml = "0.10.2" tabulate = "*" tqdm = "*" -pynvml = "*" rich = "*" astor = "*" -typer = { extras = ["all"], version = "^0.9.0" } +typer = "^0.13.1" +filelock = "^3.15.4" [tool.poetry.group.dev.dependencies] black = "23.1.0" @@ -43,6 +43,10 @@ pytest = "^7.4.4" pytest-cov = "^4.1.0" torch = { version = "1.13.1", source = "torch" } torchvision = { version = "0.14.1", source = "torch" } +mypy = "^1.13.0" +types-toml = "^0.10.8.20240310" +types-tqdm = "^4.67.0.20241119" +types-tabulate = "^0.9.0.20240106" [[tool.poetry.source]] name = "torch" diff --git a/tests/init.py b/tests/init.py index 9c79a1b..1e09e31 100644 --- a/tests/init.py +++ b/tests/init.py @@ -1,3 +1,4 @@ +import os import subprocess import toml @@ -18,6 +19,8 @@ def excute(command: str, inputs=None): command.split(" "), capture_output=True, ) + print(result.stdout) + print(result.stderr) assert result.returncode == 0, result.stderr @@ -40,6 +43,9 @@ def init(): toml.dump(cfg, f) excute("excore update") excute("excore auto-register") + import excore + + assert os.path.exists(os.path.join(excore.workspace.cache_base_dir, "tests")) if __name__ == "__main__": diff --git a/tests/source_code/__init__.py b/tests/source_code/__init__.py index 0cb1414..2ad986b 100644 --- a/tests/source_code/__init__.py +++ b/tests/source_code/__init__.py @@ -15,6 +15,15 @@ OPTIM = Registry("Optimizer") TRANSFORM = Registry("Transform") MODULE = Registry("module") - MODULE.register_module(time) MODULE.register_module(torch) +MODEL = Registry("Model") +DATA = Registry("Data") +BACKBONE = Registry("Backbone") +HEAD = Registry("Head") +HOOK = Registry("Hook") +LOSS = Registry("Loss") +LRSCHE = Registry("LRSche") +OPTIMIZER = Registry("Optimizer") +TRANSFORM = Registry("Transform") +MODULE = Registry("module") diff --git a/tests/test_a_registry.py b/tests/test_a_registry.py index a4999ea..d2c6f61 100644 --- a/tests/test_a_registry.py +++ b/tests/test_a_registry.py @@ -2,52 +2,46 @@ import pytest -from excore import Registry, load_registries +from excore import Registry, load_registries, _enable_excore_debug load_registries() -Registry.unlock_register() +_enable_excore_debug() +import source_code as S -def test_print(): - import source_code as S +def test_print(): print(S.MODEL) def test_find(): - import source_code as S - tar, base = S.BACKBONE.find("ResNet") assert base == "Backbone" assert tar.split(".")[-1] == "ResNet" def test_register_module(): + Registry.unlock_register() reg = Registry("__test") reg.register_module(time) assert reg["time"] == "time" def test_global(): - import source_code as S - g = Registry.make_global() assert g["time"] == "time" assert g["ResNet"] == "torchvision.models.resnet.ResNet" with pytest.raises(ValueError): S.MODEL.register_module(time) assert id(S.MODEL) == id(Registry.get_registry("Model")) + Registry.lock_register() def test_module_table(): - import source_code as S - print(S.MODEL.module_table()) print(S.MODEL.module_table("*resnet*")) def test_id(): - import source_code as S - assert Registry.get_registry("Head") == S.HEAD diff --git a/tests/test_config.py b/tests/test_config.py index 569da58..86d1875 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -247,6 +247,3 @@ def test_dict_param(self): TestClass, VGG, ] - - -TestConfig().test_scratchpads() diff --git a/tests/test_config_extention.py b/tests/test_config_extention.py index 3715f12..70b3ebd 100644 --- a/tests/test_config_extention.py +++ b/tests/test_config_extention.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from excore import Registry from excore.config._json_schema import parse_registry @@ -15,6 +15,9 @@ class Tmp2: pass +t = Tmp() + + @R.register() class A: def __init__( @@ -25,6 +28,7 @@ def __init__( d: Dict, d1: dict, obj: Tmp, + c: Callable, uni1: Union[int, float], uni2: Union[int, str], uni3: Union[Tmp, Tmp2], @@ -40,6 +44,7 @@ def __init__( default_d={}, # noqa: B006 # pylint: disable=W0102 default_tuple=(0, "", 0.0), default_list=[0, 1], # noqa: B006 + e=t, *args, **kwargs, ): @@ -67,6 +72,8 @@ def test_type_parsing(): _assert(properties, "d", "object") _assert(properties, "d1", "object") _assert(properties, "obj", "string") + _assert(properties, "c", "string") + _assert(properties, "e", "number") _assert(properties, "uni1", "") _assert(properties, "uni2", "") _assert(properties, "uni3", "") diff --git a/tests/test_config_extention_310.py b/tests/test_config_extention_310.py index c61f428..ee8640c 100644 --- a/tests/test_config_extention_310.py +++ b/tests/test_config_extention_310.py @@ -1,4 +1,5 @@ import sys +from typing import Callable import pytest from test_config_extention import _assert @@ -17,6 +18,8 @@ class Tmp2: pass +t = Tmp() + if sys.version_info >= (3, 10, 0): @S.register() @@ -28,6 +31,7 @@ def __init__( f: float, d: dict, obj: Tmp, + c: Callable, uni1: int | float, uni2: int | str, uni3: Tmp | Tmp2, @@ -42,6 +46,9 @@ def __init__( default_d={}, # noqa: B006 # pylint: disable=W0102 default_tuple=(0, "", 0.0), default_list=[0, 1], # noqa: B006 + e=t, + *args, + **kwargs, ): pass @@ -58,6 +65,8 @@ def test_type_parsing(): _assert(properties, "f", "number") _assert(properties, "d", "object") _assert(properties, "obj", "string") + _assert(properties, "c", "string") + _assert(properties, "e", "number") _assert(properties, "uni1", "") _assert(properties, "uni2", "") _assert(properties, "uni3", "") @@ -72,4 +81,6 @@ def test_type_parsing(): _assert(properties, "default_d", "object") _assert(properties, "default_tuple", "array") _assert(properties, "default_list", "array", "number") + _assert(properties, "args", "array") + _assert(properties, "kwargs", "object") S.clear() diff --git a/tests/test_z_cli.py b/tests/test_z_cli.py index 5b8c525..6e21d43 100644 --- a/tests/test_z_cli.py +++ b/tests/test_z_cli.py @@ -8,6 +8,7 @@ def test_init_force(): def test_generate_registries(): + # FIXME(typer): Argument excute("excore generate-registries temp") assert os.path.exists("./source_code/temp.py") from source_code import temp # noqa: F401 @@ -26,6 +27,7 @@ def test_primary(): def test_typehints(): + # FIXME(typer): Argument excute( "excore generate-typehints temp_typing --class-name " "TypedWrapper --info-class-name Info --config ./configs/launch/test_optim.toml" @@ -42,3 +44,5 @@ def test_quote(): excute("excore quote ./configs/lrsche") assert os.path.exists("./configs/lrsche/lrsche_overrode.toml") assert os.path.exists("./configs/lrsche/lrsche_error_overrode.toml") + os.remove("./configs/lrsche/lrsche_overrode.toml") + os.remove("./configs/lrsche/lrsche_error_overrode.toml")