diff --git a/docs/notebook/Tutorial.ipynb b/docs/notebook/Tutorial.ipynb index 772ae00e6..d17712c57 100644 --- a/docs/notebook/Tutorial.ipynb +++ b/docs/notebook/Tutorial.ipynb @@ -1,6 +1,7 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -14,6 +15,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -45,6 +47,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -79,6 +82,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -112,6 +116,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -149,6 +154,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -190,6 +196,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -227,6 +234,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -265,6 +273,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -303,6 +312,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -334,6 +344,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -365,6 +376,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -396,6 +408,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -416,6 +429,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -436,6 +450,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -456,6 +471,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -487,6 +503,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -523,6 +540,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -534,6 +552,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -594,6 +613,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -627,6 +647,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -661,6 +682,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -694,6 +716,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -714,6 +737,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -771,6 +795,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -811,6 +836,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -818,6 +844,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -876,6 +903,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -883,6 +911,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -917,6 +946,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -946,6 +976,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1004,6 +1035,7 @@ ] }, { + "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -1112,6 +1144,85 @@ "conf.merge_with_cli()\n", "print(OmegaConf.to_yaml(conf))" ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "By default, `merge()` is replacing the target list with the source list.\n", + "Use `list_merge_mode` to control the merge behavior for lists.\n", + "Currently there are three different merge modes:\n", + "* `REPLACE`: Replaces the target list with the new one (default)\n", + "* `EXTEND`: Extends the target list with the new one\n", + "* `EXTEND_UNIQUE`: Extends the target list items with items not present in it" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`example2.yaml` file:\n", + "```yaml\n", + "server:\n", + " port: 80\n", + "users:\n", + " - user1\n", + " - user2\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`example4.yaml` file:\n", + "```yaml\n", + "users:\n", + " - user3\n", + " - user2\n", + "```" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you load them and merge them with `list_merge_mode=ListMergeMode.EXTEND_UNIQUE` you will get this:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from omegaconf import OmegaConf, ListMergeMode\n", + "\n", + "cfg_example2 = OmegaConf.load('../source/example2.yaml')\n", + "cfg_example4 = OmegaConf.load('../source/example4.yaml')\n", + "\n", + "conf = OmegaConf.merge(cfg_example2, cfg_example4, list_merge_mode=ListMergeMode.EXTEND_UNIQUE)\n", + "print(OmegaConf.to_yaml(conf))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "```yaml\n", + "server:\n", + " port: 80\n", + "users:\n", + "- user1\n", + "- user2\n", + "- user3\n", + "```" + ] } ], "metadata": { @@ -1130,7 +1241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.5" + "version": "3.10.8" }, "pycharm": { "stem_cell": { diff --git a/docs/source/example4.yaml b/docs/source/example4.yaml new file mode 100644 index 000000000..e5ea3256c --- /dev/null +++ b/docs/source/example4.yaml @@ -0,0 +1,3 @@ +users: + - user3 + - user2 diff --git a/docs/source/usage.rst b/docs/source/usage.rst index feed18391..3ead89099 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -492,6 +492,43 @@ Note how the port changes to 82, and how the users lists are combined. file: log.txt +By default, ``merge()`` is replacing the target list with the source list. +Use ``list_merge_mode`` to control the merge behavior for lists. +This Enum is defined in ``omegaconf.ListMergeMode`` and defines the following modes: +* ``REPLACE``: Replaces the target list with the new one (default) +* ``EXTEND``: Extends the target list with the new one +* ``EXTEND_UNIQUE``: Extends the target list items with items not present in it + +**example2.yaml** file: + +.. include:: example2.yaml + :code: yaml + +**example4.yaml** file: + +.. include:: example4.yaml + :code: yaml + +If you load them and merge them with ``list_merge_mode=ListMergeMode.EXTEND_UNIQUE`` you will get this: + +.. doctest:: + + >>> from omegaconf import OmegaConf, ListMergeMode + >>> + >>> cfg_1 = OmegaConf.load('source/example2.yaml') + >>> cfg_2 = OmegaConf.load('source/example4.yaml') + >>> + >>> mode = ListMergeMode.EXTEND_UNIQUE + >>> conf = OmegaConf.merge(cfg_1, cfg_2, list_merge_mode=mode) + >>> print(OmegaConf.to_yaml(conf)) + server: + port: 80 + users: + - user1 + - user2 + - user3 + + OmegaConf.unsafe_merge() ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/news/1082.feature b/news/1082.feature new file mode 100644 index 000000000..bcdea20de --- /dev/null +++ b/news/1082.feature @@ -0,0 +1 @@ +OmegaConf.merge() can now take a list_merge_mode parameter that controls the strategy for merging lists (replace, extend and more). \ No newline at end of file diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 1670cf08f..e8b9e369e 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -1,4 +1,4 @@ -from .base import Container, DictKeyType, Node, SCMode, UnionNode +from .base import Container, DictKeyType, ListMergeMode, Node, SCMode, UnionNode from .dictconfig import DictConfig from .errors import ( KeyValidationError, @@ -46,6 +46,7 @@ "OmegaConf", "Resolver", "SCMode", + "ListMergeMode", "flag_override", "read_write", "open_dict", diff --git a/omegaconf/base.py b/omegaconf/base.py index 381323d3c..77e951058 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -795,6 +795,12 @@ class SCMode(Enum): INSTANTIATE = 3 # Create a dataclass or attrs class instance +class ListMergeMode(Enum): + REPLACE = 1 # Replaces the target list with the new one (default) + EXTEND = 2 # Extends the target list with the new one + EXTEND_UNIQUE = 3 # Extends the target list items with items not present in it + + class UnionNode(Box): """ This class handles Union type hints. The `_content` attribute is either a diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 575df3c72..156b1ca30 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -34,6 +34,7 @@ Container, ContainerMetadata, DictKeyType, + ListMergeMode, Node, SCMode, UnionNode, @@ -304,7 +305,11 @@ def get_node_value(key: Union[DictKeyType, int]) -> Any: assert False @staticmethod - def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: + def _map_merge( + dest: "BaseContainer", + src: "BaseContainer", + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, + ) -> None: """merge src into dest and return a new copy, does not modified input""" from omegaconf import AnyNode, DictConfig, ValueNode @@ -396,7 +401,10 @@ def expand(node: Container) -> None: if dest_node is not None: if isinstance(dest_node, BaseContainer): if isinstance(src_node, BaseContainer): - dest_node._merge_with(src_node) + dest_node._merge_with( + src_node, + list_merge_mode=list_merge_mode, + ) elif not src_node_missing: dest.__setitem__(key, src_node) else: @@ -441,7 +449,11 @@ def expand(node: Container) -> None: dest._set_flag(flag, value) @staticmethod - def _list_merge(dest: Any, src: Any) -> None: + def _list_merge( + dest: Any, + src: Any, + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, + ) -> None: from omegaconf import DictConfig, ListConfig, OmegaConf assert isinstance(dest, ListConfig) @@ -471,7 +483,14 @@ def _list_merge(dest: Any, src: Any) -> None: for item in src._iter_ex(resolve=False): temp_target.append(item) - dest.__dict__["_content"] = temp_target.__dict__["_content"] + if list_merge_mode == ListMergeMode.EXTEND: + dest.__dict__["_content"].extend(temp_target.__dict__["_content"]) + elif list_merge_mode == ListMergeMode.EXTEND_UNIQUE: + for entry in temp_target.__dict__["_content"]: + if entry not in dest.__dict__["_content"]: + dest.__dict__["_content"].append(entry) + else: # REPLACE (default) + dest.__dict__["_content"] = temp_target.__dict__["_content"] # explicit flags on the source config are replacing the flag values in the destination flags = src._metadata.flags @@ -485,9 +504,13 @@ def merge_with( *others: Union[ "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any ], + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, ) -> None: try: - self._merge_with(*others) + self._merge_with( + *others, + list_merge_mode=list_merge_mode, + ) except Exception as e: self._format_and_raise(key=None, value=None, cause=e) @@ -496,6 +519,7 @@ def _merge_with( *others: Union[ "BaseContainer", Dict[str, Any], List[Any], Tuple[Any, ...], Any ], + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, ) -> None: from .dictconfig import DictConfig from .listconfig import ListConfig @@ -511,9 +535,17 @@ def _merge_with( other = _ensure_container(other, flags=my_flags) if isinstance(self, DictConfig) and isinstance(other, DictConfig): - BaseContainer._map_merge(self, other) + BaseContainer._map_merge( + self, + other, + list_merge_mode=list_merge_mode, + ) elif isinstance(self, ListConfig) and isinstance(other, ListConfig): - BaseContainer._list_merge(self, other) + BaseContainer._list_merge( + self, + other, + list_merge_mode=list_merge_mode, + ) else: raise TypeError("Cannot merge DictConfig with ListConfig") diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index c8c3797b5..041602879 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -53,7 +53,7 @@ split_key, type_str, ) -from .base import Box, Container, Node, SCMode, UnionNode +from .base import Box, Container, ListMergeMode, Node, SCMode, UnionNode from .basecontainer import BaseContainer from .errors import ( MissingMandatoryValue, @@ -257,11 +257,17 @@ def merge( Tuple[Any, ...], Any, ], + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, ) -> Union[ListConfig, DictConfig]: """ Merge a list of previously created configs into a single one :param configs: Input configs + :param list_merge_mode: Behavior for merging lists + REPLACE: Replaces the target list with the new one (default) + EXTEND: Extends the target list with the new one + EXTEND_UNIQUE: Extends the target list items with items not present in it + hint: use `from omegaconf import ListMergeMode` to access the merge mode :return: the merged config object. """ assert len(configs) > 0 @@ -270,7 +276,10 @@ def merge( assert isinstance(target, (DictConfig, ListConfig)) with flag_override(target, "readonly", False): - target.merge_with(*configs[1:]) + target.merge_with( + *configs[1:], + list_merge_mode=list_merge_mode, + ) turned_readonly = target._get_flag("readonly") is True if turned_readonly: @@ -288,6 +297,7 @@ def unsafe_merge( Tuple[Any, ...], Any, ], + list_merge_mode: ListMergeMode = ListMergeMode.REPLACE, ) -> Union[ListConfig, DictConfig]: """ Merge a list of previously created configs into a single one @@ -295,6 +305,11 @@ def unsafe_merge( However, the input configs must not be used after this operation as will become inconsistent. :param configs: Input configs + :param list_merge_mode: Behavior for merging lists + REPLACE: Replaces the target list with the new one (default) + EXTEND: Extends the target list with the new one + EXTEND_UNIQUE: Extends the target list items with items not present in it + hint: use `from omegaconf import ListMergeMode` to access the merge mode :return: the merged config object. """ assert len(configs) > 0 @@ -305,7 +320,10 @@ def unsafe_merge( with flag_override( target, ["readonly", "no_deepcopy_set_nodes"], [False, True] ): - target.merge_with(*configs[1:]) + target.merge_with( + *configs[1:], + list_merge_mode=list_merge_mode, + ) turned_readonly = target._get_flag("readonly") is True if turned_readonly: diff --git a/tests/test_merge.py b/tests/test_merge.py index 6108ce70e..eb4a6f8b4 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -34,7 +34,7 @@ get_value_kind, is_structured_config, ) -from omegaconf.base import Node +from omegaconf.base import ListMergeMode, Node from omegaconf.errors import ConfigKeyError, UnsupportedValueType from omegaconf.nodes import IntegerNode from tests import ( @@ -1190,6 +1190,49 @@ def test_merge_list_list() -> None: assert a == b +@mark.parametrize("merge", [OmegaConf.merge, OmegaConf.unsafe_merge]) +@mark.parametrize( + "list_merge_mode,c1,c2,expected", + [ + (ListMergeMode.REPLACE, [1, 2], [3, 4], [3, 4]), + (ListMergeMode.EXTEND, [{"a": 1}], [{"b": 2}], [{"a": 1}, {"b": 2}]), + ( + ListMergeMode.EXTEND, + {"list": [1, 2]}, + {"list": [3, 4]}, + {"list": [1, 2, 3, 4]}, + ), + ( + ListMergeMode.EXTEND, + {"list1": [1, 2], "list2": [1, 2]}, + {"list1": [3, 4], "list2": [3, 4]}, + {"list1": [1, 2, 3, 4], "list2": [1, 2, 3, 4]}, + ), + (ListMergeMode.EXTEND, [[1, 2], [3, 4]], [[5, 6]], [[1, 2], [3, 4], [5, 6]]), + (ListMergeMode.EXTEND, [1, 2], [1, 2], [1, 2, 1, 2]), + (ListMergeMode.EXTEND, [{"a": 1}], [{"a": 1}], [{"a": 1}, {"a": 1}]), + (ListMergeMode.EXTEND_UNIQUE, [1, 2], [1, 3], [1, 2, 3]), + ( + ListMergeMode.EXTEND_UNIQUE, + [{"a": 1}, {"b": 2}], + [{"a": 1}, {"c": 3}], + [{"a": 1}, {"b": 2}, {"c": 3}], + ), + ], +) +def test_merge_list_modes( + merge: Any, + c1: Any, + c2: Any, + list_merge_mode: ListMergeMode, + expected: Any, +) -> None: + a = OmegaConf.create(c1) + b = OmegaConf.create(c2) + merged = merge(a, b, list_merge_mode=list_merge_mode) + assert merged == expected + + @mark.parametrize("merge_func", [OmegaConf.merge, OmegaConf.unsafe_merge]) @mark.parametrize( "base, merge, exception",