diff --git a/CHANGELOG.md b/CHANGELOG.md index 4630d466..da53c93d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # CHANGELOG +## 0.14.1 (UNRELEASED) + +- Changed code typing to satisfy MyPy 1.11.0 version + + ## 0.14.0 (2024-07-17) - Added `ClientForwardRefsPlugin` to standard plugins. diff --git a/ariadne_codegen/client_generators/client.py b/ariadne_codegen/client_generators/client.py index 84caf3dc..54b0fc0a 100644 --- a/ariadne_codegen/client_generators/client.py +++ b/ariadne_codegen/client_generators/client.py @@ -850,7 +850,7 @@ def _generate_variables_assign( self, variable_names: Dict[str, str], arguments_dict: ast.Dict, lineno: int = 1 ) -> ast.AnnAssign: return generate_ann_assign( - target=variable_names[self._variables_dict_variable], + target=generate_name(variable_names[self._variables_dict_variable]), annotation=generate_subscript( generate_name(DICT), generate_tuple([generate_name("str"), generate_name("object")]), diff --git a/ariadne_codegen/client_generators/custom_arguments.py b/ariadne_codegen/client_generators/custom_arguments.py index cd0700e4..0841a024 100644 --- a/ariadne_codegen/client_generators/custom_arguments.py +++ b/ariadne_codegen/client_generators/custom_arguments.py @@ -221,7 +221,7 @@ def generate_clear_arguments_section( ) -> Tuple[List[ast.stmt], List[ast.keyword]]: arguments_body = [ generate_ann_assign( - "arguments", + generate_name("arguments"), generate_subscript( generate_name(DICT), generate_tuple( @@ -240,8 +240,8 @@ def generate_clear_arguments_section( ), ), generate_dict( - return_arguments_keys, - return_arguments_values, # type: ignore + return_arguments_keys, # type: ignore + return_arguments_values, ), ), generate_assign( diff --git a/ariadne_codegen/client_generators/custom_fields.py b/ariadne_codegen/client_generators/custom_fields.py index 0c0ba369..bfc324c8 100644 --- a/ariadne_codegen/client_generators/custom_fields.py +++ b/ariadne_codegen/client_generators/custom_fields.py @@ -219,7 +219,7 @@ def _generate_class_field( name, field_name, getattr(field, "args") ) return generate_ann_assign( - target=name, + target=generate_name(name), annotation=generate_name(f'"{field_name}"'), value=generate_call( func=generate_name(field_name), args=[generate_constant(org_name)] diff --git a/ariadne_codegen/client_generators/input_fields.py b/ariadne_codegen/client_generators/input_fields.py index bf65bc60..5b2718e6 100644 --- a/ariadne_codegen/client_generators/input_fields.py +++ b/ariadne_codegen/client_generators/input_fields.py @@ -1,5 +1,5 @@ import ast -from typing import Dict, Optional, Tuple +from typing import Dict, List, Optional, Tuple, cast from graphql import ( BooleanValueNode, @@ -142,15 +142,18 @@ def parse_input_const_value_node( if isinstance(node, ListValueNode): list_ = generate_list( - [ - parse_input_const_value_node( - node=v, - field_type=field_type, - nested_object=nested_object, - nested_list=True, - ) - for v in node.values - ] + cast( + List[ast.expr], + [ + parse_input_const_value_node( + node=v, + field_type=field_type, + nested_object=nested_object, + nested_list=True, + ) + for v in node.values + ], + ) ) if not nested_list: return generate_call( @@ -166,15 +169,18 @@ def parse_input_const_value_node( if isinstance(node, ObjectValueNode): dict_ = generate_dict( keys=[generate_constant(f.name.value) for f in node.fields], - values=[ - parse_input_const_value_node( - node=f.value, - field_type=field_type, - nested_object=True, - nested_list=True, - ) - for f in node.fields - ], + values=cast( + List[ast.expr], + [ + parse_input_const_value_node( + node=f.value, + field_type=field_type, + nested_object=True, + nested_list=True, + ) + for f in node.fields + ], + ), ) if not nested_object: return generate_call( diff --git a/ariadne_codegen/client_generators/input_types.py b/ariadne_codegen/client_generators/input_types.py index 3a945413..0a5797d3 100644 --- a/ariadne_codegen/client_generators/input_types.py +++ b/ariadne_codegen/client_generators/input_types.py @@ -18,6 +18,7 @@ generate_keyword, generate_method_call, generate_module, + generate_name, generate_pydantic_field, model_has_forward_refs, ) @@ -172,7 +173,7 @@ def _parse_input_definition( field.type, custom_scalars=self.custom_scalars ) field_implementation = generate_ann_assign( - target=name, + target=generate_name(name), annotation=annotation, value=parse_input_field_default_value( node=field.ast_node, annotation=annotation, field_type=field_type diff --git a/ariadne_codegen/client_generators/result_fields.py b/ariadne_codegen/client_generators/result_fields.py index e0c377e3..f7b5581f 100644 --- a/ariadne_codegen/client_generators/result_fields.py +++ b/ariadne_codegen/client_generators/result_fields.py @@ -216,7 +216,9 @@ def parse_interface_type( ) context.abstract_type = True if inline_fragments or fragments_on_subtypes: - types = [generate_annotation_name('"' + class_name + type_.name + '"', False)] + types: List[ast.expr] = [ + generate_annotation_name('"' + class_name + type_.name + '"', False) + ] context.related_classes.append( RelatedClassData(class_name=class_name + type_.name, type_name=type_.name) ) @@ -275,7 +277,7 @@ def parse_union_type( class_name: str, ) -> Annotation: context.abstract_type = True - sub_annotations = [ + sub_annotations: List[ast.expr] = [ parse_operation_field_type( type_=subtype, context=context, diff --git a/ariadne_codegen/client_generators/result_types.py b/ariadne_codegen/client_generators/result_types.py index e94674b7..bf29ec46 100644 --- a/ariadne_codegen/client_generators/result_types.py +++ b/ariadne_codegen/client_generators/result_types.py @@ -38,6 +38,7 @@ generate_import_from, generate_method_call, generate_module, + generate_name, generate_pass, generate_pydantic_field, model_has_forward_refs, @@ -264,7 +265,7 @@ def _parse_type_definition( ) field_implementation = generate_ann_assign( - target=name, + target=generate_name(name), annotation=annotation, lineno=lineno, value=default_value, diff --git a/ariadne_codegen/codegen.py b/ariadne_codegen/codegen.py index c4893511..6f964ea1 100644 --- a/ariadne_codegen/codegen.py +++ b/ariadne_codegen/codegen.py @@ -1,5 +1,6 @@ import ast -from typing import Any, Dict, List, Optional, Union +import sys +from typing import Any, Dict, List, Optional, Union, cast from graphql import ( GraphQLEnumType, @@ -24,11 +25,6 @@ from .exceptions import ParsingError -def generate_import(names: List[str], level: int = 0) -> ast.Import: - """Generate import statement.""" - return ast.Import(names=[ast.alias(n) for n in names], level=level) - - def generate_import_from( names: List[str], from_: Optional[str] = None, level: int = 0 ) -> ast.ImportFrom: @@ -94,17 +90,21 @@ def generate_async_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, - decorator_list: Optional[List[ast.Name]] = None, + decorator_list: Optional[List[ast.expr]] = None, ) -> ast.AsyncFunctionDef: """Generate async function.""" - return ast.AsyncFunctionDef( - name=name, - args=arguments, - body=body if body else [ast.Pass()], - decorator_list=decorator_list if decorator_list else [], - returns=return_type, - lineno=lineno, - ) + params: Dict[str, Any] = { + "name": name, + "args": arguments, + "body": body if body else [ast.Pass()], + "decorator_list": decorator_list if decorator_list else [], + "returns": return_type, + "lineno": lineno, + } + if sys.version_info >= (3, 12): + params["type_params"] = [] + + return ast.AsyncFunctionDef(**params) def generate_class_def( @@ -113,14 +113,20 @@ def generate_class_def( body: Optional[List[ast.stmt]] = None, ) -> ast.ClassDef: """Generate class definition.""" - bases = [ast.Name(id=name) for name in base_names] if base_names else [] - return ast.ClassDef( - name=name, - bases=bases, - keywords=[], - body=body if body else [], - decorator_list=[], + bases = cast( + List[ast.expr], [ast.Name(id=name) for name in base_names] if base_names else [] ) + params: Dict[str, Any] = { + "name": name, + "bases": bases, + "keywords": [], + "body": body if body else [], + "decorator_list": [], + } + if sys.version_info >= (3, 12): + params["type_params"] = [] + + return ast.ClassDef(**params) def generate_name(name: str) -> ast.Name: @@ -153,28 +159,30 @@ def generate_assign( ) -> ast.Assign: """Generate assign object.""" return ast.Assign( - targets=[ast.Name(t) for t in targets], value=value, lineno=lineno + targets=[ast.Name(t) for t in targets], + value=value, # type:ignore + lineno=lineno, ) def generate_ann_assign( - target: Union[str, ast.expr], + target: Union[ast.Name, ast.Attribute, ast.Subscript], annotation: Annotation, value: Optional[ast.expr] = None, lineno: int = 1, ) -> ast.AnnAssign: """Generate ann assign object.""" return ast.AnnAssign( - target=target if isinstance(target, ast.expr) else ast.Name(id=target), + target=target, annotation=annotation, - simple=1, value=value, + simple=1, lineno=lineno, ) def generate_union_annotation( - types: List[Union[ast.Name, ast.Subscript]], nullable: bool = True + types: List[ast.expr], nullable: bool = True ) -> ast.Subscript: """Generate union annotation.""" result = ast.Subscript(value=ast.Name(id=UNION), slice=ast.Tuple(elts=types)) @@ -182,8 +190,8 @@ def generate_union_annotation( def generate_dict( - keys: Optional[List[ast.expr]] = None, - values: Optional[List[Optional[ast.expr]]] = None, + keys: Optional[List[Optional[ast.expr]]] = None, + values: Optional[List[ast.expr]] = None, ) -> ast.Dict: """Generate dict object.""" return ast.Dict(keys=keys if keys else [], values=values if values else []) @@ -201,7 +209,9 @@ def generate_call( ) -> ast.Call: """Generate call object.""" return ast.Call( - func=func, args=args if args else [], keywords=keywords if keywords else [] + func=func, + args=args if args else [], # type:ignore + keywords=keywords if keywords else [], ) @@ -240,7 +250,10 @@ def parse_field_type( return generate_annotation_name('"' + type_.name + '"', nullable) if isinstance(type_, GraphQLUnionType): - subtypes = [parse_field_type(subtype, False) for subtype in type_.types] + subtypes = cast( + List[ast.expr], + [parse_field_type(subtype, False) for subtype in type_.types], + ) return generate_union_annotation(subtypes, nullable) if isinstance(type_, GraphQLList): @@ -255,7 +268,7 @@ def parse_field_type( def generate_method_call( - object_name: str, method_name: str, args: Optional[List[Optional[ast.expr]]] = None + object_name: str, method_name: str, args: Optional[List[ast.expr]] = None ) -> ast.Call: """Generate object`s method call.""" return ast.Call( @@ -287,7 +300,7 @@ def generate_trivial_lambda(name: str, argument_name: str) -> ast.Assign: ) -def generate_list(elements: List[Optional[ast.expr]]) -> ast.List: +def generate_list(elements: List[ast.expr]) -> ast.List: """Generate list object.""" return ast.List(elts=elements) @@ -343,16 +356,20 @@ def generate_method_definition( return_type: Union[ast.Name, ast.Subscript], body: Optional[List[ast.stmt]] = None, lineno: int = 1, - decorator_list: Optional[List[ast.Name]] = None, + decorator_list: Optional[List[ast.expr]] = None, ) -> ast.FunctionDef: - return ast.FunctionDef( - name=name, - args=arguments, - body=body if body else [ast.Pass()], - decorator_list=decorator_list if decorator_list else [], - returns=return_type, - lineno=lineno, - ) + params: Dict[str, Any] = { + "name": name, + "args": arguments, + "body": body if body else [ast.Pass()], + "decorator_list": decorator_list if decorator_list else [], + "returns": return_type, + "lineno": lineno, + } + if sys.version_info >= (3, 12): + params["type_params"] = [] + + return ast.FunctionDef(**params) def generate_async_for( diff --git a/ariadne_codegen/contrib/client_forward_refs.py b/ariadne_codegen/contrib/client_forward_refs.py index d3ac9833..009d84a2 100644 --- a/ariadne_codegen/contrib/client_forward_refs.py +++ b/ariadne_codegen/contrib/client_forward_refs.py @@ -173,16 +173,17 @@ def _insert_import_statement_in_method( if import_class is None: return - import_class_id = import_class.id + import_class_name = import_class.name # We add the class to our set of imported in methods - these classes # don't need to be imported at all in the global scope. - self.imported_in_method.add(import_class.id) + self.imported_in_method.add(import_class_name) method_def.body.insert( 0, ast.ImportFrom( - module=self.imported_classes[import_class_id], + module=self.imported_classes[import_class_name], names=[import_class], + level=1, ), ) @@ -237,19 +238,18 @@ def _get_call_arg_from_async_for( return None - def _get_class_from_call(self, call: ast.Call) -> Optional[ast.Name]: + def _get_class_from_call(self, call: ast.Call) -> Optional[ast.alias]: """Get the class from an `ast.Call`. :param call: The `ast.Call` arg - :returns: `ast.Name` or `None` + :returns: `ast.alias` or `None` """ if not isinstance(call.func, ast.Attribute): return None if not isinstance(call.func.value, ast.Name): return None - - return call.func.value + return ast.alias(name=call.func.value.id) def _update_imports(self, module: ast.Module) -> None: """Update all imports. @@ -345,7 +345,7 @@ def _add_forward_ref_imports( module_name = self.imported_classes[cls] if module_name not in type_checking_imports: type_checking_imports[module_name] = ast.ImportFrom( - module=module_name, names=[] + module=module_name, names=[], level=1 ) type_checking_imports[module_name].names.append(ast.alias(cls)) @@ -363,7 +363,8 @@ def _add_forward_ref_imports( len(non_empty_imports), ast.ImportFrom( module=TYPE_CHECKING_MODULE, - names=[ast.Name(TYPE_CHECKING_FLAG)], + names=[ast.alias(TYPE_CHECKING_FLAG)], + level=1, ), ) diff --git a/ariadne_codegen/graphql_schema_generators/schema.py b/ariadne_codegen/graphql_schema_generators/schema.py index 9de87871..68c7fed6 100644 --- a/ariadne_codegen/graphql_schema_generators/schema.py +++ b/ariadne_codegen/graphql_schema_generators/schema.py @@ -1,5 +1,6 @@ import ast from pathlib import Path +from typing import List, cast from graphql import GraphQLSchema, print_schema from graphql.type.schema import TypeMap @@ -46,50 +47,53 @@ def generate_schema_module( schema: GraphQLSchema, type_map_name: str, schema_variable_name: str ) -> ast.Module: return generate_module( - body=[ - generate_import_from( - names=[ - "DirectiveLocation", - "GraphQLArgument", - "GraphQLDirective", - "GraphQLEnumType", - "GraphQLEnumValue", - "GraphQLField", - "GraphQLInputField", - "GraphQLInputObjectType", - "GraphQLInterfaceType", - "GraphQLList", - "GraphQLNamedType", - "GraphQLNonNull", - "GraphQLObjectType", - "GraphQLScalarType", - "GraphQLSchema", - "GraphQLUnionType", - "GraphQLID", - "GraphQLInt", - "GraphQLFloat", - "GraphQLString", - "GraphQLBoolean", - "Undefined", - ], - from_="graphql", - ), - generate_import_from( - names=["TypeMap"], - from_="graphql.type.schema", - ), - generate_import_from(names=["cast", "List"], from_="typing"), - generate_ann_assign( - target=type_map_name, - annotation=generate_name("TypeMap"), - value=generate_type_map(schema.type_map, type_map_name), - ), - generate_ann_assign( - target=schema_variable_name, - annotation=generate_name("GraphQLSchema"), - value=generate_schema(schema, type_map_name), - ), - ] + body=cast( + List[ast.stmt], + [ + generate_import_from( + names=[ + "DirectiveLocation", + "GraphQLArgument", + "GraphQLDirective", + "GraphQLEnumType", + "GraphQLEnumValue", + "GraphQLField", + "GraphQLInputField", + "GraphQLInputObjectType", + "GraphQLInterfaceType", + "GraphQLList", + "GraphQLNamedType", + "GraphQLNonNull", + "GraphQLObjectType", + "GraphQLScalarType", + "GraphQLSchema", + "GraphQLUnionType", + "GraphQLID", + "GraphQLInt", + "GraphQLFloat", + "GraphQLString", + "GraphQLBoolean", + "Undefined", + ], + from_="graphql", + ), + generate_import_from( + names=["TypeMap"], + from_="graphql.type.schema", + ), + generate_import_from(names=["cast", "List"], from_="typing"), + generate_ann_assign( + target=generate_name(type_map_name), + annotation=generate_name("TypeMap"), + value=generate_type_map(schema.type_map, type_map_name), + ), + generate_ann_assign( + target=generate_name(schema_variable_name), + annotation=generate_name("GraphQLSchema"), + value=generate_schema(schema, type_map_name), + ), + ], + ) ) diff --git a/tests/client_generators/input_types_generator/test_default_values.py b/tests/client_generators/input_types_generator/test_default_values.py index 68322917..a578720e 100644 --- a/tests/client_generators/input_types_generator/test_default_values.py +++ b/tests/client_generators/input_types_generator/test_default_values.py @@ -46,6 +46,7 @@ def test_generate_returns_module_with_parsed_inputs_scalar_field_with_default_va simple=1, ) ], + type_params=[], ) module = generator.generate() diff --git a/tests/client_generators/input_types_generator/test_names.py b/tests/client_generators/input_types_generator/test_names.py index c8dbb2ff..d904bc54 100644 --- a/tests/client_generators/input_types_generator/test_names.py +++ b/tests/client_generators/input_types_generator/test_names.py @@ -77,6 +77,7 @@ ), ], decorator_list=[], + type_params=[], ), ), ( @@ -147,6 +148,7 @@ ), ], decorator_list=[], + type_params=[], ), ), ], diff --git a/tests/client_generators/input_types_generator/test_parsing_inputs.py b/tests/client_generators/input_types_generator/test_parsing_inputs.py index eee46667..b5349a99 100644 --- a/tests/client_generators/input_types_generator/test_parsing_inputs.py +++ b/tests/client_generators/input_types_generator/test_parsing_inputs.py @@ -65,6 +65,7 @@ simple=1, ), ], + type_params=[], ), ast.ClassDef( name="CustomInput2", @@ -78,6 +79,7 @@ simple=1, ) ], + type_params=[], ), ], ) diff --git a/tests/client_generators/result_types_generator/test_parsing_fragments.py b/tests/client_generators/result_types_generator/test_parsing_fragments.py index 1e7027db..788afe4a 100644 --- a/tests/client_generators/result_types_generator/test_parsing_fragments.py +++ b/tests/client_generators/result_types_generator/test_parsing_fragments.py @@ -36,6 +36,7 @@ def test_get_classes_returns_list_with_types_generated_from_fragment(): ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="TestFragmentField1", @@ -49,6 +50,7 @@ def test_get_classes_returns_list_with_types_generated_from_fragment(): ) ], decorator_list=[], + type_params=[], ), ] generator = ResultTypesGenerator( @@ -96,6 +98,7 @@ def test_get_classes_returns_types_generated_from_fragment_which_uses_other_frag ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="TestFragmentField1", @@ -103,6 +106,7 @@ def test_get_classes_returns_types_generated_from_fragment_which_uses_other_frag keywords=[], body=[ast.Pass()], decorator_list=[], + type_params=[], ), ] generator = ResultTypesGenerator( diff --git a/tests/client_generators/result_types_generator/test_parsing_operations.py b/tests/client_generators/result_types_generator/test_parsing_operations.py index 9656f41b..a99f09af 100644 --- a/tests/client_generators/result_types_generator/test_parsing_operations.py +++ b/tests/client_generators/result_types_generator/test_parsing_operations.py @@ -62,6 +62,7 @@ simple=1, ) ], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery2", @@ -75,6 +76,7 @@ simple=1, ) ], + type_params=[], ), ], ), @@ -118,6 +120,7 @@ ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery1", @@ -174,6 +177,7 @@ ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery1Field1", @@ -187,6 +191,7 @@ ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery1Field2", @@ -202,6 +207,7 @@ ) ], decorator_list=[], + type_params=[], ), ], ), @@ -263,6 +269,7 @@ def test_generate_returns_module_with_types_generated_from_mutation(): simple=1, ) ], + type_params=[], ), ast.ClassDef( name="CustomMutationMutation1", @@ -276,6 +283,7 @@ def test_generate_returns_module_with_types_generated_from_mutation(): simple=1, ) ], + type_params=[], ), ] generator = ResultTypesGenerator( @@ -328,6 +336,7 @@ def test_generate_returns_module_with_types_generated_from_subscription(): simple=1, ) ], + type_params=[], ), ast.ClassDef( name="CustomSubscriptionSubscription1", @@ -341,6 +350,7 @@ def test_generate_returns_module_with_types_generated_from_subscription(): simple=1, ) ], + type_params=[], ), ] generator = ResultTypesGenerator( @@ -441,6 +451,7 @@ def test_generate_returns_module_with_types_generated_from_query_that_uses_fragm ), ], decorator_list=[], + type_params=[], ) generator = ResultTypesGenerator( schema=build_schema(SCHEMA_STR), @@ -512,6 +523,7 @@ def test_generate_returns_module_with_class_with_union_from_unpacked_fragment(): ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryInterfaceQueryInterfaceI", @@ -545,6 +557,7 @@ def test_generate_returns_module_with_class_with_union_from_unpacked_fragment(): ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryInterfaceQueryTypeA", @@ -589,6 +602,7 @@ def test_generate_returns_module_with_class_with_union_from_unpacked_fragment(): ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryInterfaceQueryTypeB", @@ -633,6 +647,7 @@ def test_generate_returns_module_with_class_with_union_from_unpacked_fragment(): ), ], decorator_list=[], + type_params=[], ), ] operation_definition, fragment_def = cast( @@ -685,6 +700,7 @@ def test_generate_returns_module_with_class_for_every_appearance_of_type(): ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery1", @@ -698,6 +714,7 @@ def test_generate_returns_module_with_class_for_every_appearance_of_type(): ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryCamelCaseQuery", @@ -709,6 +726,7 @@ def test_generate_returns_module_with_class_for_every_appearance_of_type(): ) ], decorator_list=[], + type_params=[], ), ] generator = ResultTypesGenerator( @@ -743,6 +761,7 @@ def test_generate_returns_module_with_class_for_every_appearance_of_type(): simple=1, ) ], + type_params=[], ), ), ( @@ -759,6 +778,7 @@ def test_generate_returns_module_with_class_for_every_appearance_of_type(): simple=1, ) ], + type_params=[], ), ), ], diff --git a/tests/client_generators/result_types_generator/test_unions.py b/tests/client_generators/result_types_generator/test_unions.py index a70f436b..b99d2c45 100644 --- a/tests/client_generators/result_types_generator/test_unions.py +++ b/tests/client_generators/result_types_generator/test_unions.py @@ -116,6 +116,7 @@ def test_generate_returns_module_with_classes_for_union_fields(): ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery4CustomType1", @@ -146,6 +147,7 @@ def test_generate_returns_module_with_classes_for_union_fields(): ), ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="CustomQueryQuery4CustomType2", @@ -179,6 +181,7 @@ def test_generate_returns_module_with_classes_for_union_fields(): ), ], decorator_list=[], + type_params=[], ), ] @@ -237,6 +240,7 @@ def test_generate_returns_module_with_class_generated_from_union_with_one_member ), ], decorator_list=[], + type_params=[], ) generator = ResultTypesGenerator( schema=build_ast_schema(parse(SCHEMA_STR)), diff --git a/tests/client_generators/test_client_generator.py b/tests/client_generators/test_client_generator.py index 1f59080d..f2500b9c 100644 --- a/tests/client_generators/test_client_generator.py +++ b/tests/client_generators/test_client_generator.py @@ -1,5 +1,5 @@ import ast -from typing import cast +from typing import List, cast import pytest from graphql import GraphQLSchema, OperationDefinitionNode, build_schema, parse @@ -57,6 +57,7 @@ def test_generate_returns_module_with_gql_lambda_definition(async_base_client_im body=[ast.Return(value=ast.Name(id="q"))], returns=ast.Name(id="str"), decorator_list=[], + type_params=[], ) module = generator.generate() @@ -484,61 +485,76 @@ def test_add_method_generates_async_generator_for_subscription_definition( kw_defaults=[], defaults=[], ), - body=[ - ast.Assign( - targets=[ast.Name(id="query")], - value=ast.Call( - func=ast.Name(id="gql"), - args=[ - [ast.Constant(value="subscription GetCounter { counter }\n")] - ], - keywords=[], - ), - ), - ast.AnnAssign( - target=ast.Name(id="variables"), - annotation=ast.Subscript( - value=ast.Name(id=DICT), - slice=ast.Tuple(elts=[ast.Name(id="str"), ast.Name(id="object")]), + body=cast( + List[ast.stmt], + [ + ast.Assign( + targets=[ast.Name(id="query")], + value=ast.Call( + func=ast.Name(id="gql"), + args=[ + [ + ast.Constant( + value="subscription GetCounter { counter }\n" + ) + ] + ], + keywords=[], + ), ), - value=ast.Dict(keys=[], values=[]), - simple=1, - ), - ast.AsyncFor( - target=ast.Name(id="data"), - iter=ast.Call( - func=ast.Attribute(value=ast.Name(id="self"), attr="execute_ws"), - args=[], - keywords=[ - ast.keyword(arg="query", value=ast.Name(id="query")), - ast.keyword( - arg="operation_name", value=ast.Constant(value="GetCounter") + ast.AnnAssign( + target=ast.Name(id="variables"), + annotation=ast.Subscript( + value=ast.Name(id=DICT), + slice=ast.Tuple( + elts=[ast.Name(id="str"), ast.Name(id="object")] ), - ast.keyword(arg="variables", value=ast.Name(id="variables")), - ast.keyword(value=ast.Name(id=KWARGS_NAMES)), - ], + ), + value=ast.Dict(keys=[], values=[]), + simple=1, ), - body=[ - ast.Expr( - value=ast.Yield( - value=ast.Call( - func=ast.Attribute( - value=ast.Name(id="GetCounter"), - attr=MODEL_VALIDATE_METHOD, - ), - args=[ast.Name(id="data")], - keywords=[], + ast.AsyncFor( + target=ast.Name(id="data"), + iter=ast.Call( + func=ast.Attribute( + value=ast.Name(id="self"), attr="execute_ws" + ), + args=[], + keywords=[ + ast.keyword(arg="query", value=ast.Name(id="query")), + ast.keyword( + arg="operation_name", + value=ast.Constant(value="GetCounter"), + ), + ast.keyword( + arg="variables", value=ast.Name(id="variables") + ), + ast.keyword(value=ast.Name(id=KWARGS_NAMES)), + ], + ), + body=[ + ast.Expr( + value=ast.Yield( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id="GetCounter"), + attr=MODEL_VALIDATE_METHOD, + ), + args=[ast.Name(id="data")], + keywords=[], + ) ) ) - ) - ], - orelse=[], - ), - ], + ], + orelse=[], + ), + ], + ), decorator_list=[], returns=ast.Subscript( value=ast.Name(id="AsyncIterator"), slice=ast.Name(id="GetCounter") ), + type_params=[], ) generator.add_method( diff --git a/tests/client_generators/test_enums_generator.py b/tests/client_generators/test_enums_generator.py index 4026bf89..09f855a7 100644 --- a/tests/client_generators/test_enums_generator.py +++ b/tests/client_generators/test_enums_generator.py @@ -55,6 +55,7 @@ def test_generate_returns_module_with_enum_class_definition(): targets=[ast.Name(id="import_")], value=ast.Constant(value="import") ), ], + type_params=[], ) generator = EnumsGenerator(schema=build_ast_schema(parse(schema_str))) @@ -101,6 +102,7 @@ def test_generate_returns_module_with_enum_class_definition_for_every_enum(): targets=[ast.Name(id="VALUE2")], value=ast.Constant(value="VALUE2") ), ], + type_params=[], ), ast.ClassDef( name="TestEnumB", @@ -112,6 +114,7 @@ def test_generate_returns_module_with_enum_class_definition_for_every_enum(): ast.Assign(targets=[ast.Name(id="B")], value=ast.Constant(value="B")), ast.Assign(targets=[ast.Name(id="C")], value=ast.Constant(value="C")), ], + type_params=[], ), ast.ClassDef( name="TestEnumC", @@ -125,6 +128,7 @@ def test_generate_returns_module_with_enum_class_definition_for_every_enum(): ast.Assign(targets=[ast.Name(id="D4")], value=ast.Constant(value="D4")), ast.Assign(targets=[ast.Name(id="E5")], value=ast.Constant(value="E5")), ], + type_params=[], ), ] generator = EnumsGenerator(schema=build_ast_schema(parse(schema_str))) diff --git a/tests/client_generators/test_fragments_generator.py b/tests/client_generators/test_fragments_generator.py index d43bd692..1bf1c75e 100644 --- a/tests/client_generators/test_fragments_generator.py +++ b/tests/client_generators/test_fragments_generator.py @@ -75,6 +75,7 @@ def test_generate_returns_module_with_class_for_every_fragment( ) ], decorator_list=[], + type_params=[], ), ast.ClassDef( name="FragmentB", @@ -97,6 +98,7 @@ def test_generate_returns_module_with_class_for_every_fragment( ) ], decorator_list=[], + type_params=[], ), ] generator = FragmentsGenerator( diff --git a/tests/codegen/test_generated_assignments.py b/tests/codegen/test_generated_assignments.py index b4b70a60..0855eb6f 100644 --- a/tests/codegen/test_generated_assignments.py +++ b/tests/codegen/test_generated_assignments.py @@ -17,12 +17,12 @@ def test_generate_assign_returns_objects_with_correct_targets_and_value(): def test_generate_ann_assign_returns_object_with_given_annotation_and_tartget(): - target_name = "xyz" + target_name = ast.Name("xyz") annotation = ast.Name(id="Xyz") result = generate_ann_assign(target_name, annotation) assert isinstance(result, ast.AnnAssign) assert isinstance(result.target, ast.Name) - assert result.target.id == target_name + assert result.target == target_name assert result.annotation == annotation