Skip to content

Commit

Permalink
🐛 Fix __no_call__ with ReusedNode
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Jul 10, 2024
1 parent c413314 commit 8607f23
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
5 changes: 4 additions & 1 deletion excore/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ def __call__(self, func):
@functools.wraps(func)
def _cache(self):
if not hasattr(self, "cached_elem"):
self.cached_elem = func(self)
cached_elem = func(self)
if cached_elem != self:
self.cached_elem = cached_elem
return cached_elem
return self.cached_elem

return _cache
Expand Down
12 changes: 5 additions & 7 deletions excore/config/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import importlib
import os
import re
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError
Expand Down Expand Up @@ -73,7 +73,7 @@ def _str_to_target(module_name):
@dataclass
class ModuleNode(dict):
cls: Any
_no_call: bool = False
_no_call: bool = field(default=False, repr=False)

def _get_params(self, **kwargs):
params = {}
Expand Down Expand Up @@ -120,7 +120,7 @@ def __lshift__(self, kwargs):

def __rshift__(self, __other):
if not isinstance(__other, ModuleNode):
raise TypeError(f"Expect type is dict, but got {type(__other)}")
raise TypeError(f"Expect type is `ModuleNode`, but got {type(__other)}")
__other.update(self)
return self

Expand Down Expand Up @@ -148,8 +148,7 @@ def from_node(cls, _other: "ModuleNode") -> "ModuleNode":
if _other.__class__.__name__ == cls.__name__:
return _other
node = cls(_other.cls) << _other
if hasattr(_other, "__no_call"):
node._no_call = True
node._no_call = _other._no_call
return node


Expand Down Expand Up @@ -212,7 +211,6 @@ def __call__(self):
return self.value


# FIXME: Refactor ModuleWrapper
class ModuleWrapper(dict):
def __init__(
self,
Expand Down Expand Up @@ -259,7 +257,7 @@ def first(self):
def __getattr__(self, __name: str) -> Any:
if __name in self.keys():
return self[__name]
raise KeyError(__name)
raise KeyError(f"Invalid key `{__name}`, must be one of `{list(self.keys())}`")

def __call__(self):
if self.is_dict:
Expand Down
6 changes: 6 additions & 0 deletions tests/configs/launch/test_no_call_reused.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[Model.TestClass]
__no_call__ = true
$cls = "torch.nn.ReLU"

[Backbone.VGG]
@x = "TestClass"
7 changes: 6 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ImplicitModuleParseError,
ModuleBuildError,
)
from excore.config.model import ModuleNode
from excore.config.model import ModuleNode, ReusedNode


class TestConfig:
Expand Down Expand Up @@ -117,6 +117,11 @@ def test_no_call(self):
modules, _ = self._load("./configs/launch/test_no_call.toml", False)
assert isinstance(modules.Model, ModuleNode)

def test_no_call_with_reused_node(self):
modules, _ = self._load("./configs/launch/test_no_call_reused.toml", False)
assert isinstance(modules.Backbone.x, ReusedNode)
assert id(modules.Backbone.x) == id(modules.Model)

def test_dict_action(self):
from init import excute

Expand Down

0 comments on commit 8607f23

Please sign in to comment.