diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 8df5ae6..776cbe3 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -9,16 +9,19 @@ on: jobs: test: runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9, '3.10'] steps: - - uses: actions/checkout@v1 + - uses: actions/checkout@v2 with: fetch-depth: 1 - - name: Set up Python 3.8 - uses: actions/setup-python@v1 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 with: - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Install Poetry uses: snok/install-poetry@v1 diff --git a/example/init.py b/example/init.py index 7f328b9..3818dfa 100644 --- a/example/init.py +++ b/example/init.py @@ -41,6 +41,7 @@ def init(): toml.dump(cfg, f) excute("excore update") excute("excore auto-register") + excute("excore config-extention") excute("excore generate-typehints temp_typing --config ./configs/run.toml") diff --git a/excore/config/_json_schema.py b/excore/config/_json_schema.py index 9389e06..5408b85 100644 --- a/excore/config/_json_schema.py +++ b/excore/config/_json_schema.py @@ -1,11 +1,13 @@ +from __future__ import annotations + import inspect import json import os import os.path as osp -from collections.abc import Sequence -from inspect import Parameter, _empty, isclass +import sys +from inspect import Parameter, _empty, _ParameterKind, isclass from types import ModuleType -from typing import Dict, Optional, Union, _GenericAlias +from typing import Any, Dict, Sequence, Union, get_args, get_origin import toml @@ -15,51 +17,45 @@ from ..engine.registry import Registry, load_registries from .model import _str_to_target -NoneType = type(None) +if sys.version_info >= (3, 10, 0): + from types import NoneType, UnionType +else: + NoneType = type(None) + + # just a placeholder + class UnionType: + pass -TYPE_MAPPER = { + +TYPE_MAPPER: dict[type, str] = { int: "number", # sometimes default value are not accurate str: "string", float: "number", list: "array", tuple: "array", dict: "object", + Dict: "object", bool: "boolean", } -SPECIAL_KEYS = {"kwargs": "object", "args": "array"} - -def _get_type(t): - if isinstance(t, _GenericAlias): - if isinstance(t.__args__[0], _GenericAlias): - return None - return TYPE_MAPPER.get(t) - potential_type = TYPE_MAPPER.get(t) - if potential_type is None: - return "string" - return potential_type - - -def _init_json_schema(settings: Optional[Dict]) -> Dict: +def _init_json_schema(settings: dict | None) -> dict[str, Any]: default_schema = { - "$schema": "https://json-schema.org/draft/2020-12/schema", - "$id": "https://example.com/product.schema.json", "title": "ExCore", "description": "Uesd for ExCore config file completion", "type": "object", "properties": {}, } default_schema.update(settings or {}) - assert len(default_schema) == 6 + assert len(default_schema) == 4 return default_schema def _generate_json_schema_and_class_mapping( - fields: Dict, - save_path: Optional[str] = None, - class_mapping_save_path: Optional[str] = None, - schema_settings: Optional[Dict] = None, + fields: dict, + save_path: str | None = None, + class_mapping_save_path: str | None = None, + schema_settings: dict | None = None, ) -> None: load_registries() schema = _init_json_schema(schema_settings) @@ -90,7 +86,7 @@ def _generate_json_schema_and_class_mapping( logger.success("class mapping has been written to {}", class_mapping_save_path) -def _check(bases): +def _check(bases) -> bool: for b in bases: if b is object: return False @@ -99,7 +95,7 @@ def _check(bases): return False -def parse_registry(reg: Registry): +def parse_registry(reg: Registry) -> tuple[dict, dict[str, list[str | int]]]: props = { "type": "object", "properties": {}, @@ -143,63 +139,73 @@ def parse_registry(reg: Registry): return props, class_mapping -def _clean(anno): - if not hasattr(anno, "__origin__"): +def _remove_optional(anno): + origin = get_origin(anno) + inner_types = get_args(anno) + if origin is not Union and len(inner_types) != 2: return anno - # Optional - if anno.__origin__ is type or (anno.__origin__ is Union and anno.__args__[1] is NoneType): - return _clean(anno.__args__[0]) + filter_types = [i for i in inner_types if i is not NoneType] + if len(filter_types) == 1: + return _remove_optional(filter_types[0]) return anno -def _parse_inner_types(prop, args): - inner_types = [_get_type(i) for i in args] +def _parse_inner_types(prop: dict, inner_types: Sequence[type]) -> None: first_type = inner_types[0] is_all_the_same = True for t in inner_types: is_all_the_same &= t == first_type - if is_all_the_same: - prop["items"] = {"type": first_type} + if is_all_the_same and first_type in TYPE_MAPPER: + prop["items"] = {"type": TYPE_MAPPER.get(first_type)} -def _parse_generic_alias(prop, anno) -> Optional[str]: - potential_type = None - if anno.__origin__ in (Sequence, list, tuple): +def _parse_typehint(prop: dict, anno: type) -> str | None: + potential_type = TYPE_MAPPER.get(anno) + if potential_type is not None: + return potential_type + origin = get_origin(anno) + inner_types = get_args(anno) + if origin in (Sequence, list, tuple): potential_type = "array" - # Do not support like `List[ResNet]`. - _parse_inner_types(prop, anno.__args__) - elif anno.__origin__ == Union: - potential_type = None - elif anno == Dict: - return "object" - elif len(anno.__args__) > 0: - potential_type = _get_type(anno.__args__[0]) - return potential_type - - -def parse_single_param(param: Parameter): + _parse_inner_types(prop, inner_types) + elif origin in (Union, UnionType) and len(inner_types) == 2: + filter_types = [i for i in inner_types if i is not NoneType] + if len(filter_types) == 1: + return _parse_typehint(prop, filter_types[0]) + return None + elif origin in (Union, UnionType): + return None + return potential_type or "string" + + +def parse_single_param(param: Parameter) -> tuple[bool, dict[str, Any]]: prop = {} anno = param.annotation potential_type = None - anno = _clean(anno) + anno = _remove_optional(anno) # hardcore for torch.optim if param.default.__class__.__name__ == "_RequiredParameter": param._default = _empty - if isinstance(anno, _GenericAlias): - potential_type = _parse_generic_alias(prop, anno) + if isinstance(anno, str): + raise RuntimeError( + "Use a higher version of python, e.g. 3.10, " + "and remove `from __future__ import annotations`." + ) elif anno is not _empty: - potential_type = _get_type(anno) + potential_type = _parse_typehint(prop, anno) # determine type by default value elif param.default is not _empty: - potential_type = _get_type(type(param.default)) + potential_type = TYPE_MAPPER[type(param.default)] if isinstance(param.default, (list, tuple)): types = [type(t) for t in param.default] _parse_inner_types(prop, types) - if param.name in SPECIAL_KEYS: - return False, SPECIAL_KEYS[param.name] + elif param._kind is _ParameterKind.VAR_POSITIONAL: + return False, {"type": "array"} + elif param._kind is _ParameterKind.VAR_KEYWORD: + return False, {"type": "object"} if anno is _empty and param.default is _empty: potential_type = "number" if potential_type: @@ -207,15 +213,15 @@ def parse_single_param(param: Parameter): return param.default is _empty, prop -def _json_schema_path(): +def _json_schema_path() -> str: return os.path.join(_cache_dir, _json_schema_file) -def _class_mapping_path(): +def _class_mapping_path() -> str: return os.path.join(_cache_dir, _class_mapping_file) -def _generate_taplo_config(path): +def _generate_taplo_config(path: str) -> None: cfg = dict( schema=dict( path=osp.join(osp.expanduser(path), _json_schema_file), diff --git a/tests/test_config_extention.py b/tests/test_config_extention.py index 041e3f1..d23e7fc 100644 --- a/tests/test_config_extention.py +++ b/tests/test_config_extention.py @@ -29,14 +29,18 @@ def __init__( uni3: Union[Tmp, Tmp2], tup: Tuple[int, str, float], lis: List[str], + test1: Union[str, None], + test2: Union[str, None, int], op: Optional[str] = None, op1: Optional[Union[str, int]] = None, default_i=1, default_s="", default_f=0.0, - default_d={}, # noqa: B006 + default_d={}, # noqa: B006 # pylint: disable=W0102 default_tuple=(0, "", 0.0), default_list=[0, 1], # noqa: B006 + *args, + **kwargs, ): pass @@ -51,27 +55,32 @@ def _assert(p: dict, name: str, t: str, item=None): def test_type_parsing(): - property, _ = parse_registry(R) - property = property["properties"]["A"]["properties"] - for k, v in property.items(): + properties, _ = parse_registry(R) + properties = properties["properties"]["A"]["properties"] + for k, v in properties.items(): print(k, v) - _assert(property, "i", "number") - _assert(property, "s", "string") - _assert(property, "f", "number") - _assert(property, "d", "object") - _assert(property, "d1", "object") - _assert(property, "obj", "string") - _assert(property, "uni1", "") - _assert(property, "uni2", "") - _assert(property, "uni3", "") - _assert(property, "tup", "array") - _assert(property, "lis", "array", "string") - _assert(property, "op", "string") - _assert(property, "op1", "") - _assert(property, "default_i", "number") - _assert(property, "default_s", "string") - _assert(property, "default_f", "number") - _assert(property, "default_d", "object") - _assert(property, "default_tuple", "array") - _assert(property, "default_list", "array", "number") + _assert(properties, "i", "number") + _assert(properties, "s", "string") + _assert(properties, "f", "number") + _assert(properties, "d", "object") + _assert(properties, "d1", "object") + _assert(properties, "obj", "string") + _assert(properties, "uni1", "") + _assert(properties, "uni2", "") + _assert(properties, "uni3", "") + _assert(properties, "tup", "array") + _assert(properties, "lis", "array", "string") + _assert(properties, "test1", "string") + _assert(properties, "test2", "") + _assert(properties, "op", "string") + _assert(properties, "op1", "") + _assert(properties, "default_i", "number") + _assert(properties, "default_s", "string") + _assert(properties, "default_f", "number") + _assert(properties, "default_d", "object") + _assert(properties, "default_tuple", "array") + _assert(properties, "default_list", "array", "number") + _assert(properties, "args", "array") + _assert(properties, "kwargs", "object") + R.clear() diff --git a/tests/test_config_extention_310.py b/tests/test_config_extention_310.py new file mode 100644 index 0000000..c61f428 --- /dev/null +++ b/tests/test_config_extention_310.py @@ -0,0 +1,75 @@ +import sys + +import pytest +from test_config_extention import _assert + +from excore import Registry +from excore.config._json_schema import parse_registry + +S = Registry("S") + + +class Tmp: + pass + + +class Tmp2: + pass + + +if sys.version_info >= (3, 10, 0): + + @S.register() + class B: + def __init__( + self, + i: int, + s: str, + f: float, + d: dict, + obj: Tmp, + uni1: int | float, + uni2: int | str, + uni3: Tmp | Tmp2, + tup: tuple[int, str, float], + lis: list[str], + test1: str | None | int, + op: str | None = None, + op1: str | int | None = None, + default_i=1, + default_s="", + default_f=0.0, + default_d={}, # noqa: B006 # pylint: disable=W0102 + default_tuple=(0, "", 0.0), + default_list=[0, 1], # noqa: B006 + ): + pass + + +@pytest.mark.skipif(sys.version_info < (3, 10), reason="Python version >= 3.10 required") +def test_type_parsing(): + properties, _ = parse_registry(S) + properties = properties["properties"]["B"]["properties"] + for k, v in properties.items(): + print(k, v) + + _assert(properties, "i", "number") + _assert(properties, "s", "string") + _assert(properties, "f", "number") + _assert(properties, "d", "object") + _assert(properties, "obj", "string") + _assert(properties, "uni1", "") + _assert(properties, "uni2", "") + _assert(properties, "uni3", "") + _assert(properties, "tup", "array") + _assert(properties, "lis", "array", "string") + _assert(properties, "test1", "") + _assert(properties, "op", "string") + _assert(properties, "op1", "") + _assert(properties, "default_i", "number") + _assert(properties, "default_s", "string") + _assert(properties, "default_f", "number") + _assert(properties, "default_d", "object") + _assert(properties, "default_tuple", "array") + _assert(properties, "default_list", "array", "number") + S.clear()