diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6b8d43e44..b496e633d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 20.8b1 + rev: 22.1.0 hooks: - id: black language_version: python3.8 @@ -21,4 +21,4 @@ repos: hooks: - id: mypy args: [--strict] - additional_dependencies: ['pytest'] + additional_dependencies: ['pytest', 'types-dataclasses'] diff --git a/omegaconf/__init__.py b/omegaconf/__init__.py index 0be34f497..fe8ec8337 100644 --- a/omegaconf/__init__.py +++ b/omegaconf/__init__.py @@ -15,6 +15,7 @@ EnumNode, FloatNode, IntegerNode, + LiteralNode, StringNode, ValueNode, ) @@ -55,6 +56,7 @@ "BytesNode", "BooleanNode", "EnumNode", + "LiteralNode", "FloatNode", "MISSING", "SI", diff --git a/omegaconf/_utils.py b/omegaconf/_utils.py index 891e72b49..e7956f80a 100644 --- a/omegaconf/_utils.py +++ b/omegaconf/_utils.py @@ -43,6 +43,13 @@ except ImportError: # pragma: no cover attr = None # type: ignore # pragma: no cover +if sys.version_info >= (3, 8): + from typing import Literal # pragma: no cover +else: + from typing_extensions import Literal # pragma: no cover +if sys.version_info < (3, 7): + from typing_extensions import _Literal # type: ignore # pragma: no cover + # Regexprs to match key paths like: a.b, a[b], ..a[c].d, etc. # We begin by matching the head (in these examples: a, a, ..a). @@ -593,6 +600,15 @@ def is_tuple_annotation(type_: Any) -> bool: return origin is tuple # pragma: no cover +def is_literal_annotation(type_: Any) -> bool: + origin = getattr(type_, "__origin__", None) + # For python 3.6 and earllier typing_extensions.Literal does not have an origin attribute, and + # Literal is an instance of an internal _Literal class that we can check against. + if sys.version_info < (3, 7): + return type(type_) is _Literal # pragma: no cover + return origin is Literal # pragma: no cover + + def is_dict_subclass(type_: Any) -> bool: return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict) @@ -829,13 +845,21 @@ def type_str(t: Any, include_module_name: bool = False) -> str: return "Any" if t is ...: return "..." + if ( + isinstance(t, int) + or isinstance(t, str) + or isinstance(t, bytes) + or isinstance(t, Enum) + ): + # only occurs when using typing.Literal after 3.8 + return str(t) # pragma: no cover if sys.version_info < (3, 7, 0): # pragma: no cover # Python 3.6 if hasattr(t, "__name__"): name = str(t.__name__) else: - if t.__origin__ is not None: + if getattr(t, "__origin__", None) is not None: name = type_str(t.__origin__) else: name = str(t) diff --git a/omegaconf/nodes.py b/omegaconf/nodes.py index 9a764f552..a904f4919 100644 --- a/omegaconf/nodes.py +++ b/omegaconf/nodes.py @@ -10,6 +10,7 @@ _is_interpolation, get_type_of, get_value_kind, + is_literal_annotation, is_primitive_container, type_str, ) @@ -438,6 +439,70 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode": return res +class LiteralNode(ValueNode): + def __init__( + self, + literal_type: Any, # cannot Type[Literal] because Literal requires an argument + value: Optional[Union[Enum, str, int, bool]] = None, + key: Any = None, + parent: Optional[Container] = None, + is_optional: bool = True, + flags: Optional[Dict[str, bool]] = None, + ): + if not is_literal_annotation(literal_type): + raise ValidationError( + f"LiteralNode can only operate on Literal annotation ({literal_type})" + ) + self.literal_type = literal_type + if hasattr(self.literal_type, "__args__"): # pragma: no cover + # python 3.7 and above + args = self.literal_type.__args__ + self.fields = list(args) if args is not None else [] + elif hasattr(self.literal_type, "__values__"): # pragma: no cover + # python 3.6 and below + values = self.literal_type.__values__ + self.fields = list(values) if values is not None else [] + else: # pragma: no cover + raise ValidationError( + f"literal_type={literal_type} is a literal but has no __args__ or __values__" + ) + super().__init__( + parent=parent, + value=value, + metadata=Metadata( + key=key, + optional=is_optional, + ref_type=literal_type, + object_type=literal_type, + flags=flags, + ), + ) + + def _validate_and_convert_impl(self, value: Any) -> Any: + return self.validate_and_convert_to_literal( + enum_type=self.literal_type, value=value + ) + + def validate_and_convert_to_literal(self, enum_type: Type[Enum], value: Any) -> Any: + if value not in self.fields: + raise ValidationError( + f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}" + ) + index = self.fields.index(value) + if not isinstance(value, type(self.fields[index])): + raise ValidationError( + f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type} because " + f"type(value)={type(value)} but the matching literal value's type={type(self.fields[index])}" + ) + + return value + + def __deepcopy__(self, memo: Dict[int, Any]) -> "LiteralNode": + res = LiteralNode(literal_type=self.literal_type) + self._deepcopy_impl(res, memo) + return res + + class InterpolationResultNode(ValueNode): """ Special node type, used to wrap interpolation results. diff --git a/omegaconf/omegaconf.py b/omegaconf/omegaconf.py index 2d928e3eb..ed67fc09d 100644 --- a/omegaconf/omegaconf.py +++ b/omegaconf/omegaconf.py @@ -43,6 +43,7 @@ is_dict_annotation, is_int, is_list_annotation, + is_literal_annotation, is_primitive_container, is_primitive_dict, is_primitive_list, @@ -67,6 +68,7 @@ EnumNode, FloatNode, IntegerNode, + LiteralNode, StringNode, ValueNode, ) @@ -1047,6 +1049,14 @@ def _node_wrap( ) elif type_ == Any or type_ is None: node = AnyNode(value=value, key=key, parent=parent) + elif is_literal_annotation(type_): + node = LiteralNode( + literal_type=type_, + value=value, + key=key, + parent=parent, + is_optional=is_optional, + ) elif issubclass(type_, Enum): node = EnumNode( enum_type=type_, diff --git a/requirements/base.txt b/requirements/base.txt index dab8a132e..660d96ab1 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -2,3 +2,4 @@ antlr4-python3-runtime==4.8 PyYAML>=5.1.0 # Use dataclasses backport for Python 3.6. dataclasses;python_version=='3.6' +typing_extensions;python_version<='3.7' \ No newline at end of file diff --git a/tests/structured_conf/data/attr_classes.py b/tests/structured_conf/data/attr_classes.py index a834e4e02..83c93a41a 100644 --- a/tests/structured_conf/data/attr_classes.py +++ b/tests/structured_conf/data/attr_classes.py @@ -8,7 +8,9 @@ from tests import Color if sys.version_info >= (3, 8): # pragma: no cover - from typing import TypedDict + from typing import Literal, TypedDict +else: + from typing_extensions import Literal # attr is a dependency of pytest which means it's always available when testing with pytest. importorskip("attr") @@ -184,6 +186,48 @@ class EnumConfig: interpolation: Color = II("with_default") +if sys.version_info >= (3, 7): # pragma: no cover + + @attr.s(auto_attribs=True) + class LiteralConfig: + # with default value + with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo" + + # default is None + null_default: Optional[ + Literal["foo", "bar", True, b"baz", 5, Color.GREEN] + ] = None + + # explicit no default + mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING + + # interpolation, will inherit the type and value of `with_default' + interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II( + "with_default" + ) + +else: # pragma: no cover + # bare literals throw errors for python 3.7+. They're against spec for python 3.6 and earlier, + # but we should test that they fail to validate anyway. + @attr.s(auto_attribs=True) + class LiteralConfig: + # with default value + with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo" + + # default is None + null_default: Optional[ + Literal["foo", "bar", True, b"baz", 5, Color.GREEN] + ] = None + # explicit no default + mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING + + # interpolation, will inherit the type and value of `with_default' + interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II( + "with_default" + ) + no_args: Optional[Literal] = None # type: ignore + + @attr.s(auto_attribs=True) class ConfigWithList: list1: List[int] = [1, 2, 3] diff --git a/tests/structured_conf/data/dataclasses.py b/tests/structured_conf/data/dataclasses.py index 146a82442..748a95170 100644 --- a/tests/structured_conf/data/dataclasses.py +++ b/tests/structured_conf/data/dataclasses.py @@ -9,7 +9,9 @@ from tests import Color if sys.version_info >= (3, 8): # pragma: no cover - from typing import TypedDict + from typing import Literal, TypedDict +else: + from typing_extensions import Literal # skip test if dataclasses are not available importorskip("dataclasses") @@ -185,6 +187,48 @@ class EnumConfig: interpolation: Color = II("with_default") +if sys.version_info >= (3, 7): # pragma: no cover + + @dataclass + class LiteralConfig: + # with default value + with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo" + + # default is None + null_default: Optional[ + Literal["foo", "bar", True, b"baz", 5, Color.GREEN] + ] = None + + # explicit no default + mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING + + # interpolation, will inherit the type and value of `with_default' + interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II( + "with_default" + ) + +else: # pragma: no cover + # bare literals throw errors for python 3.7+. They're against spec for python 3.6 and earlier, + # but we should test that they fail to validate anyway. + @dataclass + class LiteralConfig: + # with default value + with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo" + + # default is None + null_default: Optional[ + Literal["foo", "bar", True, b"baz", 5, Color.GREEN] + ] = None + # explicit no default + mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING + + # interpolation, will inherit the type and value of `with_default' + interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II( + "with_default" + ) + no_args: Optional[Literal] = None # type: ignore + + @dataclass class ConfigWithList: list1: List[int] = field(default_factory=lambda: [1, 2, 3]) diff --git a/tests/structured_conf/test_structured_config.py b/tests/structured_conf/test_structured_config.py index b23fbc619..81a345fb4 100644 --- a/tests/structured_conf/test_structured_config.py +++ b/tests/structured_conf/test_structured_config.py @@ -49,6 +49,11 @@ class EnumConfigAssignments: illegal = ["foo", True, b"RED", False, 4, 1.0] +class LiteralConfigAssignments: + legal = ["foo", "bar", True, b"baz", 5, Color.GREEN] + illegal = ["fuh", False, 4, 1.0, b"feh", Color.RED] + + class IntegersConfigAssignments: legal = [("10", 10), ("-10", -10), 100, 0, 1] illegal = ["foo", 1.0, float("inf"), b"123", float("nan"), Color.BLUE, True] @@ -265,6 +270,7 @@ def validate(cfg: DictConfig) -> None: ("BytesConfig", BytesConfigAssignments, {}), ("StringConfig", StringConfigAssignments, {}), ("EnumConfig", EnumConfigAssignments, {}), + ("LiteralConfig", LiteralConfigAssignments, {}), # Use instance to build config ("BoolConfig", BoolConfigAssignments, {"with_default": False}), ("IntegersConfig", IntegersConfigAssignments, {"with_default": 42}), @@ -272,6 +278,7 @@ def validate(cfg: DictConfig) -> None: ("BytesConfig", BytesConfigAssignments, {"with_default": b"bin"}), ("StringConfig", StringConfigAssignments, {"with_default": "fooooooo"}), ("EnumConfig", EnumConfigAssignments, {"with_default": Color.BLUE}), + ("LiteralConfig", LiteralConfigAssignments, {"with_default": "foo"}), ("AnyTypeConfig", AnyTypeConfigAssignments, {}), ], ) @@ -310,6 +317,10 @@ def validate(input_: Any, expected: Any) -> None: with raises(ValidationError): conf.mandatory_missing = illegal_value + if hasattr(conf, "no_args"): + with raises(ValidationError): + conf.no_args = illegal_value + # Test assignment of legal values for legal_value in assignment_data.legal: expected_data = legal_value diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 0c6924811..875f35ac9 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -1,5 +1,6 @@ import copy import re +import sys from typing import Any, Optional from pytest import mark, raises, warns @@ -12,6 +13,7 @@ FloatNode, IntegerNode, ListConfig, + LiteralNode, OmegaConf, StringNode, ValidationError, @@ -20,6 +22,11 @@ from omegaconf._utils import _is_optional from tests import Color, Group +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + SKIP = object() @@ -70,6 +77,16 @@ def verify( ), [Color.RED], ), + # LiteralNode + ( + lambda value, is_optional, key=None: LiteralNode( + literal_type=Literal["foo", 5, b"bar", True, Color.GREEN], + value=value, + is_optional=is_optional, + key=key, + ), + ["foo", 5, b"bar", True, Color.GREEN], + ), # DictConfig ( lambda value, is_optional, key=None: DictConfig( @@ -99,6 +116,7 @@ def verify( "IntegerNode", "StringNode", "EnumNode", + "LiteralNode", "DictConfig", "ListConfig", "dataclass", diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 4b14175a1..52cb0fff8 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -1,6 +1,7 @@ import copy import functools import re +import sys from enum import Enum from functools import partial from typing import Any, Dict, Tuple, Type @@ -16,6 +17,7 @@ FloatNode, IntegerNode, ListConfig, + LiteralNode, Node, OmegaConf, StringNode, @@ -30,6 +32,11 @@ from omegaconf.nodes import InterpolationResultNode from tests import Color, Enum1, IllegalType, User +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + # testing valid conversions @mark.parametrize( @@ -156,6 +163,25 @@ def test_valid_inputs(type_: type, input_: Any, output_: Any) -> None: (partial(EnumNode, Color), {"foo": "bar"}), (partial(EnumNode, Color), ListConfig([1, 2])), (partial(EnumNode, Color), DictConfig({"foo": "bar"})), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), "baz"), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), 4), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), Color.RED), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), False), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), b"bez"), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), 1.0), + (partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), [1, 2]), + ( + partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), + {"foo": "bar"}, + ), + ( + partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), + ListConfig([1, 2]), + ), + ( + partial(LiteralNode, Literal["foo", b"bar", 5, Color.GREEN, True]), + DictConfig({"foo": "bar"}), + ), ], ) def test_invalid_inputs(type_: type, input_: Any) -> None: @@ -384,6 +410,13 @@ def test_legal_assignment( type_(value) +def test_literal_node_bad_literal_type() -> None: + with raises(ValidationError): + LiteralNode(literal_type=5, value=5) + with raises(ValidationError): + LiteralNode(literal_type=int, value=5) + + @mark.parametrize( "node,value", [ @@ -449,6 +482,7 @@ class DummyEnum(Enum): (float, float, 3.1415, FloatNode), (bool, bool, True, BooleanNode), (str, str, "foo", StringNode), + (Literal["foo"], Literal["foo"], "foo", LiteralNode), ], ) def test_node_wrap( @@ -552,6 +586,18 @@ def test_deepcopy(obj: Any) -> None: True, ), (EnumNode(enum_type=Enum1, value=Enum1.BAR), Enum1.BAR, True), + (LiteralNode(literal_type=Literal["foo"], value="foo"), "foo", True), + ( + LiteralNode(literal_type=Literal["foo"], value="foo"), + LiteralNode(literal_type=Literal["foo"], value="foo"), + True, + ), + ( + LiteralNode(literal_type=Literal["foo"], value="foo"), + StringNode(value="foo"), + True, + ), + (LiteralNode(literal_type=Literal["foo"], value="foo"), 5, False), (InterpolationResultNode("foo"), "foo", True), (InterpolationResultNode("${foo}"), "${foo}", True), (InterpolationResultNode("${foo"), "${foo", True), @@ -649,6 +695,9 @@ def test_dereference_missing() -> None: lambda val, is_optional: EnumNode( enum_type=Color, value=val, is_optional=is_optional ), + lambda val, is_optional: LiteralNode( + literal_type=Literal["foo"], value=val, is_optional=is_optional + ), ], ) def test_validate_and_convert_none(make_func: Any) -> None: diff --git a/tests/test_omegaconf.py b/tests/test_omegaconf.py index 4ddd66ff7..0e6784b08 100644 --- a/tests/test_omegaconf.py +++ b/tests/test_omegaconf.py @@ -1,4 +1,5 @@ import re +import sys from typing import Any from pytest import mark, param, raises, warns @@ -23,8 +24,14 @@ InterpolationToMissingValueError, UnsupportedInterpolationType, ) +from omegaconf.nodes import LiteralNode from tests import Color, ConcretePlugin, IllegalType, StructuredWithMissing +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + @mark.parametrize( "cfg, key, expected_is_missing, expectation", @@ -248,6 +255,13 @@ def test_coverage_for_deprecated_OmegaConf_is_optional() -> None: is_optional=True, ) ), + ( + lambda none: LiteralNode( + literal_type=Literal["foo"], + value="foo" if not none else None, + is_optional=True, + ) + ), ( lambda none: ListConfig( content=[1, 2, 3] if not none else None, is_optional=True @@ -335,6 +349,13 @@ def test_is_none_invalid_node() -> None: is_optional=True, ) ), + ( + lambda inter: LiteralNode( + literal_type=Literal["foo"], + value="foo" if inter is None else inter, + is_optional=True, + ) + ), ( lambda inter: ListConfig( content=[1, 2, 3] if inter is None else inter, is_optional=True @@ -360,6 +381,7 @@ def test_is_none_invalid_node() -> None: "BooleanNode", "BytesNode", "EnumNode", + "LiteralNode", "ListConfig", "DictConfig", "ConcretePlugin", diff --git a/tests/test_pydev_resolver_plugin.py b/tests/test_pydev_resolver_plugin.py index e0e7280dc..b24e2a606 100644 --- a/tests/test_pydev_resolver_plugin.py +++ b/tests/test_pydev_resolver_plugin.py @@ -1,4 +1,5 @@ import builtins +import sys from typing import Any from pytest import fixture, mark, param @@ -19,12 +20,18 @@ ValueNode, ) from omegaconf._utils import type_str +from omegaconf.nodes import LiteralNode from pydevd_plugins.extensions.pydevd_plugin_omegaconf import ( OmegaConfDeveloperResolver, OmegaConfUserResolver, ) from tests import Color +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + @fixture def resolver() -> Any: @@ -42,6 +49,7 @@ def resolver() -> Any: param(BooleanNode(True), {}, id="bool:True"), param(BytesNode(b"binary"), {}, id="bytes:binary"), param(EnumNode(enum_type=Color, value=Color.RED), {}, id="Color:Color.RED"), + param(LiteralNode(literal_type=Literal["foo"], value="foo"), {}, id="str:foo"), # nodes are never returning a dictionary param(AnyNode("${foo}", parent=DictConfig({"foo": 10})), {}, id="any:inter_10"), # DictConfig @@ -234,6 +242,7 @@ def test_get_dictionary_listconfig( (BooleanNode, True), (BytesNode, True), (EnumNode, True), + (LiteralNode, True), # not covering some other things. (builtins.int, False), (dict, False), diff --git a/tests/test_utils.py b/tests/test_utils.py index be266f3e5..dc8fe4d9b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ import re +import sys from dataclasses import dataclass, field from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Union @@ -28,11 +29,17 @@ EnumNode, FloatNode, IntegerNode, + LiteralNode, StringNode, ) from omegaconf.omegaconf import _node_wrap from tests import Color, ConcretePlugin, Dataframe, IllegalType, Plugin, User +if sys.version_info >= (3, 8): + from typing import Literal +else: + from typing_extensions import Literal + @mark.parametrize( "target_type, value, expected", @@ -100,9 +107,30 @@ param( Color, "Color.RED", EnumNode(enum_type=Color, value=Color.RED), id="Color" ), + param( + Literal["foo"], + "foo", + LiteralNode(literal_type=Literal["foo"], value="foo"), + id='Literal["foo"]', + ), + param( + Literal["foo"], + 5, + ValidationError, + id='Literal["foo"]', + ), param(Color, b"123", ValidationError, id="Color"), param( - Color, "Color.RED", EnumNode(enum_type=Color, value=Color.RED), id="Color" + Literal["foo"], + "foo", + LiteralNode(literal_type=Literal["foo"], value="foo"), + id='Literal["foo"]', + ), + param( + Literal["foo"], + 5, + ValidationError, + id='Literal["foo"]', ), # bad type param(IllegalType, "nope", ValidationError, id="bad_type"), @@ -564,6 +592,11 @@ def test_is_tuple_annotation(type_: Any, expected: Any) -> Any: Optional[Color], id="EnumNode[Color]", ), + param( + LiteralNode(literal_type=Literal["foo"], value="foo"), + Optional[Literal["foo"]], + id="Literal[foo]", + ), # Non-optional value nodes: param(IntegerNode(10, is_optional=False), int, id="IntegerNode"), param(FloatNode(10.0, is_optional=False), float, id="FloatNode"), @@ -575,6 +608,11 @@ def test_is_tuple_annotation(type_: Any, expected: Any) -> Any: Color, id="EnumNode[Color]", ), + param( + LiteralNode(literal_type=Literal["foo"], value="foo", is_optional=False), + Literal["foo"], + id="Literal[foo]", + ), # DictConfig param(DictConfig(content={}), Any, id="DictConfig"), param(