Skip to content

Commit

Permalink
typecheck: Update for Python 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
david-yz-liu committed Jan 5, 2021
1 parent dbfa865 commit 29419ae
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 7 deletions.
6 changes: 4 additions & 2 deletions python_ta/typecheck/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,7 +927,9 @@ def _ann_node_to_type(node: astroid.Name) -> TypeResult:


def _generic_to_annotation(ann_node_type: type, node: NodeNG) -> TypeResult:
if (isinstance(ann_node_type, _GenericAlias) and
is_generic = isinstance(ann_node_type, _GenericAlias) or \
(sys.version_info >= (3, 9) and hasattr(ann_node_type, '__origin__'))
if (is_generic and
ann_node_type is getattr(typing, getattr(ann_node_type, '_name', '') or '', None)):
if ann_node_type == Dict:
ann_type = wrap_container(ann_node_type, Any, Any)
Expand All @@ -936,7 +938,7 @@ def _generic_to_annotation(ann_node_type: type, node: NodeNG) -> TypeResult:
ann_type = wrap_container(ann_node_type, Any)
else:
ann_type = wrap_container(ann_node_type, Any)
elif isinstance(ann_node_type, _GenericAlias):
elif is_generic:
parsed_args = []
for arg in ann_node_type.__args__:
_generic_to_annotation(arg, node) >> parsed_args.append
Expand Down
13 changes: 10 additions & 3 deletions python_ta/typecheck/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import *
from typing import _GenericAlias
import astroid
from astroid.node_classes import NodeNG
from python_ta.utils import _get_name, _gorg


Expand Down Expand Up @@ -173,16 +174,22 @@ def subscript_error_message(node: astroid.Subscript) -> str:

if subscript_gorg is list:
slice_type = _get_name(node.slice.inf_type.getValue())
slice_val = node.slice.value
slice_str = slice_val.as_string() if isinstance(slice_val, NodeNG) else str(slice_val)
return f'You can only access elements of a list using an int. ' \
f'You used {_correct_article(slice_type)}, {node.slice.value.as_string()}.'
f'You used {_correct_article(slice_type)}, {slice_str}.'
elif subscript_gorg is tuple:
slice_type = _get_name(node.slice.inf_type.getValue())
slice_val = node.slice.value
slice_str = slice_val.as_string() if isinstance(slice_val, NodeNG) else str(slice_val)
return f'You can only access elements of a tuple using an int. ' \
f'You used {_correct_article(slice_type)}, {node.slice.value.as_string()}.'
f'You used {_correct_article(slice_type)}, {slice_str}.'
elif subscript_gorg is dict:
slice_type = _get_name(node.slice.inf_type.getValue())
slice_val = node.slice.value
slice_str = slice_val.as_string() if isinstance(slice_val, NodeNG) else str(slice_val)
return f'You tried to access an element of this dictionary using ' \
f'{_correct_article(slice_type)}, {node.slice.value.as_string()}, ' \
f'{_correct_article(slice_type)}, {slice_str}, ' \
f'but the keys are of type {_get_name(subscript_concrete_type.__args__[0])}.'
else:
return f'You make a type annotation with an incorrect subscript.'
Expand Down
2 changes: 2 additions & 0 deletions tests/test_type_inference/test_compare.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@


@given(hs.integers(), hs.lists(hs.tuples(cs.comparator_operator_equality, hs.integers()), min_size=1))
@settings(deadline=None)
def test_compare_equality(left_value, operator_value_tuples):
"""Test type setting of Compare node representing comparators: ''==', '!=', '>=', '<=', 'is'. """
program = f'{repr(left_value)}'
Expand All @@ -19,6 +20,7 @@ def test_compare_equality(left_value, operator_value_tuples):


@given(hs.lists(cs.comparator_operator, min_size=3), cs.numeric_list(min_size=4))
@settings(deadline=None)
def test_compare_inequality(operators, values):
"""Test type setting of Compare node representing comparators: '<', '>'. """
a = list(zip(operators, values))
Expand Down
1 change: 1 addition & 0 deletions tests/test_type_inference/test_function_annotation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import astroid
import sys
from typing import Any, List, Tuple

from python_ta.typecheck.base import TypeFailAnnotationUnify
Expand Down
7 changes: 6 additions & 1 deletion tests/test_type_inference/test_function_def_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from .. import custom_hypothesis_support as cs
from ..custom_hypothesis_support import lookup_type, types_in_callable
import hypothesis.strategies as hs
from typing import Callable, ForwardRef, Type, _GenericAlias
import sys
from typing import Callable, ForwardRef, Type, _GenericAlias, Generic

from python_ta.typecheck.base import _gorg
from python_ta.transforms.type_inference_visitor import TypeFail
Expand Down Expand Up @@ -79,13 +80,17 @@ def test_functiondef_annotated_simple_return(functiondef_node):
# need to do by name because annotations must be name nodes.
if isinstance(expected_type, _GenericAlias):
assert _gorg(expected_type).__name__ == functiondef_node.args.annotations[i].name
elif sys.version_info >= (3, 9) and hasattr(expected_type, '__origin__'):
assert expected_type.__origin__.__name__ == functiondef_node.args.annotations[i].name
else:
assert expected_type.__name__ == functiondef_node.args.annotations[i].name
# test return type
return_node = functiondef_node.body[0].value
expected_rtype = inferer.type_constraints.resolve(functiondef_node.type_environment.lookup_in_env(return_node.name)).getValue()
if isinstance(expected_rtype, _GenericAlias):
assert _gorg(expected_rtype).__name__ == functiondef_node.returns.name
elif sys.version_info >= (3, 9) and hasattr(expected_rtype, '__origin__'):
assert expected_rtype.__origin__.__name__ == functiondef_node.args.annotations[i].name
else:
assert expected_rtype.__name__ == functiondef_node.returns.name

Expand Down
6 changes: 5 additions & 1 deletion tests/test_type_inference/test_literals.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
import astroid
import sys

from hypothesis import assume, given, settings, HealthCheck
from .. import custom_hypothesis_support as cs
from typing import Any, Dict, List, Set, Tuple
import pytest

settings.load_profile("pyta")


@given(cs.subscript_node())
@settings(suppress_health_check=[HealthCheck.too_slow])
def test_index(node):
if sys.version_info >= (3, 9):
pytest.skip('Index node is deprecated in Python 3.9')
module, _ = cs._parse_text(node)
for index_node in module.nodes_of_class(astroid.Index):
assert index_node.inf_type.getValue() == index_node.value.inf_type.getValue()
Expand Down

0 comments on commit 29419ae

Please sign in to comment.