diff --git a/excore/cli/_registry.py b/excore/cli/_registry.py index 7c458ea..6f317ce 100644 --- a/excore/cli/_registry.py +++ b/excore/cli/_registry.py @@ -188,7 +188,8 @@ def auto_register(): Automatically import all modules in `src_dir` and register all modules, then dump to files. """ if not os.path.exists(_workspace_config_file): - logger.warning("Please run `excore init` in your command line first!") + logger.critical("Please run `excore init` in your command line first!") + sys.exit(0) target_dir = osp.abspath(_workspace_cfg["src_dir"]) module_name = _get_default_module_name(target_dir) sys.path.append(os.getcwd()) diff --git a/excore/config/lazy_config.py b/excore/config/lazy_config.py index 8b97343..8942a20 100644 --- a/excore/config/lazy_config.py +++ b/excore/config/lazy_config.py @@ -17,6 +17,7 @@ def __init__(self, config: ConfigDict) -> None: self.modules_dict, self.isolated_dict = {}, {} self.target_modules = config.target_fields config.registered_fields = list(Registry._registry_pool.keys()) + config.all_fields = set([*config.registered_fields, *config.target_fields]) self._config = deepcopy(config) self.build_config_hooks() self._config.parse() diff --git a/excore/config/parse.py b/excore/config/parse.py index dfc604c..50cf732 100644 --- a/excore/config/parse.py +++ b/excore/config/parse.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Type +from typing import Dict, List, Set, Type from .._exceptions import CoreConfigParseError from ..engine import Registry, logger @@ -35,6 +35,7 @@ class ConfigDict(dict): target_fields: List target_to_registry: Dict[str, str] registered_fields: List + all_fields: Set[str] def __new__(cls): if not hasattr(cls, "target_fields"): @@ -138,8 +139,7 @@ def _parse_isolated_obj(self): self._parse_isolated_module(name, ModuleNode) def _contain_module(self, name): - fileds = set([*self.target_fields, *self.registered_fields]) - for k in fileds: + for k in self.all_fields: if k not in self: continue for node in self[k].values(): @@ -175,23 +175,33 @@ def _parse_single_param(self, ori_name, module_type): target_type = _dispatch_module_node[module_type] name, attrs, hooks = _parse_param_name(ori_name) name = self._get_name(name, ori_name) + if name in self.all_fields: + raise CoreConfigParseError( + f"Conflict name: `{name}`, the class name cannot be same with registry" + ) ori_type = None if name in self: ori_type = self[name].__class__ if ori_type == ModuleNode: - self[name] = target_type.from_node(self[name]) + node = target_type.from_node(self[name]) + self._parse_module(node) + self[name] = node node = self[name] elif self._contain_module(name): ori_type = self[self.__base__][name].__class__ if ori_type == ModuleNode: - self[self.__base__][name] = target_type.from_node(self[self.__base__][name]) + node = target_type.from_node(self[self.__base__][name]) + base = self.__base__ + self._parse_module(node) + self.__base__ = base + self[self.__base__][name] = node node = self[self.__base__][name] else: node = self._parse_implicit_module(name, target_type) if ori_type and ori_type not in (ModuleNode, target_type): raise CoreConfigParseError( - f"Error when parsing params {ori_name}, \ - target_type is {target_type}, but got {ori_type}" + f"Error when parsing param {ori_name}, " + f"target_type is {target_type}, but got {ori_type}" ) name = node.name # for ModuleWrapper if attrs: @@ -204,7 +214,7 @@ def _parse_single_param(self, ori_name, module_type): delattr(self, "__base__") return node - def _parse_modules(self, node: ModuleNode): + def _parse_module(self, node: ModuleNode): for param_name in list(node.keys()): true_name, module_type = _is_special(param_name) if not module_type: @@ -229,9 +239,9 @@ def _parse_inter_modules(self): module = self[name] if name in self.target_fields and isinstance(module, dict): for m in module.values(): - self._parse_modules(m) + self._parse_module(m) elif isinstance(module, ModuleNode): - self._parse_modules(module) + self._parse_module(module) def __str__(self): _dict = {} diff --git a/tests/configs/launch/test_class.toml b/tests/configs/launch/test_class.toml new file mode 100644 index 0000000..68291a7 --- /dev/null +++ b/tests/configs/launch/test_class.toml @@ -0,0 +1,4 @@ +[Model.TestClass] +$cls = "VGG" + +[Backbone.VGG] diff --git a/tests/configs/launch/test_conflict_name.toml b/tests/configs/launch/test_conflict_name.toml new file mode 100644 index 0000000..f1d0da2 --- /dev/null +++ b/tests/configs/launch/test_conflict_name.toml @@ -0,0 +1,2 @@ +[Model.TestClass] +$cls = "Backbone" diff --git a/tests/configs/launch/test_hidden_error.toml b/tests/configs/launch/test_hidden_error.toml new file mode 100644 index 0000000..e8f7384 --- /dev/null +++ b/tests/configs/launch/test_hidden_error.toml @@ -0,0 +1,7 @@ +[Model.FCN] +!classifier = "$Head" +@backbone = "$resnnnet" + +[Head.FCNHead] +in_channels = 512 +channels = 10 diff --git a/tests/configs/launch/test_module.toml b/tests/configs/launch/test_module.toml new file mode 100644 index 0000000..c005f43 --- /dev/null +++ b/tests/configs/launch/test_module.toml @@ -0,0 +1,2 @@ +[Model.TestClass] +$cls = "torch.nn.ReLU" diff --git a/tests/configs/launch/test_nest.toml b/tests/configs/launch/test_nest.toml new file mode 100644 index 0000000..8b76ea8 --- /dev/null +++ b/tests/configs/launch/test_nest.toml @@ -0,0 +1,12 @@ +[Model.FCN] +!classifier = "$Head" +@backbone = "$Backbone" + +[Backbone.MockModel] +!block = "TestBlock" + +[Block.TestBlock] + +[Head.FCNHead] +in_channels = 512 +channels = 10 diff --git a/tests/configs/launch/test_nest_hidden.toml b/tests/configs/launch/test_nest_hidden.toml new file mode 100644 index 0000000..2b4858b --- /dev/null +++ b/tests/configs/launch/test_nest_hidden.toml @@ -0,0 +1,11 @@ +[Model.FCN] +!classifier = "$Head" +@backbone = "$Backbone" + +[Backbone.MockModel] +!block = "TestBlock" + + +[Head.FCNHead] +in_channels = 512 +channels = 10 diff --git a/tests/configs/launch/test_ref_field_error.toml b/tests/configs/launch/test_ref_field_error.toml new file mode 100644 index 0000000..c142bfd --- /dev/null +++ b/tests/configs/launch/test_ref_field_error.toml @@ -0,0 +1,7 @@ +[Model.FCN] +!classifier = "$Headdd" +@backbone = "$resnet18" + +[Head.FCNHead] +in_channels = 512 +channels = 10 diff --git a/tests/configs/launch/test_regitered_error.toml b/tests/configs/launch/test_regitered_error.toml new file mode 100644 index 0000000..2be40b3 --- /dev/null +++ b/tests/configs/launch/test_regitered_error.toml @@ -0,0 +1 @@ +[Model.ResNt] diff --git a/tests/source_code/__init__.py b/tests/source_code/__init__.py index 113ef05..0cb1414 100644 --- a/tests/source_code/__init__.py +++ b/tests/source_code/__init__.py @@ -1,5 +1,7 @@ import time +import torch + from excore import Registry MODEL = Registry("Model") @@ -15,3 +17,4 @@ MODULE = Registry("module") MODULE.register_module(time) +MODULE.register_module(torch) diff --git a/tests/source_code/models/nets.py b/tests/source_code/models/nets.py index e234aef..85fc6ad 100644 --- a/tests/source_code/models/nets.py +++ b/tests/source_code/models/nets.py @@ -1,8 +1,10 @@ from torchvision import models from torchvision.models import segmentation +from excore import Registry from source_code import BACKBONE, HEAD, MODEL +BLOCK = Registry("Block") MODEL.match(segmentation) @@ -14,3 +16,20 @@ def _match(name: str, module): BACKBONE.match(models, force=True, match_func=_match) HEAD.register_module(segmentation.fcn.FCNHead) + + +@MODEL.register() +class TestClass: + def __init__(self, cls): + self.cls = cls + + +@BACKBONE.register() +class MockModel: + def __init__(self, block): + self.block = block + + +@BLOCK.register() +class TestBlock: + pass diff --git a/tests/test_config.py b/tests/test_config.py index e3d7806..0685b2a 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,6 +1,6 @@ import pytest import torch -from torchvision.models import ResNet +from torchvision.models import VGG, ResNet from excore import config from excore._exceptions import CoreConfigParseError, CoreConfigSupportError, ModuleBuildError @@ -62,3 +62,33 @@ def test_argument_error(self): def test_wrong_type(self): with pytest.raises(CoreConfigSupportError): self._load("./configs/launch/error.yaml") + + def test_class(self): + modules, _ = self._load("./configs/launch/test_class.toml", False) + assert modules.Model.cls == VGG + + def test_module(self): + modules, _ = self._load("./configs/launch/test_module.toml", False) + assert modules.Model.cls == torch.nn.ReLU + + def test_regitered_error(self): + with pytest.raises(ModuleBuildError): + self._load("./configs/launch/test_regitered_error.toml", False) + + def test_hidden_error(self): + with pytest.raises(CoreConfigParseError): + self._load("./configs/launch/test_hidden_error.toml", False) + + def test_ref_field_error(self): + with pytest.raises(CoreConfigParseError): + self._load("./configs/launch/test_ref_field_error.toml", False) + + def test_nest(self): + self._load("./configs/launch/test_nest.toml", False) + + def test_nest_hidden(self): + self._load("./configs/launch/test_nest_hidden.toml", False) + + def test_conflict_name_error(self): + with pytest.raises(CoreConfigParseError): + self._load("./configs/launch/test_conflict_name.toml", False)