Skip to content

Commit

Permalink
♻️ Unify ChainedInvocationWrapper and ConfigArgumentHook
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Jul 11, 2024
1 parent 6da1d7f commit 17a80d5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
15 changes: 6 additions & 9 deletions excore/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,12 @@ def __call__(self):
return self.cls


@dataclass
class ChainedInvocationWrapper:
node: ModuleNode
attrs: Sequence[str]

def __getattr__(self, __name):
return getattr(self.node, __name)
class ChainedInvocationWrapper(ConfigArgumentHook):
def __init__(self, node: ModuleNode, attrs: Sequence[str]) -> None:
super().__init__(node)
self.attrs = attrs

def __call__(self, **kwargs):
def hook(self, **kwargs):
target = self.node(**kwargs)
if isinstance(target, ModuleNode):
raise ModuleBuildError(f"Do not support `{DO_NOT_CALL_KEY}`")
Expand Down Expand Up @@ -221,7 +218,7 @@ def __init__(
if modules is None:
return
self.is_dict = is_dict
if isinstance(modules, (ModuleNode, ConfigArgumentHook, ChainedInvocationWrapper)):
if isinstance(modules, (ModuleNode, ConfigArgumentHook)):
self[modules.name] = modules
elif isinstance(modules, dict):
for k, m in modules.items():
Expand Down
14 changes: 6 additions & 8 deletions excore/config/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,11 @@ def _get_name_and_field(self, name, ori_name):
)
return list(modules.keys())[0], base

def _apply_hooks(self, node, hooks):
def _apply_hooks(self, node, hooks, attrs):
if attrs:
node = ChainedInvocationWrapper(node, attrs)
if not hooks:
return node
for hook in hooks:
if hook not in self:
raise CoreConfigParseError(f"Unregistered hook {hook}")
Expand Down Expand Up @@ -277,13 +281,7 @@ def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks):
f"Error when parsing param `{ori_name}`, "
f"target_type is `{target_type}`, but got `{ori_type}`"
)
name = node.name # for ModuleWrapper
if attrs:
node = ChainedInvocationWrapper(node, attrs)
node.name = name
if hooks:
node = self._apply_hooks(node, hooks)
node.name = name
node = self._apply_hooks(node, hooks, attrs)
return node

def _parse_param(self, ori_name, module_type):
Expand Down
9 changes: 5 additions & 4 deletions excore/engine/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,17 +218,18 @@ def __init__(
):
self.node = node
self.enabled = enabled
self.name = node.name
self._is_initialized = True

def hook(self):
def hook(self, **kwargs):
raise NotImplementedError(f"`{self.__class__.__name__}` do not implement `hook` method.")

@final
def __call__(self):
def __call__(self, **kwargs):
if not getattr(self, "_is_initialized", False):
raise CoreConfigSupportError(
f"Call super().__init__() in class `{self.__class__.__name__}`"
)
if self.enabled:
return self.hook()
return self.node()
return self.hook(**kwargs)
return self.node(**kwargs)

0 comments on commit 17a80d5

Please sign in to comment.