From a08b676de44721c82355a547193494b80f350789 Mon Sep 17 00:00:00 2001 From: Omry Yadan Date: Fri, 29 Jan 2021 15:23:35 -0800 Subject: [PATCH] changing default reftype to Any This effects assignment to fields that are Structured Config: before such assignment was only allowed for objects of compatible type. Now anything can be assigned to such fields. --- omegaconf/_utils.py | 69 ++------ omegaconf/base.py | 8 +- omegaconf/basecontainer.py | 44 ++--- omegaconf/dictconfig.py | 43 ++--- omegaconf/listconfig.py | 2 +- omegaconf/nodes.py | 2 +- omegaconf/omegaconf.py | 24 +-- .../structured_conf/test_structured_basic.py | 21 ++- .../structured_conf/test_structured_config.py | 18 +- tests/test_basic_ops_dict.py | 156 ++++++++++++------ tests/test_create.py | 7 +- tests/test_errors.py | 19 +-- tests/test_merge.py | 53 ++++-- tests/test_serialization.py | 17 +- tests/test_utils.py | 65 +++++--- 15 files changed, 288 insertions(+), 260 deletions(-) diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index e5b3adea7..2aa0b3e09 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -527,57 +527,23 @@ def _get_value(value: Any) -> Any: def get_ref_type(obj: Any, key: Any = None) -> Optional[Type[Any]]: - from omegaconf import DictConfig, ListConfig - from omegaconf.base import Container, Node - from omegaconf.nodes import ValueNode + from omegaconf import Container, Node - def none_as_any(t: Optional[Type[Any]]) -> Union[Type[Any], Any]: - if t is None: - return Any - else: - return t - - if isinstance(obj, Container) and key is not None: - obj = obj._get_node(key) + if isinstance(obj, Container): + if key is not None: + obj = obj._get_node(key) + else: + if key is not None: + raise ValueError("Key must only be provided when obj is a container") - is_optional = True - ref_type = None - if isinstance(obj, ValueNode): - is_optional = obj._is_optional() + if isinstance(obj, Node): ref_type = obj._metadata.ref_type - elif isinstance(obj, Container): - if isinstance(obj, Node): - ref_type = obj._metadata.ref_type - is_optional = obj._is_optional() - kt = none_as_any(obj._metadata.key_type) - vt = none_as_any(obj._metadata.element_type) - if ( - ref_type is Any - and kt is Any - and vt is Any - and not obj._is_missing() - and not obj._is_none() - ): - ref_type = Any # type: ignore - elif not is_structured_config(ref_type): - if kt is Any: - kt = Union[str, Enum] - if isinstance(obj, DictConfig): - ref_type = Dict[kt, vt] # type: ignore - elif isinstance(obj, ListConfig): - ref_type = List[vt] # type: ignore - else: - if isinstance(obj, dict): - ref_type = Dict[Union[str, Enum], Any] - elif isinstance(obj, (list, tuple)): - ref_type = List[Any] + if obj._is_optional() and ref_type is not Any: + return Optional[ref_type] # type: ignore else: - ref_type = get_type_of(obj) - - ref_type = none_as_any(ref_type) - if is_optional and ref_type is not Any: - ref_type = Optional[ref_type] # type: ignore - return ref_type + return ref_type + else: + return Any # type: ignore def _raise(ex: Exception, cause: Exception) -> None: @@ -590,7 +556,7 @@ def _raise(ex: Exception, cause: Exception) -> None: ex.__cause__ = cause else: ex.__cause__ = None - raise ex # set end OC_CAUSE=1 for full backtrace + raise ex.with_traceback(sys.exc_info()[2]) # set end OC_CAUSE=1 for full backtrace def format_and_raise( @@ -604,9 +570,6 @@ def format_and_raise( from omegaconf import OmegaConf from omegaconf.base import Node - # Uncomment to make debugging easier. Note that this will cause some tests to fail - # raise cause - if isinstance(cause, AssertionError): raise @@ -766,7 +729,3 @@ def is_generic_dict(type_: Any) -> bool: def is_container_annotation(type_: Any) -> bool: return is_list_annotation(type_) or is_dict_annotation(type_) - - -def is_generic_container(type_: Any) -> bool: - return is_generic_dict(type_) or is_generic_list(type_) diff --git a/omegaconf/base.py b/omegaconf/base.py index 1092cb6e1..f58cd0d5a 100644 --- a/omegaconf/base.py +++ b/omegaconf/base.py @@ -21,9 +21,9 @@ @dataclass class Metadata: - ref_type: Optional[Type[Any]] + ref_type: Union[Type[Any], Any] - object_type: Optional[Type[Any]] + object_type: Union[Type[Any], Any] optional: bool @@ -47,6 +47,8 @@ class ContainerMetadata(Metadata): element_type: Any = None def __post_init__(self) -> None: + if self.ref_type is None: + self.ref_type = Any assert self.key_type is Any or isinstance(self.key_type, type) if self.element_type is not None: assert self.element_type is Any or isinstance(self.element_type, type) @@ -438,7 +440,7 @@ def _resolve_simple_interpolation( value=value, parent=self, metadata=Metadata( - ref_type=None, object_type=None, key=key, optional=True + ref_type=Any, object_type=Any, key=key, optional=True ), ) except Exception as e: diff --git a/omegaconf/basecontainer.py b/omegaconf/basecontainer.py index 0767533c3..ec8a0b4f8 100644 --- a/omegaconf/basecontainer.py +++ b/omegaconf/basecontainer.py @@ -41,11 +41,6 @@ class BaseContainer(Container, ABC): def __init__(self, parent: Optional["Container"], metadata: ContainerMetadata): super().__init__(parent=parent, metadata=metadata) self.__dict__["_content"] = None - self._normalize_ref_type() - - def _normalize_ref_type(self) -> None: - if self._metadata.ref_type is None: - self._metadata.ref_type = Any # type: ignore def _resolve_with_default( self, @@ -300,18 +295,24 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None: _update_types(node=dest, ref_type=src_ref_type, object_type=src_type) return - dest._validate_merge(key=None, value=src) + dest._validate_merge(value=src) def expand(node: Container) -> None: - type_ = get_ref_type(node) - if type_ is not None: - _is_optional, type_ = _resolve_optional(type_) - if is_dict_annotation(type_): - node._set_value({}) - elif is_list_annotation(type_): - node._set_value([]) + rt = node._metadata.ref_type + val: Any + if rt is not Any: + if is_dict_annotation(rt): + val = {} + elif is_list_annotation(rt): + val = [] else: - node._set_value(type_) + val = rt + elif isinstance(node, DictConfig): + val = {} + else: + assert False + + node._set_value(val) if ( src._is_missing() @@ -330,6 +331,10 @@ def expand(node: Container) -> None: for key, src_value in src.items_ex(resolve=False): src_node = src._get_node(key, validate_access=False) dest_node = dest._get_node(key, validate_access=False) + + if isinstance(dest_node, DictConfig): + dest_node._validate_merge(value=src_node) + missing_src_value = _is_missing_value(src_value) if ( @@ -360,7 +365,6 @@ def expand(node: Container) -> None: if dest_node is not None: if isinstance(dest_node, BaseContainer): if isinstance(src_value, BaseContainer): - dest._validate_merge(key=key, value=src_value) dest_node._merge_with(src_value) elif not missing_src_value: dest.__setitem__(key, src_value) @@ -392,7 +396,7 @@ def expand(node: Container) -> None: from omegaconf import open_dict if is_structured_config(src_type): - # verified to be compatible above in _validate_set_merge_impl + # verified to be compatible above in _validate_merge with open_dict(dest): dest[key] = src._get_node(key) else: @@ -489,7 +493,7 @@ def _set_item_impl(self, key: Any, value: Any) -> None: Changes the value of the node key with the desired value. If the node key doesn't exist it creates a new one. """ - from omegaconf.omegaconf import OmegaConf, _maybe_wrap + from omegaconf.omegaconf import _maybe_wrap from .nodes import AnyNode, ValueNode @@ -552,11 +556,7 @@ def wrap(key: Any, val: Any) -> Node: target = self._get_node(key) if target is None: if is_structured_config(val): - element_type = self._metadata.element_type - if element_type is Any: - ref_type = OmegaConf.get_type(val) - else: - ref_type = element_type + ref_type = self._metadata.element_type else: is_optional = target._is_optional() ref_type = target._metadata.ref_type diff --git a/omegaconf/dictconfig.py b/omegaconf/dictconfig.py index b5b5834ff..9dd41f643 100644 --- a/omegaconf/dictconfig.py +++ b/omegaconf/dictconfig.py @@ -17,6 +17,7 @@ from ._utils import ( ValueKind, + _get_value, _is_interpolation, _valid_dict_key_annotation_type, format_and_raise, @@ -25,7 +26,6 @@ get_value_kind, is_container_annotation, is_dict, - is_generic_container, is_primitive_dict, is_structured_config, is_structured_config_frozen, @@ -201,34 +201,37 @@ def _validate_set(self, key: Any, value: Any) -> None: if validation_error: self._raise_invalid_value(value, value_type, target_type) - def _validate_merge(self, key: Any, value: Any) -> None: + def _validate_merge(self, value: Any) -> None: from omegaconf import OmegaConf - self._validate_non_optional(key, value) + dest = self + src = value - target = self._get_node(key) if key is not None else self + self._validate_non_optional(None, src) - target_has_ref_type = isinstance( - target, DictConfig - ) and target._metadata.ref_type not in (Any, dict) - is_valid_value = target is None or not target_has_ref_type - if is_valid_value: - return + dest_obj_type = OmegaConf.get_type(dest) + src_obj_type = OmegaConf.get_type(src) - target_type = target._metadata.ref_type # type: ignore - value_type = OmegaConf.get_type(value) - if is_generic_container(target_type): + if dest._is_missing() and src._metadata.object_type is not None: + self._validate_set(key=None, value=_get_value(src)) + + if src._is_missing(): return - # Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks) + validation_error = ( - target_type is not None - and value_type is not None - and not issubclass(value_type, target_type) - and not is_dict(value_type) + dest_obj_type is not None + and src_obj_type is not None + and is_structured_config(dest_obj_type) + and not OmegaConf.is_none(src) + and not is_dict(src_obj_type) + and not issubclass(src_obj_type, dest_obj_type) ) - if validation_error: - self._raise_invalid_value(value, value_type, target_type) + msg = ( + f"Merge error : {type_str(src_obj_type)} is not a " + f"subclass of {type_str(dest_obj_type)}. value: {src}" + ) + raise ValidationError(msg) def _validate_non_optional(self, key: Any, value: Any) -> None: from omegaconf import OmegaConf diff --git a/omegaconf/listconfig.py b/omegaconf/listconfig.py index 420cc9a45..121d1352e 100644 --- a/omegaconf/listconfig.py +++ b/omegaconf/listconfig.py @@ -47,7 +47,7 @@ def __init__( content: Union[List[Any], Tuple[Any, ...], str, None], key: Any = None, parent: Optional[Container] = None, - element_type: Optional[Type[Any]] = None, + element_type: Union[Type[Any], Any] = Any, is_optional: bool = True, ref_type: Union[Type[Any], Any] = Any, flags: Optional[Dict[str, bool]] = None, diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 2b5f14a9e..d0d6bcc21 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -142,7 +142,7 @@ def __init__( parent=parent, value=value, metadata=Metadata( - ref_type=Any, object_type=None, key=key, optional=is_optional # type: ignore + ref_type=Any, object_type=None, key=key, optional=is_optional ), ) diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 0c0ac6f29..c085a6617 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -231,34 +231,26 @@ def _create_impl( # noqa F811 or is_structured_config(obj) or obj is None ): - ref_type = None - if is_structured_config(obj): - ref_type = get_type_of(obj) - elif OmegaConf.is_dict(obj): - ref_type = obj._metadata.ref_type - - if ref_type is None: - ref_type = OmegaConf.get_type(obj) - if isinstance(obj, DictConfig): key_type = obj._metadata.key_type element_type = obj._metadata.element_type else: - key_type, element_type = get_dict_key_value_types(ref_type) + obj_type = OmegaConf.get_type(obj) + key_type, element_type = get_dict_key_value_types(obj_type) return DictConfig( content=obj, parent=parent, - ref_type=ref_type, + ref_type=Any, key_type=key_type, element_type=element_type, flags=flags, ) elif is_primitive_list(obj) or OmegaConf.is_list(obj): - ref_type = OmegaConf.get_type(obj) - element_type = get_list_element_type(ref_type) + obj_type = OmegaConf.get_type(obj) + element_type = get_list_element_type(obj_type) return ListConfig( element_type=element_type, - ref_type=ref_type, + ref_type=Any, content=obj, parent=parent, flags=flags, @@ -805,7 +797,7 @@ def _node_wrap( is_optional: bool, value: Any, key: Any, - ref_type: Any = None, + ref_type: Any = Any, ) -> Node: node: Node is_dict = is_primitive_dict(value) or is_dict_annotation(type_) @@ -836,7 +828,7 @@ def _node_wrap( ref_type=ref_type, ) elif is_structured_config(type_) or is_structured_config(value): - key_type, element_type = get_dict_key_value_types(type_) + key_type, element_type = get_dict_key_value_types(value) node = DictConfig( ref_type=type_, is_optional=is_optional, diff --git a/tests/structured_conf/test_structured_basic.py b/tests/structured_conf/test_structured_basic.py index 716624f64..440cb180a 100644 --- a/tests/structured_conf/test_structured_basic.py +++ b/tests/structured_conf/test_structured_basic.py @@ -59,7 +59,9 @@ def test_assignment_of_subclass(self, class_type: str) -> None: def test_assignment_of_non_subclass_1(self, class_type: str) -> None: module: Any = import_module(class_type) - cfg = OmegaConf.create({"plugin": module.Plugin}) + cfg = OmegaConf.create( + {"plugin": DictConfig(module.Plugin, ref_type=module.Plugin)} + ) with pytest.raises(ValidationError): cfg.plugin = OmegaConf.structured(module.FaultyPlugin) @@ -120,7 +122,9 @@ def test_none_assignment(self, class_type: str) -> None: class TestFailedAssignmentOrMerges: def test_assignment_of_non_subclass_2(self, class_type: str, rhs: Any) -> None: module: Any = import_module(class_type) - cfg = OmegaConf.create({"plugin": module.Plugin}) + cfg = OmegaConf.create( + {"plugin": DictConfig(module.Plugin, ref_type=module.Plugin)} + ) with pytest.raises(ValidationError): cfg.plugin = rhs @@ -191,6 +195,12 @@ def test_merge_missing_object_onto_typed_dictconfig(self, class_type: str) -> No assert isinstance(c2, DictConfig) assert OmegaConf.is_missing(c2.users, "bob") + def test_merge_into_missing_sc(self, class_type: str) -> None: + module: Any = import_module(class_type) + c1 = OmegaConf.structured(module.PluginHolder) + c2 = OmegaConf.merge(c1, {"plugin": "???"}) + assert c2.plugin == module.Plugin() + def test_merge_missing_key_onto_structured_none(self, class_type: str) -> None: module: Any = import_module(class_type) c1 = OmegaConf.create({"foo": OmegaConf.structured(module.OptionalUser)}) @@ -228,9 +238,9 @@ def test_merge_structured_interpolation_onto_dict(self, class_type: str) -> None src.user_3 = None c2 = OmegaConf.merge(c1, src) assert c2.user_2.name == "bob" - assert get_ref_type(c2, "user_2") == Optional[module.User] + assert get_ref_type(c2, "user_2") == Any assert c2.user_3 is None - assert get_ref_type(c2, "user_3") == Optional[module.User] + assert get_ref_type(c2, "user_3") == Any class TestMissing: def test_missing1(self, class_type: str) -> None: @@ -285,6 +295,9 @@ def test_plugin_merge(self, class_type: str) -> None: assert ret == concrete assert OmegaConf.get_type(ret) == module.ConcretePlugin + def test_plugin_merge_2(self, class_type: str) -> None: + module: Any = import_module(class_type) + plugin = OmegaConf.structured(module.Plugin) more_fields = OmegaConf.structured(module.PluginWithAdditionalField) ret = OmegaConf.merge(plugin, more_fields) assert ret == more_fields diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index 3d0f3e2a4..6c0109b39 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -1,7 +1,6 @@ import sys -from enum import Enum from importlib import import_module -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import pytest @@ -221,7 +220,9 @@ def test_assignment_to_nested_structured_config(self, class_type: str) -> None: def test_assignment_to_structured_inside_dict_config(self, class_type: str) -> None: module: Any = import_module(class_type) - conf = OmegaConf.create({"val": module.Nested}) + conf = OmegaConf.create( + {"val": DictConfig(module.Nested, ref_type=module.Nested)} + ) with pytest.raises(ValidationError): conf.val = 10 @@ -850,9 +851,8 @@ def test_recursive_list(self, class_type: str) -> None: def test_create_untyped_dict(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.structured(module.UntypedDict) - dt = Dict[Union[str, Enum], Any] - assert _utils.get_ref_type(cfg, "dict") == dt - assert _utils.get_ref_type(cfg, "opt_dict") == Optional[dt] + assert _utils.get_ref_type(cfg, "dict") == Dict[Any, Any] + assert _utils.get_ref_type(cfg, "opt_dict") == Optional[Dict[Any, Any]] assert cfg.dict == {"foo": "var"} assert cfg.opt_dict is None @@ -915,13 +915,13 @@ def test_str2str(self, class_type: str) -> None: assert cfg.hello == "world" with pytest.raises(KeyValidationError): - cfg[Color.RED] = "fail" + cfg[Color.RED] def test_str2str_as_sub_node(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.create({"foo": module.DictSubclass.Str2Str}) assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Str2Str - assert _utils.get_ref_type(cfg.foo) == Optional[module.DictSubclass.Str2Str] + assert _utils.get_ref_type(cfg.foo) == Any cfg.foo.hello = "world" assert cfg.foo.hello == "world" @@ -955,7 +955,7 @@ def test_int2str_as_sub_node(self, class_type: str) -> None: module: Any = import_module(class_type) cfg = OmegaConf.create({"foo": module.DictSubclass.Int2Str}) assert OmegaConf.get_type(cfg.foo) == module.DictSubclass.Int2Str - assert _utils.get_ref_type(cfg.foo) == Optional[module.DictSubclass.Int2Str] + assert _utils.get_ref_type(cfg.foo) == Any cfg.foo[10] = "ten" assert cfg.foo[10] == "ten" diff --git a/tests/test_basic_ops_dict.py b/tests/test_basic_ops_dict.py index f9beebae9..02d021179 100644 --- a/tests/test_basic_ops_dict.py +++ b/tests/test_basic_ops_dict.py @@ -22,15 +22,7 @@ from omegaconf.basecontainer import BaseContainer from omegaconf.errors import ConfigKeyError, ConfigTypeError, KeyValidationError -from . import ( - ConcretePlugin, - Enum1, - IllegalType, - Plugin, - StructuredWithMissing, - User, - does_not_raise, -) +from . import ConcretePlugin, Enum1, IllegalType, Plugin, StructuredWithMissing, User def test_setattr_deep_value() -> None: @@ -689,11 +681,11 @@ def test_get_ref_type_with_conflict() -> None: ) assert OmegaConf.get_type(cfg.user) == User - assert _utils.get_ref_type(cfg.user) == Optional[User] + assert _utils.get_ref_type(cfg.user) == Any # Interpolation inherits both type and ref type from the target assert OmegaConf.get_type(cfg.inter) == User - assert _utils.get_ref_type(cfg.inter) == Optional[User] + assert _utils.get_ref_type(cfg.inter) == Any def test_is_missing() -> None: @@ -722,61 +714,115 @@ def test_assign_to_reftype_none_or_any(ref_type: Any, assign: Any) -> None: @pytest.mark.parametrize( - "ref_type,values,assign,expectation", + "ref_type,assign", [ - (Plugin, [None, "???", Plugin], None, does_not_raise), - (Plugin, [None, "???", Plugin], Plugin, does_not_raise), - (Plugin, [None, "???", Plugin], Plugin(), does_not_raise), - (Plugin, [None, "???", Plugin], ConcretePlugin, does_not_raise), - (Plugin, [None, "???", Plugin], ConcretePlugin(), does_not_raise), - (Plugin, [None, "???", Plugin], 10, lambda: pytest.raises(ValidationError)), - (ConcretePlugin, [None, "???", ConcretePlugin], None, does_not_raise), - ( - ConcretePlugin, - [None, "???", ConcretePlugin], + (Plugin, None), + (Plugin, Plugin), + (Plugin, Plugin()), + (Plugin, ConcretePlugin), + (Plugin, ConcretePlugin()), + (ConcretePlugin, None), + pytest.param(ConcretePlugin, ConcretePlugin, id="subclass=subclass_obj"), + pytest.param(ConcretePlugin, ConcretePlugin(), id="subclass=subclass_obj"), + ], +) +class TestAssignAndMergeIntoReftypePlugin: + def _test_assign(self, ref_type: Any, value: Any, assign: Any) -> None: + cfg = OmegaConf.create({"foo": DictConfig(ref_type=ref_type, content=value)}) + assert _utils.get_ref_type(cfg, "foo") == Optional[ref_type] + cfg.foo = assign + assert cfg.foo == assign + assert _utils.get_ref_type(cfg, "foo") == Optional[ref_type] + + def _test_merge(self, ref_type: Any, value: Any, assign: Any) -> None: + cfg = OmegaConf.create({"foo": DictConfig(ref_type=ref_type, content=value)}) + cfg2 = OmegaConf.merge(cfg, {"foo": assign}) + assert isinstance(cfg2, DictConfig) + assert cfg2.foo == assign + assert _utils.get_ref_type(cfg2, "foo") == Optional[ref_type] + + def test_assign_to_reftype_plugin1(self, ref_type: Any, assign: Any) -> None: + self._test_assign(ref_type, ref_type, assign) + self._test_assign(ref_type, ref_type(), assign) + + @pytest.mark.parametrize("value", [None, "???"]) + def test_assign_to_reftype_plugin( + self, ref_type: Any, value: Any, assign: Any + ) -> None: + self._test_assign(ref_type, value, assign) + + def test_merge_into_reftype_plugin_(self, ref_type: Any, assign: Any) -> None: + self._test_merge(ref_type, ref_type, assign) + self._test_merge(ref_type, ref_type(), assign) + + @pytest.mark.parametrize("value", [None, "???"]) + def test_merge_into_reftype_plugin( + self, ref_type: Any, value: Any, assign: Any + ) -> None: + self._test_merge(ref_type, value, assign) + + +@pytest.mark.parametrize( + "ref_type,assign,expectation", + [ + pytest.param( Plugin, - lambda: pytest.raises(ValidationError), - ), - ( - ConcretePlugin, - [None, "???", ConcretePlugin], - Plugin(), - lambda: pytest.raises(ValidationError), + 10, + pytest.raises(ValidationError), + id="assign_primitive_to_typed", ), - ( - ConcretePlugin, - [None, "???", ConcretePlugin], + pytest.param( ConcretePlugin, - does_not_raise, + Plugin, + pytest.raises(ValidationError), + id="assign_base_type_to_subclass", ), - ( + pytest.param( ConcretePlugin, - [None, "???", ConcretePlugin], - ConcretePlugin(), - does_not_raise, + Plugin(), + pytest.raises(ValidationError), + id="assign_base_instance_to_subclass", ), ], ) -def test_assign_to_reftype_plugin( - ref_type: Any, values: List[Any], assign: Any, expectation: Any -) -> None: - for value in values: +class TestAssignAndMergeIntoReftypePlugin_Errors: + def _test_assign( + self, ref_type: Any, value: Any, assign: Any, expectation: Any + ) -> None: cfg = OmegaConf.create({"foo": DictConfig(ref_type=ref_type, content=value)}) - with expectation(): - assert _utils.get_ref_type(cfg, "foo") == Optional[ref_type] + with expectation: cfg.foo = assign - assert cfg.foo == assign - # validate assignment does not change ref type. - assert _utils.get_ref_type(cfg, "foo") == Optional[ref_type] - if value is not None: - cfg = OmegaConf.create( - {"foo": DictConfig(ref_type=ref_type, content=value)} - ) - with expectation(): - cfg2 = OmegaConf.merge(cfg, {"foo": assign}) - assert isinstance(cfg2, DictConfig) - assert cfg2.foo == assign + def _test_merge( + self, ref_type: Any, value: Any, assign: Any, expectation: Any + ) -> None: + cfg = OmegaConf.create({"foo": DictConfig(ref_type=ref_type, content=value)}) + with expectation: + OmegaConf.merge(cfg, {"foo": assign}) + + def test_assign_to_reftype_plugin_( + self, ref_type: Any, assign: Any, expectation: Any + ) -> None: + self._test_assign(ref_type, ref_type, assign, expectation) + self._test_assign(ref_type, ref_type(), assign, expectation) + + @pytest.mark.parametrize("value", [None, "???"]) + def test_assign_to_reftype_plugin( + self, ref_type: Any, value: Any, assign: Any, expectation: Any + ) -> None: + self._test_assign(ref_type, value, assign, expectation) + + def test_merge_into_reftype_plugin1( + self, ref_type: Any, assign: Any, expectation: Any + ) -> None: + self._test_merge(ref_type, ref_type, assign, expectation) + self._test_merge(ref_type, ref_type(), assign, expectation) + + @pytest.mark.parametrize("value", [None, "???"]) + def test_merge_into_reftype_plugin( + self, ref_type: Any, value: Any, assign: Any, expectation: Any + ) -> None: + self._test_merge(ref_type, value, assign, expectation) def test_setdefault() -> None: @@ -813,7 +859,7 @@ def test_self_assign_list_value_with_ref_type(c: Any) -> None: assert cfg == c -def test_assign_to_sc_field_without_ref_type(): +def test_assign_to_sc_field_without_ref_type() -> None: cfg = OmegaConf.create({"plugin": ConcretePlugin}) with pytest.raises(ValidationError): cfg.plugin.params.foo = "bar" diff --git a/tests/test_create.py b/tests/test_create.py index a3cf39257..ed1d61de2 100644 --- a/tests/test_create.py +++ b/tests/test_create.py @@ -1,8 +1,7 @@ """Testing for OmegaConf""" import re import sys -from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional import pytest import yaml @@ -286,11 +285,11 @@ def test_create_untyped_list() -> None: from omegaconf._utils import get_ref_type cfg = ListConfig(ref_type=List, content=[]) - assert get_ref_type(cfg) == Optional[List[Any]] + assert get_ref_type(cfg) == Optional[List] def test_create_untyped_dict() -> None: from omegaconf._utils import get_ref_type cfg = DictConfig(ref_type=Dict, content={}) - assert get_ref_type(cfg) == Optional[Dict[Union[str, Enum], Any]] + assert get_ref_type(cfg) == Optional[Dict] diff --git a/tests/test_errors.py b/tests/test_errors.py index a00cc7da6..8b63aa801 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from enum import Enum from textwrap import dedent -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type import pytest @@ -102,7 +102,6 @@ def finalize(self, cfg: Any) -> None: parent_node=lambda cfg: cfg, child_node=lambda cfg: cfg._get_node("num"), object_type=StructuredWithMissing, - ref_type=Optional[StructuredWithMissing], key="num", ), id="structured:update_with_invalid_value", @@ -116,7 +115,6 @@ def finalize(self, cfg: Any) -> None: parent_node=lambda cfg: cfg, child_node=lambda cfg: cfg._get_node("num"), object_type=StructuredWithMissing, - ref_type=Optional[StructuredWithMissing], key="num", ), id="structured:update:none_to_non_optional", @@ -127,7 +125,6 @@ def finalize(self, cfg: Any) -> None: op=lambda cfg: OmegaConf.update(cfg, "a", IllegalType(), merge=True), key="a", exception_type=UnsupportedValueType, - ref_type=Optional[Dict[Union[str, Enum], Any]], msg="Value 'IllegalType' is not a supported primitive type", ), id="dict:update:object_of_illegal_type", @@ -140,7 +137,6 @@ def finalize(self, cfg: Any) -> None: key="foo", child_node=lambda cfg: cfg._get_node("foo"), exception_type=ReadonlyConfigError, - ref_type=Optional[Dict[Union[str, Enum], Any]], msg="Cannot pop from read-only node", ), id="dict,readonly:pop", @@ -175,7 +171,6 @@ def finalize(self, cfg: Any) -> None: msg="Key 'fail' not in 'ConcretePlugin'", key="fail", object_type=ConcretePlugin, - ref_type=Optional[ConcretePlugin], ), id="structured:access_invalid_attribute", ), @@ -231,7 +226,6 @@ def finalize(self, cfg: Any) -> None: msg="Invalid type assigned : int is not a subclass of FoobarParams. value: 20", key="params", object_type=ConcretePlugin, - ref_type=Optional[ConcretePlugin], child_node=lambda cfg: cfg.params, ), id="structured:setattr,invalid_type_assigned_to_structured", @@ -445,7 +439,7 @@ def finalize(self, cfg: Any) -> None: """\ Key foo (str) is incompatible with (int) \tfull_key: foo - \treference_type=Optional[Dict[int, Any]] + \treference_type=Any \tobject_type=dict""" ), key="foo", @@ -546,7 +540,6 @@ def finalize(self, cfg: Any) -> None: exception_type=ValidationError, msg="Value 'fail' could not be converted to Integer", key="baz", - ref_type=Optional[Dict[str, int]], ), id="DictConfig[str,int]:assigned_str_value", ), @@ -645,7 +638,6 @@ def finalize(self, cfg: Any) -> None: key="foo", full_key="[foo]", msg="ListConfig indices must be integers or slices, not str", - ref_type=Optional[List[Any]], ), id="list:get_nox_ex:invalid_index_type", ), @@ -668,7 +660,6 @@ def finalize(self, cfg: Any) -> None: msg="Cannot get_node from a ListConfig object representing None", key=20, full_key="[20]", - ref_type=Optional[List[Any]], ), id="list:get_node_none", ), @@ -877,7 +868,6 @@ def finalize(self, cfg: Any) -> None: exception_type=ValidationError, object_type=None, msg="Non optional ListConfig cannot be constructed from None", - ref_type=List[int], low_level=True, ), id="list:create_not_optional:_set_value(None)", @@ -904,7 +894,6 @@ def finalize(self, cfg: Any) -> None: key=0, full_key="[0]", child_node=lambda cfg: cfg[0], - ref_type=Optional[List[int]], ), id="list,int_elements:assigned_str_element", ), @@ -920,7 +909,6 @@ def finalize(self, cfg: Any) -> None: key=0, full_key="[0]", child_node=lambda cfg: cfg[0], - ref_type=Optional[List[int]], ), id="list,int_elements:assigned_str_element", ), @@ -935,7 +923,6 @@ def finalize(self, cfg: Any) -> None: key=0, full_key="[0]", child_node=lambda cfg: cfg[0], - ref_type=Optional[List[Any]], ), id="list,not_optional:null_assignment", ), @@ -982,7 +969,6 @@ def finalize(self, cfg: Any) -> None: key=1, full_key="[1]", child_node=lambda _cfg: None, - ref_type=Optional[List[Any]], ), id="list:insert_into_missing", ), @@ -995,7 +981,6 @@ def finalize(self, cfg: Any) -> None: msg="Cannot get from a ListConfig object representing None", key=0, full_key="[0]", - ref_type=Optional[List[Any]], ), id="list:get_from_none", ), diff --git a/tests/test_merge.py b/tests/test_merge.py index 68aad7b85..e865b7919 100644 --- a/tests/test_merge.py +++ b/tests/test_merge.py @@ -11,10 +11,10 @@ OmegaConf, ReadonlyConfigError, ValidationError, - nodes, ) from omegaconf._utils import is_structured_config from omegaconf.errors import ConfigKeyError, UnsupportedValueType +from omegaconf.nodes import IntegerNode from . import ( A, @@ -46,18 +46,29 @@ "inputs, expected", [ # dictionaries - ([{}, {"a": 1}], {"a": 1}), - ([{"a": None}, {"b": None}], {"a": None, "b": None}), - ([{"a": 1}, {"b": 2}], {"a": 1, "b": 2}), - ([{"a": {"a1": 1, "a2": 2}}, {"a": {"a1": 2}}], {"a": {"a1": 2, "a2": 2}}), - ([{"a": 1, "b": 2}, {"b": 3}], {"a": 1, "b": 3}), - (({"a": 1, "b": 2}, {"b": {"c": 3}}), {"a": 1, "b": {"c": 3}}), - (({"b": {"c": 1}}, {"b": 1}), {"b": 1}), - (({"list": [1, 2, 3]}, {"list": [4, 5, 6]}), {"list": [4, 5, 6]}), - (({"a": 1}, {"a": nodes.IntegerNode(10)}), {"a": 10}), - (({"a": 1}, {"a": nodes.IntegerNode(10)}), {"a": nodes.IntegerNode(10)}), - (({"a": nodes.IntegerNode(10)}, {"a": 1}), {"a": 1}), - (({"a": nodes.IntegerNode(10)}, {"a": 1}), {"a": nodes.IntegerNode(1)}), + pytest.param([{}, {"a": 1}], {"a": 1}, id="dict"), + pytest.param( + [{"a": None}, {"b": None}], {"a": None, "b": None}, id="dict:none" + ), + pytest.param([{"a": 1}, {"b": 2}], {"a": 1, "b": 2}, id="dict"), + pytest.param( + [ + {"a": {"a1": 1, "a2": 2}}, + {"a": {"a1": 2}}, + ], + {"a": {"a1": 2, "a2": 2}}, + id="dict", + ), + pytest.param([{"a": 1, "b": 2}, {"b": 3}], {"a": 1, "b": 3}, id="dict"), + pytest.param( + ({"a": 1}, {"a": {"b": 3}}), {"a": {"b": 3}}, id="dict:merge_dict_into_int" + ), + pytest.param(({"b": {"c": 1}}, {"b": 1}), {"b": 1}, id="dict:merge_int_dict"), + pytest.param(({"list": [1, 2, 3]}, {"list": [4, 5, 6]}), {"list": [4, 5, 6]}), + pytest.param(({"a": 1}, {"a": IntegerNode(10)}), {"a": 10}), + pytest.param(({"a": 1}, {"a": IntegerNode(10)}), {"a": IntegerNode(10)}), + pytest.param(({"a": IntegerNode(10)}, {"a": 1}), {"a": 1}), + pytest.param(({"a": IntegerNode(10)}, {"a": 1}), {"a": IntegerNode(1)}), pytest.param( ({"a": "???"}, {"a": {}}), {"a": {}}, id="dict_merge_into_missing" ), @@ -217,12 +228,14 @@ {"dict": "???"}, id="merge_missing_dict_into_missing_dict", ), - ([{"user": User}, {"user": Group}], pytest.raises(ValidationError)), - ( - [{"user": DictConfig(ref_type=User, content=User)}, {"user": Group}], + pytest.param( + [{"user": User}, {"user": Group}], pytest.raises(ValidationError), + id="merge_group_onto_user_error", + ), + pytest.param( + [Plugin, ConcretePlugin], ConcretePlugin, id="merge_subclass_on_superclass" ), - ([Plugin, ConcretePlugin], ConcretePlugin), pytest.param( [{"user": "???"}, {"user": Group}], {"user": Group}, @@ -233,7 +246,11 @@ {"admin": None}, id="merge_none_into_existing_node", ), - ([{"user": User()}, {"user": {"foo": "bar"}}], pytest.raises(ConfigKeyError)), + pytest.param( + [{"user": User()}, {"user": {"foo": "bar"}}], + pytest.raises(ConfigKeyError), + id="merge_unknown_key_into_structured_node", + ), # DictConfig with element_type of Structured Config pytest.param( ( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 8a84d0964..dab9a468f 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -4,10 +4,9 @@ import pathlib import pickle import tempfile -from enum import Enum from pathlib import Path from textwrap import dedent -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Dict, List, Optional, Type import pytest @@ -127,13 +126,9 @@ def test_save_illegal_type() -> None: @pytest.mark.parametrize( - "obj,ref_type", - [ - ({"a": "b"}, Dict[Union[str, Enum], Any]), - ([1, 2, 3], List[Any]), - ], + "obj", [pytest.param({"a": "b"}, id="dict"), pytest.param([1, 2, 3], id="list")] ) -def test_pickle(obj: Any, ref_type: Any) -> None: +def test_pickle(obj: Any) -> None: with tempfile.TemporaryFile() as fp: c = OmegaConf.create(obj) pickle.dump(c, fp) @@ -141,7 +136,7 @@ def test_pickle(obj: Any, ref_type: Any) -> None: fp.seek(0) c1 = pickle.load(fp) assert c == c1 - assert get_ref_type(c1) == Optional[ref_type] + assert get_ref_type(c1) == Any assert c1._metadata.element_type is Any assert c1._metadata.optional is True if isinstance(c, DictConfig): @@ -203,14 +198,14 @@ def test_load_empty_file(tmpdir: str) -> None: [ (UntypedList, "list", Any, Any, False, List[Any]), (UntypedList, "opt_list", Any, Any, True, Optional[List[Any]]), - (UntypedDict, "dict", Any, Any, False, Dict[Union[str, Enum], Any]), + (UntypedDict, "dict", Any, Any, False, Dict[Any, Any]), ( UntypedDict, "opt_dict", Any, Any, True, - Optional[Dict[Union[str, Enum], Any]], + Optional[Dict[Any, Any]], ), (SubscriptedDict, "dict", int, str, False, Dict[str, int]), (SubscriptedList, "list", int, Any, False, List[int]), diff --git a/tests/test_utils.py b/tests/test_utils.py index 59b99c267..7f959c692 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union import attr -from pytest import param, raises, mark +from pytest import mark, param, raises from omegaconf import DictConfig, ListConfig, Node, OmegaConf, _utils from omegaconf._utils import is_dict_annotation, is_list_annotation @@ -18,7 +18,7 @@ ) from omegaconf.omegaconf import _node_wrap -from . import Color, ConcretePlugin, IllegalType, Plugin, does_not_raise +from . import Color, ConcretePlugin, IllegalType, Plugin, User, does_not_raise @mark.parametrize( @@ -440,16 +440,16 @@ def test_is_list_annotation(type_: Any, expected: Any) -> Any: "obj, expected", [ # Unwrapped values - param(10, Optional[int], id="int"), - param(10.0, Optional[float], id="float"), - param(True, Optional[bool], id="bool"), - param("bar", Optional[str], id="str"), - param(None, type(None), id="NoneType"), - param({}, Optional[Dict[Union[str, Enum], Any]], id="dict"), - param([], Optional[List[Any]], id="List[Any]"), - param(tuple(), Optional[List[Any]], id="List[Any]"), - param(ConcretePlugin(), Optional[ConcretePlugin], id="ConcretePlugin"), - param(ConcretePlugin, Optional[ConcretePlugin], id="ConcretePlugin"), + param(10, Any, id="int"), + param(10.0, Any, id="float"), + param(True, Any, id="bool"), + param("bar", Any, id="str"), + param(None, Any, id="NoneType"), + param({}, Any, id="dict"), + param([], Any, id="List[Any]"), + param(tuple(), Any, id="List[Any]"), + param(ConcretePlugin(), Any, id="ConcretePlugin"), + param(ConcretePlugin, Any, id="ConcretePlugin"), # Optional value nodes param(IntegerNode(10), Optional[int], id="IntegerNode"), param(FloatNode(10.0), Optional[float], id="FloatNode"), @@ -474,12 +474,12 @@ def test_is_list_annotation(type_: Any, expected: Any) -> Any: param(DictConfig(content={}), Any, id="DictConfig"), param( DictConfig(key_type=str, element_type=Color, content={}), - Optional[Dict[str, Color]], + Any, id="DictConfig[str,Color]", ), param( DictConfig(key_type=Color, element_type=int, content={}), - Optional[Dict[Color, int]], + Any, id="DictConfig[Color,int]", ), param( @@ -489,12 +489,12 @@ def test_is_list_annotation(type_: Any, expected: Any) -> Any: ), param( DictConfig(content="???"), - Optional[Dict[Union[str, Enum], Any]], + Any, id="DictConfig[Union[str, Enum], Any]_missing", ), param( DictConfig(content="???", element_type=int, key_type=str), - Optional[Dict[str, int]], + Any, id="DictConfig[str, int]_missing", ), param( @@ -514,23 +514,40 @@ def test_is_list_annotation(type_: Any, expected: Any) -> Any: id="Plugin", ), # ListConfig - param(ListConfig([]), Optional[List[Any]], id="ListConfig[Any]"), - param( - ListConfig([], element_type=int), Optional[List[int]], id="ListConfig[int]" - ), - param(ListConfig(content="???"), Optional[List[Any]], id="ListConfig_missing"), + param(ListConfig([]), Any, id="ListConfig[Any]"), + param(ListConfig([], element_type=int), Any, id="ListConfig[int]"), + param(ListConfig(content="???"), Any, id="ListConfig_missing"), param( ListConfig(content="???", element_type=int), - Optional[List[int]], + Any, id="ListConfig[int]_missing", ), - param(ListConfig(content=None), Optional[List[Any]], id="ListConfig_none"), + param(ListConfig(content=None), Any, id="ListConfig_none"), param( ListConfig(content=None, element_type=int), - Optional[List[int]], + Any, id="ListConfig[int]_none", ), ], ) def test_get_ref_type(obj: Any, expected: Any) -> None: assert _utils.get_ref_type(obj) == expected + + +@mark.parametrize( + "obj, key, expected", + [ + param({"foo": 10}, "foo", Any, id="dict"), + param(User, "name", str, id="User.name"), + param(User, "age", int, id="User.age"), + param({"user": User}, "user", Any, id="user"), + ], +) +def test_get_node_ref_type(obj: Any, key: str, expected: Any) -> None: + cfg = OmegaConf.create(obj) + assert _utils.get_ref_type(cfg, key) == expected + + +def test_get_ref_type_error() -> None: + with raises(ValueError): + _utils.get_ref_type(AnyNode(), "foo")