Skip to content

Commit

Permalink
✨ Support dict modules as parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
Asthestarsfalll committed Jun 6, 2024
1 parent 968f720 commit eb52e78
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 2 deletions.
10 changes: 9 additions & 1 deletion excore/config/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,17 +211,23 @@ def __call__(self):
return self.value


# FIXME: Refactor ModuleWrapper
class ModuleWrapper(dict):
def __init__(
self, modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None
self,
modules: Optional[Union[Dict[str, ModuleNode], List[ModuleNode], ModuleNode]] = None,
is_dict=False,
):
super().__init__()
if modules is None:
return
self.is_dict = is_dict
if isinstance(modules, (ModuleNode, ConfigArgumentHook, ChainedInvocationWrapper)):
self[modules.name] = modules
elif isinstance(modules, dict):
for k, m in modules.items():
if isinstance(m, list):
m = ModuleWrapper(m)
self[k] = m
elif isinstance(modules, list):
for m in modules:
Expand Down Expand Up @@ -255,6 +261,8 @@ def __getattr__(self, __name: str) -> Any:
raise KeyError(__name)

def __call__(self):
if self.is_dict:
return {k: v() for k, v in self.items()}
res = [m() for m in self.values()]
if len(res) == 1:
return res[0]
Expand Down
16 changes: 15 additions & 1 deletion excore/config/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def _flatten_list(lis):
return new_lis


def _flatten_dict(dic):
new_dic = {}
for k, v in dic.items():
if isinstance(v, list):
v = _flatten_list(v)
new_dic[k] = v
return new_dic


class ConfigDict(dict):
primary_fields: List
primary_to_registry: Dict[str, str]
Expand Down Expand Up @@ -296,11 +305,16 @@ def _parse_module(self, node: ModuleNode):
if not module_type:
continue
value = node.pop(param_name)
is_dict = False
if isinstance(value, list):
value = [self._parse_param(v, module_type) for v in value]
value = _flatten_list(value)
elif isinstance(value, str):
value = self._parse_param(value, module_type)
elif isinstance(value, dict):
value = {k: self._parse_param(v, module_type) for k, v in value.items()}
value = _flatten_dict(value)
is_dict = True
else:
raise CoreConfigParseError(f"Wrong type: {param_name, value}")
if isinstance(value, VariableReference):
Expand All @@ -312,7 +326,7 @@ def _parse_module(self, node: ModuleNode):
else:
node[true_name] = ref_name
else:
node[true_name] = ModuleWrapper(value)
node[true_name] = ModuleWrapper(value, is_dict)

def _parse_inter_modules(self):
for name in list(self.keys()):
Expand Down
12 changes: 12 additions & 0 deletions tests/configs/launch/test_dict_param.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
[Model.TestClass."$cls"]
a = "VGG"
b = "$Sc::VGG"

[Model.TestClass."$cls1"]
a = "VGG"
b = "$Sc::*"

[Sc.VGG]
[Sc.TestClass]

[VGG]
11 changes: 11 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,3 +204,14 @@ def test_get_error3(self):
def test_ref_error(self):
with pytest.raises(CoreConfigParseError):
self._load("./configs/launch/test_ref_error.toml", False)

def test_dict_param(self):
modules, _ = self._load("./configs/launch/test_dict_param.toml", False)

from source_code.models.nets import VGG, TestClass

assert modules.Model.cls["a"] == VGG
assert modules.Model.cls["b"] == VGG

assert modules.Model.cls1["a"] == VGG
assert modules.Model.cls1["b"] == [VGG, TestClass]

0 comments on commit eb52e78

Please sign in to comment.