Skip to content

Commit

Permalink
✨ Add validate mechanism of ModuleNode
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Dec 9, 2024
1 parent 66bab3f commit ac1ee17
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 87 deletions.
2 changes: 2 additions & 0 deletions example/configs/detail.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
[Backbone.ResNet]
$target = []
9 changes: 9 additions & 0 deletions excore/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ class _WorkspaceConfig:
primary_to_registry: dict[str, str] = field(default_factory=dict)
json_schema_fields: dict[str, str | list[str]] = field(default_factory=dict)
props: dict[Any, Any] = field(default_factory=dict)
excore_validate: bool = field(default=True)
excore_manual_set: bool = field(default=True)
excore_log_build_message: bool = field(default=False)

@property
def base_name(self):
Expand All @@ -48,6 +51,12 @@ def __post_init__(self) -> None:
logger.warning("Please use `excore init` in your command line first")
else:
self.update(toml.load(_workspace_config_file))
if os.environ.get("EXCORE_VALIDATE", "1") == "0":
self.excore_validate = False
if os.environ.get("EXCORE_LOG_BUILD_MESSAGE", "0") == "1":
self.excore_log_build_message = True
if os.environ.get("EXCORE_MANUAL_SET", "1") == "0":
self.excore_manual_set = False

def _get_cache_dir(self) -> str:
base_name = osp.basename(osp.normpath(os.getcwd()))
Expand Down
8 changes: 4 additions & 4 deletions excore/_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,6 @@ class EnvVarParseError(BaseException):
pass


class ImplicitModuleParseError(BaseException):
pass


class ModuleBuildError(BaseException):
pass

Expand All @@ -60,3 +56,7 @@ class HookBuildError(BaseException):

class AnnotationsFutureError(Exception):
pass


class ModuleValidateError(Exception):
pass
75 changes: 59 additions & 16 deletions excore/config/models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations

import importlib
import inspect
import os
import re
from dataclasses import dataclass, field
from inspect import Parameter, isclass
from typing import TYPE_CHECKING, Type, Union

from .._exceptions import EnvVarParseError, ModuleBuildError, StrToClassError
from .._constants import workspace
from .._exceptions import EnvVarParseError, ModuleBuildError, ModuleValidateError, StrToClassError
from .._misc import CacheOut
from ..engine.hook import ConfigArgumentHook, Hook
from ..engine.logging import logger
from ..engine.registry import Registry
from .action import DictAction

if TYPE_CHECKING:
from types import FunctionType, ModuleType
Expand All @@ -19,7 +23,7 @@
from typing_extensions import Self

NodeClassType = Type[Any]
NodeParams = Dict[Any, Any]
NodeParams = Dict[str, Any]
NodeInstance = object

NoCallSkipFlag = Self
Expand All @@ -37,14 +41,12 @@
OTHER_FLAG: Literal[""] = ""

FLAG_PATTERN = re.compile(r"^([@!$&])(.*)$")
LOG_BUILD_MESSAGE = True
DO_NOT_CALL_KEY = "__no_call__"
SPECIAL_FLAGS = [OTHER_FLAG, INTER_FLAG, REUSE_FLAG, CLASS_FLAG, REFER_FLAG]


def silent() -> None:
global LOG_BUILD_MESSAGE # pylint: disable=global-statement
LOG_BUILD_MESSAGE = False
workspace.excore_log_build_message = False


def _is_special(k: str) -> tuple[str, SpecialFlag]:
Expand Down Expand Up @@ -93,14 +95,14 @@ class ModuleNode(dict):
_no_call: bool = field(default=False, repr=False)
priority: int = field(default=0, repr=False)

def _get_params(self, **params: NodeParams) -> NodeParams:
def _update_params(self, **params: NodeParams) -> None:
return_params = {}
for k, v in self.items():
if isinstance(v, (ModuleWrapper, ModuleNode)):
v = v()
return_params[k] = v
return_params.update(params)
return return_params
self.update(params)
self.update(return_params)

@property
def name(self) -> str:
Expand All @@ -110,24 +112,25 @@ def add(self, **params: NodeParams) -> Self:
self.update(params)
return self

def _instantiate(self, params: NodeParams) -> NodeInstance:
def _instantiate(self) -> NodeInstance:
try:
module = self.cls(**params)
module = self.cls(**self)
except Exception as exc:
raise ModuleBuildError(
f"Instantiate Error with module {self.cls} and arguments {params}"
f"Instantiate Error with module {self.cls} and arguments {self}"
) from exc
if LOG_BUILD_MESSAGE:
if workspace.excore_log_build_message:
logger.success(
f"Successfully instantiate module: {self.cls.__name__}, with arguments {params}"
f"Successfully instantiate module: {self.cls.__name__}, with arguments {self}"
)
return module

def __call__(self, **params: NodeParams) -> NoCallSkipFlag | NodeInstance: # type: ignore
if self._no_call:
return self
params = self._get_params(**params)
module = self._instantiate(params)
self._update_params(**params)
self.validate()
module = self._instantiate()
return module

def __lshift__(self, params: NodeParams) -> Self:
Expand Down Expand Up @@ -177,6 +180,39 @@ def from_node(cls, _other: ModuleNode) -> ModuleNode:
node._no_call = _other._no_call
return node

def validate(self) -> None:
if not workspace.excore_validate:
return
signature = inspect.signature(self.cls.__init__)
missing = []
params = list(signature.parameters.values())
if isclass(self.cls): # skip self
params = params[1:]

for param in params:
if (
param.default == param.empty
and param.kind
not in [
Parameter.VAR_POSITIONAL,
Parameter.VAR_KEYWORD,
]
and param.name not in self
):
missing.append(param.name)
message = (
f"Validating `{self.cls.__name__}` , "
f"finding missing parameters: `{missing}` without default values."
)
if not workspace.excore_manual_set and missing:
raise ModuleValidateError(message)
if missing:
logger.info(message)
for param_name in missing:
logger.info(f"Input value of paramter `{param_name}`:")
value = input()
self[param_name] = DictAction._parse_iterable(value)


class InterNode(ModuleNode):
priority: int = 2
Expand All @@ -187,10 +223,17 @@ def __excore_check_target_type__(cls, target_type: type[ModuleNode]) -> bool:


class ConfigHookNode(ModuleNode):
def validate(self) -> None:
if "node" in self:
raise ModuleValidateError(
f"Parameter `node:{self['node']}` should not exist in `ConfigHookNode`."
)
super().validate()

def __call__(self, **params: NodeParams) -> NodeInstance | ConfigHookSkipFlag | Hook:
if issubclass(self.cls, ConfigArgumentHook):
return None
params = self._get_params(**params)
params = self._update_params(**params)

Check failure on line 236 in excore/config/models.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

excore/config/models.py#L236

Assigning result of a function call, where the function has no return
return self._instantiate(params)


Expand Down
Loading

0 comments on commit ac1ee17

Please sign in to comment.