From eb52e78eff70fa34260da911564398cfcc376133 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Thu, 6 Jun 2024 19:05:29 +0800 Subject: [PATCH] :sparkles: Support dict modules as parameters --- excore/config/model.py | 10 +++++++++- excore/config/parse.py | 16 +++++++++++++++- tests/configs/launch/test_dict_param.toml | 12 ++++++++++++ tests/test_config.py | 11 +++++++++++ 4 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 tests/configs/launch/test_dict_param.toml diff --git a/excore/config/model.py b/excore/config/model.py index 75ef799..8acbb0f 100644 --- a/excore/config/model.py +++ b/excore/config/model.py @@ -211,17 +211,23 @@ def __call__(self): return self.value +# FIXME: Refactor ModuleWrapper class ModuleWrapper(dict): def __init__( - self, modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None + self, + modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None, + is_dict=False, ): super().__init__() if modules is None: return + self.is_dict = is_dict if isinstance(modules, (ModuleNode, ConfigArgumentHook, ChainedInvocationWrapper)): self[modules.name] = modules elif isinstance(modules, dict): for k, m in modules.items(): + if isinstance(m, list): + m = ModuleWrapper(m) self[k] = m elif isinstance(modules, list): for m in modules: @@ -255,6 +261,8 @@ def __getattr__(self, __name: str) -> Any: raise KeyError(__name) def __call__(self): + if self.is_dict: + return {k: v() for k, v in self.items()} res = [m() for m in self.values()] if len(res) == 1: return res[0] diff --git a/excore/config/parse.py b/excore/config/parse.py index abfe92c..2c8ea22 100644 --- a/excore/config/parse.py +++ b/excore/config/parse.py @@ -61,6 +61,15 @@ def _flatten_list(lis): return new_lis +def _flatten_dict(dic): + new_dic = {} + for k, v in dic.items(): + if isinstance(v, list): + v = _flatten_list(v) + new_dic[k] = v + return new_dic + + class ConfigDict(dict): primary_fields: List primary_to_registry: Dict[str, str] @@ -296,11 +305,16 @@ def _parse_module(self, node: ModuleNode): if not module_type: continue value = node.pop(param_name) + is_dict = False if isinstance(value, list): value = [self._parse_param(v, module_type) for v in value] value = _flatten_list(value) elif isinstance(value, str): value = self._parse_param(value, module_type) + elif isinstance(value, dict): + value = {k: self._parse_param(v, module_type) for k, v in value.items()} + value = _flatten_dict(value) + is_dict = True else: raise CoreConfigParseError(f"Wrong type: {param_name, value}") if isinstance(value, VariableReference): @@ -312,7 +326,7 @@ def _parse_module(self, node: ModuleNode): else: node[true_name] = ref_name else: - node[true_name] = ModuleWrapper(value) + node[true_name] = ModuleWrapper(value, is_dict) def _parse_inter_modules(self): for name in list(self.keys()): diff --git a/tests/configs/launch/test_dict_param.toml b/tests/configs/launch/test_dict_param.toml new file mode 100644 index 0000000..9342cc4 --- /dev/null +++ b/tests/configs/launch/test_dict_param.toml @@ -0,0 +1,12 @@ +[Model.TestClass."$cls"] +a = "VGG" +b = "$Sc::VGG" + +[Model.TestClass."$cls1"] +a = "VGG" +b = "$Sc::*" + +[Sc.VGG] +[Sc.TestClass] + +[VGG] diff --git a/tests/test_config.py b/tests/test_config.py index 527c8c9..285e5d8 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -204,3 +204,14 @@ def test_get_error3(self): def test_ref_error(self): with pytest.raises(CoreConfigParseError): self._load("./configs/launch/test_ref_error.toml", False) + + def test_dict_param(self): + modules, _ = self._load("./configs/launch/test_dict_param.toml", False) + + from source_code.models.nets import VGG, TestClass + + assert modules.Model.cls["a"] == VGG + assert modules.Model.cls["b"] == VGG + + assert modules.Model.cls1["a"] == VGG + assert modules.Model.cls1["b"] == [VGG, TestClass]