Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for 4 merging modes of DictConfig #917

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 64 additions & 14 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,50 @@ def get_node_value(key: Union[DictKeyType, int]) -> Any:

return retlist
assert False

@staticmethod
def _map_val_merge(dest: "BaseContainer", src_value: Any, key: str, how: str) -> None:
"""merges a value into a destination container conditional on the type of merge (left, inner, outer-left or outer-right) """
if how == "left":
pass
elif how == "inner":
pass
elif how == "outer-left":
if key in dest:
pass
else:
dest.__setitem__(key, src_value)
elif how == "outer-right":
dest.__setitem__(key, src_value)

@staticmethod
def _none_map_merge(dest: "BaseContainer", src: "BaseContainer", key: str, how: str) -> None:
"""merges src container value at given key into a destination container (without said key) conditional on the type of merge (left, inner, outer-left or outer-right) """
if how == "left":
pass
elif how == "inner":
pass
elif how == "outer-left":
BaseContainer._set_item(dest, src, key)
elif how in "outer-right":
BaseContainer._set_item(dest, src, key)

@staticmethod
def _set_item(dest: "BaseContainer", src: "BaseContainer", key: str) -> None:
from omegaconf import DictConfig
assert isinstance(src, DictConfig)
src_type = src._metadata.object_type
from omegaconf import open_dict

if is_structured_config(src_type):
# verified to be compatible above in _validate_merge
with open_dict(dest):
dest[key] = src._get_node(key)
else:
dest[key] = src._get_node(key)

@staticmethod
def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
def _map_merge(dest: "BaseContainer", src: "BaseContainer", how: str = "outer-right") -> None:
"""merge src into dest and return a new copy, does not modified input"""
from omegaconf import AnyNode, DictConfig, ValueNode

Expand Down Expand Up @@ -364,12 +405,12 @@ def expand(node: Container) -> None:
if dest_node is not None:
if isinstance(dest_node, BaseContainer):
if isinstance(src_value, BaseContainer):
dest_node._merge_with(src_value)
dest_node._merge_with(src_value, how=how)
elif not missing_src_value:
dest.__setitem__(key, src_value)
BaseContainer._map_val_merge(dest, src_value, key, how)
else:
if isinstance(src_value, BaseContainer):
dest.__setitem__(key, src_value)
BaseContainer._map_val_merge(dest, src_value, key, how)
else:
assert isinstance(dest_node, ValueNode)
assert isinstance(src_node, ValueNode)
Expand All @@ -392,14 +433,21 @@ def expand(node: Container) -> None:
except (ValidationError, ReadonlyConfigError) as e:
dest._format_and_raise(key=key, value=src_value, cause=e)
else:
from omegaconf import open_dict

if is_structured_config(src_type):
# verified to be compatible above in _validate_merge
with open_dict(dest):
dest[key] = src._get_node(key)
else:
dest[key] = src._get_node(key)
BaseContainer._none_map_merge(dest, src, key, how)
if how == "inner":
# Remove non-overlapping keys from destination
dest_items = dest.items_ex(resolve=False) if not dest._is_missing() else []
for key, dest_value in dest_items:
src_node = src._get_node(key, validate_access=False)
dest_node = dest._get_node(key, validate_access=False)
assert src_node is None or isinstance(src_node, Node)
assert dest_node is None or isinstance(dest_node, Node)
if (
key not in src
or (isinstance(dest_node, BaseContainer) and not isinstance(src_node, BaseContainer))
or (not isinstance(dest_node, BaseContainer) and isinstance(src_node, BaseContainer))
):
dest.__delitem__(key)

_update_types(node=dest, ref_type=src_ref_type, object_type=src_type)

Expand Down Expand Up @@ -455,9 +503,10 @@ def merge_with(
*others: Union[
"BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
],
how: str = "outer-right"
) -> None:
try:
self._merge_with(*others)
self._merge_with(*others, how=how)
except Exception as e:
self._format_and_raise(key=None, value=None, cause=e)

Expand All @@ -466,6 +515,7 @@ def _merge_with(
*others: Union[
"BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any
],
how: str = "outer-right"
) -> None:
from .dictconfig import DictConfig
from .listconfig import ListConfig
Expand All @@ -481,7 +531,7 @@ def _merge_with(
other = _ensure_container(other, flags=my_flags)

if isinstance(self, DictConfig) and isinstance(other, DictConfig):
BaseContainer._map_merge(self, other)
BaseContainer._map_merge(self, other, how=how)
elif isinstance(self, ListConfig) and isinstance(other, ListConfig):
BaseContainer._list_merge(self, other)
else:
Expand Down
26 changes: 25 additions & 1 deletion omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,34 @@ def merge(
Tuple[Any, ...],
Any,
],
how: str = "outer-right"
) -> Union[ListConfig, DictConfig]:
"""
Merge a list of previously created configs into a single one
:param configs: Input configs
:how: How to merge DictConfigs. One of left, inner, outer-left or outer-right.
left: Only merge on keys from the left DictConfig. E.g.
{a1: 1, a2: {b1: 2, b2: 3}},
{a2: {b2: 11, b2: 12}, a3: 13}
-> {a1: 1, a2: {b1: 2, b2: 11}}
inner: Only merge on keys from both DictConfigs. E.g.
{a1: 1, a2: {b1: 2, b2: 3}},
{a2: {b2: 11, b2: 12}, a3: 13}
-> {a2: {b2: 11}}
outer-left: Merge on keys from both DictConfigs. E.g.
{a1: 1, a2: {b1: 2, b2: 3}},
{a2: {b2: 11, b2: 12}, a3: 13}
-> {a1: 1, a2: {b1: 2, b2: 11, b3: }}

If nesting levels conflict, use keys on the left. E.g.
{a1: {b1: 9}}, {a1: {b1: {c1: 21}}, a2: 22}
-> {a1: {b1: 9}, a2: 22}
outer-right: Merge on keys from both DictConfigs (see outer-left).

If nesting levels conflict, use keys on the right. E.g.
{a1: {b1: 9}}, {a1: {b1: {c1: 21}}, a2: 22}
-> {a1: {b1: {c1: 21}}, a2: 22}
Currently defaults to outer-right.
:return: the merged config object.
"""
assert len(configs) > 0
Expand All @@ -264,7 +288,7 @@ def merge(
assert isinstance(target, (DictConfig, ListConfig))

with flag_override(target, "readonly", False):
target.merge_with(*configs[1:])
target.merge_with(*configs[1:], how=how)
turned_readonly = target._get_flag("readonly") is True

if turned_readonly:
Expand Down
124 changes: 123 additions & 1 deletion tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,128 @@ def test_merge(
with expected:
merge_function(*configs)

@mark.parametrize(
("config1", "config2", "how", "expected"),
[
param(
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
{"a":2, "b": 3, "c": 4},
"left",
{"a": 2, "b": {"b1": 1, "b2": 2}, "d": 5},
id="left"
),
param(
{"a":2, "b": 3, "c": 4},
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
"left-reverse",
{"a": 1, "b": 3, "c": 4},
id="left-reverse"
),
param(
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
{"a":2, "b": 3, "c": 4},
"inner",
{"a": 2},
id="inner"
),
param(
{"a":2, "b": 3, "c": 4},
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
"inner",
{"a": 1},
id="inner-reverse"
),
param(
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
{"a":2, "b": 3, "c": 4},
"outer-left",
{"a": 2, "b": {"b1": 1, "b2": 2}, "c": 4, "d": 5},
id="outer-left"
),
param(
{"a":2, "b": 3, "c": 4},
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
"outer-left",
{"a": 1, "b": 3, "c": 4, "d": 5},
id="outer-left-reverse"
),
param(
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
{"a":2, "b": 3, "c": 4},
"outer-right",
{"a": 2, "b": 3, "c": 4, "d": 5},
id="outer-right"
),
param(
{"a":2, "b": 3, "c": 4},
{"a": 1, "b": {"b1": 1, "b2": 2}, "d": 5},
"outer-right",
{"a": 1, "b": {"b1": 1, "b2": 2}, "c": 4, "d": 5},
id="outer-right-reverse"
),
param(
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 4}}},
{"a2": {"b2": {"c2": 14, "c3": 15}}},
"left",
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 14}}},
id="nested-left"
),
param(
{"a1": {"b1": 1}},
{"a1": {"b2": 2}},
"left",
{"a1": {"b1": 1}},
id="nested-left"
),
param(
{"a1": {"b1": 1}},
{"a2": {"b1": 2}},
"left",
{"a1": {"b1": 1}},
id="nested-left"
),
param(
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 4}}},
{"a2": {"b2": {"c2": 14, "c3": 15}}},
"inner",
{"a2": {"b2": {"c2": 14}}},
id="nested-inner"
),
param(
{"a1": {"b1": 1}},
{"a1": {"b2": 2}},
"inner",
{"a1": {}},
id="nested-inner"
),
param(
{"a1": {"b1": 1}},
{"a2": {"b1": 2}},
"inner",
{},
id="nested-inner"
),
param(
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 4}}},
{"a2": {"b2": {"c2": 14, "c3": 15}}},
"outer-left",
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 14, "c3": 15}}},
id="nested-outer-left"
),
param(
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 4}}},
{"a2": {"b2": {"c2": 14, "c3": 15}}},
"outer-left",
{"a1" : 1, "a2": {"b1": 2, "b2": {"c1": 3, "c2": 14, "c3": 15}}},
id="nested-outer-right"
),
]
)
def test_merge_how(config1, config2, how, expected) -> None:
config1 = OmegaConf.create(config1)
config2 = OmegaConf.create(config2)
merge_config = OmegaConf.merge(config1, config2, how=how)
assert merge_config == expected

@mark.parametrize(
"inputs,expected,ref_type,is_optional",
Expand Down Expand Up @@ -821,7 +943,7 @@ def test_merge_list_list() -> None:
b = OmegaConf.create([4, 5, 6])
a.merge_with(b)
assert a == b


@mark.parametrize("merge_func", [OmegaConf.merge, OmegaConf.unsafe_merge])
@mark.parametrize(
Expand Down