From a68b149b8e1a9e9a0cabc83e8691df8c6620909a Mon Sep 17 00:00:00 2001 From: Jasha Sommer-Simpson <8935917+Jasha10@users.noreply.github.com> Date: Sat, 24 Jun 2023 18:01:43 -0500 Subject: [PATCH] OmegaConf.resolve: handle custom resolvers returning dict/list (#1093) --- omegaconf/_impl.py | 15 ++++++++++++++- tests/interpolation/test_custom_resolvers.py | 19 +++++++++++++++++++ 2 files changed, 33 insertions(+), 1 deletion(-) diff --git a/omegaconf/_impl.py b/omegaconf/_impl.py index 0c30ef6a2..49be30329 100644 --- a/omegaconf/_impl.py +++ b/omegaconf/_impl.py @@ -2,8 +2,15 @@ from omegaconf import MISSING, Container, DictConfig, ListConfig, Node, ValueNode from omegaconf.errors import ConfigTypeError, InterpolationToMissingValueError +from omegaconf.nodes import InterpolationResultNode -from ._utils import _DEFAULT_MARKER_, _get_value +from ._utils import ( + _DEFAULT_MARKER_, + _ensure_container, + _get_value, + is_primitive_container, + is_structured_config, +) def _resolve_container_value(cfg: Container, key: Any) -> None: @@ -17,6 +24,12 @@ def _resolve_container_value(cfg: Container, key: Any) -> None: else: if isinstance(resolved, Container): _resolve(resolved) + if isinstance(resolved, InterpolationResultNode): + resolved_value = _get_value(resolved) + if is_primitive_container(resolved_value) or is_structured_config( + resolved_value + ): + resolved = _ensure_container(resolved_value) if isinstance(resolved, Container) and isinstance(node, ValueNode): cfg[key] = resolved else: diff --git a/tests/interpolation/test_custom_resolvers.py b/tests/interpolation/test_custom_resolvers.py index 216b4da04..e645d0ace 100644 --- a/tests/interpolation/test_custom_resolvers.py +++ b/tests/interpolation/test_custom_resolvers.py @@ -6,6 +6,7 @@ from omegaconf import OmegaConf, Resolver from omegaconf.nodes import InterpolationResultNode +from tests import User from tests.interpolation import dereference_node @@ -478,3 +479,21 @@ def test_merge_into_resolver_output( cfg = OmegaConf.create({"foo": "${make:}"}) assert OmegaConf.merge(cfg, cfg2) == expected + + +@mark.parametrize( + "primitive_container", + [ + param({"first": 1, "second": 2}, id="dict"), + param(["first", "second"], id="list"), + param(User(name="Bond", age=7), id="user"), + ], +) +def test_resolve_resolver_returning_primitive_container( + restore_resolvers: Any, primitive_container: Any +) -> None: + OmegaConf.register_new_resolver("returns_container", lambda: primitive_container) + cfg = OmegaConf.create({"foo": "${returns_container:}"}) + assert cfg.foo == primitive_container + OmegaConf.resolve(cfg) + assert cfg.foo == primitive_container