Skip to content

Commit

Permalink
changing default reftype to Any
Browse files Browse the repository at this point in the history
This effects assignment to fields that are Structured Config:
before such assignment was only allowed for objects of compatible type.
Now anything can be assigned to such fields.
  • Loading branch information
omry committed Feb 1, 2021
1 parent 6850ccf commit a08b676
Show file tree
Hide file tree
Showing 15 changed files with 288 additions and 260 deletions.
69 changes: 14 additions & 55 deletions omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,57 +527,23 @@ def _get_value(value: Any) -> Any:


def get_ref_type(obj: Any, key: Any = None) -> Optional[Type[Any]]:
from omegaconf import DictConfig, ListConfig
from omegaconf.base import Container, Node
from omegaconf.nodes import ValueNode
from omegaconf import Container, Node

def none_as_any(t: Optional[Type[Any]]) -> Union[Type[Any], Any]:
if t is None:
return Any
else:
return t

if isinstance(obj, Container) and key is not None:
obj = obj._get_node(key)
if isinstance(obj, Container):
if key is not None:
obj = obj._get_node(key)
else:
if key is not None:
raise ValueError("Key must only be provided when obj is a container")

is_optional = True
ref_type = None
if isinstance(obj, ValueNode):
is_optional = obj._is_optional()
if isinstance(obj, Node):
ref_type = obj._metadata.ref_type
elif isinstance(obj, Container):
if isinstance(obj, Node):
ref_type = obj._metadata.ref_type
is_optional = obj._is_optional()
kt = none_as_any(obj._metadata.key_type)
vt = none_as_any(obj._metadata.element_type)
if (
ref_type is Any
and kt is Any
and vt is Any
and not obj._is_missing()
and not obj._is_none()
):
ref_type = Any # type: ignore
elif not is_structured_config(ref_type):
if kt is Any:
kt = Union[str, Enum]
if isinstance(obj, DictConfig):
ref_type = Dict[kt, vt] # type: ignore
elif isinstance(obj, ListConfig):
ref_type = List[vt] # type: ignore
else:
if isinstance(obj, dict):
ref_type = Dict[Union[str, Enum], Any]
elif isinstance(obj, (list, tuple)):
ref_type = List[Any]
if obj._is_optional() and ref_type is not Any:
return Optional[ref_type] # type: ignore
else:
ref_type = get_type_of(obj)

ref_type = none_as_any(ref_type)
if is_optional and ref_type is not Any:
ref_type = Optional[ref_type] # type: ignore
return ref_type
return ref_type
else:
return Any # type: ignore


def _raise(ex: Exception, cause: Exception) -> None:
Expand All @@ -590,7 +556,7 @@ def _raise(ex: Exception, cause: Exception) -> None:
ex.__cause__ = cause
else:
ex.__cause__ = None
raise ex # set end OC_CAUSE=1 for full backtrace
raise ex.with_traceback(sys.exc_info()[2]) # set end OC_CAUSE=1 for full backtrace


def format_and_raise(
Expand All @@ -604,9 +570,6 @@ def format_and_raise(
from omegaconf import OmegaConf
from omegaconf.base import Node

# Uncomment to make debugging easier. Note that this will cause some tests to fail
# raise cause

if isinstance(cause, AssertionError):
raise

Expand Down Expand Up @@ -766,7 +729,3 @@ def is_generic_dict(type_: Any) -> bool:

def is_container_annotation(type_: Any) -> bool:
return is_list_annotation(type_) or is_dict_annotation(type_)


def is_generic_container(type_: Any) -> bool:
return is_generic_dict(type_) or is_generic_list(type_)
8 changes: 5 additions & 3 deletions omegaconf/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@
@dataclass
class Metadata:

ref_type: Optional[Type[Any]]
ref_type: Union[Type[Any], Any]

object_type: Optional[Type[Any]]
object_type: Union[Type[Any], Any]

optional: bool

Expand All @@ -47,6 +47,8 @@ class ContainerMetadata(Metadata):
element_type: Any = None

def __post_init__(self) -> None:
if self.ref_type is None:
self.ref_type = Any
assert self.key_type is Any or isinstance(self.key_type, type)
if self.element_type is not None:
assert self.element_type is Any or isinstance(self.element_type, type)
Expand Down Expand Up @@ -438,7 +440,7 @@ def _resolve_simple_interpolation(
value=value,
parent=self,
metadata=Metadata(
ref_type=None, object_type=None, key=key, optional=True
ref_type=Any, object_type=Any, key=key, optional=True
),
)
except Exception as e:
Expand Down
44 changes: 22 additions & 22 deletions omegaconf/basecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,6 @@ class BaseContainer(Container, ABC):
def __init__(self, parent: Optional["Container"], metadata: ContainerMetadata):
super().__init__(parent=parent, metadata=metadata)
self.__dict__["_content"] = None
self._normalize_ref_type()

def _normalize_ref_type(self) -> None:
if self._metadata.ref_type is None:
self._metadata.ref_type = Any # type: ignore

def _resolve_with_default(
self,
Expand Down Expand Up @@ -300,18 +295,24 @@ def _map_merge(dest: "BaseContainer", src: "BaseContainer") -> None:
_update_types(node=dest, ref_type=src_ref_type, object_type=src_type)
return

dest._validate_merge(key=None, value=src)
dest._validate_merge(value=src)

def expand(node: Container) -> None:
type_ = get_ref_type(node)
if type_ is not None:
_is_optional, type_ = _resolve_optional(type_)
if is_dict_annotation(type_):
node._set_value({})
elif is_list_annotation(type_):
node._set_value([])
rt = node._metadata.ref_type
val: Any
if rt is not Any:
if is_dict_annotation(rt):
val = {}
elif is_list_annotation(rt):
val = []
else:
node._set_value(type_)
val = rt
elif isinstance(node, DictConfig):
val = {}
else:
assert False

node._set_value(val)

if (
src._is_missing()
Expand All @@ -330,6 +331,10 @@ def expand(node: Container) -> None:
for key, src_value in src.items_ex(resolve=False):
src_node = src._get_node(key, validate_access=False)
dest_node = dest._get_node(key, validate_access=False)

if isinstance(dest_node, DictConfig):
dest_node._validate_merge(value=src_node)

missing_src_value = _is_missing_value(src_value)

if (
Expand Down Expand Up @@ -360,7 +365,6 @@ def expand(node: Container) -> None:
if dest_node is not None:
if isinstance(dest_node, BaseContainer):
if isinstance(src_value, BaseContainer):
dest._validate_merge(key=key, value=src_value)
dest_node._merge_with(src_value)
elif not missing_src_value:
dest.__setitem__(key, src_value)
Expand Down Expand Up @@ -392,7 +396,7 @@ def expand(node: Container) -> None:
from omegaconf import open_dict

if is_structured_config(src_type):
# verified to be compatible above in _validate_set_merge_impl
# verified to be compatible above in _validate_merge
with open_dict(dest):
dest[key] = src._get_node(key)
else:
Expand Down Expand Up @@ -489,7 +493,7 @@ def _set_item_impl(self, key: Any, value: Any) -> None:
Changes the value of the node key with the desired value. If the node key doesn't
exist it creates a new one.
"""
from omegaconf.omegaconf import OmegaConf, _maybe_wrap
from omegaconf.omegaconf import _maybe_wrap

from .nodes import AnyNode, ValueNode

Expand Down Expand Up @@ -552,11 +556,7 @@ def wrap(key: Any, val: Any) -> Node:
target = self._get_node(key)
if target is None:
if is_structured_config(val):
element_type = self._metadata.element_type
if element_type is Any:
ref_type = OmegaConf.get_type(val)
else:
ref_type = element_type
ref_type = self._metadata.element_type
else:
is_optional = target._is_optional()
ref_type = target._metadata.ref_type
Expand Down
43 changes: 23 additions & 20 deletions omegaconf/dictconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from ._utils import (
ValueKind,
_get_value,
_is_interpolation,
_valid_dict_key_annotation_type,
format_and_raise,
Expand All @@ -25,7 +26,6 @@
get_value_kind,
is_container_annotation,
is_dict,
is_generic_container,
is_primitive_dict,
is_structured_config,
is_structured_config_frozen,
Expand Down Expand Up @@ -201,34 +201,37 @@ def _validate_set(self, key: Any, value: Any) -> None:
if validation_error:
self._raise_invalid_value(value, value_type, target_type)

def _validate_merge(self, key: Any, value: Any) -> None:
def _validate_merge(self, value: Any) -> None:
from omegaconf import OmegaConf

self._validate_non_optional(key, value)
dest = self
src = value

target = self._get_node(key) if key is not None else self
self._validate_non_optional(None, src)

target_has_ref_type = isinstance(
target, DictConfig
) and target._metadata.ref_type not in (Any, dict)
is_valid_value = target is None or not target_has_ref_type
if is_valid_value:
return
dest_obj_type = OmegaConf.get_type(dest)
src_obj_type = OmegaConf.get_type(src)

target_type = target._metadata.ref_type # type: ignore
value_type = OmegaConf.get_type(value)
if is_generic_container(target_type):
if dest._is_missing() and src._metadata.object_type is not None:
self._validate_set(key=None, value=_get_value(src))

if src._is_missing():
return
# Merging of a dictionary is allowed even if assignment is illegal (merge would do deeper checks)

validation_error = (
target_type is not None
and value_type is not None
and not issubclass(value_type, target_type)
and not is_dict(value_type)
dest_obj_type is not None
and src_obj_type is not None
and is_structured_config(dest_obj_type)
and not OmegaConf.is_none(src)
and not is_dict(src_obj_type)
and not issubclass(src_obj_type, dest_obj_type)
)

if validation_error:
self._raise_invalid_value(value, value_type, target_type)
msg = (
f"Merge error : {type_str(src_obj_type)} is not a "
f"subclass of {type_str(dest_obj_type)}. value: {src}"
)
raise ValidationError(msg)

def _validate_non_optional(self, key: Any, value: Any) -> None:
from omegaconf import OmegaConf
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/listconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
content: Union[List[Any], Tuple[Any, ...], str, None],
key: Any = None,
parent: Optional[Container] = None,
element_type: Optional[Type[Any]] = None,
element_type: Union[Type[Any], Any] = Any,
is_optional: bool = True,
ref_type: Union[Type[Any], Any] = Any,
flags: Optional[Dict[str, bool]] = None,
Expand Down
2 changes: 1 addition & 1 deletion omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __init__(
parent=parent,
value=value,
metadata=Metadata(
ref_type=Any, object_type=None, key=key, optional=is_optional # type: ignore
ref_type=Any, object_type=None, key=key, optional=is_optional
),
)

Expand Down
24 changes: 8 additions & 16 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,34 +231,26 @@ def _create_impl( # noqa F811
or is_structured_config(obj)
or obj is None
):
ref_type = None
if is_structured_config(obj):
ref_type = get_type_of(obj)
elif OmegaConf.is_dict(obj):
ref_type = obj._metadata.ref_type

if ref_type is None:
ref_type = OmegaConf.get_type(obj)

if isinstance(obj, DictConfig):
key_type = obj._metadata.key_type
element_type = obj._metadata.element_type
else:
key_type, element_type = get_dict_key_value_types(ref_type)
obj_type = OmegaConf.get_type(obj)
key_type, element_type = get_dict_key_value_types(obj_type)
return DictConfig(
content=obj,
parent=parent,
ref_type=ref_type,
ref_type=Any,
key_type=key_type,
element_type=element_type,
flags=flags,
)
elif is_primitive_list(obj) or OmegaConf.is_list(obj):
ref_type = OmegaConf.get_type(obj)
element_type = get_list_element_type(ref_type)
obj_type = OmegaConf.get_type(obj)
element_type = get_list_element_type(obj_type)
return ListConfig(
element_type=element_type,
ref_type=ref_type,
ref_type=Any,
content=obj,
parent=parent,
flags=flags,
Expand Down Expand Up @@ -805,7 +797,7 @@ def _node_wrap(
is_optional: bool,
value: Any,
key: Any,
ref_type: Any = None,
ref_type: Any = Any,
) -> Node:
node: Node
is_dict = is_primitive_dict(value) or is_dict_annotation(type_)
Expand Down Expand Up @@ -836,7 +828,7 @@ def _node_wrap(
ref_type=ref_type,
)
elif is_structured_config(type_) or is_structured_config(value):
key_type, element_type = get_dict_key_value_types(type_)
key_type, element_type = get_dict_key_value_types(value)
node = DictConfig(
ref_type=type_,
is_optional=is_optional,
Expand Down
Loading

0 comments on commit a08b676

Please sign in to comment.