diff --git a/excore/config/models.py b/excore/config/models.py index bf46b35..0b0b58f 100644 --- a/excore/config/models.py +++ b/excore/config/models.py @@ -322,7 +322,7 @@ def __repr__(self) -> str: } -def register_special_flag(flag: str, target_module: ModuleNode, force: bool = False) -> None: +def register_special_flag(flag: str, target_module: NodeType, force: bool = False) -> None: if not force and flag in SPECIAL_FLAGS: raise ValueError(f"Special flag `{flag}` already exist.") SPECIAL_FLAGS.append(flag) diff --git a/tests/configs/launch/test_module_registry.toml b/tests/configs/launch/test_module_registry.toml new file mode 100644 index 0000000..83365ed --- /dev/null +++ b/tests/configs/launch/test_module_registry.toml @@ -0,0 +1,7 @@ +[Model.FCN] +*classifier = "$Head" +backbone = "$resnet" + +[Head.FCNHead] +in_channels = 512 +channels = 10 diff --git a/tests/test_module_node_registry.py b/tests/test_module_node_registry.py new file mode 100644 index 0000000..e3597af --- /dev/null +++ b/tests/test_module_node_registry.py @@ -0,0 +1,22 @@ +from excore.config import ModuleNode, load, register_special_flag + + +class FoolModuleNode(ModuleNode): + priority = 2 + + def __call__(self, **params): + p = {} + for k, v in self.items(): + if isinstance(v, int): + v += 1 + p[k] = v + return super().__call__(**p) + + +register_special_flag("*", FoolModuleNode) + + +def test_module_node_registry(): + cfg = load("./configs/launch/test_module_registry.toml") + module, info = cfg.build_all() + assert module.Model.classifier[0].in_channels == 513