Skip to content

Commit

Permalink
update ast usage in code for python 3.12 changes (#306)
Browse files Browse the repository at this point in the history
update-typing-to-satisfy-mypy
  • Loading branch information
DamianCzajkowski authored Jul 30, 2024
1 parent a064d22 commit d311712
Show file tree
Hide file tree
Showing 21 changed files with 264 additions and 172 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# CHANGELOG

## 0.14.1 (UNRELEASED)

- Changed code typing to satisfy MyPy 1.11.0 version


## 0.14.0 (2024-07-17)

- Added `ClientForwardRefsPlugin` to standard plugins.
Expand Down
2 changes: 1 addition & 1 deletion ariadne_codegen/client_generators/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ def _generate_variables_assign(
self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1
) -> ast.AnnAssign:
return generate_ann_assign(
target=variable_names[self._variables_dict_variable],
target=generate_name(variable_names[self._variables_dict_variable]),
annotation=generate_subscript(
generate_name(DICT),
generate_tuple([generate_name("str"), generate_name("object")]),
Expand Down
6 changes: 3 additions & 3 deletions ariadne_codegen/client_generators/custom_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def generate_clear_arguments_section(
) -> Tuple[List[ast.stmt], List[ast.keyword]]:
arguments_body = [
generate_ann_assign(
"arguments",
generate_name("arguments"),
generate_subscript(
generate_name(DICT),
generate_tuple(
Expand All @@ -240,8 +240,8 @@ def generate_clear_arguments_section(
),
),
generate_dict(
return_arguments_keys,
return_arguments_values, # type: ignore
return_arguments_keys, # type: ignore
return_arguments_values,
),
),
generate_assign(
Expand Down
2 changes: 1 addition & 1 deletion ariadne_codegen/client_generators/custom_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _generate_class_field(
name, field_name, getattr(field, "args")
)
return generate_ann_assign(
target=name,
target=generate_name(name),
annotation=generate_name(f'"{field_name}"'),
value=generate_call(
func=generate_name(field_name), args=[generate_constant(org_name)]
Expand Down
44 changes: 25 additions & 19 deletions ariadne_codegen/client_generators/input_fields.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import ast
from typing import Dict, Optional, Tuple
from typing import Dict, List, Optional, Tuple, cast

from graphql import (
BooleanValueNode,
Expand Down Expand Up @@ -142,15 +142,18 @@ def parse_input_const_value_node(

if isinstance(node, ListValueNode):
list_ = generate_list(
[
parse_input_const_value_node(
node=v,
field_type=field_type,
nested_object=nested_object,
nested_list=True,
)
for v in node.values
]
cast(
List[ast.expr],
[
parse_input_const_value_node(
node=v,
field_type=field_type,
nested_object=nested_object,
nested_list=True,
)
for v in node.values
],
)
)
if not nested_list:
return generate_call(
Expand All @@ -166,15 +169,18 @@ def parse_input_const_value_node(
if isinstance(node, ObjectValueNode):
dict_ = generate_dict(
keys=[generate_constant(f.name.value) for f in node.fields],
values=[
parse_input_const_value_node(
node=f.value,
field_type=field_type,
nested_object=True,
nested_list=True,
)
for f in node.fields
],
values=cast(
List[ast.expr],
[
parse_input_const_value_node(
node=f.value,
field_type=field_type,
nested_object=True,
nested_list=True,
)
for f in node.fields
],
),
)
if not nested_object:
return generate_call(
Expand Down
3 changes: 2 additions & 1 deletion ariadne_codegen/client_generators/input_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
generate_keyword,
generate_method_call,
generate_module,
generate_name,
generate_pydantic_field,
model_has_forward_refs,
)
Expand Down Expand Up @@ -172,7 +173,7 @@ def _parse_input_definition(
field.type, custom_scalars=self.custom_scalars
)
field_implementation = generate_ann_assign(
target=name,
target=generate_name(name),
annotation=annotation,
value=parse_input_field_default_value(
node=field.ast_node, annotation=annotation, field_type=field_type
Expand Down
6 changes: 4 additions & 2 deletions ariadne_codegen/client_generators/result_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def parse_interface_type(
)
context.abstract_type = True
if inline_fragments or fragments_on_subtypes:
types = [generate_annotation_name('"' + class_name + type_.name + '"', False)]
types: List[ast.expr] = [
generate_annotation_name('"' + class_name + type_.name + '"', False)
]
context.related_classes.append(
RelatedClassData(class_name=class_name + type_.name, type_name=type_.name)
)
Expand Down Expand Up @@ -275,7 +277,7 @@ def parse_union_type(
class_name: str,
) -> Annotation:
context.abstract_type = True
sub_annotations = [
sub_annotations: List[ast.expr] = [
parse_operation_field_type(
type_=subtype,
context=context,
Expand Down
3 changes: 2 additions & 1 deletion ariadne_codegen/client_generators/result_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
generate_import_from,
generate_method_call,
generate_module,
generate_name,
generate_pass,
generate_pydantic_field,
model_has_forward_refs,
Expand Down Expand Up @@ -264,7 +265,7 @@ def _parse_type_definition(
)

field_implementation = generate_ann_assign(
target=name,
target=generate_name(name),
annotation=annotation,
lineno=lineno,
value=default_value,
Expand Down
101 changes: 59 additions & 42 deletions ariadne_codegen/codegen.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import ast
from typing import Any, Dict, List, Optional, Union
import sys
from typing import Any, Dict, List, Optional, Union, cast

from graphql import (
GraphQLEnumType,
Expand All @@ -24,11 +25,6 @@
from .exceptions import ParsingError


def generate_import(names: List[str], level: int = 0) -> ast.Import:
"""Generate import statement."""
return ast.Import(names=[ast.alias(n) for n in names], level=level)


def generate_import_from(
names: List[str], from_: Optional[str] = None, level: int = 0
) -> ast.ImportFrom:
Expand Down Expand Up @@ -94,17 +90,21 @@ def generate_async_method_definition(
return_type: Union[ast.Name, ast.Subscript],
body: Optional[List[ast.stmt]] = None,
lineno: int = 1,
decorator_list: Optional[List[ast.Name]] = None,
decorator_list: Optional[List[ast.expr]] = None,
) -> ast.AsyncFunctionDef:
"""Generate async function."""
return ast.AsyncFunctionDef(
name=name,
args=arguments,
body=body if body else [ast.Pass()],
decorator_list=decorator_list if decorator_list else [],
returns=return_type,
lineno=lineno,
)
params: Dict[str, Any] = {
"name": name,
"args": arguments,
"body": body if body else [ast.Pass()],
"decorator_list": decorator_list if decorator_list else [],
"returns": return_type,
"lineno": lineno,
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.AsyncFunctionDef(**params)


def generate_class_def(
Expand All @@ -113,14 +113,20 @@ def generate_class_def(
body: Optional[List[ast.stmt]] = None,
) -> ast.ClassDef:
"""Generate class definition."""
bases = [ast.Name(id=name) for name in base_names] if base_names else []
return ast.ClassDef(
name=name,
bases=bases,
keywords=[],
body=body if body else [],
decorator_list=[],
bases = cast(
List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else []
)
params: Dict[str, Any] = {
"name": name,
"bases": bases,
"keywords": [],
"body": body if body else [],
"decorator_list": [],
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.ClassDef(**params)


def generate_name(name: str) -> ast.Name:
Expand Down Expand Up @@ -153,37 +159,39 @@ def generate_assign(
) -> ast.Assign:
"""Generate assign object."""
return ast.Assign(
targets=[ast.Name(t) for t in targets], value=value, lineno=lineno
targets=[ast.Name(t) for t in targets],
value=value, # type:ignore
lineno=lineno,
)


def generate_ann_assign(
target: Union[str, ast.expr],
target: Union[ast.Name, ast.Attribute, ast.Subscript],
annotation: Annotation,
value: Optional[ast.expr] = None,
lineno: int = 1,
) -> ast.AnnAssign:
"""Generate ann assign object."""
return ast.AnnAssign(
target=target if isinstance(target, ast.expr) else ast.Name(id=target),
target=target,
annotation=annotation,
simple=1,
value=value,
simple=1,
lineno=lineno,
)


def generate_union_annotation(
types: List[Union[ast.Name, ast.Subscript]], nullable: bool = True
types: List[ast.expr], nullable: bool = True
) -> ast.Subscript:
"""Generate union annotation."""
result = ast.Subscript(value=ast.Name(id=UNION), slice=ast.Tuple(elts=types))
return result if not nullable else generate_nullable_annotation(result)


def generate_dict(
keys: Optional[List[ast.expr]] = None,
values: Optional[List[Optional[ast.expr]]] = None,
keys: Optional[List[Optional[ast.expr]]] = None,
values: Optional[List[ast.expr]] = None,
) -> ast.Dict:
"""Generate dict object."""
return ast.Dict(keys=keys if keys else [], values=values if values else [])
Expand All @@ -201,7 +209,9 @@ def generate_call(
) -> ast.Call:
"""Generate call object."""
return ast.Call(
func=func, args=args if args else [], keywords=keywords if keywords else []
func=func,
args=args if args else [], # type:ignore
keywords=keywords if keywords else [],
)


Expand Down Expand Up @@ -240,7 +250,10 @@ def parse_field_type(
return generate_annotation_name('"' + type_.name + '"', nullable)

if isinstance(type_, GraphQLUnionType):
subtypes = [parse_field_type(subtype, False) for subtype in type_.types]
subtypes = cast(
List[ast.expr],
[parse_field_type(subtype, False) for subtype in type_.types],
)
return generate_union_annotation(subtypes, nullable)

if isinstance(type_, GraphQLList):
Expand All @@ -255,7 +268,7 @@ def parse_field_type(


def generate_method_call(
object_name: str, method_name: str, args: Optional[List[Optional[ast.expr]]] = None
object_name: str, method_name: str, args: Optional[List[ast.expr]] = None
) -> ast.Call:
"""Generate object`s method call."""
return ast.Call(
Expand Down Expand Up @@ -287,7 +300,7 @@ def generate_trivial_lambda(name: str, argument_name: str) -> ast.Assign:
)


def generate_list(elements: List[Optional[ast.expr]]) -> ast.List:
def generate_list(elements: List[ast.expr]) -> ast.List:
"""Generate list object."""
return ast.List(elts=elements)

Expand Down Expand Up @@ -343,16 +356,20 @@ def generate_method_definition(
return_type: Union[ast.Name, ast.Subscript],
body: Optional[List[ast.stmt]] = None,
lineno: int = 1,
decorator_list: Optional[List[ast.Name]] = None,
decorator_list: Optional[List[ast.expr]] = None,
) -> ast.FunctionDef:
return ast.FunctionDef(
name=name,
args=arguments,
body=body if body else [ast.Pass()],
decorator_list=decorator_list if decorator_list else [],
returns=return_type,
lineno=lineno,
)
params: Dict[str, Any] = {
"name": name,
"args": arguments,
"body": body if body else [ast.Pass()],
"decorator_list": decorator_list if decorator_list else [],
"returns": return_type,
"lineno": lineno,
}
if sys.version_info >= (3, 12):
params["type_params"] = []

return ast.FunctionDef(**params)


def generate_async_for(
Expand Down
Loading

0 comments on commit d311712

Please sign in to comment.