Skip to content

Commit

Permalink
✅ Add test of module node registry
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Dec 3, 2024
1 parent 97c8340 commit 5288b64
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 1 deletion.
2 changes: 1 addition & 1 deletion excore/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def __repr__(self) -> str:
}


def register_special_flag(flag: str, target_module: ModuleNode, force: bool = False) -> None:
def register_special_flag(flag: str, target_module: NodeType, force: bool = False) -> None:
if not force and flag in SPECIAL_FLAGS:
raise ValueError(f"Special flag `{flag}` already exist.")
SPECIAL_FLAGS.append(flag)
Expand Down
7 changes: 7 additions & 0 deletions tests/configs/launch/test_module_registry.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
[Model.FCN]
*classifier = "$Head"
backbone = "$resnet"

[Head.FCNHead]
in_channels = 512
channels = 10
22 changes: 22 additions & 0 deletions tests/test_module_node_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from excore.config import ModuleNode, load, register_special_flag


class FoolModuleNode(ModuleNode):
priority = 2

def __call__(self, **params):
p = {}
for k, v in self.items():
if isinstance(v, int):
v += 1
p[k] = v
return super().__call__(**p)


register_special_flag("*", FoolModuleNode)


def test_module_node_registry():
cfg = load("./configs/launch/test_module_registry.toml")
module, info = cfg.build_all()
assert module.Model.classifier[0].in_channels == 513

0 comments on commit 5288b64

Please sign in to comment.