From f50eccd3aa22369022f66574784644f105116194 Mon Sep 17 00:00:00 2001
From: Asthestarsfalll <72954905+Asthestarsfalll@users.noreply.github.com>
Date: Mon, 25 Nov 2024 19:12:43 +0800
Subject: [PATCH] :sparkles::recycle: Enhance type hints and others (#41)
* stage1
* stage2
* Simplify _is_special
* Fix _get_params
* refine error message
* fix something
* fix
---
.github/workflows/codestyle-check.yml | 8 +-
.github/workflows/mypy.yaml | 35 +++++
.github/workflows/tests.yaml | 7 +-
excore/__init__.py | 7 +-
excore/_constants.py | 86 ++++++++-----
excore/_exceptions.py | 4 +
excore/_misc.py | 38 ++----
excore/cli/_cache.py | 36 +++---
excore/cli/_extention.py | 40 +++---
excore/cli/_registry.py | 90 ++++++-------
excore/cli/_workspace.py | 32 ++---
excore/config/_json_schema.py | 92 +++++++------
excore/config/action.py | 12 +-
excore/config/config.py | 16 +--
excore/config/lazy_config.py | 31 +++--
excore/config/model.py | 149 +++++++++++----------
excore/config/parse.py | 123 ++++++++++--------
excore/engine/hook.py | 36 +++---
excore/engine/logging.py | 60 ++++++---
excore/engine/registry.py | 179 ++++++++++++++++----------
excore/plugins/hub.py | 22 ++--
excore/plugins/path_manager.py | 26 ++--
pyproject.toml | 10 +-
tests/init.py | 6 +
tests/source_code/__init__.py | 11 +-
tests/test_a_registry.py | 18 +--
tests/test_config.py | 3 -
tests/test_config_extention.py | 9 +-
tests/test_config_extention_310.py | 11 ++
tests/test_z_cli.py | 4 +
30 files changed, 707 insertions(+), 494 deletions(-)
create mode 100644 .github/workflows/mypy.yaml
diff --git a/.github/workflows/codestyle-check.yml b/.github/workflows/codestyle-check.yml
index 2720b14..7187e75 100644
--- a/.github/workflows/codestyle-check.yml
+++ b/.github/workflows/codestyle-check.yml
@@ -9,7 +9,11 @@ on:
jobs:
code-style-check:
runs-on: ubuntu-latest
- name: CodeStyle Check (Python 3.8)
+ strategy:
+ matrix:
+ python-version: [3.8, 3.9, '3.10']
+
+ name: CodeStyle Check ${{ matrix.python-version }}
steps:
- name: Checkout
uses: actions/checkout@v3
@@ -17,7 +21,7 @@ jobs:
- name: Install python
uses: actions/setup-python@v4
with:
- python-version: 3.8
+ python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml
new file mode 100644
index 0000000..3aa6e62
--- /dev/null
+++ b/.github/workflows/mypy.yaml
@@ -0,0 +1,35 @@
+name: MYPY Check
+
+on:
+ push:
+ branches: [main]
+ pull_request:
+ branches: [main]
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ python-version: [3.8, 3.9, '3.10']
+
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 1
+
+ - name: Set up Python ${{ matrix.python-version }}
+ uses: actions/setup-python@v2
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - name: Install Poetry
+ uses: snok/install-poetry@v1
+ with:
+ version: 1.6.1
+
+ - name: Install Dependencies
+ run: poetry install --with dev
+
+ - name: Mypy Check
+ run: poetry run mypy excore
diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml
index a6b9a79..fc788a5 100644
--- a/.github/workflows/tests.yaml
+++ b/.github/workflows/tests.yaml
@@ -31,12 +31,17 @@ jobs:
- name: Install Dependencies
run: poetry install --with dev
- - name: Test with pytest
+ - name: Initialize ExCore
run: |
cd ./tests
export EXCORE_DEBUG=1
poetry run python init.py
+
+ - name: Test with pytest
+ run: |
+ cd ./tests
poetry run pytest --cov=../excore
+ poetry run python init.py
poetry run pytest test_config.py
poetry run pytest test_config.py
poetry run pytest test_config.py
diff --git a/excore/__init__.py b/excore/__init__.py
index ff27640..07a9b44 100644
--- a/excore/__init__.py
+++ b/excore/__init__.py
@@ -3,7 +3,7 @@
from rich.traceback import install
from . import config, plugins
-from ._constants import __author__, __version__, _load_workspace_config, _workspace_cfg
+from ._constants import __author__, __version__, workspace
from .config.action import DictAction
from .config.config import build_all, load
from .config.parse import set_primary_fields
@@ -43,7 +43,6 @@
install()
init_logger()
-_load_workspace_config()
-set_primary_fields(_workspace_cfg)
+set_primary_fields(workspace)
_enable_excore_debug()
-sys.path.append(_workspace_cfg["base_dir"])
+sys.path.append(workspace.base_dir)
diff --git a/excore/_constants.py b/excore/_constants.py
index d14333b..3e44f02 100644
--- a/excore/_constants.py
+++ b/excore/_constants.py
@@ -1,60 +1,80 @@
+from __future__ import annotations
+
import os
import os.path as osp
+from dataclasses import dataclass, field
+from typing import Any
+
+import toml
from .engine.logging import logger
__author__ = "Asthestarsfalll"
__version__ = "0.1.1beta"
-_cache_base_dir = osp.expanduser("~/.cache/excore/")
_workspace_config_file = "./.excore.toml"
_registry_cache_file = "registry_cache.pkl"
_json_schema_file = "excore_schema.json"
_class_mapping_file = "class_mapping.json"
-def _load_workspace_config():
- if osp.exists(_workspace_config_file):
- _workspace_cfg.update(toml.load(_workspace_config_file))
- logger.ex("load `.excore.toml`")
- else:
- logger.warning("Please use `excore init` in your command line first")
+@dataclass
+class _WorkspaceConfig:
+ name: str = field(default="")
+ src_dir: str = field(default="")
+ base_dir: str = field(default="")
+ cache_base_dir: str = field(default=osp.expanduser("~/.cache/excore/"))
+ cache_dir: str = field(default="")
+ registry_cache_file: str = field(default="")
+ json_schema_file: str = field(default="")
+ class_mapping_file: str = field(default="")
+ registries: list[str] = field(default_factory=list)
+ primary_fields: list[str] = field(default_factory=list)
+ primary_to_registry: dict[str, str] = field(default_factory=dict)
+ json_schema_fields: dict[str, str | list[str]] = field(default_factory=dict)
+ props: dict[Any, Any] = field(default_factory=dict)
+ @property
+ def base_name(self):
+ return osp.split(self.cache_dir)[-1]
-def _update_name(base_name):
- name = base_name
+ def __post_init__(self) -> None:
+ if not osp.exists(_workspace_config_file):
+ self.base_dir = os.getcwd()
+ self.cache_dir = self._get_cache_dir()
+ self.registry_cache_file = osp.join(self.cache_dir, _registry_cache_file)
+ self.json_schema_file = osp.join(self.cache_dir, _json_schema_file)
+ self.class_mapping_file = osp.join(self.cache_dir, _class_mapping_file)
+ logger.warning("Please use `excore init` in your command line first")
+ else:
+ self.update(toml.load(_workspace_config_file))
- suffix = 1
- while osp.exists(osp.join(_cache_base_dir, name)):
- name = f"{_base_name}_{suffix}"
- suffix += 1
+ def _get_cache_dir(self) -> str:
+ base_name = osp.basename(osp.normpath(os.getcwd()))
+ base_name = self._update_name(base_name)
+ return osp.join(self.cache_base_dir, base_name)
- return name
+ def _update_name(self, base_name: str) -> str:
+ name = base_name
+ suffix = 1
+ while osp.exists(osp.join(self.cache_base_dir, name)):
+ name = f"{base_name}_{suffix}"
+ suffix += 1
-if not osp.exists(_workspace_config_file):
- _base_name = osp.basename(osp.normpath(os.getcwd()))
- _base_name = _update_name(_base_name)
-else:
- import toml # pylint: disable=import-outside-toplevel
+ return name
- cfg = toml.load(_workspace_config_file)
- _base_name = cfg["name"]
+ def update(self, _cfg: dict[Any, Any]) -> None:
+ self.__dict__.update(_cfg)
-_cache_dir = osp.join(_cache_base_dir, _base_name)
+ def dump(self, path: str) -> None:
+ with open(path, "w") as f:
+ cfg = self.__dict__
+ cfg.pop("base_dir", None)
+ toml.dump(cfg, f)
-# TODO: Use a data class to store this
-_workspace_cfg = dict(
- name="",
- src_dir="",
- base_dir=os.getcwd(),
- registries=[],
- primary_fields=[],
- primary_to_registry={},
- json_schema_fields={},
- props={},
-)
+workspace = _WorkspaceConfig()
LOGO = r"""
▓█████ ▒██ ██▒ ▄████▄ ▒█████ ██▀███ ▓█████
▓█ ▀ ▒▒ █ █ ▒░▒██▀ ▀█ ▒██▒ ██▒▓██ ▒ ██▒▓█ ▀
diff --git a/excore/_exceptions.py b/excore/_exceptions.py
index a5f2101..7ce6cb8 100644
--- a/excore/_exceptions.py
+++ b/excore/_exceptions.py
@@ -56,3 +56,7 @@ class HookManagerBuildError(BaseException):
class HookBuildError(BaseException):
pass
+
+
+class AnnotationsFutureError(Exception):
+ pass
diff --git a/excore/_misc.py b/excore/_misc.py
index 03812ac..0fb0eb5 100644
--- a/excore/_misc.py
+++ b/excore/_misc.py
@@ -1,14 +1,15 @@
+from __future__ import annotations
+
import functools
-import threading
-import time
+from typing import Any, Callable, Sequence
from tabulate import tabulate
class CacheOut:
- def __call__(self, func):
+ def __call__(self, func: Callable[..., Any]):
@functools.wraps(func)
- def _cache(self):
+ def _cache(self) -> Any:
if not hasattr(self, "cached_elem"):
cached_elem = func(self)
if cached_elem != self:
@@ -19,27 +20,14 @@ def _cache(self):
return _cache
-class FileLock:
- def __init__(self, file_path, timeout=15):
- self.file_path = file_path
- self.timeout = timeout
- self.lock = threading.Lock()
-
- def __enter__(self):
- start_time = time.time()
- while not self.lock.acquire(False):
- if time.time() - start_time >= self.timeout:
- raise TimeoutError("Failed to acquire lock on file")
- time.sleep(0.1)
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.lock.release()
-
-
-def _create_table(header, contents, split=True, prefix="\n", **tabel_kwargs):
- if split:
- contents = [(i,) for i in contents]
+def _create_table(
+ header: str | list[str] | tuple[str, ...] | None,
+ contents: Sequence[str] | Sequence[Sequence[str]],
+ prefix: str = "\n",
+ **tabel_kwargs: Any,
+) -> str:
+ if len(contents) > 0 and isinstance(contents[0], str):
+ contents = [(i,) for i in contents] # type: ignore
if header is None:
header = ()
if not isinstance(header, (list, tuple)):
diff --git a/excore/cli/_cache.py b/excore/cli/_cache.py
index 1516e60..743a95b 100644
--- a/excore/cli/_cache.py
+++ b/excore/cli/_cache.py
@@ -1,14 +1,17 @@
+from __future__ import annotations
+
import os
import typer
-from .._constants import _base_name, _cache_base_dir, _cache_dir
+from excore import workspace
+
from .._misc import _create_table
from ..engine.logging import logger
from ._app import app
-def _clear_cache(cache_dir):
+def _clear_cache(cache_dir: str) -> None:
if os.path.exists(cache_dir):
import shutil # pylint: disable=import-outside-toplevel
@@ -19,38 +22,41 @@ def _clear_cache(cache_dir):
@app.command()
-def clear_cache():
+def clear_cache() -> None:
"""
Remove the cache folder which belongs to current workspace.
"""
- if not typer.confirm(f"Are you sure you want to clear cache of {_base_name}?"):
+ if not typer.confirm(
+ f"Are you sure you want to clear cache of {workspace.name}?"
+ f" Cache dir is {workspace.cache_dir}."
+ ):
return
-
- target = os.path.join(_cache_dir, _base_name)
- _clear_cache(target)
+ _clear_cache(workspace.cache_dir)
@app.command()
-def clear_all_cache():
+def clear_all_cache() -> None:
"""
Remove the whole cache folder.
"""
+ if not os.path.exists(workspace.cache_base_dir):
+ logger.warning("Cache dir {} does not exist", workspace.cache_base_dir)
+ return
+ print(_create_table("Cache Names", os.listdir(workspace.cache_base_dir)))
if not typer.confirm("Are you sure you want to clear all cache?"):
return
- _clear_cache(_cache_base_dir)
+ _clear_cache(workspace.cache_base_dir)
@app.command()
-def cache_list():
+def cache_list() -> None:
"""
Show cache folders.
"""
- tabel = _create_table("NAMES", os.listdir(_cache_base_dir))
+ tabel = _create_table("Cache Names", os.listdir(workspace.cache_base_dir))
logger.info(tabel)
@app.command()
-def cache_dir():
- # if not os.path.exists(_workspace_config_file):
- # raise RuntimeError("Not in ExCore project")
- print(_cache_dir)
+def cache_dir() -> None:
+ print(workspace.cache_dir)
diff --git a/excore/cli/_extention.py b/excore/cli/_extention.py
index ac7253a..f6ae725 100644
--- a/excore/cli/_extention.py
+++ b/excore/cli/_extention.py
@@ -1,42 +1,44 @@
+from __future__ import annotations
+
import os
import sys
-from typing import Optional
from typer import Argument as CArg
from typer import Option as COp
from typing_extensions import Annotated
-from .._constants import _cache_base_dir, _workspace_cfg
+from excore import workspace
+
from ..config._json_schema import _generate_json_schema_and_class_mapping, _generate_taplo_config
from ..engine.logging import logger
from ._app import app
@app.command()
-def config_extention():
+def config_extention() -> None:
"""
Generate json_schema for onfig completion and class_mapping for class navigation.
"""
- target_dir = os.path.join(_cache_base_dir, _workspace_cfg["name"])
- os.makedirs(target_dir, exist_ok=True)
- _generate_taplo_config(target_dir)
- if not _workspace_cfg["json_schema_fields"]:
+ _generate_taplo_config()
+ if not workspace.json_schema_fields:
logger.warning("You should set json_schema_fields first")
sys.exit(0)
- _generate_json_schema_and_class_mapping(_workspace_cfg["json_schema_fields"])
+ _generate_json_schema_and_class_mapping(workspace.json_schema_fields)
-def _generate_typehints(entry: str, class_name: str, info_class_name: str, config: Optional[str]):
- if not _workspace_cfg["primary_fields"]:
+def _generate_typehints(
+ entry: str, class_name: str, info_class_name: str, config: str = ""
+) -> None:
+ if not workspace.primary_fields:
logger.critical("Please initialize the workspace first.")
return
- target_file = os.path.join(_workspace_cfg["src_dir"], entry + ".py")
+ target_file = os.path.join(workspace.src_dir, entry + ".py")
logger.info(f"Generating module type hints in {target_file}.")
with open(target_file, "w", encoding="UTF-8") as f:
f.write(f"from typing import Union{', Any' if config else ''}\n")
f.write("from excore.config.model import ModuleNode, ModuleWrapper\n\n")
f.write(f"class {class_name}:\n")
- for i in _workspace_cfg["primary_fields"]:
+ for i in workspace.primary_fields:
f.write(f" {i}: Union[ModuleNode, ModuleWrapper]\n")
logger.info(f"Generating isolated objects type hints in {target_file}.")
if config:
@@ -53,21 +55,19 @@ def _generate_typehints(entry: str, class_name: str, info_class_name: str, confi
@app.command()
def generate_typehints(
- entry: Annotated[str, CArg(help="The file to generate.")] = "module_types",
+ entry: str = CArg(default="module_types", help="The file to generate."),
class_name: Annotated[str, COp(help="The class name of type hints.")] = "TypedModules",
info_class_name: Annotated[str, COp(help="The class name of run_info.")] = "RunInfo",
- config: Annotated[
- Optional[str], COp(help="Used generate type hints for isolated objects.")
- ] = None,
-):
+ config: Annotated[str, COp(help="Used generate type hints for isolated objects.")] = "",
+) -> None:
"""
Generate type hints for modules and isolated objects.
"""
_generate_typehints(entry, class_name, info_class_name, config)
-def _quote(config: str, override: bool):
- config_paths = []
+def _quote(config: str, override: bool) -> None:
+ config_paths: list[str] = []
def _get_path(path, paths):
if not os.path.isdir(path):
@@ -94,7 +94,7 @@ def _get_path(path, paths):
def quote(
config: Annotated[str, CArg(help="Target config file or folder.")],
override: Annotated[bool, COp(help="Whether to override configs.")] = False,
-):
+) -> None:
"""
Quote all special keys in target config files.
"""
diff --git a/excore/cli/_registry.py b/excore/cli/_registry.py
index 2c5f6e5..6719921 100644
--- a/excore/cli/_registry.py
+++ b/excore/cli/_registry.py
@@ -1,22 +1,24 @@
+from __future__ import annotations
+
import ast
import importlib
import os
import os.path as osp
import sys
+from typing import Any
-import astor
+import astor # type: ignore
import typer
from typer import Argument as CArg
-from typing_extensions import Annotated
-from .._constants import _cache_base_dir, _workspace_cfg, _workspace_config_file
+from .._constants import _workspace_config_file, workspace
from .._misc import _create_table
from ..engine.logging import logger
from ..engine.registry import Registry
from ._app import app
-def _has_import_excore(node):
+def _has_import_excore(node) -> bool:
if isinstance(node, ast.Module):
for child in node.body:
f = _has_import_excore(child)
@@ -33,18 +35,18 @@ def _has_import_excore(node):
def _build_ast(name: str) -> ast.Assign:
- targets = [ast.Name(name.upper(), ast.Store)]
- func = ast.Name("Registry", ast.Load)
+ targets = [ast.Name(name.upper(), ast.Store)] # type: ignore
+ func = ast.Name("Registry", ast.Load) # type: ignore
args = [ast.Constant(name)]
- value = ast.Call(func, args, [])
- return ast.Assign(targets, value)
+ value = ast.Call(func, args, []) # type: ignore
+ return ast.Assign(targets, value) # type: ignore
def _generate_registries(entry="__init__"):
- if not _workspace_cfg["primary_fields"]:
+ if not workspace.primary_fields:
return
logger.info("Generating Registry definition code.")
- target_file = osp.join(_workspace_cfg["src_dir"], entry + ".py")
+ target_file = osp.join(workspace.src_dir, entry + ".py")
if not osp.exists(target_file):
with open(target_file, "w", encoding="UTF-8") as f:
@@ -59,7 +61,7 @@ def _generate_registries(entry="__init__"):
name = [ast.alias("Registry", None)]
source_code.body.insert(0, ast.ImportFrom("excore", name, 0))
- for name in _get_registries(_workspace_cfg["registries"]):
+ for name in _get_registries(workspace.registries):
if name.startswith("*"):
name = name[1:]
source_code.body.append(_build_ast(name))
@@ -69,7 +71,7 @@ def _generate_registries(entry="__init__"):
logger.success("Generate Registry definition in {} according to `primary_fields`", target_file)
-def _detect_assign(node, definition):
+def _detect_assign(node: ast.AST, definition: list) -> None:
if isinstance(node, ast.Module):
for child in node.body:
_detect_assign(child, definition)
@@ -79,19 +81,19 @@ def _detect_assign(node, definition):
and hasattr(node.value.func, "id")
and node.value.func.id == "Registry"
):
- definition.append(node.value.args[0].value)
+ definition.append(node.value.args[0].value) # type: ignore
def _detect_registy_difinition() -> bool:
- target_file = osp.join(_workspace_cfg["src_dir"], "__init__.py")
+ target_file = osp.join(workspace.src_dir, "__init__.py")
logger.info("Detect Registry definition in {}", target_file)
- definition = []
+ definition: list[Any] = []
with open(target_file, encoding="UTF-8") as f:
source_code = ast.parse(f.read())
_detect_assign(source_code, definition)
if len(definition) > 0:
logger.info("Find Registry definition: {}", definition)
- _workspace_cfg["registries"] = definition
+ workspace.registries = definition
return True
return False
@@ -104,7 +106,9 @@ def _format(reg_and_fields: str) -> str:
return splits[0] + ": " + ", ".join(fields)
-def _parse_registries(reg_and_fields):
+def _parse_registries(
+ reg_and_fields: list[str],
+) -> tuple[list[str], dict[Any, Any], dict[Any, Any]]:
fields = [i[1:].split(":") for i in reg_and_fields if i.startswith("*")]
targets = []
rev = {}
@@ -121,16 +125,15 @@ def _parse_registries(reg_and_fields):
rev[j] = i[0]
targets.extend(tar)
json_schema["isolated_fields"] = isolated_fields
- return set(targets), rev, json_schema
+ return list(set(targets)), rev, json_schema
-def _get_registries(reg_and_fields):
+def _get_registries(reg_and_fields) -> list[str]:
return [i.split(":")[0] for i in reg_and_fields]
-def _update(is_init=True, entry="__init__"):
- target_dir = osp.join(_cache_base_dir, _workspace_cfg["name"])
- os.makedirs(target_dir, exist_ok=True)
+def _update(is_init: bool = True, entry: str = "__init__") -> None:
+ os.makedirs(workspace.cache_dir, exist_ok=True)
logger.success("Generate `.taplo.toml`")
if is_init:
if not _detect_registy_difinition():
@@ -142,14 +145,14 @@ def _update(is_init=True, entry="__init__"):
regs.append(_format(inp))
else:
break
- _workspace_cfg["registries"] = regs
+ workspace.registries = regs
else:
logger.imp("You can define fields later.")
(
- _workspace_cfg["primary_fields"],
- _workspace_cfg["primary_to_registry"],
- _workspace_cfg["json_schema_fields"],
- ) = _parse_registries(_workspace_cfg["registries"])
+ workspace.primary_fields,
+ workspace.primary_to_registry,
+ workspace.json_schema_fields,
+ ) = _parse_registries(workspace.registries)
_generate_registries(entry)
else:
logger.imp(
@@ -158,20 +161,20 @@ def _update(is_init=True, entry="__init__"):
)
else:
(
- _workspace_cfg["primary_fields"],
- _workspace_cfg["primary_to_registry"],
- _workspace_cfg["json_schema_fields"],
- ) = _parse_registries([_format(i) for i in _workspace_cfg["registries"]])
+ workspace.primary_fields,
+ workspace.primary_to_registry,
+ workspace.json_schema_fields,
+ ) = _parse_registries([_format(i) for i in workspace.registries])
logger.success("Update primary_fields")
-def _get_default_module_name(target_dir):
+def _get_default_module_name(target_dir: str) -> str:
assert os.path.isdir(target_dir)
full_path = os.path.abspath(target_dir)
return full_path.split(os.sep)[-1]
-def _auto_register(target_dir, module_name):
+def _auto_register(target_dir: str, module_name: str) -> None:
for file_name in os.listdir(target_dir):
full_path = os.path.join(target_dir, file_name)
if os.path.isdir(full_path):
@@ -190,44 +193,43 @@ def _auto_register(target_dir, module_name):
@app.command()
-def auto_register():
+def auto_register() -> None:
"""
Automatically import all modules in `src_dir` and register all modules, then dump to files.
"""
if not os.path.exists(_workspace_config_file):
logger.critical("Please run `excore init` in your command line first!")
sys.exit(0)
- target_dir = osp.abspath(_workspace_cfg["src_dir"])
+ target_dir = osp.abspath(workspace.src_dir)
module_name = _get_default_module_name(target_dir)
- sys.path.append(os.getcwd())
_auto_register(target_dir, module_name)
Registry.dump()
@app.command()
-def primary_fields():
+def primary_fields() -> None:
"""
Show primary_fields.
"""
- table = _create_table("FIELDS", _parse_registries(_workspace_cfg["registries"])[0])
+ table = _create_table("FIELDS", _parse_registries(workspace.registries)[0])
logger.info(table)
@app.command()
-def registries():
+def registries() -> None:
"""
Show registries.
"""
- table = _create_table("Registry", [_format(i) for i in _workspace_cfg["registries"]])
+ table = _create_table("Registry", [_format(i) for i in workspace.registries])
logger.info(table)
@app.command()
def generate_registries(
- entry: Annotated[
- str, CArg(help="Used for detect or generate Registry definition code")
- ] = "__init__",
-):
+ entry: str = CArg(
+ default="__init__", help="Used for detect or generate Registry definition code"
+ ),
+) -> None:
"""
Generate registries definition code according to workspace config.
"""
diff --git a/excore/cli/_workspace.py b/excore/cli/_workspace.py
index 21d465d..2be4b6a 100644
--- a/excore/cli/_workspace.py
+++ b/excore/cli/_workspace.py
@@ -1,34 +1,24 @@
import os
import os.path as osp
-import toml
import typer
from typer import Argument as CArg
from typer import Option as COp
from typing_extensions import Annotated
-from .._constants import (
- LOGO,
- _base_name,
- _cache_base_dir,
- _cache_dir,
- _workspace_cfg,
- _workspace_config_file,
-)
+from .._constants import LOGO, _workspace_config_file, workspace
from ..engine.logging import logger
from ._app import app
from ._registry import _update
-def _dump_workspace_config():
+def _dump_workspace_config() -> None:
logger.info("Dump config to {}", _workspace_config_file)
- _workspace_cfg.pop("base_dir", None)
- with open(_workspace_config_file, "w", encoding="UTF-8") as f:
- toml.dump(_workspace_cfg, f)
+ workspace.dump(_workspace_config_file)
@app.command()
-def update():
+def update() -> None:
"""
Update workspace config file.
"""
@@ -42,11 +32,11 @@ def init(
entry: Annotated[
str, CArg(help="Used for detect or generate Registry definition code")
] = "__init__",
-):
+) -> None:
"""
Initialize workspace and generate a config file.
"""
- if osp.exists(_cache_dir) and not force:
+ if osp.exists(_workspace_config_file) and not force:
logger.warning("excore.toml already existed!")
return
cwd = os.getcwd()
@@ -56,17 +46,17 @@ def init(
colors=True,
)
logger.opt(colors=True).info(f"It will be generated in {cwd}\n")
- logger.opt(colors=True).info(f"WorkSpace Name [{_base_name}]:")
- name = typer.prompt("", default=_base_name, show_default=False, prompt_suffix="")
- if not force and os.path.exists(os.path.join(_cache_base_dir, _base_name)):
+ logger.opt(colors=True).info(f"WorkSpace Name [{workspace.base_name}]:")
+ name = typer.prompt("", default=workspace.base_name, show_default=False, prompt_suffix="")
+ if not force and os.path.exists(workspace.cache_dir):
logger.warning(f"name {name} already existed!")
return
logger.opt(colors=True).info("Source Code Directory(relative path):")
src_dir = typer.prompt("", prompt_suffix="")
- _workspace_cfg["name"] = name
- _workspace_cfg["src_dir"] = src_dir
+ workspace.name = name
+ workspace.src_dir = src_dir
_update(True, entry)
_dump_workspace_config()
diff --git a/excore/config/_json_schema.py b/excore/config/_json_schema.py
index 5408b85..247535a 100644
--- a/excore/config/_json_schema.py
+++ b/excore/config/_json_schema.py
@@ -2,31 +2,45 @@
import inspect
import json
-import os
-import os.path as osp
import sys
from inspect import Parameter, _empty, _ParameterKind, isclass
from types import ModuleType
-from typing import Any, Dict, Sequence, Union, get_args, get_origin
+from typing import TYPE_CHECKING, Any, Callable, Dict, Sequence, Union, get_args, get_origin
import toml
-from .._constants import _cache_dir, _class_mapping_file, _json_schema_file
+from excore import workspace
+
+from .._exceptions import AnnotationsFutureError
from ..engine.hook import ConfigArgumentHook
from ..engine.logging import logger
from ..engine.registry import Registry, load_registries
from .model import _str_to_target
if sys.version_info >= (3, 10, 0):
- from types import NoneType, UnionType
+ from types import NoneType, UnionType # type: ignore
else:
- NoneType = type(None)
+ NoneType = type(None) # type: ignore
# just a placeholder
- class UnionType:
+ class UnionType: # type: ignore
pass
+if TYPE_CHECKING:
+ from typing import TypedDict
+
+ from typing_extensions import NotRequired
+
+ class Property(TypedDict):
+ properties: NotRequired[Property]
+ type: NotRequired[str]
+ items: NotRequired[dict]
+ value: NotRequired[str]
+ description: NotRequired[str]
+ required: NotRequired[list[str]]
+
+
TYPE_MAPPER: dict[type, str] = {
int: "number", # sometimes default value are not accurate
str: "string",
@@ -76,8 +90,8 @@ def _generate_json_schema_and_class_mapping(
for name, v in props["properties"].items():
schema["properties"][name] = v
json_str = json.dumps(schema, indent=2)
- save_path = save_path or _json_schema_path()
- class_mapping_save_path = class_mapping_save_path or _class_mapping_path()
+ save_path = save_path or workspace.json_schema_file
+ class_mapping_save_path = class_mapping_save_path or workspace.class_mapping_file
with open(save_path, "w", encoding="UTF-8") as f:
f.write(json_str)
logger.success("json schema has been written to {}", save_path)
@@ -95,14 +109,14 @@ def _check(bases) -> bool:
return False
-def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
- props = {
+def parse_registry(reg: Registry) -> tuple[Property, dict[str, list[str | int]]]:
+ props: Property = {
"type": "object",
"properties": {},
}
- class_mapping = {}
+ class_mapping: dict[str, list[str | int]] = {}
for name, item_dir in reg.items():
- func = _str_to_target(item_dir)
+ func = _str_to_target(item_dir) # type: ignore
if isinstance(func, ModuleType):
continue
class_mapping[name] = [inspect.getfile(func), inspect.getsourcelines(func)[1]]
@@ -110,7 +124,7 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
is_hook = isclass(func) and issubclass(func, ConfigArgumentHook)
if isclass(func) and _check(func.__bases__):
func = func.__init__
- params = inspect.signature(func).parameters
+ params = inspect.signature(func).parameters # type: ignore
param_props = {"type": "object", "properties": {}}
if doc_string:
# TODO: parse doc string to each parameters
@@ -122,11 +136,16 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
continue
try:
is_required, item = parse_single_param(param_obj)
- except Exception:
+ except Exception as e:
from rich.console import Console
- logger.error(f"Skip {param_obj.name} of {name}")
Console().print_exception()
+ if isinstance(e, AnnotationsFutureError):
+ logger.error(
+ f"Skip {name} due to mismatch of python version and annotations future."
+ )
+ break
+ logger.error(f"Skip parameter {param_obj.name} of {name}")
continue
items[param_name] = item
if is_required:
@@ -135,14 +154,16 @@ def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]:
param_props["properties"] = items
if required:
param_props["required"] = required
- props["properties"][name] = param_props
+ props["properties"][name] = param_props # type: ignore
return props, class_mapping
def _remove_optional(anno):
origin = get_origin(anno)
+ if origin is not Union:
+ return anno
inner_types = get_args(anno)
- if origin is not Union and len(inner_types) != 2:
+ if len(inner_types) != 2:
return anno
filter_types = [i for i in inner_types if i is not NoneType]
if len(filter_types) == 1:
@@ -150,7 +171,7 @@ def _remove_optional(anno):
return anno
-def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None:
+def _parse_inner_types(prop: Property, inner_types: Sequence[type]) -> None:
first_type = inner_types[0]
is_all_the_same = True
for t in inner_types:
@@ -159,11 +180,13 @@ def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None:
prop["items"] = {"type": TYPE_MAPPER.get(first_type)}
-def _parse_typehint(prop: dict, anno: type) -> str | None:
+def _parse_typehint(prop: Property, anno: type) -> str | None:
potential_type = TYPE_MAPPER.get(anno)
if potential_type is not None:
return potential_type
origin = get_origin(anno)
+ if anno is Callable:
+ return "string"
inner_types = get_args(anno)
if origin in (Sequence, list, tuple):
potential_type = "array"
@@ -178,8 +201,8 @@ def _parse_typehint(prop: dict, anno: type) -> str | None:
return potential_type or "string"
-def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]:
- prop = {}
+def parse_single_param(param: Parameter) -> tuple[bool, Property]:
+ prop: Property = {}
anno = param.annotation
potential_type = None
@@ -187,24 +210,25 @@ def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]:
# hardcore for torch.optim
if param.default.__class__.__name__ == "_RequiredParameter":
- param._default = _empty
+ param._default = _empty # type: ignore
if isinstance(anno, str):
- raise RuntimeError(
+ raise AnnotationsFutureError(
"Use a higher version of python, e.g. 3.10, "
"and remove `from __future__ import annotations`."
)
elif anno is not _empty:
potential_type = _parse_typehint(prop, anno)
# determine type by default value
- elif param.default is not _empty:
- potential_type = TYPE_MAPPER[type(param.default)]
+ elif param.default is not _empty and param.default is not None:
+ # TODO: Allow user to add more type mapper.
+ potential_type = TYPE_MAPPER.get(type(param.default), "number")
if isinstance(param.default, (list, tuple)):
types = [type(t) for t in param.default]
_parse_inner_types(prop, types)
- elif param._kind is _ParameterKind.VAR_POSITIONAL:
+ elif param.kind is _ParameterKind.VAR_POSITIONAL:
return False, {"type": "array"}
- elif param._kind is _ParameterKind.VAR_KEYWORD:
+ elif param.kind is _ParameterKind.VAR_KEYWORD:
return False, {"type": "object"}
if anno is _empty and param.default is _empty:
potential_type = "number"
@@ -213,18 +237,10 @@ def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]:
return param.default is _empty, prop
-def _json_schema_path() -> str:
- return os.path.join(_cache_dir, _json_schema_file)
-
-
-def _class_mapping_path() -> str:
- return os.path.join(_cache_dir, _class_mapping_file)
-
-
-def _generate_taplo_config(path: str) -> None:
+def _generate_taplo_config() -> None:
cfg = dict(
schema=dict(
- path=osp.join(osp.expanduser(path), _json_schema_file),
+ path=workspace.json_schema_file,
enabled=True,
),
formatting=dict(align_entries=False),
diff --git a/excore/config/action.py b/excore/config/action.py
index b7c4943..4af9acc 100644
--- a/excore/config/action.py
+++ b/excore/config/action.py
@@ -2,9 +2,11 @@
Copy and adapt from mmengine/config/config.py
"""
+from __future__ import annotations
+
import copy
from argparse import Action, ArgumentParser, Namespace
-from typing import Any, Sequence, Union
+from typing import Any, Sequence
__all__ = ["DictAction"]
@@ -37,7 +39,7 @@ def __init__(
self._dict = {}
@staticmethod
- def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
+ def _parse_int_float_bool(val: str) -> int | float | bool | Any:
"""parse int/float/bool value in the string."""
try:
return int(val)
@@ -54,7 +56,7 @@ def _parse_int_float_bool(val: str) -> Union[int, float, bool, Any]:
return val
@staticmethod
- def _parse_iterable(val: str) -> Union[list, tuple, Any]:
+ def _parse_iterable(val: str) -> list | tuple | Any:
"""Parse iterable values in the string.
All elements inside '()' or '[]' are treated as iterable values.
@@ -137,8 +139,8 @@ def __call__(
self,
parser: ArgumentParser,
namespace: Namespace,
- values: Union[str, Sequence[Any], None],
- option_string: str = None,
+ values: str | Sequence[Any] | None,
+ option_string: str | None = None,
):
"""Parse Variables in string and add them into argparser.
diff --git a/excore/config/config.py b/excore/config/config.py
index b68a1ea..f457dca 100644
--- a/excore/config/config.py
+++ b/excore/config/config.py
@@ -1,6 +1,8 @@
+from __future__ import annotations
+
import os
import time
-from typing import Dict, Optional, Tuple
+from typing import Any
import toml
@@ -14,10 +16,6 @@
__all__ = ["load", "build_all", "load_config"]
-# TODO: Improve error messages. high priority.
-# TODO: Support multiple same modules parsing.
-
-
BASE_CONFIG_KEY = "__base__"
@@ -39,7 +37,7 @@ def load_config(filename: str, base_key: str = "__base__") -> ConfigDict:
return base_cfg
-def _merge_config(base_cfg, new_cfg):
+def _merge_config(base_cfg: ConfigDict, new_cfg: dict) -> None:
for k, v in new_cfg.items():
if k in base_cfg and isinstance(v, dict):
_merge_config(base_cfg[k], v)
@@ -50,8 +48,8 @@ def _merge_config(base_cfg, new_cfg):
def load(
filename: str,
*,
- dump_path: Optional[str] = None,
- update_dict: Optional[Dict] = None,
+ dump_path: str | None = None,
+ update_dict: dict[str, Any] | None = None,
base_key: str = BASE_CONFIG_KEY,
parse_config: bool = True,
) -> LazyConfig:
@@ -89,7 +87,7 @@ def load(
return lazy_config
-def build_all(cfg: LazyConfig) -> Tuple[ModuleWrapper, dict]:
+def build_all(cfg: LazyConfig) -> tuple[ModuleWrapper, dict[str, Any]]:
st = time.time()
modules = cfg.build_all()
logger.success("Modules building costs {:.4f}s!", time.time() - st)
diff --git a/excore/config/lazy_config.py b/excore/config/lazy_config.py
index 2c3b465..a38a6c2 100644
--- a/excore/config/lazy_config.py
+++ b/excore/config/lazy_config.py
@@ -1,8 +1,10 @@
+from __future__ import annotations
+
import time
from copy import deepcopy
-from typing import Any, Dict, Tuple
+from typing import Any
-from ..engine.hook import ConfigHookManager
+from ..engine.hook import ConfigHookManager, Hook
from ..engine.logging import logger
from ..engine.registry import Registry
from .model import ConfigHookNode, InterNode, ModuleWrapper
@@ -10,7 +12,9 @@
class LazyConfig:
- hook_key = "ConfigHook"
+ hook_key: str = "ConfigHook"
+ modules_dict: dict[str, ModuleWrapper]
+ isolated_dict: dict[str, Any]
def __init__(self, config: ConfigDict) -> None:
self.modules_dict, self.isolated_dict = {}, {}
@@ -21,29 +25,30 @@ def __init__(self, config: ConfigDict) -> None:
self._original_config = deepcopy(config)
self.__is_parsed__ = False
- def parse(self):
+ def parse(self) -> None:
st = time.time()
self.build_config_hooks()
self._config.parse()
logger.success("Config parsing cost {:.4f}s!", time.time() - st)
self.__is_parsed__ = True
- logger.ex(self._config)
+ logger.ex(str(self._config))
@property
- def config(self):
+ def config(self) -> ConfigDict:
return self._original_config
- def update(self, cfg: "LazyConfig"):
+ def update(self, cfg: LazyConfig) -> None:
self._config.update(cfg._config)
- def build_config_hooks(self):
+ def build_config_hooks(self) -> None:
hook_cfgs = self._config.pop(LazyConfig.hook_key, [])
hooks = []
if hook_cfgs:
_, base = Registry.find(list(hook_cfgs.keys())[0])
+ assert base is not None, hook_cfgs
reg = Registry.get_registry(base)
for name, params in hook_cfgs.items():
- hook = ConfigHookNode.from_str(reg[name], params)()
+ hook: Hook = ConfigHookNode.from_str(reg[name], params)() # type: ignore
if hook:
hooks.append(hook)
else:
@@ -57,10 +62,12 @@ def __getattr__(self, __name: str) -> Any:
return self._config[__name]
raise AttributeError(__name)
- def build_all(self) -> Tuple[ModuleWrapper, Dict]:
+ def build_all(self) -> tuple[ModuleWrapper, dict[str, Any]]:
if not self.__is_parsed__:
self.parse()
- module_dict, isolated_dict = ModuleWrapper(), {}
+ module_dict = ModuleWrapper()
+ isolated_dict: dict[str, Any] = {}
+
self.hooks.call_hooks("pre_build", self, module_dict, isolated_dict)
for name in self.target_modules:
if name not in self._config:
@@ -78,5 +85,5 @@ def build_all(self) -> Tuple[ModuleWrapper, Dict]:
def dump(self, dump_path: str) -> None:
self._original_config.dump(dump_path)
- def __str__(self):
+ def __str__(self) -> str:
return str(self._config)
diff --git a/excore/config/model.py b/excore/config/model.py
index bbcf2ff..aaa07c4 100644
--- a/excore/config/model.py
+++ b/excore/config/model.py
@@ -1,33 +1,51 @@
+from __future__ import annotations
+
import importlib
import os
import re
from dataclasses import dataclass, field
-from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
+from typing import TYPE_CHECKING, Type, Union
from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError
from .._misc import CacheOut
-from ..engine.hook import ConfigArgumentHook
+from ..engine.hook import ConfigArgumentHook, Hook
from ..engine.logging import logger
from ..engine.registry import Registry
+if TYPE_CHECKING:
+ from types import FunctionType, ModuleType
+ from typing import Any, Dict, Literal
+
+ from typing_extensions import Self
+
+ NodeClassType = Type[Any]
+ NodeParams = Dict[Any, Any]
+ NodeInstance = object
+
+ NoCallSkipFlag = Self
+ ConfigHookSkipFlag = Type[None]
+
+ SpecialFlag = Literal["@", "!", "$", "&", ""]
+
+
__all__ = ["silent"]
-REUSE_FLAG = "@"
-INTER_FLAG = "!"
-CLASS_FLAG = "$"
-REFER_FLAG = "&"
-OTHER_FLAG = ""
+REUSE_FLAG: Literal["@"] = "@"
+INTER_FLAG: Literal["!"] = "!"
+CLASS_FLAG: Literal["$"] = "$"
+REFER_FLAG: Literal["&"] = "&"
+OTHER_FLAG: Literal[""] = ""
LOG_BUILD_MESSAGE = True
DO_NOT_CALL_KEY = "__no_call__"
-def silent():
+def silent() -> None:
global LOG_BUILD_MESSAGE # pylint: disable=global-statement
LOG_BUILD_MESSAGE = False
-def _is_special(k: str) -> Tuple[str, str]:
+def _is_special(k: str) -> tuple[str, SpecialFlag]:
"""
Determine if the given string begin with target special character.
`@` denotes reused module, which will only be built once and cached out.
@@ -41,31 +59,27 @@ def _is_special(k: str) -> Tuple[str, str]:
Returns:
Tuple[str, str]: A tuple containing the modified string and the special character.
"""
- if k.startswith(REUSE_FLAG):
- return k[1:], REUSE_FLAG
- if k.startswith(INTER_FLAG):
- return k[1:], INTER_FLAG
- if k.startswith(CLASS_FLAG):
- return k[1:], CLASS_FLAG
- if k.startswith(REFER_FLAG):
- return k[1:], REFER_FLAG
+ pattern = re.compile(r"^([@!$&])(.*)$")
+ match = pattern.match(k)
+ if match:
+ return match.group(2), match.group(1) # type: ignore
return k, ""
-def _str_to_target(module_name):
- module_name = module_name.split(".")
- if len(module_name) == 1:
- return importlib.import_module(module_name[0])
- target_name = module_name.pop(-1)
+def _str_to_target(module_name: str) -> ModuleType | NodeClassType | FunctionType:
+ module_names = module_name.split(".")
+ if len(module_names) == 1:
+ return importlib.import_module(module_names[0])
+ target_name = module_names.pop(-1)
try:
- module = importlib.import_module(".".join(module_name))
+ module = importlib.import_module(".".join(module_names))
except ModuleNotFoundError as exc:
- raise StrToClassError(f"Cannot import such module: `{'.'.join(module_name)}`") from exc
+ raise StrToClassError(f"Cannot import such module: `{'.'.join(module_names)}`") from exc
try:
module = getattr(module, target_name)
except AttributeError as exc:
raise StrToClassError(
- f"Cannot find such module `{target_name}` form `{'.'.join(module_name)}`"
+ f"Cannot find such module `{target_name}` form `{'.'.join(module_names)}`"
) from exc
return module
@@ -76,57 +90,57 @@ class ModuleNode(dict):
_no_call: bool = field(default=False, repr=False)
priority: int = field(default=0, repr=False)
- def _get_params(self, **kwargs):
- params = {}
+ def _get_params(self, **params: NodeParams) -> NodeParams:
+ return_params = {}
for k, v in self.items():
- if isinstance(v, ModuleWrapper):
+ if isinstance(v, (ModuleWrapper, ModuleNode)):
v = v()
- params[k] = v
- params.update(kwargs)
- return params
+ return_params[k] = v
+ return_params.update(params)
+ return return_params
@property
- def name(self):
+ def name(self) -> str:
return self.cls.__name__
- def add(self, **kwargs):
- self.update(kwargs)
+ def add(self, **params: NodeParams) -> Self:
+ self.update(params)
return self
- def _instantiate(self, params):
+ def _instantiate(self, params: NodeParams) -> NodeInstance:
try:
module = self.cls(**params)
except Exception as exc:
raise ModuleBuildError(
- f"Build Error with module {self.cls} and arguments {params}"
+ f"Instantiate Error with module {self.cls} and arguments {params}"
) from exc
if LOG_BUILD_MESSAGE:
logger.success(
- f"Successfully build module: {self.cls.__name__}, with arguments {params}"
+ f"Successfully instantiate module: {self.cls.__name__}, with arguments {params}"
)
return module
- def __call__(self, **kwargs):
+ def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance: # type: ignore
if self._no_call:
return self
- params = self._get_params(**kwargs)
+ params = self._get_params(**params)
module = self._instantiate(params)
return module
- def __lshift__(self, kwargs):
- if not isinstance(kwargs, dict):
- raise TypeError(f"Expect type is dict, but got {type(kwargs)}")
- self.update(kwargs)
+ def __lshift__(self, params: NodeParams) -> Self:
+ if not isinstance(params, dict):
+ raise TypeError(f"Expect type is dict, but got {type(params)}")
+ self.update(params)
return self
- def __rshift__(self, __other):
+ def __rshift__(self, __other: ModuleNode) -> Self:
if not isinstance(__other, ModuleNode):
raise TypeError(f"Expect type is `ModuleNode`, but got {type(__other)}")
__other.update(self)
return self
@classmethod
- def from_str(cls, str_target, params=None):
+ def from_str(cls, str_target: str, params: NodeParams | None = None) -> ModuleNode:
node = cls(_str_to_target(str_target))
if params:
node.update(params)
@@ -135,7 +149,7 @@ def from_str(cls, str_target, params=None):
return node
@classmethod
- def from_base_name(cls, base, name, params=None):
+ def from_base_name(cls, base: str, name: str, params: NodeParams | None = None) -> ModuleNode:
try:
cls_name = Registry.get_registry(base)[name]
except KeyError as exc:
@@ -145,7 +159,7 @@ def from_base_name(cls, base, name, params=None):
return cls.from_str(cls_name, params)
@classmethod
- def from_node(cls, _other: "ModuleNode") -> "ModuleNode":
+ def from_node(cls, _other: ModuleNode) -> ModuleNode:
if _other.__class__.__name__ == cls.__name__:
return _other
node = cls(_other.cls) << _other
@@ -159,10 +173,10 @@ class InterNode(ModuleNode):
class ConfigHookNode(ModuleNode):
- def __call__(self, **kwargs):
+ def __call__(self, **params: NodeParams) -> NodeInstance | ConfigHookSkipFlag | Hook:
if issubclass(self.cls, ConfigArgumentHook):
return None
- params = self._get_params(**kwargs)
+ params = self._get_params(**params)
return self._instantiate(params)
@@ -170,24 +184,24 @@ class ReusedNode(InterNode):
priority: int = 3
@CacheOut()
- def __call__(self, **kwargs):
- return super().__call__(**kwargs)
+ def __call__(self, **params: NodeParams) -> NodeInstance | NoCallSkipFlag: # type: ignore
+ return super().__call__(**params)
class ClassNode(InterNode):
priority: int = 1
- def __call__(self):
+ def __call__(self) -> NodeClassType | FunctionType: # type: ignore
return self.cls
class ChainedInvocationWrapper(ConfigArgumentHook):
- def __init__(self, node: ModuleNode, attrs: Sequence[str]) -> None:
+ def __init__(self, node: ModuleNode, attrs: list[str]) -> None:
super().__init__(node)
self.attrs = attrs
- def hook(self, **kwargs):
- target = self.node(**kwargs)
+ def hook(self, **params: NodeParams) -> Any:
+ target = self.node(**params)
if isinstance(target, ModuleNode):
raise ModuleBuildError(f"Do not support `{DO_NOT_CALL_KEY}`")
if self.attrs:
@@ -201,7 +215,7 @@ def hook(self, **kwargs):
@dataclass
class VariableReference:
- def __init__(self, value: str):
+ def __init__(self, value: str) -> None:
env_names = re.findall(r"\$\{([^}]+)\}", value)
self.has_env = len(env_names) > 0
for env in env_names:
@@ -210,16 +224,22 @@ def __init__(self, value: str):
value = re.sub(r"\$\{" + re.escape(env) + r"\}", env_value, value)
self.value = value
- def __call__(self):
+ def __call__(self) -> str:
return self.value
+ConfigNode = Union[ModuleNode, ConfigArgumentHook]
+NodeType = Type[ModuleNode]
+
+
class ModuleWrapper(dict):
def __init__(
self,
- modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None,
- is_dict=False,
- ):
+ modules: (
+ dict[str, ConfigNode] | list[ConfigNode] | ConfigNode | VariableReference | None
+ ) = None,
+ is_dict: bool = False,
+ ) -> None:
super().__init__()
if modules is None:
return
@@ -241,14 +261,14 @@ def __init__(
f"Expect modules to be `list`, `dict` or `ModuleNode`, but got {type(modules)}"
)
- def _get_name(self, m):
+ def _get_name(self, m) -> Any:
if hasattr(m, "name"):
return m.name
return m.__class__.__name__
- def __lshift__(self, kwargs):
+ def __lshift__(self, params: NodeParams) -> None:
if len(self) == 1:
- self[list(self.keys())[0]] << kwargs
+ self[list(self.keys())[0]] << params
else:
raise RuntimeError("Wrapped more than 1 ModuleNode, index first")
@@ -274,10 +294,9 @@ def __repr__(self) -> str:
return f"ModuleWrapper{list(self.values())}"
-_dispatch_module_node = {
+_dispatch_module_node: dict[SpecialFlag, NodeType] = {
OTHER_FLAG: ModuleNode,
REUSE_FLAG: ReusedNode,
INTER_FLAG: InterNode,
CLASS_FLAG: ClassNode,
- REFER_FLAG: VariableReference,
}
diff --git a/excore/config/parse.py b/excore/config/parse.py
index bb1951f..d09654e 100644
--- a/excore/config/parse.py
+++ b/excore/config/parse.py
@@ -1,6 +1,6 @@
from __future__ import annotations
-from typing import Any, Generator
+from typing import TYPE_CHECKING
from .._exceptions import CoreConfigParseError, ImplicitModuleParseError
from .._misc import _create_table
@@ -18,6 +18,13 @@
_is_special,
)
+if TYPE_CHECKING:
+ from typing import Generator, Sequence
+
+ from typing_extensions import Self
+
+ from .model import ConfigNode, NodeType, SpecialFlag
+
def _check_implicit_module(module: ModuleNode) -> None:
import inspect
@@ -40,12 +47,12 @@ def _check_implicit_module(module: ModuleNode) -> None:
)
-def _dict2node(module_type: str, base: str, _dict: dict):
- ModuleType: type[ModuleNode] = _dispatch_module_node[module_type]
+def _dict2node(module_type: SpecialFlag, base: str, _dict: dict):
+ ModuleType = _dispatch_module_node[module_type]
return {name: ModuleType.from_base_name(base, name, v) for name, v in _dict.items()}
-def _parse_param_name(name):
+def _parse_param_name(name) -> tuple[str, list[str], list[str]]:
names = name.split("@")
attrs = names.pop(0).split(".")
attrs = [i for i in attrs if i]
@@ -53,25 +60,18 @@ def _parse_param_name(name):
return attrs.pop(0), attrs, hooks
-def _flatten_list(lis):
+def _flatten_list(
+ lis: Sequence[ConfigNode | VariableReference | list[ConfigNode]],
+) -> Sequence[ConfigNode | VariableReference]:
new_lis = []
for i in lis:
if isinstance(i, list):
new_lis.extend(i)
else:
- new_lis.append(i)
+ new_lis.append(i) # type: ignore
return new_lis
-def _flatten_dict(dic):
- new_dic = {}
- for k, v in dic.items():
- if isinstance(v, list):
- v = _flatten_list(v)
- new_dic[k] = v
- return new_dic
-
-
class ConfigDict(dict):
primary_fields: list
primary_to_registry: dict[str, str]
@@ -80,7 +80,7 @@ class ConfigDict(dict):
scratchpads_fields: set[str] = set()
current_field: str | None = None
- def __new__(cls):
+ def __new__(cls) -> Self:
if not hasattr(cls, "primary_fields"):
raise RuntimeError("Call `set_primary_fields` before `load`")
@@ -91,11 +91,13 @@ class ConfigDictImpl(ConfigDict):
primary_to_registry = ConfigDict.primary_to_registry
scratchpads_fields = ConfigDict.scratchpads_fields
- inst = super().__new__(ConfigDictImpl)
- return inst
+ inst = super().__new__(ConfigDictImpl) # type: ignore
+ return inst # type: ignore
@classmethod
- def set_primary_fields(cls, primary_fields, primary_to_registry):
+ def set_primary_fields(
+ cls, primary_fields: Sequence[str], primary_to_registry: dict[str, str]
+ ) -> None:
"""
Sets the `primary_fields` attribute to the specified list of module names,
and `registered_fields` attributes based on the current state
@@ -103,7 +105,7 @@ def set_primary_fields(cls, primary_fields, primary_to_registry):
Note that `set_primary_fields` must be called before `config.load`.
"""
- cls.primary_fields = primary_fields
+ cls.primary_fields = list(primary_fields)
cls.primary_to_registry = primary_to_registry
def parse(self) -> None:
@@ -177,12 +179,12 @@ def _clean(self) -> None:
if name in self.registered_fields or isinstance(self[name], ModuleNode):
self.pop(name)
- def primary_keys(self) -> Generator[Any, Any, None]:
+ def primary_keys(self) -> Generator[str, None, None]:
for name in self.primary_fields:
if name in self:
yield name
- def non_primary_keys(self) -> Generator[Any, Any, None]:
+ def non_primary_keys(self) -> Generator[str, None, None]:
keys = list(self.keys())
for k in keys:
if k not in self.primary_fields:
@@ -207,7 +209,7 @@ def _parse_primary_modules(self) -> None:
self[name] = _dict2node(OTHER_FLAG, base, self.pop(name))
logger.ex(f"Set ModuleNode to self[{name}].")
- def _parse_isolated_registered_module(self, name) -> None:
+ def _parse_isolated_registered_module(self, name: str) -> None:
v = _dict2node(OTHER_FLAG, name, self.pop(name))
for i in v.values():
_, _ = Registry.find(i.name)
@@ -215,7 +217,7 @@ def _parse_isolated_registered_module(self, name) -> None:
raise CoreConfigParseError(f"Unregistered module `{i.name}`")
self[name] = v
- def _parse_implicit_module(self, name, module_type):
+ def _parse_implicit_module(self, name: str, module_type: NodeType) -> ModuleNode:
_, base = Registry.find(name)
if not base:
raise CoreConfigParseError(f"Unregistered module `{name}`")
@@ -227,7 +229,7 @@ def _parse_implicit_module(self, name, module_type):
self[name] = node
return node
- def _parse_isolated_module(self, name) -> bool:
+ def _parse_isolated_module(self, name: str) -> bool:
logger.ex(f"Not a registed field. Parse isolated module {name}.")
_, base = Registry.find(name)
if base:
@@ -236,7 +238,7 @@ def _parse_isolated_module(self, name) -> bool:
return True
return False
- def _parse_scratchpads(self, name) -> None:
+ def _parse_scratchpads(self, name: str) -> None:
logger.ex(f"Not a registed node. Regrad as scratchpads {name}.")
has_module = False
modules = self[name]
@@ -264,7 +266,7 @@ def _parse_isolated_obj(self) -> None:
elif not self._parse_isolated_module(name):
self._parse_scratchpads(name)
- def _contain_module(self, name) -> bool:
+ def _contain_module(self, name: str) -> bool:
is_contain = False
for k in self.all_fields:
if k not in self:
@@ -282,7 +284,9 @@ def _contain_module(self, name) -> bool:
self.current_field = k
return is_contain
- def _get_name_and_field(self, name, ori_name):
+ def _get_name_and_field(
+ self, name: str, ori_name: str
+ ) -> tuple[str, str | None] | tuple[list[str], str]:
if not name.startswith("$"):
return name, None
names = name[1:].split("::")
@@ -295,6 +299,10 @@ def _get_name_and_field(self, name, ori_name):
raise CoreConfigParseError(f"Cannot find field `{base}` with `{ori_name}`")
if len(spec_name) > 0:
if spec_name[0] == "*":
+ logger.warning(
+ f"`The results of {names} "
+ "depend on their definition in config files when using `*`."
+ )
return [k.name for k in modules.values()], base
if spec_name[0] not in modules:
raise CoreConfigParseError(
@@ -310,20 +318,22 @@ def _get_name_and_field(self, name, ori_name):
)
return list(modules.keys())[0], base
- def _apply_hooks(self, node, hooks, attrs):
+ def _apply_hooks(self, node: ModuleNode, hooks: list[str], attrs: list[str]) -> ConfigNode:
if attrs:
- node = ChainedInvocationWrapper(node, attrs)
+ node = ChainedInvocationWrapper(node, attrs) # type: ignore
if not hooks:
return node
for hook in hooks:
if hook not in self:
- raise CoreConfigParseError(f"Unregistered hook {hook}")
+ raise CoreConfigParseError(f"Unregistered hook `{hook}`")
node = self[hook](node=node)
return node
- def _convert_node(self, name, source, target_type) -> tuple[ModuleNode, type[ModuleNode]]:
+ def _convert_node(
+ self, name: str, source: ConfigDict, target_type: NodeType
+ ) -> tuple[ModuleNode, NodeType]:
ori_type = source[name].__class__
- logger.ex(f"Original_type is {ori_type}, target_type is {target_type}.")
+ logger.ex(f"Original_type is `{ori_type}`, target_type is `{target_type}`.")
node = source[name]
if target_type.priority != ori_type.priority or target_type is ClassNode:
node = target_type.from_node(source[name])
@@ -332,7 +342,15 @@ def _convert_node(self, name, source, target_type) -> tuple[ModuleNode, type[Mod
source[name] = node
return node, ori_type
- def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks):
+ def _parse_single_param(
+ self,
+ name: str,
+ ori_name: str,
+ field: str | None,
+ target_type: NodeType,
+ attrs: list[str],
+ hooks: list[str],
+ ) -> ConfigNode:
if name in self.all_fields:
raise CoreConfigParseError(
f"Conflict name: `{name}`, the class name cannot be same with field name"
@@ -358,26 +376,28 @@ def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks):
f"target_type is `{target_type}`, but got original_type `{ori_type}`. "
f"Please considering using `scratchpads` to avoid conflicts."
)
- node = self._apply_hooks(node, hooks, attrs)
+ node = self._apply_hooks(node, hooks, attrs) # type: ignore
return node
- def _parse_param(self, ori_name, module_type):
+ def _parse_param(
+ self, ori_name: str, module_type: SpecialFlag
+ ) -> ConfigNode | list[ConfigNode] | VariableReference:
logger.ex(f"Parse with {ori_name}, {module_type}")
if module_type == REFER_FLAG:
return VariableReference(ori_name)
target_type = _dispatch_module_node[module_type]
name, attrs, hooks = _parse_param_name(ori_name)
- name, field = self._get_name_and_field(name, ori_name)
- logger.ex(f"Get name:{name}, field:{field}, attrs:{attrs}, hooks:{hooks}.")
- if isinstance(name, list):
+ names, field = self._get_name_and_field(name, ori_name)
+ logger.ex(f"Get name:{names}, field:{field}, attrs:{attrs}, hooks:{hooks}.")
+ if isinstance(names, list):
logger.ex(f"Detect output type list {ori_name}.")
return [
self._parse_single_param(n, ori_name, field, target_type, attrs, hooks)
- for n in name
+ for n in names
]
- return self._parse_single_param(name, ori_name, field, target_type, attrs, hooks)
+ return self._parse_single_param(names, ori_name, field, target_type, attrs, hooks)
- def _parse_module(self, node: ModuleNode):
+ def _parse_module(self, node: ModuleNode) -> None:
logger.ex(f"Parse ModuleNode {node}.")
for param_name in list(node.keys()):
true_name, module_type = _is_special(param_name)
@@ -396,7 +416,6 @@ def _parse_module(self, node: ModuleNode):
elif isinstance(value, dict):
logger.ex(f"{param_name}: Dict parameter {value}.")
value = {k: self._parse_param(v, module_type) for k, v in value.items()}
- value = _flatten_dict(value)
is_dict = True
else:
raise CoreConfigParseError(f"Wrong type: {param_name, value}")
@@ -410,9 +429,10 @@ def _parse_module(self, node: ModuleNode):
else:
node[true_name] = ref_name
else:
+ # FIXME: Parsing inner VariableReference
node[true_name] = ModuleWrapper(value, is_dict)
- def _parse_inter_modules(self):
+ def _parse_inter_modules(self) -> None:
for name in list(self.keys()):
logger.ex(f"Parse inter module {name}")
module = self[name]
@@ -427,33 +447,32 @@ def _parse_inter_modules(self):
elif isinstance(module, ModuleNode):
self._parse_module(module)
- def __str__(self):
- _dict = {}
+ def __str__(self) -> str:
+ _dict: dict = {}
for k, v in self.items():
self._flatten(_dict, k, v)
return _create_table(
None,
[(k, v) for k, v in _dict.items()],
- False,
)
- def _flatten(self, _dict, k, v):
+ def _flatten(self, _dict: dict, k: str, v: dict) -> None:
if isinstance(v, dict) and not isinstance(v, ModuleNode):
for _k, _v in v.items():
_dict[".".join([k, _k])] = _v
else:
_dict[k] = v
- def dump(self, path: str):
+ def dump(self, path: str) -> None:
import toml
with open(path, "w", encoding="UTF-8") as f:
toml.dump(self, f)
-def set_primary_fields(cfg):
- primary_fields = cfg["primary_fields"]
- primary_to_registry = cfg["primary_to_registry"]
+def set_primary_fields(cfg) -> None:
+ primary_fields = cfg.primary_fields
+ primary_to_registry = cfg.primary_to_registry
if hasattr(ConfigDict, "primary_fields"):
logger.ex("`primary_fields` will be set to {}", primary_fields)
if primary_fields:
diff --git a/excore/engine/hook.py b/excore/engine/hook.py
index b147b1c..241ef7c 100644
--- a/excore/engine/hook.py
+++ b/excore/engine/hook.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from collections import defaultdict
from typing import Any, Callable, Protocol, Sequence, final
@@ -31,7 +33,7 @@ class Hook(Protocol):
class MetaHookManager(type):
- stages = None
+ stages: tuple[str, ...] = ()
"""
A metaclass that is used to validate the `stages` attribute of a `HookManager` subclass.
@@ -66,7 +68,7 @@ def __new__(cls, name, bases, attrs):
"""
inst = type.__new__(cls, name, bases, attrs)
stages = inst.stages
- if inst.__name__ != "HookManager" and stages is None:
+ if inst.__name__ != "HookManager" and not stages:
raise HookManagerBuildError(
f"The hook manager `{inst.__name__}` must have valid stages"
)
@@ -75,7 +77,7 @@ def __new__(cls, name, bases, attrs):
class HookManager(metaclass=MetaHookManager):
- stages = tuple()
+ stages: tuple[str, ...] = tuple()
"""
Manages a set of hooks that can be triggered by events that occur during program execution.
@@ -109,7 +111,7 @@ class HookManager(metaclass=MetaHookManager):
for different applications.
"""
- def __init__(self, hooks: Sequence[Hook]):
+ def __init__(self, hooks: Sequence[Hook]) -> None:
assert isinstance(hooks, Sequence)
__error_msg = "The hook `{}` must have a valid `{}`, got {}"
@@ -127,12 +129,12 @@ def __init__(self, hooks: Sequence[Hook]):
__error_msg.format(h.__class__.__name__, "__CallInter__", h.__CallInter__)
)
self.hooks = defaultdict(list)
- self.calls = defaultdict(int)
+ self.calls: dict[str, int] = defaultdict(int)
for h in hooks:
self.hooks[h.__HookType__].append(h)
@staticmethod
- def check_life_span(hook: Hook):
+ def check_life_span(hook: Hook) -> bool:
"""
Checks whether a given `Hook` object has exceeded its maximum lifespan.
@@ -142,7 +144,7 @@ def check_life_span(hook: Hook):
hook.__LifeSpan__ -= 1
return hook.__LifeSpan__ <= 0
- def exist(self, stage):
+ def exist(self, stage) -> bool:
"""
Determines whether any hooks are registered for a given event stage.
@@ -154,19 +156,19 @@ def exist(self, stage):
"""
return self.hooks[stage] != []
- def pre_call(self):
+ def pre_call(self) -> Any:
"""
Called before any hooks are executed during an event stage.
"""
return
- def after_call(self):
+ def after_call(self) -> Any:
"""
Called after all hooks have been executed during an event stage.
"""
return
- def __call__(self, stage, *inps):
+ def __call__(self, stage, *inps) -> None:
"""
Executes all hooks registered for a given event stage.
@@ -174,7 +176,7 @@ def __call__(self, stage, *inps):
stage (str): The name of the event stage to trigger.
*inps: Input arguments to pass to the hook functions.
"""
- dead_hook_idx = []
+ dead_hook_idx: list[int] = []
calls = self.calls[stage]
for idx, hook in enumerate(self.hooks[stage]):
if calls % hook.__CallInter__ == 0:
@@ -185,7 +187,7 @@ def __call__(self, stage, *inps):
self.hooks[stage].pop(idx)
self.calls[stage] = calls + 1
- def call_hooks(self, stage, *inps):
+ def call_hooks(self, stage, *inps) -> None:
"""
Convenience method for calling all hooks at a given event stage.
@@ -199,7 +201,7 @@ def call_hooks(self, stage, *inps):
class ConfigHookManager(HookManager):
- stages = ("pre_build", "every_build", "after_build")
+ stages: tuple[str, ...] = ("pre_build", "every_build", "after_build")
"""A subclass of HookManager that allows hooks to be registered and executed
at specific points in the build process.
@@ -215,17 +217,19 @@ def __init__(
self,
node: Callable,
enabled: bool = True,
- ):
+ ) -> None:
self.node = node
self.enabled = enabled
+ if not hasattr(node, "name"):
+ raise ValueError("The `node` must have name attribute.")
self.name = node.name
self._is_initialized = True
- def hook(self, **kwargs):
+ def hook(self, **kwargs: Any) -> Any:
raise NotImplementedError(f"`{self.__class__.__name__}` do not implement `hook` method.")
@final
- def __call__(self, **kwargs):
+ def __call__(self, **kwargs: Any) -> Any:
if not getattr(self, "_is_initialized", False):
raise CoreConfigSupportError(
f"Call super().__init__() in class `{self.__class__.__name__}`"
diff --git a/excore/engine/logging.py b/excore/engine/logging.py
index aa415d3..980294f 100644
--- a/excore/engine/logging.py
+++ b/excore/engine/logging.py
@@ -1,11 +1,29 @@
+from __future__ import annotations
+
import os
import sys
+from typing import TYPE_CHECKING
from loguru import logger as _logger
+if TYPE_CHECKING:
+ from typing import Any, Callable, TextIO
+
+ from loguru import FilterDict, FilterFunction, FormatFunction, Message, Record, Writable
+ from loguru._handler import Handler
+ from loguru._logger import Logger
+
+ class PatchedLogger(Logger):
+ def ex(self, __message: str, *args: Any, **kwargs: Any) -> None:
+ pass
+
+ def imp(self, __message: str, *args: Any, **kwargs: Any) -> None:
+ pass
+
+
__all__ = ["logger", "add_logger", "remove_logger", "debug_only", "log_to_file_only"]
-LOGGERS = {}
+LOGGERS: dict[str, int] = {}
FORMAT = (
"{time:YYYY-MM-DD HH:mm:ss} | "
@@ -22,21 +40,21 @@ def _trace_patcher(log_record):
log_record["function"] = "\b"
-logger = _logger.patch(_trace_patcher)
+logger: PatchedLogger = _logger.patch(_trace_patcher) # type: ignore
def add_logger(
- name,
- sink,
+ name: str,
+ sink: TextIO | Writable | Callable[[Message], None] | Handler,
*,
- level=None, # pylint: disable=unused-argument
- format=None, # pylint: disable=unused-argument
- filter=None, # pylint: disable=unused-argument
- colorize=None, # pylint: disable=unused-argument
- serialize=None, # pylint: disable=unused-argument
- backtrace=None, # pylint: disable=unused-argument
- diagnose=None, # pylint: disable=unused-argument
- enqueue=None, # pylint: disable=unused-argument
+ level: str | int | None = None, # pylint: disable=unused-argument
+ format: str | FormatFunction | None = None, # pylint: disable=unused-argument
+ filter: str | FilterFunction | FilterDict | None = None, # pylint: disable=unused-argument
+ colorize: bool | None = None, # pylint: disable=unused-argument
+ serialize: bool | None = None, # pylint: disable=unused-argument
+ backtrace: bool | None = None, # pylint: disable=unused-argument
+ diagnose: bool | None = None, # pylint: disable=unused-argument
+ enqueue: bool | None = None, # pylint: disable=unused-argument
) -> None:
params = {k: v for k, v in locals().items() if v is not None}
params.pop("sink")
@@ -54,14 +72,14 @@ def remove_logger(name: str) -> None:
logger.warning(f"Cannot find logger with name {name}")
-def log_to_file_only(file_name: str, *args, **kwargs) -> None:
+def log_to_file_only(file_name: str, *args: Any, **kwargs: Any) -> None:
logger.remove(None)
logger.add(file_name, *args, **kwargs)
logger.success(f"Log to file {file_name} only")
-def debug_only(*args, **kwargs) -> None:
- def _debug_only(record):
+def debug_only(*args: Any, **kwargs: Any) -> None:
+ def _debug_only(record: Record) -> bool:
return record["level"].name == "DEBUG"
filter = kwargs.pop("filter", None)
@@ -72,27 +90,27 @@ def _debug_only(record):
logger.debug("DEBUG ONLY!!!")
-def _call_importance(__message: str, *args, **kwargs):
+def _call_importance(__message: str, *args: Any, **kwargs: Any) -> None:
logger.log("IMPORT", __message, *args, **kwargs)
-def _excore_debug(__message: str, *args, **kwargs):
+def _excore_debug(__message: str, *args: Any, **kwargs: Any) -> None:
logger._log("EXCORE", False, logger._options, __message, args, kwargs)
-def _enable_excore_debug():
+def _enable_excore_debug() -> None:
if os.getenv("EXCORE_DEBUG"):
logger.remove()
logger.add(sys.stdout, format=FORMAT, level="EXCORE")
logger.ex("Enabled excore debug")
-def init_logger():
+def init_logger() -> None:
logger.remove()
logger.add(sys.stderr, format=FORMAT)
logger.level("SUCCESS", color="")
logger.level("WARNING", color="")
logger.level("IMPORT", no=45, color="")
logger.level("EXCORE", no=9, color="")
- logger.imp = _call_importance
- logger.ex = _excore_debug
+ logger.imp = _call_importance # type: ignore
+ logger.ex = _excore_debug # type: ignore
diff --git a/excore/engine/registry.py b/excore/engine/registry.py
index 00ccf33..a6e70c2 100644
--- a/excore/engine/registry.py
+++ b/excore/engine/registry.py
@@ -1,14 +1,18 @@
+from __future__ import annotations
+
import fnmatch
import functools
import inspect
import os
import re
import sys
-from types import ModuleType
-from typing import Any, Callable, Dict, List, Optional, Sequence, Union
+from types import FunctionType, ModuleType
+from typing import Any, Callable, Literal, Sequence, Type, overload
+
+from filelock import FileLock
-from .._constants import _cache_dir, _registry_cache_file, _workspace_config_file
-from .._misc import FileLock, _create_table
+from .._constants import _workspace_config_file, workspace
+from .._misc import _create_table
from .logging import logger
_name_re = re.compile(r"^[A-Za-z0-9_]+$")
@@ -16,19 +20,21 @@
__all__ = ["Registry"]
+_ClassType = Type[Any]
+
# TODO: Maybe some methods need to be cleared.
-def _is_pure_ascii(name: str):
+def _is_pure_ascii(name: str) -> None:
if not _name_re.match(name):
raise ValueError(
- f"""Unexpected name, only support ASCII letters, ASCII digits,
- underscores, and dashes, but got {name}"""
+ "Unexpected name, only support ASCII letters, ASCII digits, "
+ f"underscores, and dashes, but got {name}."
)
-def _is_function_or_class(module):
+def _is_function_or_class(module: Any) -> bool:
return inspect.isfunction(module) or inspect.isclass(module)
@@ -36,7 +42,7 @@ def _default_filter_func(values: Sequence[Any]) -> bool:
return all(v for v in values)
-def _default_match_func(m, base_module):
+def _default_match_func(m: str, base_module: ModuleType) -> bool:
if not m.startswith("__"):
m = getattr(base_module, m)
if inspect.isfunction(m) or inspect.isclass(m):
@@ -44,12 +50,12 @@ def _default_match_func(m, base_module):
return False
-def _get_module_name(m):
+def _get_module_name(m: ModuleType | _ClassType | FunctionType) -> str:
return getattr(m, "__qualname__", m.__name__)
class RegistryMeta(type):
- _registry_pool: Dict[str, "Registry"] = dict()
+ _registry_pool: dict[str, Registry] = {}
"""Metaclass that governs the creation of instances of its subclasses, which are
`Registry` objects.
@@ -65,7 +71,7 @@ class RegistryMeta(type):
message is logged indicating that the extra arguments will be ignored.
"""
- def __call__(cls, name, **kwargs) -> "Registry":
+ def __call__(cls, name: str, **kwargs: Any) -> Registry:
r"""Assert only call `__init__` once"""
_is_pure_ascii(name)
extra_field = kwargs.get("extra_field", None)
@@ -85,70 +91,68 @@ def __call__(cls, name, **kwargs) -> "Registry":
# Maybe someday we can get rid of Registry?
-class Registry(dict, metaclass=RegistryMeta):
- _globals: Optional["Registry"] = None
- _registry_dir = "registry"
+class Registry(dict, metaclass=RegistryMeta): # type: ignore
+ _globals: Registry | None = None
# just a workaround for twice registry
- _prevent_register = False
+ _prevent_register: bool = False
+
+ extra_info: dict[str, str]
+
"""A registry that stores functions and classes by name.
Attributes:
name (str): The name of the registry.
- extra_field (Optional[Union[str, Sequence[str]]]): A field or fields that can be
+ extra_field (str|Sequence[str]|None): A field or fields that can be
used to store additional information about each function or class in the
registry.
- extra_info (Dict[str, List[Any]]): A dictionary that maps each registered name
+ extra_info (dict[str, list[Any]]): A dictionary that maps each registered name
to a list of extra values associated with that name (if any).
- _globals (Optional[Registry]): A static variable that stores a global registry
+ _globals (Registry|None): A static variable that stores a global registry
containing all functions and classes registered using Registry.
"""
- def __init__(
- self, /, name: str, *, extra_field: Optional[Union[str, Sequence[str]]] = None
- ) -> None:
+ def __init__(self, /, name: str, *, extra_field: str | Sequence[str] | None = None) -> None:
super().__init__()
self.name = name
if extra_field:
self.extra_field = [extra_field] if isinstance(extra_field, str) else extra_field
- self.extra_info = dict()
+ self.extra_info = {}
@classmethod
- def dump(cls):
- file_path = os.path.join(_cache_dir, cls._registry_dir, _registry_cache_file)
- os.makedirs(os.path.join(_cache_dir, cls._registry_dir), exist_ok=True)
+ def dump(cls) -> None:
+ file_path = workspace.registry_cache_file
import pickle # pylint: disable=import-outside-toplevel
- with FileLock(file_path): # noqa: SIM117
- with open(file_path, "wb") as f:
- pickle.dump(cls._registry_pool, f)
+ with FileLock(file_path + ".lock", timeout=5), open(file_path, "wb") as f:
+ pickle.dump(cls._registry_pool, f)
+ logger.success(f"Dump registry cache to {workspace.registry_cache_file}!")
@classmethod
- def load(cls):
+ def load(cls) -> None:
if not os.path.exists(_workspace_config_file):
logger.warning("Please run `excore init` in your command line first!")
- sys.exit(0)
- file_path = os.path.join(_cache_dir, cls._registry_dir, _registry_cache_file)
+ sys.exit(1)
+ file_path = workspace.registry_cache_file
if not os.path.exists(file_path):
# shall we need to be silent? Or raise error?
logger.critical(
"Registry cache file do not exist!"
" Please run `excore auto-register in your command line first`"
)
- sys.exit(0)
+ sys.exit(1)
import pickle # pylint: disable=import-outside-toplevel
- with FileLock(file_path): # noqa: SIM117
- with open(file_path, "rb") as f:
- data = pickle.load(f)
+ with FileLock(file_path + ".lock"), open(file_path, "rb") as f:
+ data = pickle.load(f)
cls._registry_pool.update(data)
@classmethod
- def lock_register(cls):
+ def lock_register(cls) -> None:
cls._prevent_register = True
@classmethod
- def unlock_register(cls):
+ def unlock_register(cls) -> None:
cls._prevent_register = False
@classmethod
@@ -161,7 +165,7 @@ def get_registry(cls, name: str, default: Any = None) -> Any:
@classmethod
@functools.lru_cache(32)
- def find(cls, name: str) -> Any:
+ def find(cls, name: str) -> tuple[Any, str] | tuple[None, None]:
"""
Searches all registries for an element with the given name. If found,
returns a tuple containing the element and the name of the registry where it
@@ -173,7 +177,7 @@ def find(cls, name: str) -> Any:
return (None, None)
@classmethod
- def make_global(cls):
+ def make_global(cls) -> Registry:
"""
Creates a global `Registry` instance that contains all elements from all
other registries. If the global registry already exists, returns it instead
@@ -187,7 +191,7 @@ def make_global(cls):
cls._globals = reg
return reg
- def __setitem__(self, k, v) -> None:
+ def __setitem__(self, k: str, v: Any) -> None:
_is_pure_ascii(k)
super().__setitem__(k, v)
@@ -195,18 +199,47 @@ def __repr__(self) -> str:
return _create_table(
["NAEM", "DIR"],
[(k, v) for k, v in self.items()],
- False,
)
__str__ = __repr__
+ @overload
def register_module(
self,
- module: Union[Callable, ModuleType],
- force: bool = False,
- _is_str: bool = False,
+ module: Callable[..., Any],
+ force: bool = ...,
+ _is_str: bool = ...,
+ **extra_info: Any,
+ ) -> Callable[..., Any]:
+ pass
+
+ @overload
+ def register_module(
+ self,
+ module: ModuleType,
+ force: bool = ...,
+ _is_str: bool = ...,
+ **extra_info: Any,
+ ) -> ModuleType:
+ pass
+
+ @overload
+ def register_module(
+ self,
+ module: str,
+ force: bool = ...,
+ _is_str: Literal[True] = ...,
+ **extra_info: Any,
+ ) -> str:
+ pass
+
+ def register_module(
+ self,
+ module,
+ force=False,
+ _is_str=False,
**extra_info,
- ) -> Union[Callable, ModuleType]:
+ ):
if Registry._prevent_register:
logger.ex("Registry has been locked!!!")
return module
@@ -216,6 +249,7 @@ def register_module(
name = _get_module_name(module)
else:
name = module.split(".")[-1]
+
if not force and name in self:
raise ValueError(f"The name {name} exists")
@@ -238,6 +272,8 @@ def register_module(
)
else:
target = module
+
+ logger.ex(f"Register {name} with {target}.")
self[name] = target
# update to globals
@@ -246,7 +282,7 @@ def register_module(
return module
- def register(self, force: bool = False, **extra_info) -> Callable:
+ def register(self, force: bool = False, **extra_info: Any) -> Callable[..., Any]:
"""
Decorator that registers a function or class with the current `Registry`.
Any keyword arguments provided are added to the `extra_info` list for the
@@ -257,8 +293,8 @@ def register(self, force: bool = False, **extra_info) -> Callable:
def register_all(
self,
- modules: Sequence[Callable],
- extra_info: Optional[Sequence[Dict[str, Any]]] = None,
+ modules: Sequence[Callable[..., Any]],
+ extra_info: Sequence[dict[str, Any]] | None = None,
force: bool = False,
_is_str: bool = False,
) -> None:
@@ -274,14 +310,14 @@ def register_all(
def merge(
self,
- others: Union["Registry", List["Registry"]],
+ others: Registry | Sequence[Registry],
force: bool = False,
) -> None:
"""
Merge the contents of one or more other registries into the current one.
If `force` is True, overwrites any existing elements with the same names.
"""
- if not isinstance(others, list):
+ if not isinstance(others, (list, tuple, Sequence)):
others = [others]
for other in others:
if not isinstance(other, Registry):
@@ -291,9 +327,9 @@ def merge(
def filter(
self,
- filter_field: Union[Sequence[str], str],
- filter_func: Callable = _default_filter_func,
- ) -> List[str]:
+ filter_field: Sequence[str] | str,
+ filter_func: Callable[[Sequence[Any]], bool] = _default_filter_func,
+ ) -> list[str]:
"""
Returns a sorted list of all names in the registry for which the values of
the given extra field(s) pass a filtering function.
@@ -310,7 +346,12 @@ def filter(
out = list(sorted(out))
return out
- def match(self, base_module, match_func=_default_match_func, force=False):
+ def match(
+ self,
+ base_module: ModuleType,
+ match_func: Callable[[str, ModuleType], bool] = _default_match_func,
+ force: bool = False,
+ ) -> None:
"""
Registers all functions or classes from the given module that pass a matching
function. If `match_func` is not provided, uses `_default_match_func`.
@@ -328,11 +369,11 @@ def match(self, base_module, match_func=_default_match_func, force=False):
def module_table(
self,
- filter: Optional[Union[Sequence[str], str]] = None,
- select_info: Optional[Union[Sequence[str], str]] = None,
- module_list: Optional[Sequence[str]] = None,
- **table_kwargs,
- ) -> Any:
+ filter: Sequence[str] | str | None = None,
+ select_info: Sequence[str] | str | None = None,
+ module_list: Sequence[str] | None = None,
+ **table_kwargs: Any,
+ ) -> str:
"""
Returns a table containing information about each registered function or
class, filtered by name and/or extra info fields. `select_info` specifies
@@ -347,16 +388,16 @@ def module_table(
else:
select_info = []
- all_modules = module_list if module_list else self.keys()
+ all_modules = module_list if module_list else list(self.keys())
if filter:
- modules = set()
+ set_modules: set[str] = set()
filters = [filter] if isinstance(filter, str) else filter
for f in filters:
include_models = fnmatch.filter(all_modules, f)
if len(include_models):
- modules = modules.union(include_models)
+ modules = list(set_modules.union(include_models))
else:
- modules = all_modules
+ modules = all_modules # type: ignore
modules = list(sorted(modules))
@@ -371,14 +412,13 @@ def module_table(
table = _create_table(
table_headers,
[(i, *[self.extra_info[i][idx] for idx in select_idx]) for i in modules],
- False,
**table_kwargs,
)
table = "\n" + table
return table
@classmethod
- def registry_table(cls, **table_kwargs) -> Any:
+ def registry_table(cls, **table_kwargs) -> str:
"""
Returns a table containing the names of all available registries.
"""
@@ -386,15 +426,14 @@ def registry_table(cls, **table_kwargs) -> Any:
table = _create_table(
table_headers,
list(sorted([[i] for i in cls._registry_pool])),
- False,
**table_kwargs,
)
table = "\n" + table
return table
-def load_registries():
- if not os.path.exists(os.path.join(_cache_dir, Registry._registry_dir, _registry_cache_file)):
+def load_registries() -> None:
+ if not os.path.exists(workspace.registry_cache_file):
logger.warning("Please run `excore auto-register` in your command line first!")
return
Registry.load()
@@ -406,4 +445,4 @@ def load_registries():
"No module has been registered, \
you may need to call `excore.registry.auto_register` first"
)
- sys.exit(0)
+ sys.exit(1)
diff --git a/excore/plugins/hub.py b/excore/plugins/hub.py
index 20fe42b..2c7c9f3 100644
--- a/excore/plugins/hub.py
+++ b/excore/plugins/hub.py
@@ -21,7 +21,7 @@
import requests
from tqdm import tqdm
-from .._constants import __version__, _cache_dir
+from .._constants import __version__, workspace
from .._exceptions import (
GitCheckoutError,
GitPullError,
@@ -153,7 +153,7 @@ def fetch(
kwargs = {"stderr": subprocess.PIPE, "stdout": subprocess.PIPE} if silent else {}
if commit is None:
# shallow clone repo by branch/tag
- p = subprocess.Popen(
+ p = subprocess.Popen( # type: ignore
[
"git",
"clone",
@@ -169,14 +169,16 @@ def fetch(
else:
# clone repo and checkout to commit_id
p = subprocess.Popen( # pylint: disable=consider-using-with
- ["git", "clone", git_url, repo_dir], **kwargs
+ ["git", "clone", git_url, repo_dir],
+ **kwargs, # type: ignore
)
cls._check_clone_pipe(p)
with cd(repo_dir):
logger.debug("git checkout to {}", commit)
p = subprocess.Popen( # pylint: disable=consider-using-with
- ["git", "checkout", commit], **kwargs
+ ["git", "checkout", commit],
+ **kwargs, # type: ignore
)
_, err = p.communicate()
if p.returncode:
@@ -289,7 +291,7 @@ def _get_repo(
raise InvalidProtocol(
"Invalid protocol, the value should be one of {}.".format(", ".join(PROTOCOLS.keys()))
)
- cache_dir = os.path.expanduser(os.path.join(_cache_dir, "hub"))
+ cache_dir = os.path.expanduser(os.path.join(workspace.cache_dir, "hub"))
with cd(cache_dir):
fetcher = PROTOCOLS[protocol]
repo_dir = fetcher.fetch(git_host, repo_info, use_cache, commit)
@@ -303,14 +305,14 @@ def _check_dependencies(module: types.ModuleType) -> None:
dependencies = getattr(module, HUBDEPENDENCY)
if not dependencies:
return
- missing_deps = [m for m in dependencies if importlib.util.find_spec(m)]
+ missing_deps = [m for m in dependencies if importlib.util.find_spec(m)] # type: ignore
if len(missing_deps):
raise RuntimeError("Missing dependencies: {}".format(", ".join(missing_deps)))
def load_module(name: str, path: str) -> types.ModuleType:
- spec = importlib.util.spec_from_file_location(name, path)
- module = importlib.util.module_from_spec(spec)
+ spec = importlib.util.spec_from_file_location(name, path) # type: ignore
+ module = importlib.util.module_from_spec(spec) # type: ignore
spec.loader.exec_module(module)
return module
@@ -323,7 +325,7 @@ def _init_hub(
commit: Optional[str] = None,
protocol: str = DEFAULT_PROTOCOL,
):
- cache_dir = os.path.expanduser(os.path.join(_cache_dir, "hub"))
+ cache_dir = os.path.expanduser(os.path.join(workspace.cache_dir, "hub"))
os.makedirs(cache_dir, exist_ok=True)
absolute_repo_dir = _get_repo(
git_host, repo_info, use_cache=use_cache, commit=commit, protocol=protocol
@@ -412,7 +414,7 @@ def pretrained_model_func(pretrained=False, **kwargs):
digest = sha256.hexdigest()[:6]
filename = digest + "_" + filename
- cached_file = os.path.join(_cache_dir, filename)
+ cached_file = os.path.join(workspace.cache_dir, filename)
download_from_url(self.url, cached_file)
self.load_func(cached_file, model)
return model
diff --git a/excore/plugins/path_manager.py b/excore/plugins/path_manager.py
index d944a3e..8b9733f 100644
--- a/excore/plugins/path_manager.py
+++ b/excore/plugins/path_manager.py
@@ -1,8 +1,12 @@
+from __future__ import annotations
+
import shutil
import time
from dataclasses import is_dataclass
from pathlib import Path
-from typing import Callable, Dict, List, Optional, Protocol, Union
+from typing import Callable, Protocol
+
+from typing_extensions import Self
from ..engine.logging import logger
@@ -10,9 +14,9 @@
class DataclassProtocol(Protocol):
- __dataclass_fields__: Dict
- __dataclass_params__: Dict
- __post_init__: Optional[Callable]
+ __dataclass_fields__: dict
+ __dataclass_params__: dict
+ __post_init__: Callable | None
class PathManager:
@@ -59,9 +63,9 @@ def __init__(
self,
/,
base_path: str,
- sub_folders: Union[List[str], DataclassProtocol],
+ sub_folders: list[str] | DataclassProtocol,
config_name: str,
- instance_name: Optional[str] = None,
+ instance_name: str | None = None,
*,
remove_if_fail: bool = False,
sub_folder_exist_ok: bool = False,
@@ -76,9 +80,9 @@ def __init__(
self.sub_exist_ok = sub_folder_exist_ok
self.config_first = config_name_first
self.return_str = return_str
- self._info = {}
+ self._info: dict[str, Path] = {}
- def _get_sub_folders(self, sub_folders):
+ def _get_sub_folders(self, sub_folders) -> None:
if not isinstance(sub_folders, list):
if not is_dataclass(sub_folders):
raise TypeError("Only Support dataclass or list of str")
@@ -101,7 +105,7 @@ def mkdir(self) -> None:
sub.mkdir(parents=True, exist_ok=self.sub_exist_ok)
self._info[str(f)] = sub
- def get(self, name: str) -> Union[Path, str]:
+ def get(self, name: str) -> Path | str | None:
"""
Retrieve the path for a specific sub-folder by name.
"""
@@ -131,11 +135,11 @@ def remove_all(self) -> None:
logger.info(f"Remove sub_folders {f}")
shutil.rmtree(str(f))
- def __enter__(self):
+ def __enter__(self) -> Self:
self.init()
return self
- def __exit__(self, exc_type, exc_value, traceback):
+ def __exit__(self, exc_type, exc_value, traceback) -> bool:
if exc_type is not None:
self.final()
return False
diff --git a/pyproject.toml b/pyproject.toml
index 04282ff..d524b84 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[tool.poetry]
name = "excore"
-version = "0.1.1beta"
+version = "0.1.1beta1"
description = "Build your own development toolkit efficiently."
authors = ["Asthestarsfalll <1186454801@qq.com>"]
license = "MIT"
@@ -30,10 +30,10 @@ requests = "2.28.1"
toml = "0.10.2"
tabulate = "*"
tqdm = "*"
-pynvml = "*"
rich = "*"
astor = "*"
-typer = { extras = ["all"], version = "^0.9.0" }
+typer = "^0.13.1"
+filelock = "^3.15.4"
[tool.poetry.group.dev.dependencies]
black = "23.1.0"
@@ -43,6 +43,10 @@ pytest = "^7.4.4"
pytest-cov = "^4.1.0"
torch = { version = "1.13.1", source = "torch" }
torchvision = { version = "0.14.1", source = "torch" }
+mypy = "^1.13.0"
+types-toml = "^0.10.8.20240310"
+types-tqdm = "^4.67.0.20241119"
+types-tabulate = "^0.9.0.20240106"
[[tool.poetry.source]]
name = "torch"
diff --git a/tests/init.py b/tests/init.py
index 9c79a1b..1e09e31 100644
--- a/tests/init.py
+++ b/tests/init.py
@@ -1,3 +1,4 @@
+import os
import subprocess
import toml
@@ -18,6 +19,8 @@ def excute(command: str, inputs=None):
command.split(" "),
capture_output=True,
)
+ print(result.stdout)
+ print(result.stderr)
assert result.returncode == 0, result.stderr
@@ -40,6 +43,9 @@ def init():
toml.dump(cfg, f)
excute("excore update")
excute("excore auto-register")
+ import excore
+
+ assert os.path.exists(os.path.join(excore.workspace.cache_base_dir, "tests"))
if __name__ == "__main__":
diff --git a/tests/source_code/__init__.py b/tests/source_code/__init__.py
index 0cb1414..2ad986b 100644
--- a/tests/source_code/__init__.py
+++ b/tests/source_code/__init__.py
@@ -15,6 +15,15 @@
OPTIM = Registry("Optimizer")
TRANSFORM = Registry("Transform")
MODULE = Registry("module")
-
MODULE.register_module(time)
MODULE.register_module(torch)
+MODEL = Registry("Model")
+DATA = Registry("Data")
+BACKBONE = Registry("Backbone")
+HEAD = Registry("Head")
+HOOK = Registry("Hook")
+LOSS = Registry("Loss")
+LRSCHE = Registry("LRSche")
+OPTIMIZER = Registry("Optimizer")
+TRANSFORM = Registry("Transform")
+MODULE = Registry("module")
diff --git a/tests/test_a_registry.py b/tests/test_a_registry.py
index a4999ea..d2c6f61 100644
--- a/tests/test_a_registry.py
+++ b/tests/test_a_registry.py
@@ -2,52 +2,46 @@
import pytest
-from excore import Registry, load_registries
+from excore import Registry, load_registries, _enable_excore_debug
load_registries()
-Registry.unlock_register()
+_enable_excore_debug()
+import source_code as S
-def test_print():
- import source_code as S
+def test_print():
print(S.MODEL)
def test_find():
- import source_code as S
-
tar, base = S.BACKBONE.find("ResNet")
assert base == "Backbone"
assert tar.split(".")[-1] == "ResNet"
def test_register_module():
+ Registry.unlock_register()
reg = Registry("__test")
reg.register_module(time)
assert reg["time"] == "time"
def test_global():
- import source_code as S
-
g = Registry.make_global()
assert g["time"] == "time"
assert g["ResNet"] == "torchvision.models.resnet.ResNet"
with pytest.raises(ValueError):
S.MODEL.register_module(time)
assert id(S.MODEL) == id(Registry.get_registry("Model"))
+ Registry.lock_register()
def test_module_table():
- import source_code as S
-
print(S.MODEL.module_table())
print(S.MODEL.module_table("*resnet*"))
def test_id():
- import source_code as S
-
assert Registry.get_registry("Head") == S.HEAD
diff --git a/tests/test_config.py b/tests/test_config.py
index 569da58..86d1875 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -247,6 +247,3 @@ def test_dict_param(self):
TestClass,
VGG,
]
-
-
-TestConfig().test_scratchpads()
diff --git a/tests/test_config_extention.py b/tests/test_config_extention.py
index 3715f12..70b3ebd 100644
--- a/tests/test_config_extention.py
+++ b/tests/test_config_extention.py
@@ -1,4 +1,4 @@
-from typing import Dict, List, Optional, Tuple, Union
+from typing import Callable, Dict, List, Optional, Tuple, Union
from excore import Registry
from excore.config._json_schema import parse_registry
@@ -15,6 +15,9 @@ class Tmp2:
pass
+t = Tmp()
+
+
@R.register()
class A:
def __init__(
@@ -25,6 +28,7 @@ def __init__(
d: Dict,
d1: dict,
obj: Tmp,
+ c: Callable,
uni1: Union[int, float],
uni2: Union[int, str],
uni3: Union[Tmp, Tmp2],
@@ -40,6 +44,7 @@ def __init__(
default_d={}, # noqa: B006 # pylint: disable=W0102
default_tuple=(0, "", 0.0),
default_list=[0, 1], # noqa: B006
+ e=t,
*args,
**kwargs,
):
@@ -67,6 +72,8 @@ def test_type_parsing():
_assert(properties, "d", "object")
_assert(properties, "d1", "object")
_assert(properties, "obj", "string")
+ _assert(properties, "c", "string")
+ _assert(properties, "e", "number")
_assert(properties, "uni1", "")
_assert(properties, "uni2", "")
_assert(properties, "uni3", "")
diff --git a/tests/test_config_extention_310.py b/tests/test_config_extention_310.py
index c61f428..ee8640c 100644
--- a/tests/test_config_extention_310.py
+++ b/tests/test_config_extention_310.py
@@ -1,4 +1,5 @@
import sys
+from typing import Callable
import pytest
from test_config_extention import _assert
@@ -17,6 +18,8 @@ class Tmp2:
pass
+t = Tmp()
+
if sys.version_info >= (3, 10, 0):
@S.register()
@@ -28,6 +31,7 @@ def __init__(
f: float,
d: dict,
obj: Tmp,
+ c: Callable,
uni1: int | float,
uni2: int | str,
uni3: Tmp | Tmp2,
@@ -42,6 +46,9 @@ def __init__(
default_d={}, # noqa: B006 # pylint: disable=W0102
default_tuple=(0, "", 0.0),
default_list=[0, 1], # noqa: B006
+ e=t,
+ *args,
+ **kwargs,
):
pass
@@ -58,6 +65,8 @@ def test_type_parsing():
_assert(properties, "f", "number")
_assert(properties, "d", "object")
_assert(properties, "obj", "string")
+ _assert(properties, "c", "string")
+ _assert(properties, "e", "number")
_assert(properties, "uni1", "")
_assert(properties, "uni2", "")
_assert(properties, "uni3", "")
@@ -72,4 +81,6 @@ def test_type_parsing():
_assert(properties, "default_d", "object")
_assert(properties, "default_tuple", "array")
_assert(properties, "default_list", "array", "number")
+ _assert(properties, "args", "array")
+ _assert(properties, "kwargs", "object")
S.clear()
diff --git a/tests/test_z_cli.py b/tests/test_z_cli.py
index 5b8c525..6e21d43 100644
--- a/tests/test_z_cli.py
+++ b/tests/test_z_cli.py
@@ -8,6 +8,7 @@ def test_init_force():
def test_generate_registries():
+ # FIXME(typer): Argument
excute("excore generate-registries temp")
assert os.path.exists("./source_code/temp.py")
from source_code import temp # noqa: F401
@@ -26,6 +27,7 @@ def test_primary():
def test_typehints():
+ # FIXME(typer): Argument
excute(
"excore generate-typehints temp_typing --class-name "
"TypedWrapper --info-class-name Info --config ./configs/launch/test_optim.toml"
@@ -42,3 +44,5 @@ def test_quote():
excute("excore quote ./configs/lrsche")
assert os.path.exists("./configs/lrsche/lrsche_overrode.toml")
assert os.path.exists("./configs/lrsche/lrsche_error_overrode.toml")
+ os.remove("./configs/lrsche/lrsche_overrode.toml")
+ os.remove("./configs/lrsche/lrsche_error_overrode.toml")