diff --git a/excore/_misc.py b/excore/_misc.py index 961de77..03812ac 100644 --- a/excore/_misc.py +++ b/excore/_misc.py @@ -10,7 +10,10 @@ def __call__(self, func): @functools.wraps(func) def _cache(self): if not hasattr(self, "cached_elem"): - self.cached_elem = func(self) + cached_elem = func(self) + if cached_elem != self: + self.cached_elem = cached_elem + return cached_elem return self.cached_elem return _cache diff --git a/excore/config/model.py b/excore/config/model.py index 08fef64..d5618ed 100644 --- a/excore/config/model.py +++ b/excore/config/model.py @@ -1,7 +1,7 @@ import importlib import os import re -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError @@ -73,7 +73,7 @@ def _str_to_target(module_name): @dataclass class ModuleNode(dict): cls: Any - _no_call: bool = False + _no_call: bool = field(default=False, repr=False) def _get_params(self, **kwargs): params = {} @@ -120,7 +120,7 @@ def __lshift__(self, kwargs): def __rshift__(self, __other): if not isinstance(__other, ModuleNode): - raise TypeError(f"Expect type is dict, but got {type(__other)}") + raise TypeError(f"Expect type is `ModuleNode`, but got {type(__other)}") __other.update(self) return self @@ -148,8 +148,7 @@ def from_node(cls, _other: "ModuleNode") -> "ModuleNode": if _other.__class__.__name__ == cls.__name__: return _other node = cls(_other.cls) << _other - if hasattr(_other, "__no_call"): - node._no_call = True + node._no_call = _other._no_call return node @@ -212,7 +211,6 @@ def __call__(self): return self.value -# FIXME: Refactor ModuleWrapper class ModuleWrapper(dict): def __init__( self, @@ -259,7 +257,7 @@ def first(self): def __getattr__(self, __name: str) -> Any: if __name in self.keys(): return self[__name] - raise KeyError(__name) + raise KeyError(f"Invalid key `{__name}`, must be one of `{list(self.keys())}`") def __call__(self): if self.is_dict: diff --git a/tests/configs/launch/test_no_call_reused.toml b/tests/configs/launch/test_no_call_reused.toml new file mode 100644 index 0000000..4dec47e --- /dev/null +++ b/tests/configs/launch/test_no_call_reused.toml @@ -0,0 +1,6 @@ +[Model.TestClass] +__no_call__ = true +$cls = "torch.nn.ReLU" + +[Backbone.VGG] +@x = "TestClass" diff --git a/tests/test_config.py b/tests/test_config.py index 285e5d8..328ca59 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,7 +12,7 @@ ImplicitModuleParseError, ModuleBuildError, ) -from excore.config.model import ModuleNode +from excore.config.model import ModuleNode, ReusedNode class TestConfig: @@ -117,6 +117,11 @@ def test_no_call(self): modules, _ = self._load("./configs/launch/test_no_call.toml", False) assert isinstance(modules.Model, ModuleNode) + def test_no_call_with_reused_node(self): + modules, _ = self._load("./configs/launch/test_no_call_reused.toml", False) + assert isinstance(modules.Backbone.x, ReusedNode) + assert id(modules.Backbone.x) == id(modules.Model) + def test_dict_action(self): from init import excute