diff --git a/excore/config/model.py b/excore/config/model.py index d5618ed..55b4751 100644 --- a/excore/config/model.py +++ b/excore/config/model.py @@ -175,15 +175,12 @@ def __call__(self): return self.cls -@dataclass -class ChainedInvocationWrapper: - node: ModuleNode - attrs: Sequence[str] - - def __getattr__(self, __name): - return getattr(self.node, __name) +class ChainedInvocationWrapper(ConfigArgumentHook): + def __init__(self, node: ModuleNode, attrs: Sequence[str]) -> None: + super().__init__(node) + self.attrs = attrs - def __call__(self, **kwargs): + def hook(self, **kwargs): target = self.node(**kwargs) if isinstance(target, ModuleNode): raise ModuleBuildError(f"Do not support `{DO_NOT_CALL_KEY}`") @@ -221,7 +218,7 @@ def __init__( if modules is None: return self.is_dict = is_dict - if isinstance(modules, (ModuleNode, ConfigArgumentHook, ChainedInvocationWrapper)): + if isinstance(modules, (ModuleNode, ConfigArgumentHook)): self[modules.name] = modules elif isinstance(modules, dict): for k, m in modules.items(): diff --git a/excore/config/parse.py b/excore/config/parse.py index 2c8ea22..bef133c 100644 --- a/excore/config/parse.py +++ b/excore/config/parse.py @@ -240,7 +240,11 @@ def _get_name_and_field(self, name, ori_name): ) return list(modules.keys())[0], base - def _apply_hooks(self, node, hooks): + def _apply_hooks(self, node, hooks, attrs): + if attrs: + node = ChainedInvocationWrapper(node, attrs) + if not hooks: + return node for hook in hooks: if hook not in self: raise CoreConfigParseError(f"Unregistered hook {hook}") @@ -277,13 +281,7 @@ def _parse_single_param(self, name, ori_name, field, target_type, attrs, hooks): f"Error when parsing param `{ori_name}`, " f"target_type is `{target_type}`, but got `{ori_type}`" ) - name = node.name # for ModuleWrapper - if attrs: - node = ChainedInvocationWrapper(node, attrs) - node.name = name - if hooks: - node = self._apply_hooks(node, hooks) - node.name = name + node = self._apply_hooks(node, hooks, attrs) return node def _parse_param(self, ori_name, module_type): diff --git a/excore/engine/hook.py b/excore/engine/hook.py index e57a23b..b147b1c 100644 --- a/excore/engine/hook.py +++ b/excore/engine/hook.py @@ -218,17 +218,18 @@ def __init__( ): self.node = node self.enabled = enabled + self.name = node.name self._is_initialized = True - def hook(self): + def hook(self, **kwargs): raise NotImplementedError(f"`{self.__class__.__name__}` do not implement `hook` method.") @final - def __call__(self): + def __call__(self, **kwargs): if not getattr(self, "_is_initialized", False): raise CoreConfigSupportError( f"Call super().__init__() in class `{self.__class__.__name__}`" ) if self.enabled: - return self.hook() - return self.node() + return self.hook(**kwargs) + return self.node(**kwargs)