Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds literal node #865

Open
wants to merge 23 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
3d3ae1d
Adds literal node
jordan-schneider Feb 26, 2022
9338386
Missed a not
jordan-schneider Feb 26, 2022
b78d4dc
Adds additional LiteralNode tests. Imports Literal based on python ve…
jordan-schneider Feb 26, 2022
6d302ab
Adds no cover pragma to typing import branch
jordan-schneider Feb 26, 2022
3c115d7
Adds additional fixes for earlier python versions.
jordan-schneider Feb 26, 2022
1592df4
Used the wrong type
jordan-schneider Feb 26, 2022
3b8a82f
Update omegaconf/_utils.py
Mar 13, 2022
80f9861
Update omegaconf/_utils.py
Mar 13, 2022
0737ec5
Throws validation error when bare Literal annotation is given for pyt…
jordan-schneider Mar 14, 2022
eb4b20c
Merge remote-tracking branch 'upstream/master'
jordan-schneider Mar 14, 2022
a0e2b35
Fixes typo in typing_extensions requirement.
jordan-schneider Mar 14, 2022
f3b3f75
Adds literal node
jordan-schneider Feb 26, 2022
694a032
Missed a not
jordan-schneider Feb 26, 2022
1bd9c93
Adds additional LiteralNode tests. Imports Literal based on python ve…
jordan-schneider Mar 26, 2022
fb32fa5
Adds no cover pragma to typing import branch
jordan-schneider Feb 26, 2022
89414a8
Adds additional fixes for earlier python versions.
jordan-schneider Feb 26, 2022
c86cec5
Used the wrong type
jordan-schneider Feb 26, 2022
59eb792
Update omegaconf/_utils.py
Mar 13, 2022
5d1b801
Update omegaconf/_utils.py
Mar 13, 2022
e2d950f
Throws validation error when bare Literal annotation is given for pyt…
jordan-schneider Mar 14, 2022
4fa9593
Fixes typo in typing_extensions requirement.
jordan-schneider Mar 14, 2022
d321a4a
Merge branch 'master' of github.com:jordan-schneider/omegaconf
jordan-schneider Mar 26, 2022
b69fd38
Formats according to new black version
jordan-schneider Mar 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ repos:
hooks:
- id: mypy
args: [--strict]
additional_dependencies: ['pytest']
additional_dependencies: ['pytest', 'types-dataclasses']
2 changes: 2 additions & 0 deletions omegaconf/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
EnumNode,
FloatNode,
IntegerNode,
LiteralNode,
StringNode,
ValueNode,
)
Expand Down Expand Up @@ -53,6 +54,7 @@
"StringNode",
"BooleanNode",
"EnumNode",
"LiteralNode",
"FloatNode",
"MISSING",
"SI",
Expand Down
29 changes: 28 additions & 1 deletion omegaconf/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,11 @@
except ImportError: # pragma: no cover
attr = None # type: ignore # pragma: no cover

if sys.version_info >= (3, 8):
from typing import Literal # pragma: no cover
else:
from typing_extensions import Literal # pragma: no cover


# Regexprs to match key paths like: a.b, a[b], ..a[c].d, etc.
# We begin by matching the head (in these examples: a, a, ..a).
Expand Down Expand Up @@ -593,6 +598,20 @@ def is_tuple_annotation(type_: Any) -> bool:
return origin is tuple # pragma: no cover


def is_literal_annotation(type_: Any) -> bool:
origin = getattr(type_, "__origin__", None)
if sys.version_info >= (3, 8, 0):
jordan-schneider marked this conversation as resolved.
Show resolved Hide resolved
return origin is Literal # pragma: no cover
else:
return (
origin is Literal
or type_ is Literal
or (
"typing_extensions.Literal" in str(type_) and not isinstance(type_, str)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, why is this needed? Isn't Literal always from typing_extensions in this else branch?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what's happening here is that:

python 3.8 and later has typing.Literal which has an __origin__ attribute
python 3.7 typing_extensions.Literal has an __origin__ attribute
python 3.6 and earlier typing_extensions.Literal does not have an __origin__ attribute or any equivalent attribute

AND

A = Literal[5]
A is Literal  # False
isinstance(A, Literal)  # TypeError: isinstance() arg 2 must be a type or tuple of types
issubclass(type(A), Literal)  # TypeError: issubclass() arg 2 must be a class or tuple of classes

so as far as I can tell, for 3.6 and earlier the only way to check for Literalness is using the string. This is gross but I don't know of a better way.

This logic is unnecessarily complicated though. How do you feel about

def is_literal_annotation(type_: Any) -> bool:
    origin = getattr(type_, "__origin__", None)
    return origin is Literal or ("typing_extensions.Literal" in str(type_) and not isinstance(type_, str))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is an is_literal_type() function here which also work with Python 3.6. The solution to the problem seems very stupid, you have to import a hidden class _Literal:

Python 3.6.13 |Anaconda, Inc.| (default, Jun  4 2021, 14:25:59)
[GCC 7.5.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from typing_extensions import Literal, _Literal
>>> a = Literal[0]
>>> type(a) is Literal
False
>>> type(a) is _Literal
True

I have no idea why.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Before 3.6 typing_extensions creates the internal _Literal class using a typing._FinalTypingBase mixin https://github.com/python/typing/blob/master/typing_extensions/src/typing_extensions.py#L344 , and then Literal is an instantiation of that class. So type(a) returns the outer _Literal class instead of the instantiation or something. Anyway this seems less bad than the string hack, so I'll go with this.

) # pragma: no cover


def is_dict_subclass(type_: Any) -> bool:
return type_ is not None and isinstance(type_, type) and issubclass(type_, Dict)

Expand Down Expand Up @@ -822,13 +841,21 @@ def type_str(t: Any, include_module_name: bool = False) -> str:
return "Any"
if t is ...:
return "..."
if (
isinstance(t, int)
or isinstance(t, str)
or isinstance(t, bytes)
or isinstance(t, Enum)
):
# only occurs when using typing.Literal after 3.8
return str(t) # pragma: no cover

if sys.version_info < (3, 7, 0): # pragma: no cover
# Python 3.6
if hasattr(t, "__name__"):
name = str(t.__name__)
else:
if t.__origin__ is not None:
if hasattr(t, "__origin__") and t.__origin__ is not None:
jordan-schneider marked this conversation as resolved.
Show resolved Hide resolved
name = type_str(t.__origin__)
else:
name = str(t)
Expand Down
57 changes: 57 additions & 0 deletions omegaconf/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
_is_interpolation,
get_type_of,
get_value_kind,
is_literal_annotation,
is_primitive_container,
type_str,
)
Expand Down Expand Up @@ -400,6 +401,62 @@ def __deepcopy__(self, memo: Dict[int, Any]) -> "EnumNode":
return res


class LiteralNode(ValueNode):
def __init__(
self,
literal_type: Any, # cannot Type[Literal] because Literal requires an argument
value: Optional[Union[Enum, str, int, bool]] = None,
key: Any = None,
parent: Optional[Container] = None,
is_optional: bool = True,
flags: Optional[Dict[str, bool]] = None,
):
if not is_literal_annotation(literal_type):
raise ValidationError(
f"LiteralNode can only operate on Literal annotation ({literal_type})"
)
self.literal_type = literal_type
if hasattr(self.literal_type, "__args__"):
self.fields = list(self.literal_type.__args__) # pragma: no cover
elif hasattr(self.literal_type, "__values__"): # pragma: no cover
self.fields = list(self.literal_type.__values__) # pragma: no cover
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make these branches explicitly correspond to python versions. So I think

Suggested change
if hasattr(self.literal_type, "__args__"):
self.fields = list(self.literal_type.__args__) # pragma: no cover
elif hasattr(self.literal_type, "__values__"): # pragma: no cover
self.fields = list(self.literal_type.__values__) # pragma: no cover
if sys.version_info >= (3, 7):
self.fields = list(self.literal_type.__args__) # pragma: no cover
else:
self.fields = list(self.literal_type.__values__) # pragma: no cover

but I'm not really sure about this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you feel this is better? Philosophically I tend to think that directly checking is the right way to do things, in case e.g. the attribute name every switches back or something equally cursed. This also has the benefit that mypy/pylance/other syntax parsers can correctly infer that the attribute exists, which is left implicit if you're going by version.

If this is about readability, I'd be happy to add comments saying why the check is necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I was just thinking that at some point in the future, omegaconf might want to drop support for Python 3.6 and then it would be useful that you can directly see that this if-else check has become unnecessary.

But, yes, a comment is likely already sufficient.

super().__init__(
parent=parent,
value=value,
metadata=Metadata(
key=key,
optional=is_optional,
ref_type=literal_type,
object_type=literal_type,
flags=flags,
),
)

def _validate_and_convert_impl(self, value: Any) -> Any:
return self.validate_and_convert_to_literal(
enum_type=self.literal_type, value=value
)

def validate_and_convert_to_literal(self, enum_type: Type[Enum], value: Any) -> Any:
if value not in self.fields:
raise ValidationError(
f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type}"
)
index = self.fields.index(value)
if not isinstance(value, type(self.fields[index])):
raise ValidationError(
f"Value $VALUE ($VALUE_TYPE) is not a valid input for {enum_type} because "
f"type(value)={type(value)} but the matching literal value's type={type(self.fields[index])}"
)

return value

def __deepcopy__(self, memo: Dict[int, Any]) -> "LiteralNode":
res = LiteralNode(literal_type=self.literal_type)
self._deepcopy_impl(res, memo)
return res


class InterpolationResultNode(ValueNode):
"""
Special node type, used to wrap interpolation results.
Expand Down
10 changes: 10 additions & 0 deletions omegaconf/omegaconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
is_dict_annotation,
is_int,
is_list_annotation,
is_literal_annotation,
is_primitive_container,
is_primitive_dict,
is_primitive_list,
Expand All @@ -66,6 +67,7 @@
EnumNode,
FloatNode,
IntegerNode,
LiteralNode,
StringNode,
ValueNode,
)
Expand Down Expand Up @@ -1046,6 +1048,14 @@ def _node_wrap(
)
elif type_ == Any or type_ is None:
node = AnyNode(value=value, key=key, parent=parent)
elif is_literal_annotation(type_):
node = LiteralNode(
literal_type=type_,
value=value,
key=key,
parent=parent,
is_optional=is_optional,
)
elif issubclass(type_, Enum):
node = EnumNode(
enum_type=type_,
Expand Down
21 changes: 20 additions & 1 deletion tests/structured_conf/data/attr_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
from tests import Color

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
from typing import Literal, TypedDict
else:
from typing_extensions import Literal

# attr is a dependency of pytest which means it's always available when testing with pytest.
importorskip("attr")
Expand Down Expand Up @@ -169,6 +171,23 @@ class EnumConfig:
interpolation: Color = II("with_default")


@attr.s(auto_attribs=True)
class LiteralConfig:
# with default value
with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo"

# default is None
null_default: Optional[Literal["foo", "bar", True, b"baz", 5, Color.GREEN]] = None

# explicit no default
mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING

# interpolation, will inherit the type and value of `with_default'
interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II(
"with_default"
)


@attr.s(auto_attribs=True)
class ConfigWithList:
list1: List[int] = [1, 2, 3]
Expand Down
21 changes: 20 additions & 1 deletion tests/structured_conf/data/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from tests import Color

if sys.version_info >= (3, 8): # pragma: no cover
from typing import TypedDict
from typing import Literal, TypedDict
else:
from typing_extensions import Literal

# skip test if dataclasses are not available
importorskip("dataclasses")
Expand Down Expand Up @@ -170,6 +172,23 @@ class EnumConfig:
interpolation: Color = II("with_default")


@dataclass
class LiteralConfig:
# with default value
with_default: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = "foo"

# default is None
null_default: Optional[Literal["foo", "bar", True, b"baz", 5, Color.GREEN]] = None

# explicit no default
mandatory_missing: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = MISSING

# interpolation, will inherit the type and value of `with_default'
interpolation: Literal["foo", "bar", True, b"baz", 5, Color.GREEN] = II(
"with_default"
)


@dataclass
class ConfigWithList:
list1: List[int] = field(default_factory=lambda: [1, 2, 3])
Expand Down
7 changes: 7 additions & 0 deletions tests/structured_conf/test_structured_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class EnumConfigAssignments:
illegal = ["foo", True, False, 4, 1.0]


class LiteralConfigAssignments:
legal = ["foo", "bar", True, b"baz", 5, Color.GREEN]
illegal = ["fuh", False, 4, 1.0, b"feh", Color.RED]


class IntegersConfigAssignments:
legal = [("10", 10), ("-10", -10), 100, 0, 1, 1]
illegal = ["foo", 1.0, float("inf"), float("nan"), Color.BLUE]
Expand Down Expand Up @@ -259,12 +264,14 @@ def validate(cfg: DictConfig) -> None:
("FloatConfig", FloatConfigAssignments, {}),
("StringConfig", StringConfigAssignments, {}),
("EnumConfig", EnumConfigAssignments, {}),
("LiteralConfig", LiteralConfigAssignments, {}),
# Use instance to build config
("BoolConfig", BoolConfigAssignments, {"with_default": False}),
("IntegersConfig", IntegersConfigAssignments, {"with_default": 42}),
("FloatConfig", FloatConfigAssignments, {"with_default": 42.0}),
("StringConfig", StringConfigAssignments, {"with_default": "fooooooo"}),
("EnumConfig", EnumConfigAssignments, {"with_default": Color.BLUE}),
("LiteralConfig", LiteralConfigAssignments, {"with_default": "foo"}),
("AnyTypeConfig", AnyTypeConfigAssignments, {}),
],
)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
import sys
from typing import Any, Optional

from pytest import mark, raises, warns
Expand All @@ -11,6 +12,7 @@
FloatNode,
IntegerNode,
ListConfig,
LiteralNode,
OmegaConf,
StringNode,
ValidationError,
Expand All @@ -19,6 +21,11 @@
from omegaconf._utils import _is_optional
from tests import Color, Group

if sys.version_info >= (3, 8):
from typing import Literal
else:
from typing_extensions import Literal

SKIP = object()


Expand Down Expand Up @@ -68,6 +75,16 @@ def verify(
),
[Color.RED],
),
# LiteralNode
(
lambda value, is_optional, key=None: LiteralNode(
literal_type=Literal["foo", 5, b"bar", True, Color.GREEN],
value=value,
is_optional=is_optional,
key=key,
),
["foo", 5, b"bar", True, Color.GREEN],
),
# DictConfig
(
lambda value, is_optional, key=None: DictConfig(
Expand Down Expand Up @@ -96,6 +113,7 @@ def verify(
"IntegerNode",
"StringNode",
"EnumNode",
"LiteralNode",
"DictConfig",
"ListConfig",
"dataclass",
Expand Down
Loading