Skip to content

Commit

Permalink
✅ Add more tests of config
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Feb 10, 2024
1 parent 83111cd commit 7d65b32
Show file tree
Hide file tree
Showing 14 changed files with 122 additions and 12 deletions.
3 changes: 2 additions & 1 deletion excore/cli/_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,8 @@ def auto_register():
Automatically import all modules in `src_dir` and register all modules, then dump to files.
"""
if not os.path.exists(_workspace_config_file):
logger.warning("Please run `excore init` in your command line first!")
logger.critical("Please run `excore init` in your command line first!")
sys.exit(0)
target_dir = osp.abspath(_workspace_cfg["src_dir"])
module_name = _get_default_module_name(target_dir)
sys.path.append(os.getcwd())
Expand Down
1 change: 1 addition & 0 deletions excore/config/lazy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, config: ConfigDict) -> None:
self.modules_dict, self.isolated_dict = {}, {}
self.target_modules = config.target_fields
config.registered_fields = list(Registry._registry_pool.keys())
config.all_fields = set([*config.registered_fields, *config.target_fields])
self._config = deepcopy(config)
self.build_config_hooks()
self._config.parse()
Expand Down
30 changes: 20 additions & 10 deletions excore/config/parse.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Type
from typing import Dict, List, Set, Type

from .._exceptions import CoreConfigParseError
from ..engine import Registry, logger
Expand Down Expand Up @@ -35,6 +35,7 @@ class ConfigDict(dict):
target_fields: List
target_to_registry: Dict[str, str]
registered_fields: List
all_fields: Set[str]

def __new__(cls):
if not hasattr(cls, "target_fields"):
Expand Down Expand Up @@ -138,8 +139,7 @@ def _parse_isolated_obj(self):
self._parse_isolated_module(name, ModuleNode)

def _contain_module(self, name):
fileds = set([*self.target_fields, *self.registered_fields])
for k in fileds:
for k in self.all_fields:
if k not in self:
continue
for node in self[k].values():
Expand Down Expand Up @@ -175,23 +175,33 @@ def _parse_single_param(self, ori_name, module_type):
target_type = _dispatch_module_node[module_type]
name, attrs, hooks = _parse_param_name(ori_name)
name = self._get_name(name, ori_name)
if name in self.all_fields:
raise CoreConfigParseError(
f"Conflict name: `{name}`, the class name cannot be same with registry"
)
ori_type = None
if name in self:
ori_type = self[name].__class__
if ori_type == ModuleNode:
self[name] = target_type.from_node(self[name])
node = target_type.from_node(self[name])
self._parse_module(node)
self[name] = node
node = self[name]
elif self._contain_module(name):
ori_type = self[self.__base__][name].__class__
if ori_type == ModuleNode:
self[self.__base__][name] = target_type.from_node(self[self.__base__][name])
node = target_type.from_node(self[self.__base__][name])
base = self.__base__
self._parse_module(node)
self.__base__ = base
self[self.__base__][name] = node
node = self[self.__base__][name]
else:
node = self._parse_implicit_module(name, target_type)
if ori_type and ori_type not in (ModuleNode, target_type):
raise CoreConfigParseError(
f"Error when parsing params {ori_name}, \
target_type is {target_type}, but got {ori_type}"
f"Error when parsing param {ori_name}, "
f"target_type is {target_type}, but got {ori_type}"
)
name = node.name # for ModuleWrapper
if attrs:
Expand All @@ -204,7 +214,7 @@ def _parse_single_param(self, ori_name, module_type):
delattr(self, "__base__")
return node

def _parse_modules(self, node: ModuleNode):
def _parse_module(self, node: ModuleNode):
for param_name in list(node.keys()):
true_name, module_type = _is_special(param_name)
if not module_type:
Expand All @@ -229,9 +239,9 @@ def _parse_inter_modules(self):
module = self[name]
if name in self.target_fields and isinstance(module, dict):
for m in module.values():
self._parse_modules(m)
self._parse_module(m)
elif isinstance(module, ModuleNode):
self._parse_modules(module)
self._parse_module(module)

def __str__(self):
_dict = {}
Expand Down
4 changes: 4 additions & 0 deletions tests/configs/launch/test_class.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[Model.TestClass]
$cls = "VGG"

[Backbone.VGG]
2 changes: 2 additions & 0 deletions tests/configs/launch/test_conflict_name.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[Model.TestClass]
$cls = "Backbone"
7 changes: 7 additions & 0 deletions tests/configs/launch/test_hidden_error.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[Model.FCN]
!classifier = "$Head"
@backbone = "$resnnnet"

[Head.FCNHead]
in_channels = 512
channels = 10
2 changes: 2 additions & 0 deletions tests/configs/launch/test_module.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[Model.TestClass]
$cls = "torch.nn.ReLU"
12 changes: 12 additions & 0 deletions tests/configs/launch/test_nest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[Model.FCN]
!classifier = "$Head"
@backbone = "$Backbone"

[Backbone.MockModel]
!block = "TestBlock"

[Block.TestBlock]

[Head.FCNHead]
in_channels = 512
channels = 10
11 changes: 11 additions & 0 deletions tests/configs/launch/test_nest_hidden.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
[Model.FCN]
!classifier = "$Head"
@backbone = "$Backbone"

[Backbone.MockModel]
!block = "TestBlock"


[Head.FCNHead]
in_channels = 512
channels = 10
7 changes: 7 additions & 0 deletions tests/configs/launch/test_ref_field_error.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[Model.FCN]
!classifier = "$Headdd"
@backbone = "$resnet18"

[Head.FCNHead]
in_channels = 512
channels = 10
1 change: 1 addition & 0 deletions tests/configs/launch/test_regitered_error.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[Model.ResNt]
3 changes: 3 additions & 0 deletions tests/source_code/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time

import torch

from excore import Registry

MODEL = Registry("Model")
Expand All @@ -15,3 +17,4 @@
MODULE = Registry("module")

MODULE.register_module(time)
MODULE.register_module(torch)
19 changes: 19 additions & 0 deletions tests/source_code/models/nets.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from torchvision import models
from torchvision.models import segmentation

from excore import Registry
from source_code import BACKBONE, HEAD, MODEL

BLOCK = Registry("Block")
MODEL.match(segmentation)


Expand All @@ -14,3 +16,20 @@ def _match(name: str, module):

BACKBONE.match(models, force=True, match_func=_match)
HEAD.register_module(segmentation.fcn.FCNHead)


@MODEL.register()
class TestClass:
def __init__(self, cls):
self.cls = cls


@BACKBONE.register()
class MockModel:
def __init__(self, block):
self.block = block


@BLOCK.register()
class TestBlock:
pass
32 changes: 31 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import pytest
import torch
from torchvision.models import ResNet
from torchvision.models import VGG, ResNet

from excore import config
from excore._exceptions import CoreConfigParseError, CoreConfigSupportError, ModuleBuildError
Expand Down Expand Up @@ -62,3 +62,33 @@ def test_argument_error(self):
def test_wrong_type(self):
with pytest.raises(CoreConfigSupportError):
self._load("./configs/launch/error.yaml")

def test_class(self):
modules, _ = self._load("./configs/launch/test_class.toml", False)
assert modules.Model.cls == VGG

def test_module(self):
modules, _ = self._load("./configs/launch/test_module.toml", False)
assert modules.Model.cls == torch.nn.ReLU

def test_regitered_error(self):
with pytest.raises(ModuleBuildError):
self._load("./configs/launch/test_regitered_error.toml", False)

def test_hidden_error(self):
with pytest.raises(CoreConfigParseError):
self._load("./configs/launch/test_hidden_error.toml", False)

def test_ref_field_error(self):
with pytest.raises(CoreConfigParseError):
self._load("./configs/launch/test_ref_field_error.toml", False)

def test_nest(self):
self._load("./configs/launch/test_nest.toml", False)

def test_nest_hidden(self):
self._load("./configs/launch/test_nest_hidden.toml", False)

def test_conflict_name_error(self):
with pytest.raises(CoreConfigParseError):
self._load("./configs/launch/test_conflict_name.toml", False)

0 comments on commit 7d65b32

Please sign in to comment.