diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 776cbe3..a6b9a79 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -34,8 +34,14 @@ jobs: - name: Test with pytest run: | cd ./tests + export EXCORE_DEBUG=1 poetry run python init.py poetry run pytest --cov=../excore + poetry run pytest test_config.py + poetry run pytest test_config.py + poetry run pytest test_config.py + poetry run pytest test_config.py + - name: Upload coverage reports to Codecov uses: codecov/codecov-action@v3 diff --git a/excore/config/lazy_config.py b/excore/config/lazy_config.py index 21148b3..2c3b465 100644 --- a/excore/config/lazy_config.py +++ b/excore/config/lazy_config.py @@ -18,7 +18,7 @@ def __init__(self, config: ConfigDict) -> None: config.registered_fields = list(Registry._registry_pool.keys()) config.all_fields = set([*config.registered_fields, *config.primary_fields]) self._config = deepcopy(config) - self._origin_config = deepcopy(config) + self._original_config = deepcopy(config) self.__is_parsed__ = False def parse(self): @@ -27,10 +27,11 @@ def parse(self): self._config.parse() logger.success("Config parsing cost {:.4f}s!", time.time() - st) self.__is_parsed__ = True + logger.ex(self._config) @property def config(self): - return self._origin_config + return self._original_config def update(self, cfg: "LazyConfig"): self._config.update(cfg._config) @@ -75,7 +76,7 @@ def build_all(self) -> Tuple[ModuleWrapper, Dict]: return module_dict, isolated_dict def dump(self, dump_path: str) -> None: - self._origin_config.dump(dump_path) + self._original_config.dump(dump_path) def __str__(self): return str(self._config) diff --git a/excore/config/model.py b/excore/config/model.py index 55b4751..bbcf2ff 100644 --- a/excore/config/model.py +++ b/excore/config/model.py @@ -74,6 +74,7 @@ def _str_to_target(module_name): class ModuleNode(dict): cls: Any _no_call: bool = field(default=False, repr=False) + priority: int = field(default=0, repr=False) def _get_params(self, **kwargs): params = {} @@ -153,6 +154,7 @@ def from_node(cls, _other: "ModuleNode") -> "ModuleNode": class InterNode(ModuleNode): + priority = 2 pass @@ -165,12 +167,16 @@ def __call__(self, **kwargs): class ReusedNode(InterNode): + priority: int = 3 + @CacheOut() def __call__(self, **kwargs): return super().__call__(**kwargs) class ClassNode(InterNode): + priority: int = 1 + def __call__(self): return self.cls diff --git a/excore/config/parse.py b/excore/config/parse.py index bef133c..bb1951f 100644 --- a/excore/config/parse.py +++ b/excore/config/parse.py @@ -1,4 +1,6 @@ -from typing import Dict, List, Optional, Set, Type +from __future__ import annotations + +from typing import Any, Generator from .._exceptions import CoreConfigParseError, ImplicitModuleParseError from .._misc import _create_table @@ -38,8 +40,8 @@ def _check_implicit_module(module: ModuleNode) -> None: ) -def _dict2node(module_type: str, base: str, _dict: Dict): - ModuleType: Type[ModuleNode] = _dispatch_module_node[module_type] +def _dict2node(module_type: str, base: str, _dict: dict): + ModuleType: type[ModuleNode] = _dispatch_module_node[module_type] return {name: ModuleType.from_base_name(base, name, v) for name, v in _dict.items()} @@ -71,12 +73,12 @@ def _flatten_dict(dic): class ConfigDict(dict): - primary_fields: List - primary_to_registry: Dict[str, str] - registered_fields: List - all_fields: Set[str] - scratchpads_fields: Set[str] = set() - current_field: Optional[str] = None + primary_fields: list + primary_to_registry: dict[str, str] + registered_fields: list + all_fields: set[str] + scratchpads_fields: set[str] = set() + current_field: str | None = None def __new__(cls): if not hasattr(cls, "primary_fields"): @@ -104,49 +106,108 @@ def set_primary_fields(cls, primary_fields, primary_to_registry): cls.primary_fields = primary_fields cls.primary_to_registry = primary_to_registry - def parse(self): + def parse(self) -> None: + """ + Parsing config into some `ModuleNode`s, the procedures are as following: + + 1. Convert config nodes in `primary_fields` to `ModuleNode` at first, + the field will be regarded as a set of nodes, e.g. `self[Model]` + consists of two models. + + i. If the field is registed, the base registry is set to field; + ii. If not, search `Registry` to get the base registry of each node. + + 2. Convert isolated objects to `ModuleNode`; + + i. If the field is registed, search `Registry` to get the base registry of each node + Raising error if it cannot find; + ii. If not, regard the field as a single node, and search `Registry`. + iii. If search fail, regard the field as a scratchpads. + + 3. Parse all the `ModuleNode` to target type of Node, e.g. `ReusedNode` and `ClassNode`. + + Visit the top level of config dict: + i. If the name is in `primary_fields` or `scratchpads_fields`, parse each node of + field; + ii. Else if the node is instance of `ModuleNode`, parse it to target node. + + 4. Wrap all the nodes of `primary_fields` into ModuleWrapper. + + 5. Clean some remain non-primary nodes. Only keep the primary nodes. + + In the 3rd step, it will parse every parameter of each module node + + if its name has a special prefix, e.g. `!`, `@` or `$`. + The special parameter should be a string, a list of string or a dict of string. + They will be parsed to target module nodes in given format(alone, list or dict). + + According to given string parameters, e.g. `['ResNet', 'SegHead']`, + + 1. It will firstly search from the top level of config. + 2. If it dose not exist, it will search from `primary_fields`, + `scratchpads_fields` and `registered_fields`. + 3. If it still dose not exist, it will be regraded as a implicit module, + which must have non-required parameters. + + For the first two situations, the node will be convert to target type of node, + then it will be set back to config for cache. For the last situation, + it will only be set back when target module type is `ReusedNode`. But if the + target module type is `ClassNode`, it will not be set back. + + NOTE: Set converted nodes back to config is necceary for `ReusedNode`. + + NOTE: use `export EXCORE_DEBUG=1` to enable excore debug to + get more information when parsing. + """ + logger.ex("Parse primary modules.") self._parse_primary_modules() + logger.ex("Parse isolated objects.") self._parse_isolated_obj() + logger.ex("Parse inter modules.") self._parse_inter_modules() self._wrap() self._clean() - def _wrap(self): + def _wrap(self) -> None: for name in self.primary_keys(): self[name] = ModuleWrapper(self[name]) - def _clean(self): + def _clean(self) -> None: for name in self.non_primary_keys(): if name in self.registered_fields or isinstance(self[name], ModuleNode): self.pop(name) - def primary_keys(self): + def primary_keys(self) -> Generator[Any, Any, None]: for name in self.primary_fields: if name in self: yield name - def non_primary_keys(self): + def non_primary_keys(self) -> Generator[Any, Any, None]: keys = list(self.keys()) for k in keys: if k not in self.primary_fields: yield k - def _parse_primary_modules(self): + def _parse_primary_modules(self) -> None: for name in self.primary_keys(): + logger.ex(f"parse field {name}.") if name in self.registered_fields: base = name + logger.ex(f"Find field registed. Base field is {base}.") else: reg = Registry.get_registry(self.primary_to_registry.get(name, "")) if reg is None: - raise CoreConfigParseError(f"Undefined registry `{name}`") + raise CoreConfigParseError(f"Undefined registry `{name}`.") for n in self[name]: if n not in reg: - raise CoreConfigParseError(f"Unregistered module `{n}`") + raise CoreConfigParseError(f"Unregistered module `{n}`.") base = reg.name + logger.ex(f"Search from Registry. Base field is {base}.") self[name] = _dict2node(OTHER_FLAG, base, self.pop(name)) + logger.ex(f"Set ModuleNode to self[{name}].") - def _parse_isolated_registered_module(self, name): + def _parse_isolated_registered_module(self, name) -> None: v = _dict2node(OTHER_FLAG, name, self.pop(name)) for i in v.values(): _, _ = Registry.find(i.name) @@ -158,20 +219,25 @@ def _parse_implicit_module(self, name, module_type): _, base = Registry.find(name) if not base: raise CoreConfigParseError(f"Unregistered module `{name}`") + logger.ex(f"Find base {base} with implicit module {name}.") node = module_type.from_base_name(base, name) _check_implicit_module(node) if module_type == ReusedNode: + logger.ex("Target type is ReusedNode, set node to top level.") self[name] = node return node - def _parse_isolated_module(self, name): + def _parse_isolated_module(self, name) -> bool: + logger.ex(f"Not a registed field. Parse isolated module {name}.") _, base = Registry.find(name) if base: + logger.ex("Find registed. Convert to `ModuleNode`.") self[name] = ModuleNode.from_base_name(base, name) << self[name] return True return False - def _parse_scratchpads(self, name): + def _parse_scratchpads(self, name) -> None: + logger.ex(f"Not a registed node. Regrad as scratchpads {name}.") has_module = False modules = self[name] for k, v in list(modules.items()): @@ -180,21 +246,25 @@ def _parse_scratchpads(self, name): _, base = Registry.find(k) if base: has_module = True + logger.ex("Find item registed. Convert to `ModuleNode`.") modules[k] = ModuleNode.from_base_name(base, k) << v if has_module: + logger.ex(f"Add `{name}` to scratchpads_fields.") self.scratchpads_fields.add(name) self.all_fields.add(name) - def _parse_isolated_obj(self): + def _parse_isolated_obj(self) -> None: for name in self.non_primary_keys(): + logger.ex(f"parse module {name}.") modules = self[name] if isinstance(modules, dict): if name in self.registered_fields: + logger.ex("Find module registed.") self._parse_isolated_registered_module(name) elif not self._parse_isolated_module(name): self._parse_scratchpads(name) - def _contain_module(self, name): + def _contain_module(self, name) -> bool: is_contain = False for k in self.all_fields: if k not in self: @@ -251,46 +321,56 @@ def _apply_hooks(self, node, hooks, attrs): node = self[hook](node=node) return node + def _convert_node(self, name, source, target_type) -> tuple[ModuleNode, type[ModuleNode]]: + ori_type = source[name].__class__ + logger.ex(f"Original_type is {ori_type}, target_type is {target_type}.") + node = source[name] + if target_type.priority != ori_type.priority or target_type is ClassNode: + node = target_type.from_node(source[name]) + self._parse_module(node) + if target_type.priority > ori_type.priority: + source[name] = node + return node, ori_type + def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks): if name in self.all_fields: raise CoreConfigParseError( f"Conflict name: `{name}`, the class name cannot be same with field name" ) ori_type = None + cache_field = self.current_field if not field and name in self: - ori_type = self[name].__class__ - if ori_type in (ModuleNode, ClassNode): - node = target_type.from_node(self[name]) - self._parse_module(node) - self[name] = node - node = self[name] + logger.ex("Find module in top level.") + node, ori_type = self._convert_node(name, self, target_type) elif field or self._contain_module(name): self.current_field = field or self.current_field - ori_type = self[self.current_field][name].__class__ - if ori_type in (ModuleNode, ClassNode): - node = target_type.from_node(self[self.current_field][name]) - base = self.current_field - self._parse_module(node) - self.current_field = base - self[self.current_field][name] = node - node = self[self.current_field][name] + logger.ex(f"Find module in second level, " f"current_field is {self.current_field}.") + node, ori_type = self._convert_node(name, self[self.current_field], target_type) else: + logger.ex("Implicit module.") node = self._parse_implicit_module(name, target_type) - if ori_type and ori_type not in (ModuleNode, ClassNode, target_type): + self.current_field = cache_field + + # InterNode and ReusedNode + if ori_type and ori_type.priority + target_type.priority == 5: raise CoreConfigParseError( f"Error when parsing param `{ori_name}`, " - f"target_type is `{target_type}`, but got `{ori_type}`" + f"target_type is `{target_type}`, but got original_type `{ori_type}`. " + f"Please considering using `scratchpads` to avoid conflicts." ) node = self._apply_hooks(node, hooks, attrs) return node def _parse_param(self, ori_name, module_type): + logger.ex(f"Parse with {ori_name}, {module_type}") if module_type == REFER_FLAG: return VariableReference(ori_name) target_type = _dispatch_module_node[module_type] name, attrs, hooks = _parse_param_name(ori_name) name, field = self._get_name_and_field(name, ori_name) + logger.ex(f"Get name:{name}, field:{field}, attrs:{attrs}, hooks:{hooks}.") if isinstance(name, list): + logger.ex(f"Detect output type list {ori_name}.") return [ self._parse_single_param(n, ori_name, field, target_type, attrs, hooks) for n in name @@ -298,18 +378,23 @@ def _parse_param(self, ori_name, module_type): return self._parse_single_param(name, ori_name, field, target_type, attrs, hooks) def _parse_module(self, node: ModuleNode): + logger.ex(f"Parse ModuleNode {node}.") for param_name in list(node.keys()): true_name, module_type = _is_special(param_name) if not module_type: + logger.ex(f"Skip parameter {param_name}.") continue value = node.pop(param_name) is_dict = False if isinstance(value, list): + logger.ex(f"{param_name}: List parameter {value}.") value = [self._parse_param(v, module_type) for v in value] value = _flatten_list(value) elif isinstance(value, str): + logger.ex(f"{param_name}: Single parameter {value}.") value = self._parse_param(value, module_type) elif isinstance(value, dict): + logger.ex(f"{param_name}: Dict parameter {value}.") value = {k: self._parse_param(v, module_type) for k, v in value.items()} value = _flatten_dict(value) is_dict = True @@ -317,6 +402,7 @@ def _parse_module(self, node: ModuleNode): raise CoreConfigParseError(f"Wrong type: {param_name, value}") if isinstance(value, VariableReference): ref_name = value() + logger.ex(f"Detect VariableReference, value is parsed to {ref_name}.") if not value.has_env: if ref_name not in self: raise CoreConfigParseError(f"Can not find reference: {ref_name}.") @@ -328,12 +414,14 @@ def _parse_module(self, node: ModuleNode): def _parse_inter_modules(self): for name in list(self.keys()): + logger.ex(f"Parse inter module {name}") module = self[name] if ( name in self.primary_fields or name in self.scratchpads_fields and isinstance(module, dict) ): + logger.ex(f"Parse Dict {name}") for m in module.values(): self._parse_module(m) elif isinstance(module, ModuleNode): diff --git a/excore/engine/logging.py b/excore/engine/logging.py index 71d4ebc..aa415d3 100644 --- a/excore/engine/logging.py +++ b/excore/engine/logging.py @@ -77,7 +77,7 @@ def _call_importance(__message: str, *args, **kwargs): def _excore_debug(__message: str, *args, **kwargs): - logger.log("EXCORE", __message, *args, **kwargs) + logger._log("EXCORE", False, logger._options, __message, args, kwargs) def _enable_excore_debug(): diff --git a/tests/test_a_registry.py b/tests/test_a_registry.py index 93e5da2..a4999ea 100644 --- a/tests/test_a_registry.py +++ b/tests/test_a_registry.py @@ -6,16 +6,18 @@ load_registries() -import source_code as S # noqa: E402 - Registry.unlock_register() def test_print(): + import source_code as S + print(S.MODEL) def test_find(): + import source_code as S + tar, base = S.BACKBONE.find("ResNet") assert base == "Backbone" assert tar.split(".")[-1] == "ResNet" @@ -28,6 +30,8 @@ def test_register_module(): def test_global(): + import source_code as S + g = Registry.make_global() assert g["time"] == "time" assert g["ResNet"] == "torchvision.models.resnet.ResNet" @@ -37,9 +41,13 @@ def test_global(): def test_module_table(): + import source_code as S + print(S.MODEL.module_table()) print(S.MODEL.module_table("*resnet*")) def test_id(): + import source_code as S + assert Registry.get_registry("Head") == S.HEAD diff --git a/tests/test_config.py b/tests/test_config.py index 328ca59..569da58 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,6 @@ import os +import random +from copy import deepcopy import pytest import torch @@ -13,11 +15,31 @@ ModuleBuildError, ) from excore.config.model import ModuleNode, ReusedNode +from excore.engine import logger + + +def shuffle_fields(): + random.shuffle(config.parse.ConfigDict.primary_fields) + + +def shuffle_dict(d): + n = deepcopy(d) + n.clear() + items = list(d.items()) + random.shuffle(items) + for k, v in items: + if isinstance(v, dict): + v = shuffle_dict(v) + n[k] = v + return n class TestConfig: def _load(self, path, check=True): - cfg = config.load(path) + shuffle_fields() + cfg = config.load(path, parse_config=False) + cfg._config = shuffle_dict(cfg._config) + logger.ex(cfg) modules, info = config.build_all(cfg) if check: self.check_info(info) @@ -189,9 +211,11 @@ def test_scratchpads(self): from source_code.dataset.data import MockData from source_code.models.nets import VGG, TestClass - assert modules.Model.cls == [VGG, MockData] + assert modules.Model.cls == [VGG, MockData] or modules.Model.cls == [MockData, VGG] assert modules.Model.cls1 == [VGG, MockData] - assert modules.DataModule.train == [VGG, MockData, TestClass] + assert VGG in modules.DataModule.train + assert MockData in modules.DataModule.train + assert TestClass in modules.DataModule.train assert modules.DataModule.val == VGG def test_get_error(self): @@ -219,4 +243,10 @@ def test_dict_param(self): assert modules.Model.cls["b"] == VGG assert modules.Model.cls1["a"] == VGG - assert modules.Model.cls1["b"] == [VGG, TestClass] + assert modules.Model.cls1["b"] == [VGG, TestClass] or modules.Model.cls1["b"] == [ + TestClass, + VGG, + ] + + +TestConfig().test_scratchpads() diff --git a/tests/test_config_extention.py b/tests/test_config_extention.py index d23e7fc..3715f12 100644 --- a/tests/test_config_extention.py +++ b/tests/test_config_extention.py @@ -4,6 +4,7 @@ from excore.config._json_schema import parse_registry R = Registry("R") +R.unlock_register() class Tmp: